Source code for cutcutcodec.core.nn.dataset.image

"""An image dataset."""

import numbers
import pathlib
import typing

import torch

from cutcutcodec.core.analysis.stream import optimal_shape_video
from cutcutcodec.core.classes.frame_video import FrameVideo
from cutcutcodec.core.io import read
from cutcutcodec.core.io.cst import IMAGE_SUFFIXES

from .base import Dataset


[docs] class ImageDataset(Dataset): """A specific dataset for managing images.""" def __init__( self, root: pathlib.Path | str | bytes, shape: tuple[numbers.Integral, numbers.Integral] | list[numbers.Integral], *, dataaug: typing.Callable[[FrameVideo], FrameVideo] | None = None, **kwargs, ): """Initialise and create the class. Parameters ---------- root : pathlike Transmitted to ``Dataset`` initialisator. shape : int and int The pixel dimensions of the returned image. The image will be random reshaped and random cropped to reach this final shape. The convention adopted is the numpy convention (height, width). dataaug : callable, optional If provided, the function is called for each brut readed image before normalization. **kwargs : dict Transmitted to ``Datset`` initialisator. """ assert isinstance(shape, (tuple, list)), shape.__class__.__name__ assert len(shape) == 2, len(shape) assert all(isinstance(s, numbers.Integral) and s >= 1 for s in shape), shape assert dataaug is None or callable(dataaug), dataaug.__class__.__name__ def _selector(file: pathlib.Path) -> bool: if file.suffix.lower() not in IMAGE_SUFFIXES: return False return kwargs.get("selector", lambda p: True)(file) super().__init__(root, **kwargs, selector=_selector) self.shape = (int(shape[0]), int(shape[1])) self.dataaug = dataaug def __getitem__(self, idx: int) -> torch.Tensor: """Read the image of index ``idx``. Parameters ---------- idx : int Transmitted to ``Datset.__getitem__``. Returns ------- image : torch.Tensor The readed augmented and converted throw the method ``ImageDataset.normalize``. """ file = super().__getitem__(idx) with read(file) as container: stream = container.out_select("video")[0] img = stream.snapshot(0, self.shape or optimal_shape_video(stream)) if self.dataaug is not None: img = self.dataaug(img) return img