Source code for cutcutcodec.core.analysis.video.complexity.utils

"""Helper for metrics."""

import functools
import typing

import numpy as np
import torch


[docs] def batched_frames(func: callable) -> callable: """Decorate to vectorize the metrics. The signature of the metric has to be: ``metric(img: torch.Tensor, *args, **kwargs) -> torch.Tensor`` With img.shape == (batch, fps, height, width, 3). The returned type is based on ``img`` parameter. """ @functools.wraps(func) def batched_complexity( img: torch.Tensor | np.ndarray | typing.Iterable[torch.Tensor | np.ndarray], *args, **kwargs, ) -> torch.Tensor | np.ndarray | typing.Iterable: # cast to torch tensor and back to homogeneous returned type match img: case np.ndarray(): return batched_complexity(torch.from_numpy(img), *args, **kwargs).numpy(force=True) case list(): return batched_complexity(torch.asarray(img), *args, **kwargs).tolist() case tuple() | set() | frozenset(): return img.__class__(batched_complexity(list(img), *args, **kwargs)) case _: assert isinstance(img, torch.Tensor), img.__class__.__name__ # set shape assert img.ndim >= 2, \ f"the image requires at least 2 dimensions (height, width) vs {img.shape}" if img.ndim == 2: img = img[:, :, None].repeat(1, 1, 3) # assume y single chanel if img.ndim == 3: img = img[None, :, :, :] # only one frame is spatial complexity only, (no temporal) *batch, fps, height, width, channels = img.shape img = img.reshape(-1, fps, height, width, channels) if channels == 1: img = img.expand(-1, -1, -1, -1, 3) else: assert channels == 3, f"the image requires 1 or 3 channels, {img.shape}" # apply func res = func(img, *args, **kwargs) # back to unfolded shape res = res.reshape((*batch, *res.shape[1:])) return res return batched_complexity