Source code for cutcutcodec.core.analysis.video.complexity.utils
"""Helper for metrics."""
import functools
import typing
import numpy as np
import torch
[docs]
def batched_frames(func: callable) -> callable:
"""Decorate to vectorize the metrics.
The signature of the metric has to be:
``metric(img: torch.Tensor, *args, **kwargs) -> torch.Tensor``
With img.shape == (batch, fps, height, width, 3).
The returned type is based on ``img`` parameter.
"""
@functools.wraps(func)
def batched_complexity(
img: torch.Tensor | np.ndarray | typing.Iterable[torch.Tensor | np.ndarray],
*args,
**kwargs,
) -> torch.Tensor | np.ndarray | typing.Iterable:
# cast to torch tensor and back to homogeneous returned type
match img:
case np.ndarray():
return batched_complexity(torch.from_numpy(img), *args, **kwargs).numpy(force=True)
case list():
return batched_complexity(torch.asarray(img), *args, **kwargs).tolist()
case tuple() | set() | frozenset():
return img.__class__(batched_complexity(list(img), *args, **kwargs))
case _:
assert isinstance(img, torch.Tensor), img.__class__.__name__
# set shape
assert img.ndim >= 2, \
f"the image requires at least 2 dimensions (height, width) vs {img.shape}"
if img.ndim == 2:
img = img[:, :, None].repeat(1, 1, 3) # assume y single chanel
if img.ndim == 3:
img = img[None, :, :, :] # only one frame is spatial complexity only, (no temporal)
*batch, fps, height, width, channels = img.shape
img = img.reshape(-1, fps, height, width, channels)
if channels == 1:
img = img.expand(-1, -1, -1, -1, 3)
else:
assert channels == 3, f"the image requires 1 or 3 channels, {img.shape}"
# apply func
res = func(img, *args, **kwargs)
# back to unfolded shape
res = res.reshape((*batch, *res.shape[1:]))
return res
return batched_complexity