Source code for cutcutcodec.core.nn.dataset.video

"""A video dataset."""

import math
import numbers
import pathlib
import random
import typing

import torch

from cutcutcodec.core.analysis.stream import optimal_rate_video, optimal_shape_video
from cutcutcodec.core.io import read
from cutcutcodec.core.io.cst import VIDEO_SUFFIXES

from .base import Dataset


[docs] class VideoDataset(Dataset): """A specific dataset to manage sub videos.""" def __init__( self, root: pathlib.Path | str | bytes, shape: tuple[numbers.Integral, numbers.Integral] | None = None, *, dataaug: typing.Callable[[torch.Tensor], torch.Tensor] | None = None, **kwargs, ): """Initialise and create the class. Parameters ---------- root : pathlike Transmitted to ``Dataset`` initialisator. shape : tuple[int, int], optional If given, the video will be truncated and reshape at the shape (height, width). If not provided, the returned shape is the shape of the original video. dataaug : callable, optional If provided, the function is called for each brut readed frames batch (h, w, 3*n). n_min : int, default = 1 The minimum number of frames n_max : int, default = 1024 The maximum number of frames size : int, default = 104857600 The targeted video size in bytes. By default 100 Mio. **kwargs : dict Transmitted to ``Datset`` initialisator. """ assert dataaug is None or callable(dataaug), dataaug.__class__.__name__ if shape is not None: shape = tuple(shape) assert len(shape) == 2, shape assert isinstance(shape[0], numbers.Integral) and isinstance(shape[1], numbers.Integral) assert shape > (0, 0) def _selector(file: pathlib.Path) -> bool: if file.suffix.lower() not in VIDEO_SUFFIXES: return False return kwargs.get("selector", lambda p: True)(file) self.shape = shape self.dataaug = dataaug self.n_min, self.n_max = kwargs.pop("n_min", 1), kwargs.pop("n_max", 1024) assert isinstance(self.n_min, numbers.Integral), self.n_min.__class__.__name__ assert isinstance(self.n_max, numbers.Integral), self.n_max.__class__.__name__ assert 0 < self.n_min <= self.n_max, (self.n_min, self.n_max) self.size = kwargs.pop("size", 104857600) assert isinstance(self.size, numbers.Integral), self.size.__class__.__name__ assert self.size >= 0, self.size self.metadata: dict[pathlib.Path] = {} super().__init__(root, **kwargs, selector=_selector) def __getitem__(self, idx: int) -> torch.Tensor: """Read a random sequence in the video. Examples -------- >>> from cutcutcodec.core.nn.dataset.video import VideoDataset >>> from cutcutcodec.utils import get_project_root >>> dataset = VideoDataset(get_project_root() / "media") >>> len(dataset) 2 >>> dis, ref = dataset[0] >>> dis.shape, ref.shape (torch.Size([64, 64, 3]), torch.Size([64, 64, 3])) >>> """ file = super().__getitem__(idx) with read(file) as container: stream = container.out_select("video")[0] if file not in self.metadata: shape = self.shape or optimal_shape_video(stream) self.metadata[file] = ( optimal_rate_video(stream), stream.duration, shape, max(shape, optimal_shape_video(stream)), ) rate, duration, shape, shape_m = self.metadata[file] n_frames = round(self.size / (12 * shape[0] * shape[1])) n_frames = min(math.floor(duration * rate), self.n_max, max(self.n_min, n_frames)) t_start = random.random() * (duration - n_frames / rate) crop = (random.randint(0, shape_m[0]-shape[0]), random.randint(0, shape_m[1]-shape[1])) ref = [ stream.snapshot(t_start + i/rate, shape_m).convert(3) [crop[0]:crop[0]+shape[0], crop[1]:crop[1]+shape[1], :] for i in range(n_frames) ] while len(ref) < self.n_min: ref.append(ref[-1]) ref = torch.cat(ref, dim=2) if self.dataaug is not None: dis = self.dataaug(ref) else: dis = ref return dis, ref