"""Utils to train and create a dataset."""
import itertools
import json
import pathlib
import re
import tempfile
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from cutcutcodec.core.analysis.video.quality import ssim as compute_ssim
from cutcutcodec.core.classes.colorspace import Colorspace
from cutcutcodec.core.io.framecaster import to_rgb
from cutcutcodec.core.nn.dataaug.chain import ChainDataaug
from cutcutcodec.core.nn.dataaug.video import Transcoder, interlace
from cutcutcodec.core.nn.dataset.video import VideoDataset
from cutcutcodec.core.nn.start import save
from cutcutcodec.core.opti.parallel import map as threaded_map
from .cnn import CNN
[docs]
def plot(log: pathlib.Path | str | bytes | None = None):
"""Draw the training curves."""
def read(filename: pathlib.Path) -> object:
with open(filename, encoding="utf-8") as raw:
return json.load(raw)
log = pathlib.Path(log).expanduser()
files = list(log.iterdir())
train_ = {int(re.search(r"\d+", f.stem).group()): read(f) for f in files if "train" in f.name}
val = {int(re.search(r"\d+", f.stem).group()): read(f) for f in files if "val" in f.name}
epoch = sorted(train_)
plt.errorbar(
epoch,
[np.mean(train_[e]) for e in epoch],
yerr=[np.std(train_[e]) for e in epoch],
capsize=5.0,
label="train",
)
epoch = sorted(val)
plt.errorbar(
epoch,
[np.mean(val[e]) for e in epoch],
yerr=[np.std(val[e]) for e in epoch],
capsize=5.0,
label="val",
)
plt.xlabel("epoch")
plt.ylabel("ssim")
plt.legend()
plt.show()
[docs]
def train(
dataset: VideoDataset | pathlib.Path | str | bytes,
model: CNN | None = None,
log: pathlib.Path | str | bytes | None = None,
):
"""Train the model.
Parameters
----------
dataset : VideoDataset or pathlike
The dataset containing the videos and dataaug or the folder.
model : CNN, optional
Use the default constructor if not provided.
log : pathlike
The directory to store the logs.
Examples
--------
>>> from cutcutcodec.core.nn.model.enhancement.train import train
>>> # train("~/dataset/video_xiph")
>>>
"""
# preparation
if not isinstance(dataset, VideoDataset):
dataset = VideoDataset(
dataset,
dataaug=ChainDataaug([Transcoder("libx264"), interlace], [1, 0.1]),
size=100_000_000, # targeted batch data size in bytes
shape=(480, 720),
max_len=1, # nbr of samples in database
# n_min=1, # minimum nbr of frames per video slice
)
if model is None:
model = CNN()
else:
assert isinstance(model, CNN), model.__class__.__name__
if log is not None:
log = pathlib.Path(log).expanduser()
else:
log = pathlib.Path(tempfile.gettempdir()) / "log_train"
log.mkdir(parents=True, exist_ok=True)
# train
optim = torch.optim.RAdam(model.parameters(), lr=1e-4, weight_decay=1e-7)
train_loader, val_loader = torch.utils.data.random_split(dataset, [0.8, 0.2])
val_loader = list(val_loader) # frozen the validation set
train_loader = val_loader = list(train_loader) # for the overfiting test
for epoch in itertools.count():
if (log / f"ssim_train_epoch_{epoch:03d}.json").exists():
continue
# train
model.train()
ssims = []
for (distorded, reference) in tqdm.tqdm(
threaded_map(train_loader.__getitem__, range(len(train_loader)), maxsize=2),
desc=f"epoch {epoch}",
total=len(train_loader),
):
optim.zero_grad()
restored = model(distorded)
# reference = torch.cat([ # convert lin rgb to yuv
# c[:, :, None]
# for c in Colorspace.from_default_working().to_function(
# Colorspace.from_default_target(), compile=False
# )(
# r=reference[:, :, 0], g=reference[:, :, 1], b=reference[:, :, 2]
# )
# ], dim=2)
# restored = torch.cat([ # convert lin rgb to yuv
# c[:, :, None]
# for c in Colorspace.from_default_working().to_function(
# Colorspace.from_default_target(), compile=False
# )(
# r=restored[:, :, 0], g=restored[:, :, 1], b=restored[:, :, 2]
# )
# ], dim=2)
ssim = compute_ssim(reference, restored, weights=(6, 1, 1))
# ssim = -((reference - restored)**2).mean()
print(f"ssim: {float(ssim):.3f}")
ssims.append(float(ssim))
(1.0 - ssim).sum().backward()
optim.step()
save(model)
with open(log / f"ssim_train_epoch_{epoch:03d}.json", "w", encoding="utf-8") as file:
json.dump(ssims, file)
# validation
model.eval()
ssims = []
for i, (distorded, reference) in enumerate(tqdm.tqdm(val_loader)):
# compute loss
with torch.no_grad():
restored = model(distorded)
distorded = torch.cat([ # convert lin rgb to yuv
c[:, :, None]
for c in Colorspace.from_default_working().to_function(
Colorspace.from_default_target(),
)(
r=distorded[:, :, 0], g=distorded[:, :, 1], b=distorded[:, :, 2],
)
], dim=2)
reference = torch.cat([ # convert lin rgb to yuv
c[:, :, None]
for c in Colorspace.from_default_working().to_function(
Colorspace.from_default_target(),
)(
r=reference[:, :, 0], g=reference[:, :, 1], b=reference[:, :, 2],
)
], dim=2)
restored = torch.cat([ # convert lin rgb to yuv
c[:, :, None]
for c in Colorspace.from_default_working().to_function(
Colorspace.from_default_target(),
)(
r=restored[:, :, 0], g=restored[:, :, 1], b=restored[:, :, 2],
)
], dim=2)
ssim = compute_ssim(reference, restored, weights=(6, 1, 1))
ssims.append(float(ssim))
# write images
reference = reference[
:, :, 3*(reference.shape[2]//6):3*(reference.shape[2]//6)+3,
].detach()
distorded = distorded[
:, :, 3*(distorded.shape[2]//6):3*(distorded.shape[2]//6)+3,
].detach()
restored = restored[
:, :, 3*(restored.shape[2]//6):3*(restored.shape[2]//6)+3,
].detach()
cv2.imwrite(
str(log / f"epoch_{epoch:03d}_{i}_brut.png"),
to_rgb(
cv2.cvtColor(
torch.cat([
c[:, :, None]
for c in Colorspace.from_default_target().to_function(
Colorspace.from_default_target_rgb(),
)(
y=reference[:, :, 0], u=reference[:, :, 1], v=reference[:, :, 2],
)
], dim=2).numpy(force=True),
cv2.COLOR_RGB2BGR,
),
),
)
cv2.imwrite(
str(log / f"epoch_{epoch:03d}_{i}_dis.png"),
to_rgb(
cv2.cvtColor(
torch.cat([
c[:, :, None]
for c in Colorspace.from_default_target().to_function(
Colorspace.from_default_target_rgb(),
)(
y=distorded[:, :, 0], u=distorded[:, :, 1], v=distorded[:, :, 2],
)
], dim=2).numpy(force=True),
cv2.COLOR_RGB2BGR,
),
),
)
cv2.imwrite(
str(log / f"epoch_{epoch:03d}_{i}_res.png"),
to_rgb(
cv2.cvtColor(
torch.cat([
c[:, :, None]
for c in Colorspace.from_default_target().to_function(
Colorspace.from_default_target_rgb(),
)(
y=restored[:, :, 0], u=restored[:, :, 1], v=restored[:, :, 2],
)
], dim=2).numpy(force=True),
cv2.COLOR_RGB2BGR,
),
),
)
with open(log / f"ssim_val_epoch_{epoch}.json", "w", encoding="utf-8") as file:
json.dump(ssims, file)