Source code for cutcutcodec.core.nn.dataset.base

"""Basic generic dataloader."""

import logging
import numbers
import pathlib
import typing

import torch


[docs] class Dataset(torch.utils.data.Dataset): """Select files managing the probability. Examples -------- >>> from cutcutcodec.core.nn.dataset.base import Dataset >>> from cutcutcodec.utils import get_project_root >>> def selector(path) -> bool: ... return path.suffix == ".py" ... >>> dataset = Dataset(get_project_root(), selector, max_len=128) >>> len(dataset) 128 >>> dataset[0].relative_to(get_project_root()) PosixPath('__init__.py') >>> dataset[1].relative_to(get_project_root()) PosixPath('__main__.py') >>> dataset[2].relative_to(get_project_root()) PosixPath('doc.py') >>> dataset[3].relative_to(get_project_root()) PosixPath('utils.py') >>> dataset[4].relative_to(get_project_root()) PosixPath('cli/__init__.py') >>> dataset[-1].relative_to(get_project_root()) PosixPath('utils.py') >>> """ def __init__( self, root: pathlib.Path | str | bytes, selector: typing.Callable[[pathlib.Path], bool], **kwargs, ): """Initialise and create the class. Parameters ---------- root : pathlike The root folder containing all the files of the dataset. selector : callable Function that take a file pathlib.Path and return True to keep it or False to reject. follow_symlinks : bool, default=False Follow the symbolink links if set to True. max_len : int, optional The maximum number of files contained in the dataset. decision_depth : int, default=1 The threshold level before to flatten the tree. If 0, all the file have the same proba to be drawn. If 1, the decision tree has only one root node If n, the decision tree has a maximum of n decks. """ root = pathlib.Path(root).expanduser().resolve() assert root.is_dir(), root assert callable(selector), selector.__class__.__name__ assert isinstance(kwargs.get("follow_symlinks", False), bool), \ kwargs["follow_symlinks"].__class__.__name__ if kwargs.get("max_len") is not None: assert isinstance(kwargs["max_len"], numbers.Integral), \ kwargs["max_len"].__class__.__name__ assert kwargs["max_len"] > 0, kwargs["max_len"] assert isinstance(kwargs.get("decision_depth", 1), numbers.Integral), \ kwargs["decision_depth"].__class__.__name__ assert kwargs.get("decision_depth", 1) >= 0, kwargs["decision_depth"] self._root = root self._selector = selector self._follow_symlinks = kwargs.get("follow_symlinks", False) self._max_len = None if kwargs.get("max_len") is None else int(kwargs["max_len"]) self._decision_depth = int(kwargs.get("decision_depth", 1)) self._tree: list[pathlib.Path | list] = self.scan() def __getitem__(self, idx: int, *, _tree=None) -> pathlib.Path: """Pick out a file from the dataset. Parameters ---------- idx : int The index of the file, has to be in [0, len(self)[. Returns ------- file : pathlib.Path The absolute path of the file. Notes ----- This method should be overwritten. """ assert isinstance(idx, int), idx.__class__.__name__ tree = _tree or self._tree files = [f for f in tree if isinstance(f, pathlib.Path)] dirs_len = len(tree) - len(files) # assume sorted files then dirs if not dirs_len: file = files[idx % len(files)] logging.info("the file %s if yield twice", file) return file if idx < len(files): return files[idx] idx, dir_idx = divmod(idx-len(files), dirs_len) return Dataset.__getitem__(self, idx, _tree=tree[dir_idx+len(files)]) def __len__(self, *, _tree=None) -> int: """Return the number of images contained in the dataset.""" tree = _tree or self._tree size = sum(1 if isinstance(e, pathlib.Path) else self.__len__(_tree=e) for e in tree) if self._max_len: size = min(self._max_len, size) return size
[docs] def scan(self, *, _root=None, _depth=0) -> list[pathlib.Path | list]: """Rescan the dataset to update the properties.""" if _root is None: self._tree = [] tree = self._tree # reference root = _root or self._root else: tree = [] root = _root # scan items = sorted(root.iterdir()) tree.extend(f for f in items if f.is_file() and self._selector(f)) dirs = [ self.scan(_root=d, _depth=_depth+1) for d in items if d.is_dir() or (self._follow_symlinks and d.is_symlink()) ] # filter and flatten if _depth >= self._decision_depth: tree.extend(f for d in dirs if d for f in d) else: tree.extend(d for d in dirs if d) return tree