cutcutcodec.core.nn.model.enhancement.cnn.CNN

class cutcutcodec.core.nn.model.enhancement.cnn.CNN(dropout: float = 0.05, **kwargs)[source]

Improve RGB image quality keeping the resolution.

Initialise the layers.

Parameters

dropoutfloat, default=0.05

The dropout rate after all layers.

forward(video: Tensor) Tensor[source]

Improve the quality of the middle frame of the 5 consecutives rgb frames.

Parameters

videotorch.Tensor

The contanenation of 5 or more video batched frames in standard rgb linear format of shape (n, h, w, 3*f).

Returns

middle_frametorch.Tensor

The enhanced third frame of the sequence, of shape (n, h, w, 3*f).

Examples

>>> import torch
>>> from cutcutcodec.core.nn.model.enhancement.cnn import CNN
>>> CNN()(torch.rand((2, 128, 256, 15))).shape
torch.Size([2, 128, 256, 15])
>>>