Source code for cutcutcodec.core.nn.loader

#!/usr/bin/env python3

"""Implement some data-loader."""

import logging
import numbers
import pathlib
import typing

import torch

from cutcutcodec.core.classes.frame_video import FrameVideo
from cutcutcodec.core.filter.mix.video_cast import to_bgr
from cutcutcodec.core.filter.video.resize import resize
from cutcutcodec.core.io import IMAGE_SUFFIXES
from cutcutcodec.core.io.read_image import read_image


[docs] class Dataset(torch.utils.data.Dataset): """Select files managing the probability. Examples -------- >>> from cutcutcodec.core.nn.loader import Dataset >>> from cutcutcodec.utils import get_project_root >>> def selector(path): ... return path.suffix == ".py" ... >>> dataset = Dataset(get_project_root(), selector, max_size=128) >>> len(dataset) 128 >>> dataset[0].relative_to(get_project_root()) PosixPath('__init__.py') >>> dataset[1].relative_to(get_project_root()) PosixPath('__main__.py') >>> dataset[2].relative_to(get_project_root()) PosixPath('utils.py') >>> dataset[3].relative_to(get_project_root()) PosixPath('config/__init__.py') >>> dataset[4].relative_to(get_project_root()) PosixPath('core/__init__.py') >>> dataset[5].relative_to(get_project_root()) PosixPath('testing/__init__.py') >>> dataset[6].relative_to(get_project_root()) PosixPath('config/config.py') >>> """ def __init__( self, root: typing.Union[str, bytes, pathlib.Path], selector: typing.Callable[[pathlib.Path], bool], **kwargs, ): """Initialise and create the class. Parameters ---------- root : pathlike The root folder containing all the files of the dataset. selector : callable Function that take a file pathlib.Path and return True to keep it or False to reject. follow_symlinks : bool, default=False Follow the symbolink links if set to True. max_size : int, optional The maximum number of files contained in the dataset. decision_depth : int, default=1 The thresold level befor to flatten the tree. If 0, all the file have the same proba to be drawn. If 1, the decision tree has only one root node If n, the decision tree has a maximum of n decks. """ root = pathlib.Path(root).expanduser().resolve() assert root.is_dir(), root assert callable(selector), selector.__class__.__name__ assert isinstance(kwargs.get("follow_symlinks", False), bool), \ kwargs["follow_symlinks"].__class__.__name__ if kwargs.get("max_size", None) is not None: assert isinstance(kwargs["max_size"], numbers.Integral), \ kwargs["max_size"].__class__.__name__ assert kwargs["max_size"] > 0, kwargs["max_size"] assert isinstance(kwargs.get("decision_depth", 1), numbers.Integral), \ kwargs["decision_depth"].__class__.__name__ assert kwargs.get("decision_depth", 1) >= 0, kwargs["decision_depth"] self._root = root self._selector = selector self._follow_symlinks = kwargs.get("follow_symlinks", False) self._max_size = None if kwargs.get("max_size", None) is None else int(kwargs["max_size"]) self._decision_depth = int(kwargs.get("decision_depth", 1)) self._tree: list[typing.Union[pathlib.Path, list]] = self.scan() def __getitem__(self, idx: int, *, _tree=None) -> pathlib.Path: """Pick out a file from the dataset. Parameters ---------- idx : int The index of the file, has to be in [0, len(self)[. Returns ------- file : pathlib.Path The absolute path of the file. Notes ----- This method should be overwritten. """ assert isinstance(idx, int), idx.__class__.__name__ tree = _tree or self._tree files = [f for f in tree if isinstance(f, pathlib.Path)] dirs_len = len(tree) - len(files) # assume sorted files then dirs if not dirs_len: file = files[idx % len(files)] logging.info("the file %s if yield twice", file) return file if idx < len(files): return files[idx] idx, dir_idx = divmod(idx-len(files), dirs_len) return Dataset.__getitem__(self, idx, _tree=tree[dir_idx+len(files)]) def __len__(self, *, _tree=None) -> int: """Return the number of images contained in the dataset.""" tree = _tree or self._tree size = sum(1 if isinstance(e, pathlib.Path) else self.__len__(_tree=e) for e in tree) if self._max_size: size = min(self._max_size, size) return size
[docs] def scan(self, *, _root=None, _depth=0) -> list[typing.Union[pathlib.Path, list]]: """Rescan the dataset to update the properties.""" if _root is None: self._tree = [] tree = self._tree # reference root = _root or self._root else: tree = [] root = _root # scan items = sorted(root.iterdir()) tree.extend(f for f in items if f.is_file() and self._selector(f)) dirs = [ self.scan(_root=d, _depth=_depth+1) for d in items if d.is_dir() or (self._follow_symlinks and d.is_symlink()) ] # filter and flatten if _depth >= self._decision_depth: tree.extend(f for d in dirs if d for f in d) else: tree.extend(d for d in dirs if d) return tree
[docs] class ImageDataset(Dataset): """A specific dataset for managing images.""" def __init__( self, root: typing.Union[str, bytes, pathlib.Path], shape: typing.Union[tuple[numbers.Integral, numbers.Integral], list[numbers.Integral]], *, dataaug: typing.Optional[typing.Callable[[FrameVideo], FrameVideo]] = None, **kwargs, ): """Initialise and create the class. Parameters ---------- root : pathlike Transmitted to ``Dataset`` initialisator. shape : int and int The pixel dimensions of the returned image. The image will be random reshaped and random cropped to reach this final shape. The convention adopted is the numpy convention (height, width). dataaug : callable, optional If provided, the function is called for each brut readed image before normalization. **kwargs : dict Transmitted to ``Datset`` initialisator. """ assert isinstance(shape, (tuple, list)), shape.__class__.__name__ assert len(shape) == 2, len(shape) assert all(isinstance(s, numbers.Integral) and s >= 1 for s in shape), shape assert dataaug is None or callable(dataaug), dataaug.__class__.__name__ def _selector(file: pathlib.Path) -> bool: if file.suffix.lower() not in IMAGE_SUFFIXES: return False return kwargs.get("selector", lambda p: True)(file) super().__init__(root, **kwargs, selector=_selector) self.shape = (int(shape[0]), int(shape[1])) self.dataaug = dataaug def __getitem__(self, idx: int) -> torch.Tensor: """Read the image of index ``idx``. Parameters ---------- idx : int Transmitted to ``Datset.__getitem__``. Returns ------- image : torch.Tensor The readed augmented and converted throw the method ``ImageDataset.normalize``. """ file = super().__getitem__(idx) img = FrameVideo(0, read_image(file)) if self.dataaug is not None: img = self.dataaug(img) img = self.normalize(img) return img
[docs] def normalize(self, image: FrameVideo) -> torch.Tensor: """Pipeline to normalize any image for batching. The normalization consists in: * Resize with deformation to reach the final size. * Convertion into torch float32 with good dynamic. * Convert into torch tensor. * Convert into BGR. * Depth channel first: (H, W, C) -> (C, H, W) Parameters ---------- image : cutcutcodec.core.classes.frame_video.FrameVideo The input brut image of any shape, channels and dtype. Returns ------- normlized_image : torch.Tensor The normalized float32 image of shape (3, final_height, final_width). """ assert isinstance(image, FrameVideo), image.__class__.__name__ img = torch.as_tensor(image) resize_after = self.shape[0]*self.shape[1] > img.shape[0]*img.shape[1] bgr_after = image.shape[2] > 3 # shape normalization if not resize_after: img = resize(img, self.shape, copy=False) if not bgr_after: img = to_bgr(img) if resize_after: img = resize(img, self.shape, copy=False) if bgr_after: img = to_bgr(img) # convert into float32 if img.dtype == torch.uint8: img = img.to(torch.float32) img /= 255.0 elif img.dtype != torch.float32: img = img.to(torch.float32) # convert into bgr (C, H, W) img = torch.moveaxis(img, 2, 0) return img