Source code for cutcutcodec.core.nn.model.enhancement.train

"""Utils to train and create a dataset."""

import itertools
import json
import pathlib
import re
import tempfile

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm

from cutcutcodec.core.analysis.video.quality import ssim as compute_ssim
from cutcutcodec.core.classes.colorspace import Colorspace
from cutcutcodec.core.io.framecaster import to_rgb
from cutcutcodec.core.nn.dataaug.chain import ChainDataaug
from cutcutcodec.core.nn.dataaug.video import Transcoder, interlace
from cutcutcodec.core.nn.dataset.video import VideoDataset
from cutcutcodec.core.nn.start import save
from cutcutcodec.core.opti.parallel import map as threaded_map

from .cnn import CNN


[docs] def plot(log: pathlib.Path | str | bytes | None = None): """Draw the training curves.""" def read(filename: pathlib.Path) -> object: with open(filename, encoding="utf-8") as raw: return json.load(raw) log = pathlib.Path(log).expanduser() files = list(log.iterdir()) train_ = {int(re.search(r"\d+", f.stem).group()): read(f) for f in files if "train" in f.name} val = {int(re.search(r"\d+", f.stem).group()): read(f) for f in files if "val" in f.name} epoch = sorted(train_) plt.errorbar( epoch, [np.mean(train_[e]) for e in epoch], yerr=[np.std(train_[e]) for e in epoch], capsize=5.0, label="train", ) epoch = sorted(val) plt.errorbar( epoch, [np.mean(val[e]) for e in epoch], yerr=[np.std(val[e]) for e in epoch], capsize=5.0, label="val", ) plt.xlabel("epoch") plt.ylabel("ssim") plt.legend() plt.show()
[docs] def train( dataset: VideoDataset | pathlib.Path | str | bytes, model: CNN | None = None, log: pathlib.Path | str | bytes | None = None, ): """Train the model. Parameters ---------- dataset : VideoDataset or pathlike The dataset containing the videos and dataaug or the folder. model : CNN, optional Use the default constructor if not provided. log : pathlike The directory to store the logs. Examples -------- >>> from cutcutcodec.core.nn.model.enhancement.train import train >>> # train("~/dataset/video_xiph") >>> """ # preparation if not isinstance(dataset, VideoDataset): dataset = VideoDataset( dataset, dataaug=ChainDataaug([Transcoder("libx264"), interlace], [1, 0.1]), size=100_000_000, # targeted batch data size in bytes shape=(480, 720), max_len=1, # nbr of samples in database # n_min=1, # minimum nbr of frames per video slice ) if model is None: model = CNN() else: assert isinstance(model, CNN), model.__class__.__name__ if log is not None: log = pathlib.Path(log).expanduser() else: log = pathlib.Path(tempfile.gettempdir()) / "log_train" log.mkdir(parents=True, exist_ok=True) # train optim = torch.optim.RAdam(model.parameters(), lr=1e-4, weight_decay=1e-7) train_loader, val_loader = torch.utils.data.random_split(dataset, [0.8, 0.2]) val_loader = list(val_loader) # frozen the validation set train_loader = val_loader = list(train_loader) # for the overfiting test for epoch in itertools.count(): if (log / f"ssim_train_epoch_{epoch:03d}.json").exists(): continue # train model.train() ssims = [] for (distorded, reference) in tqdm.tqdm( threaded_map(train_loader.__getitem__, range(len(train_loader)), maxsize=2), desc=f"epoch {epoch}", total=len(train_loader), ): optim.zero_grad() restored = model(distorded) # reference = torch.cat([ # convert lin rgb to yuv # c[:, :, None] # for c in Colorspace.from_default_working().to_function( # Colorspace.from_default_target(), compile=False # )( # r=reference[:, :, 0], g=reference[:, :, 1], b=reference[:, :, 2] # ) # ], dim=2) # restored = torch.cat([ # convert lin rgb to yuv # c[:, :, None] # for c in Colorspace.from_default_working().to_function( # Colorspace.from_default_target(), compile=False # )( # r=restored[:, :, 0], g=restored[:, :, 1], b=restored[:, :, 2] # ) # ], dim=2) ssim = compute_ssim(reference, restored, weights=(6, 1, 1)) # ssim = -((reference - restored)**2).mean() print(f"ssim: {float(ssim):.3f}") ssims.append(float(ssim)) (1.0 - ssim).sum().backward() optim.step() save(model) with open(log / f"ssim_train_epoch_{epoch:03d}.json", "w", encoding="utf-8") as file: json.dump(ssims, file) # validation model.eval() ssims = [] for i, (distorded, reference) in enumerate(tqdm.tqdm(val_loader)): # compute loss with torch.no_grad(): restored = model(distorded) distorded = torch.cat([ # convert lin rgb to yuv c[:, :, None] for c in Colorspace.from_default_working().to_function( Colorspace.from_default_target(), )( r=distorded[:, :, 0], g=distorded[:, :, 1], b=distorded[:, :, 2], ) ], dim=2) reference = torch.cat([ # convert lin rgb to yuv c[:, :, None] for c in Colorspace.from_default_working().to_function( Colorspace.from_default_target(), )( r=reference[:, :, 0], g=reference[:, :, 1], b=reference[:, :, 2], ) ], dim=2) restored = torch.cat([ # convert lin rgb to yuv c[:, :, None] for c in Colorspace.from_default_working().to_function( Colorspace.from_default_target(), )( r=restored[:, :, 0], g=restored[:, :, 1], b=restored[:, :, 2], ) ], dim=2) ssim = compute_ssim(reference, restored, weights=(6, 1, 1)) ssims.append(float(ssim)) # write images reference = reference[ :, :, 3*(reference.shape[2]//6):3*(reference.shape[2]//6)+3, ].detach() distorded = distorded[ :, :, 3*(distorded.shape[2]//6):3*(distorded.shape[2]//6)+3, ].detach() restored = restored[ :, :, 3*(restored.shape[2]//6):3*(restored.shape[2]//6)+3, ].detach() cv2.imwrite( str(log / f"epoch_{epoch:03d}_{i}_brut.png"), to_rgb( cv2.cvtColor( torch.cat([ c[:, :, None] for c in Colorspace.from_default_target().to_function( Colorspace.from_default_target_rgb(), )( y=reference[:, :, 0], u=reference[:, :, 1], v=reference[:, :, 2], ) ], dim=2).numpy(force=True), cv2.COLOR_RGB2BGR, ), ), ) cv2.imwrite( str(log / f"epoch_{epoch:03d}_{i}_dis.png"), to_rgb( cv2.cvtColor( torch.cat([ c[:, :, None] for c in Colorspace.from_default_target().to_function( Colorspace.from_default_target_rgb(), )( y=distorded[:, :, 0], u=distorded[:, :, 1], v=distorded[:, :, 2], ) ], dim=2).numpy(force=True), cv2.COLOR_RGB2BGR, ), ), ) cv2.imwrite( str(log / f"epoch_{epoch:03d}_{i}_res.png"), to_rgb( cv2.cvtColor( torch.cat([ c[:, :, None] for c in Colorspace.from_default_target().to_function( Colorspace.from_default_target_rgb(), )( y=restored[:, :, 0], u=restored[:, :, 1], v=restored[:, :, 2], ) ], dim=2).numpy(force=True), cv2.COLOR_RGB2BGR, ), ), ) with open(log / f"ssim_val_epoch_{epoch}.json", "w", encoding="utf-8") as file: json.dump(ssims, file)