Source code for cutcutcodec.core.analysis.video.quality.utils

"""Helper for metrics."""

import functools
import typing

import numpy as np
import torch


[docs] def batched_comparative_frames(func: callable) -> callable: """Decorate to vectorize the metrics. The signature of the metric has to be: ``metric(dis: torch.Tensor, ref: torch.Tensor, *args, **kwargs) -> torch.Tensor`` With ref.shape == dis.shape == (batch, height, width, channels). The returned type is based on ``dis`` parameter. """ @functools.wraps(func) def batched_metric( dis: torch.Tensor | np.ndarray | typing.Iterable[torch.Tensor | np.ndarray], ref: torch.Tensor | np.ndarray | typing.Iterable[torch.Tensor | np.ndarray], *args, **kwargs, ) -> torch.Tensor | np.ndarray | typing.Iterable: # cast to torch tensor and back to homogeneous returned type ref = torch.asarray(ref) match dis: case np.ndarray(): return batched_metric(ref, torch.from_numpy(dis), *args, **kwargs).numpy(force=True) case list(): return batched_metric(ref, torch.asarray(dis), *args, **kwargs).tolist() case tuple() | set() | frozenset(): return dis.__class__(batched_metric(ref, list(dis), *args, **kwargs)) case _: assert isinstance(dis, torch.Tensor), dis.__class__.__name__ # set shape while ref.ndim < 3: ref = ref.unsqueeze(-1) while dis.ndim < 3: dis = dis.unsqueeze(-1) ref, dis = torch.broadcast_tensors(ref, dis) *batch, height, width, channels = dis.shape ref = ref.reshape(-1, height, width, channels) dis = dis.reshape(-1, height, width, channels) # apply func res = func(ref, dis, *args, **kwargs) # back to unfolded shape res = res.reshape(batch) return res return batched_metric
[docs] def batched_single_frames(func: callable) -> callable: """Decorate to vectorize the metrics. The signature of the metric has to be: ``metric(dis: torch.Tensor, *args, **kwargs) -> torch.Tensor`` With dis.shape == (batch, 5, height, width, 3). The returned type is based on ``dis`` parameter. """ @functools.wraps(func) def batched_metric( dis: torch.Tensor | np.ndarray | typing.Iterable[torch.Tensor | np.ndarray], *args, **kwargs, ) -> torch.Tensor | np.ndarray | typing.Iterable: # cast to torch tensor and back to homogeneous returned type match dis: case np.ndarray(): return batched_metric(torch.from_numpy(dis), *args, **kwargs).numpy(force=True) case list(): return batched_metric(torch.asarray(dis), *args, **kwargs).tolist() case tuple() | set() | frozenset(): return dis.__class__(batched_metric(list(dis), *args, **kwargs)) case _: assert isinstance(dis, torch.Tensor), dis.__class__.__name__ # set shape assert dis.ndim >= 4, dis.shape *batch, fps, height, width, channels = dis.shape assert fps == 5, dis.shape assert channels == 3, dis.shape *batch, fps, height, width, channels = dis.shape dis = dis.reshape(-1, fps, height, width, channels) # apply func res = func(dis, *args, **kwargs) # back to unfolded shape res = res.reshape(batch) return res return batched_metric