Source code for cutcutcodec.core.analysis.video.metrics

"""Gathers all video metrics."""

import itertools
import logging
import os
import pathlib
import threading
from fractions import Fraction

import torch
import tqdm

from cutcutcodec.config.config import Config
from cutcutcodec.core.analysis.stream.rate_video import optimal_rate_video
from cutcutcodec.core.analysis.stream.shape import optimal_shape_video
from cutcutcodec.core.analysis.video.properties import get_duration_video
from cutcutcodec.core.classes.colorspace import Colorspace
from cutcutcodec.core.classes.frame_video import FrameVideo
from cutcutcodec.core.classes.stream_video import StreamVideo
from cutcutcodec.core.exceptions import OutOfTimeRange
from cutcutcodec.core.filter.video.colorspace import FilterVideoColorspace
from cutcutcodec.core.io.read_ffmpeg import ContainerInputFFMPEG
from cutcutcodec.core.opti.parallel.buffer import _FuncEvalThread, starmap
from cutcutcodec.utils import mround

__all__ = ["video_metrics"]


TORCH_LOCK = threading.Lock()


def _fill_missing_shape_and_rate(needs: dict[str, list], dis: StreamVideo, ref: StreamVideo | None):
    """Help for _yield_frames, change needs inplace."""
    for metric, (_, comparative, space, shape, _, rate) in needs.items():
        if shape is None:
            needs[metric][3] = (
                optimal_shape_video((ref if comparative else dis)[space]) or (720, 1080)
            )
        if rate is None:
            needs[metric][5] = (
                optimal_rate_video((ref if comparative else dis)[space]) or Fraction(3000, 1001)
            )


def _lpips_alex(dis: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
    from .quality import lpips
    with TORCH_LOCK:
        return lpips(dis, ref, net="alex", threads=os.cpu_count())


def _lpips_vgg(dis: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
    from .quality import lpips
    with TORCH_LOCK:
        return lpips(dis, ref, net="vgg", threads=os.cpu_count())


def _psnr(dis: torch.Tensor, ref: torch.Tensor ) -> torch.Tensor:
    from .quality import psnr
    # the factors comes from https://github.com/fraunhoferhhi/vvenc/wiki/Encoder-Performance
    # and https://compression.ru/video/codec_comparison/2022/10_bit_report.html
    return psnr(dis, ref, weights=(6, 1, 1), threads=1)


def _read_paths(func: callable):
    """Decorate _yield_frames."""
    def _decorated_yield_frames(
        dis: pathlib.Path, ref: pathlib.Path or None, needs: dict[str],
    ) -> tuple[str, FrameVideo, FrameVideo] | tuple[str, FrameVideo]:
        if ref is None:
            with ContainerInputFFMPEG(dis) as dis_cont:
                dis_stream = dis_cont.out_select("video")[0]
                yield from func(dis_stream, None, needs)
        else:
            with ContainerInputFFMPEG(dis) as dis_cont, ContainerInputFFMPEG(ref) as ref_cont:
                dis_stream = dis_cont.out_select("video")[0]
                ref_stream = ref_cont.out_select("video")[0]
                yield from func(dis_stream, ref_stream, needs)
    return _decorated_yield_frames


def _ssim(dis: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
    from .quality import ssim
    # the factors comes from https://github.com/fraunhoferhhi/vvenc/wiki/Encoder-Performance
    # and https://compression.ru/video/codec_comparison/2022/10_bit_report.html
    return ssim(dis, ref, weights=(6, 1, 1), threads=1, data_range=1.0)


def _uvq(dis: torch.Tensor) -> torch.Tensor:
    from .quality import uvq
    with TORCH_LOCK:
        return uvq(dis, threads=max(2, os.cpu_count()//2))


@_read_paths
def _yield_frames(
    dis: StreamVideo, ref: StreamVideo | None, needs: dict[str, list],
) -> tuple[str, FrameVideo, FrameVideo] | tuple[str, FrameVideo]:
    """Read the file to yield the well formated frames."""
    # define the correct colour space for the reference
    ref_streams = {
        space: FilterVideoColorspace(
            [ref],
            Colorspace(
                space,
                ref.colorspace.primaries or Config().target_prim,
                ref.colorspace.transfer or Config().target_trc,
            ),
        ).out_streams[0]
        for space in {space for _, comparative, space, _, _, _ in needs.values() if comparative}
    }
    ref_streams = {  # optional squeeze no change colorspace
        s: ref if ref.colorspace == r.colorspace else r for s, r in ref_streams.items()
    }

    # define the correct color space for the distorded
    dis_streams = {
        space: FilterVideoColorspace(
            [dis],
            Colorspace(
                space,
                (
                    ref_streams[space].colorspace.primaries
                    if space in ref_streams else
                    (dis.colorspace.primaries or Config().target_prim)
                ),
                (
                    ref_streams[space].colorspace.transfer
                    if space in ref_streams else
                    (dis.colorspace.transfer or Config().target_trc)
                ),
            ),
        ).out_streams[0]
        for space in {space for _, _, space, _, _, _ in needs.values()}
    }
    dis_streams = {  # optional squeeze no change colorspace
        s: dis if dis.colorspace == d.colorspace else d for s, d in dis_streams.items()
    }

    # find missing shape and rate
    _fill_missing_shape_and_rate(needs, dis_streams, ref_streams)

    # retrieves all scenarios, sorted ensures repeatability to enable linear prediction
    config = sorted(
        {
            ("ref", space, shape, nbr, rate)
            for _, comparative, space, shape, nbr, rate in needs.values() if comparative
        } | {
            ("dis", space, shape, nbr, rate)
            for _, _, space, shape, nbr, rate in needs.values()
        },
    )

    # iterate until all frames are exhausted
    rate = optimal_rate_video(ref or dis) or Fraction(3000, 1001)
    for timestamp in itertools.count(1/(2*rate), 1/rate):
        # get all patches
        patches: dict[tuple, list[FrameVideo]] = {}
        for cat, space, shape, nbr, rate in config:
            try:
                patches[(cat, space, shape, nbr, rate)] = [
                    {"ref": ref_streams, "dis": dis_streams}[cat][space]
                    .snapshot(timestamp + i/rate, shape)
                    for i in range(nbr)
                ]
            except OutOfTimeRange:
                pass
        if not patches:  # if there is nothing left, it means we have reached the end
            break

        # combine patches
        for metric, (_, comparative, space, shape, nbr, rate) in needs.items():
            try:
                yield (
                    metric,
                    patches[("dis", space, shape, nbr, rate)],
                    *((patches[("ref", space, shape, nbr, rate)],) if comparative else ()),
                )
            except KeyError:
                pass


def _yield_frames_batch(
    dis: pathlib.Path, ref: pathlib.Path | None, needs: dict[str, list],
) -> tuple[str, Fraction, torch.Tensor, torch.Tensor] | tuple[str, Fraction, torch.Tensor]:
    """Gather frames in 128 Mo batches.

    Batches of shape (batch, nbr, height, width, channels).
    """
    def size(frames: list[tuple]) -> int:
        frame = frames[0][0][0]
        nbr = len(frames) * len(frames[0]) * len(frames[0][0])
        height, width, channels = frame.shape
        depth = torch.finfo(frame.dtype).bits // 8
        return nbr * height * width * channels * depth

    def cat(frames: list[tuple]) -> list[torch.Tensor]:
        return [
            torch.cat(
                [
                    patch.unsqueeze(0)
                    for patch in (
                        torch.cat([f.unsqueeze(0) for f in fs[i]], dim=0)
                        for fs in frames
                    )
                ],
                dim=0,
            )
            for i in range(len(frames[0]))
        ]

    batches: dict[str] = {}
    for metric, *dis_ref in _yield_frames(dis, ref, needs):
        batches[metric] = batches.get(metric, [])
        batches[metric].append(dis_ref)
        if size(batches[metric]) >= 128_000_000:
            yield (
                metric,
                batches[metric][0][0][0].time,
                *cat(batches[metric]),
            )
            del batches[metric]
    for metric, frames in batches.items():
        yield (
            metric,
            frames[0][0][0].time,
            *cat(frames),
        )


def _add_metric(
    metric: str, ref: pathlib.Path | None,
) -> list[callable, bool, str, tuple | None, int, Fraction]:
    """Help ``video_metrics``.

    Return [func, comparative, space, shape, nbr, rate]
    """
    match metric:
        case "lpips_alex":
            assert ref is not None, \
                "the lpips_alex comparative metric requires a reference video 'ref=...'"
            need = [_lpips_alex, True, "r'g'b'", None, 1, None]
        case "lpips_vgg":
            assert ref is not None, \
                "the lpips_vgg comparative metric requires a reference video 'ref=...'"
            need = [_lpips_vgg, True, "r'g'b'", None, 1, None]
        case "psnr":
            assert ref is not None, \
                "the psnr comparative metric requires a reference video 'ref=...'"
            need = [_psnr, True, "y'pbpr", None, 1, None]
        case "rms_sobel":
            from .complexity import rms_sobel
            need = [rms_sobel, False, "y'pbpr", None, 1, None]
        case "rms_time_diff":
            from .complexity import rms_time_diff
            need = [rms_time_diff, False, "y'pbpr", None, 2, None]
        case "spatial_dct":
            from .complexity import spatial_dct
            need = [spatial_dct, False, "y'pbpr", None, 1, None]
        case "ssim":
            assert ref is not None, \
                "the ssim comparative metric requires a reference video 'ref=...'"
            need = [_ssim, True, "y'pbpr", None, 1, None]
        case "temporal_dct":
            from .complexity import temporal_dct
            need = [temporal_dct, False, "y'pbpr", None, 2, None]
        case "uvq":
            from .quality import uvq
            need = [uvq, False, "r'g'b'", (720, 1080), 5, Fraction(5)]
        case "vif":
            assert ref is not None, \
                "the vmaf comparative metric requires a reference video 'ref=...'"
            from .quality import vif
            need = [vif, True, "y'pbpr", None, 1, None]
        case "vmaf":
            assert ref is not None, \
                "the vmaf comparative metric requires a reference video 'ref=...'"
            from .quality import vmaf
            need = [vmaf, True, "y'pbpr", None, 1, None]
        case _:
            logging.warning("the %s metric is unknown and ignored", metric)
            need = None  # default value
    return need


[docs] def video_metrics( dis: pathlib.Path | str | bytes, ref: pathlib.Path | str | bytes | None = None, **metrics, ) -> dict[str, list[float]]: """Simultaneously calculate multiple video metrics, comparative and no-reference. .. note:: The distorted video is "aligned" on the reference video. This means that the **rate**, **resolution** and **colorspace** are automatically managed in a consistent manner. Parameters ---------- dis : pathlike The distorted video file. ref : pathlike, optional The reference video file, used for comparative metrics only. lpips_alex : boolean, default=False Trigger the spatial comparative quality LPIPS metric with medium alex network. Call :py:func:`cutcutcodec.core.analysis.video.quality.lpips` on every frame. lpips_vgg : boolean, default=False Trigger the spatial comparative quality LPIPS metric with big vgg network. Call :py:func:`cutcutcodec.core.analysis.video.quality.lpips` on every frame. psnr : boolean, default=False Trigger the spatial comparative quality PSNR metric. Call :py:func:`cutcutcodec.core.analysis.video.quality.psnr` on every frame. rms_sobel : boolean, default=False Trigger the spatial root mean square sobel gradient complexity. Call :py:func:`cutcutcodec.core.analysis.video.complexity.rms_sobel` on every frame. rms_time_diff : boolean, default=False Trigger the temporal root mean square time difference complexity. Call :py:func:`cutcutcodec.core.analysis.video.complexity.rms_time_diff` on every frame. spatial_dct : boolean, default=False Call :py:func:`cutcutcodec.core.analysis.video.complexity.dct.spatial_dct` on every frame. ssim : boolean, default=False Trigger the spatial comparative quality SSIM metric. Call :py:func:`cutcutcodec.core.analysis.video.quality.ssim` on every frame. temporal_dct : boolean, default=False Call :py:func:`cutcutcodec.core.analysis.video.complexity.dct.temporal_dct` on every frame. uvq : boolean, default=False Trigger the spatial and temporal no-reference quality UVQ metric. Call :py:func:`cutcutcodec.core.analysis.video.quality.uvq`. vif : boolean, default=False Trigger the spatial comparative quality VIF metric. Call :py:func:`cutcutcodec.core.analysis.video.quality.vif` on every frame. vmaf : boolean, default=False Trigger the spatial comparative quality VMAF metric. Call :py:func:`cutcutcodec.core.analysis.video.quality.vmaf.vmaf` on every frame. Returns ------- metrics : dict[str, list[float]] Associate the corresponding scalar values with each metric. Examples -------- >>> import pprint >>> from cutcutcodec.core.analysis.video.metrics import video_metrics >>> from cutcutcodec.utils import get_project_root >>> video = get_project_root() / "media" / "video" / "intro.webm" >>> res = video_metrics(video, video, psnr=True, ssim=True, rms_sobel=True) >>> pprint.pprint(res) # doctest: +ELLIPSIS {'psnr': [100.0, 100.0, ..., 100.0, 100.0], 'rms_sobel': [0.0036468505859375, 0.0036468505859375, ..., 0.0033111572265625, 0.000560760498046875], 'ssim': [1.0, 1.0, ..., 1.0, 1.0]} >>> SeeAlso ------- * `MSU Video Quality Measurement Tool <https://pypi.org/project/msu-vqmt/>`_. * `SITI <https://github.com/Telecommunication-Telemedia-Assessment/SITI>`_. * `VCA Video Complexity Analyser <https://github.com/cd-athena/VCA>`_. * `VSHIP Fast Metric Computation <https://github.com/Line-fr/Vship>`_. """ dis = pathlib.Path(dis).expanduser() assert dis.exists(), f"the path {dis} doese not exists, it has to" if ref is not None: ref = pathlib.Path(ref).expanduser() assert ref.exists(), f"the path {ref} doese not exists, it has to" # predict video duration for progress bar duration = _FuncEvalThread(func=get_duration_video, arg=(dis,), daemon=True) # assess needs needs: dict[str, list] = {} # metric: (func, comparative, space, shape, nbr, rate) for metric, value in metrics.items(): assert isinstance(value, bool), f"the {metric} arg must be a boolean, not {value}" if not value: continue if (need := _add_metric(metric, ref)) is not None: needs[metric] = need # calculate metrics in parallel on each batch metrics: dict[str, list[float]] = {} with tqdm.tqdm( desc="Video metrics", dynamic_ncols=True, leave=False, smoothing=1e-6, unit="sec_video", ) as progress_bar: for metric, timestamp, values in starmap( lambda metric, timestamp, *batches: ( metric, round(float(timestamp), 2), needs[metric][0](*batches), ), _yield_frames_batch(dis, ref, needs), maxsize=os.cpu_count(), ): metrics[metric] = metrics.get(metric, []) metrics[metric].extend(map(mround, values.ravel().tolist())) # progress bar progress_bar.total = max( progress_bar.total or round(float(duration.get()), 2), timestamp, ) progress_bar.update(timestamp - progress_bar.n) progress_bar.update(progress_bar.total or round(float(duration.get()), 2) - progress_bar.n) return metrics