Source code for cutcutcodec.core.nn.dataaug.chain
"""Merge several data augmentation together."""
import numbers
import random
import typing
import torch
[docs]
class ChainDataaug:
"""Random applycation of the dataaug.
Attributes
----------
proba : list[float]
The probability list (readonly).
"""
def __init__(
self,
dataaugs: typing.Iterable[typing.Callable[[torch.Tensor], torch.Tensor]],
probas: typing.Iterable[numbers.Real],
):
"""Initialise the selector.
Parameters
----------
dataaugs : list[callable]
The dataaugs chain.
probas : list[float], optional
The probabilities for the dataaugs to be applyed.
"""
assert hasattr(dataaugs, "__iter__"), dataaugs.__class__.__name__
dataaugs = list(dataaugs)
assert all(callable(d) for d in dataaugs), dataaugs
if probas is None:
probas = [1.0 for _ in range(len(dataaugs))]
else:
assert hasattr(probas, "__iter__"), probas.__class__.__name__
probas = list(probas)
assert all(isinstance(p, numbers.Real) for p in probas), probas
assert all(0.0 <= p <= 1.0 for p in probas), probas
assert len(dataaugs) == len(probas), (dataaugs, probas)
self.dataaugs = dataaugs
self._probas = probas
def __call__(self, data: torch.Tensor) -> torch.Tensor:
"""Apply the dataaugs."""
for dataaug, proba in zip(self.dataaugs, self._probas):
if random.random() <= proba:
data = dataaug(data)
return data
@property
def proba(self) -> list[float]:
"""Return the probability list."""
return self._proba