Source code for cutcutcodec.core.classes.frame
"""Defines the structure of a base frame, inerit from torch array."""
import abc
import logging
import typing
import numpy as np
import torch
[docs]
class Frame(torch.Tensor):
"""A General Frame.
Attributes
----------
context : object
Any information to throw during the transformations.
"""
def __new__(
cls,
data: torch.Tensor | np.ndarray | typing.Container,
context: object = None,
**kwargs,
):
"""Initialise and create the class.
Parameters
----------
context : object
Any value to throw between the tensor operations.
data : arraylike
The data to use for this array. Do not copy if it is possible
**kwargs : dict
Transmitted to the `torch.Tensor` initialisator.
"""
if isinstance(data, torch.Tensor):
frame = super().__new__(cls, data, **kwargs) # no copy
frame.context = context
return frame
if isinstance(data, np.ndarray):
return Frame.__new__(cls, torch.from_numpy(data), context=context, **kwargs) # no copy
logging.warning("please only intitialize a frame from a torch tensor or a numpy ndarray")
return Frame.__new__(cls, torch.tensor(data), context=context, **kwargs) # copy
def __repr__(self):
"""Allow to add context to the display.
Examples
--------
>>> from cutcutcodec.core.classes.frame import Frame
>>> Frame([0.0, 1.0, 2.0], context="context_value")
Frame([0., 1., 2.], context='context_value')
>>>
"""
base = super().__repr__()
return f"{base[:-1]}, context={self.context!r})"
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
"""Enable to throw `context` into the new generations.
Examples
--------
>>> import torch
>>> from cutcutcodec.core.classes.frame import Frame
>>> class Frame_(Frame):
... def check_state(self):
... assert self.item() # just for the example
...
>>>
>>> # transmission context
>>> (frame := Frame_([.5], context="context_value"))
Frame_([0.5000], context='context_value')
>>> frame.clone() # deep copy
Frame_([0.5000], context='context_value')
>>> torch.sin(frame) # external call
Frame_([0.4794], context='context_value')
>>> frame / 2 # internal method
Frame_([0.2500], context='context_value')
>>> frame.numpy() # cast in an other type
array([0.5], dtype=float32)
>>> frame *= 2 # inplace
>>> frame
Frame_([1.], context='context_value')
>>>
>>> # cast if state not correct
>>> torch.concatenate([frame, frame], axis=0) #
tensor([1., 1.])
>>> frame * 0 # no correct because has to be != 0
tensor([0.])
>>> frame *= 0
>>> frame
tensor([0.])
>>>
"""
if kwargs is None:
kwargs = {}
result = super().__torch_function__(func, types, args, kwargs)
if isinstance(result, cls):
if isinstance(args[0], cls): # args[0] is self
result.context = args[0].context # args[0] is self
try:
result.check_state()
except AssertionError:
return torch.Tensor(result)
else:
return torch.Tensor(result)
return result
[docs]
@abc.abstractmethod
def check_state(self) -> None:
"""Apply verifications.
Raises
------
AssertionError
If something wrong in this frame.
"""
raise NotImplementedError
@property
def shape(self) -> tuple[int, ...]:
"""Solve pylint error E1136: Value 'self.shape' is unsubscriptable."""
return tuple(super().shape)