"""Gathers all video metrics."""
import itertools
import logging
import os
import pathlib
import threading
from fractions import Fraction
import torch
import tqdm
from cutcutcodec.config.config import Config
from cutcutcodec.core.analysis.stream.rate_video import optimal_rate_video
from cutcutcodec.core.analysis.stream.shape import optimal_shape_video
from cutcutcodec.core.analysis.video.properties import get_duration_video
from cutcutcodec.core.classes.colorspace import Colorspace
from cutcutcodec.core.classes.frame_video import FrameVideo
from cutcutcodec.core.classes.stream_video import StreamVideo
from cutcutcodec.core.exceptions import OutOfTimeRange
from cutcutcodec.core.filter.video.colorspace import FilterVideoColorspace
from cutcutcodec.core.io.read_ffmpeg import ContainerInputFFMPEG
from cutcutcodec.core.opti.parallel.buffer import _FuncEvalThread, starmap
from cutcutcodec.utils import mround
__all__ = ["video_metrics"]
TORCH_LOCK = threading.Lock()
def _fill_missing_shape_and_rate(needs: dict[str, list], dis: StreamVideo, ref: StreamVideo | None):
"""Help for _yield_frames, change needs inplace."""
for metric, (_, comparative, space, shape, _, rate) in needs.items():
if shape is None:
needs[metric][3] = (
optimal_shape_video((ref if comparative else dis)[space]) or (720, 1080)
)
if rate is None:
needs[metric][5] = (
optimal_rate_video((ref if comparative else dis)[space]) or Fraction(3000, 1001)
)
def _lpips_alex(dis: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
from .quality import lpips
with TORCH_LOCK:
return lpips(dis, ref, net="alex", threads=os.cpu_count())
def _lpips_vgg(dis: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
from .quality import lpips
with TORCH_LOCK:
return lpips(dis, ref, net="vgg", threads=os.cpu_count())
def _psnr(dis: torch.Tensor, ref: torch.Tensor ) -> torch.Tensor:
from .quality import psnr
# the factors comes from https://github.com/fraunhoferhhi/vvenc/wiki/Encoder-Performance
# and https://compression.ru/video/codec_comparison/2022/10_bit_report.html
return psnr(dis, ref, weights=(6, 1, 1), threads=1)
def _read_paths(func: callable):
"""Decorate _yield_frames."""
def _decorated_yield_frames(
dis: pathlib.Path, ref: pathlib.Path or None, needs: dict[str],
) -> tuple[str, FrameVideo, FrameVideo] | tuple[str, FrameVideo]:
if ref is None:
with ContainerInputFFMPEG(dis) as dis_cont:
dis_stream = dis_cont.out_select("video")[0]
yield from func(dis_stream, None, needs)
else:
with ContainerInputFFMPEG(dis) as dis_cont, ContainerInputFFMPEG(ref) as ref_cont:
dis_stream = dis_cont.out_select("video")[0]
ref_stream = ref_cont.out_select("video")[0]
yield from func(dis_stream, ref_stream, needs)
return _decorated_yield_frames
def _ssim(dis: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
from .quality import ssim
# the factors comes from https://github.com/fraunhoferhhi/vvenc/wiki/Encoder-Performance
# and https://compression.ru/video/codec_comparison/2022/10_bit_report.html
return ssim(dis, ref, weights=(6, 1, 1), threads=1, data_range=1.0)
def _uvq(dis: torch.Tensor) -> torch.Tensor:
from .quality import uvq
with TORCH_LOCK:
return uvq(dis, threads=max(2, os.cpu_count()//2))
@_read_paths
def _yield_frames(
dis: StreamVideo, ref: StreamVideo | None, needs: dict[str, list],
) -> tuple[str, FrameVideo, FrameVideo] | tuple[str, FrameVideo]:
"""Read the file to yield the well formated frames."""
# define the correct colour space for the reference
ref_streams = {
space: FilterVideoColorspace(
[ref],
Colorspace(
space,
ref.colorspace.primaries or Config().target_prim,
ref.colorspace.transfer or Config().target_trc,
),
).out_streams[0]
for space in {space for _, comparative, space, _, _, _ in needs.values() if comparative}
}
ref_streams = { # optional squeeze no change colorspace
s: ref if ref.colorspace == r.colorspace else r for s, r in ref_streams.items()
}
# define the correct color space for the distorded
dis_streams = {
space: FilterVideoColorspace(
[dis],
Colorspace(
space,
(
ref_streams[space].colorspace.primaries
if space in ref_streams else
(dis.colorspace.primaries or Config().target_prim)
),
(
ref_streams[space].colorspace.transfer
if space in ref_streams else
(dis.colorspace.transfer or Config().target_trc)
),
),
).out_streams[0]
for space in {space for _, _, space, _, _, _ in needs.values()}
}
dis_streams = { # optional squeeze no change colorspace
s: dis if dis.colorspace == d.colorspace else d for s, d in dis_streams.items()
}
# find missing shape and rate
_fill_missing_shape_and_rate(needs, dis_streams, ref_streams)
# retrieves all scenarios, sorted ensures repeatability to enable linear prediction
config = sorted(
{
("ref", space, shape, nbr, rate)
for _, comparative, space, shape, nbr, rate in needs.values() if comparative
} | {
("dis", space, shape, nbr, rate)
for _, _, space, shape, nbr, rate in needs.values()
},
)
# iterate until all frames are exhausted
rate = optimal_rate_video(ref or dis) or Fraction(3000, 1001)
for timestamp in itertools.count(1/(2*rate), 1/rate):
# get all patches
patches: dict[tuple, list[FrameVideo]] = {}
for cat, space, shape, nbr, rate in config:
try:
patches[(cat, space, shape, nbr, rate)] = [
{"ref": ref_streams, "dis": dis_streams}[cat][space]
.snapshot(timestamp + i/rate, shape)
for i in range(nbr)
]
except OutOfTimeRange:
pass
if not patches: # if there is nothing left, it means we have reached the end
break
# combine patches
for metric, (_, comparative, space, shape, nbr, rate) in needs.items():
try:
yield (
metric,
patches[("dis", space, shape, nbr, rate)],
*((patches[("ref", space, shape, nbr, rate)],) if comparative else ()),
)
except KeyError:
pass
def _yield_frames_batch(
dis: pathlib.Path, ref: pathlib.Path | None, needs: dict[str, list],
) -> tuple[str, Fraction, torch.Tensor, torch.Tensor] | tuple[str, Fraction, torch.Tensor]:
"""Gather frames in 128 Mo batches.
Batches of shape (batch, nbr, height, width, channels).
"""
def size(frames: list[tuple]) -> int:
frame = frames[0][0][0]
nbr = len(frames) * len(frames[0]) * len(frames[0][0])
height, width, channels = frame.shape
depth = torch.finfo(frame.dtype).bits // 8
return nbr * height * width * channels * depth
def cat(frames: list[tuple]) -> list[torch.Tensor]:
return [
torch.cat(
[
patch.unsqueeze(0)
for patch in (
torch.cat([f.unsqueeze(0) for f in fs[i]], dim=0)
for fs in frames
)
],
dim=0,
)
for i in range(len(frames[0]))
]
batches: dict[str] = {}
for metric, *dis_ref in _yield_frames(dis, ref, needs):
batches[metric] = batches.get(metric, [])
batches[metric].append(dis_ref)
if size(batches[metric]) >= 128_000_000:
yield (
metric,
batches[metric][0][0][0].time,
*cat(batches[metric]),
)
del batches[metric]
for metric, frames in batches.items():
yield (
metric,
frames[0][0][0].time,
*cat(frames),
)
def _add_metric(
metric: str, ref: pathlib.Path | None,
) -> list[callable, bool, str, tuple | None, int, Fraction]:
"""Help ``video_metrics``.
Return [func, comparative, space, shape, nbr, rate]
"""
match metric:
case "lpips_alex":
assert ref is not None, \
"the lpips_alex comparative metric requires a reference video 'ref=...'"
need = [_lpips_alex, True, "r'g'b'", None, 1, None]
case "lpips_vgg":
assert ref is not None, \
"the lpips_vgg comparative metric requires a reference video 'ref=...'"
need = [_lpips_vgg, True, "r'g'b'", None, 1, None]
case "psnr":
assert ref is not None, \
"the psnr comparative metric requires a reference video 'ref=...'"
need = [_psnr, True, "y'pbpr", None, 1, None]
case "rms_sobel":
from .complexity import rms_sobel
need = [rms_sobel, False, "y'pbpr", None, 1, None]
case "rms_time_diff":
from .complexity import rms_time_diff
need = [rms_time_diff, False, "y'pbpr", None, 2, None]
case "spatial_dct":
from .complexity import spatial_dct
need = [spatial_dct, False, "y'pbpr", None, 1, None]
case "ssim":
assert ref is not None, \
"the ssim comparative metric requires a reference video 'ref=...'"
need = [_ssim, True, "y'pbpr", None, 1, None]
case "temporal_dct":
from .complexity import temporal_dct
need = [temporal_dct, False, "y'pbpr", None, 2, None]
case "uvq":
from .quality import uvq
need = [uvq, False, "r'g'b'", (720, 1080), 5, Fraction(5)]
case "vif":
assert ref is not None, \
"the vmaf comparative metric requires a reference video 'ref=...'"
from .quality import vif
need = [vif, True, "y'pbpr", None, 1, None]
case "vmaf":
assert ref is not None, \
"the vmaf comparative metric requires a reference video 'ref=...'"
from .quality import vmaf
need = [vmaf, True, "y'pbpr", None, 1, None]
case _:
logging.warning("the %s metric is unknown and ignored", metric)
need = None # default value
return need
[docs]
def video_metrics(
dis: pathlib.Path | str | bytes,
ref: pathlib.Path | str | bytes | None = None,
**metrics,
) -> dict[str, list[float]]:
"""Simultaneously calculate multiple video metrics, comparative and no-reference.
.. note::
The distorted video is "aligned" on the reference video.
This means that the **rate**, **resolution** and **colorspace**
are automatically managed in a consistent manner.
Parameters
----------
dis : pathlike
The distorted video file.
ref : pathlike, optional
The reference video file, used for comparative metrics only.
lpips_alex : boolean, default=False
Trigger the spatial comparative quality LPIPS metric with medium alex network.
Call :py:func:`cutcutcodec.core.analysis.video.quality.lpips` on every frame.
lpips_vgg : boolean, default=False
Trigger the spatial comparative quality LPIPS metric with big vgg network.
Call :py:func:`cutcutcodec.core.analysis.video.quality.lpips` on every frame.
psnr : boolean, default=False
Trigger the spatial comparative quality PSNR metric.
Call :py:func:`cutcutcodec.core.analysis.video.quality.psnr` on every frame.
rms_sobel : boolean, default=False
Trigger the spatial root mean square sobel gradient complexity.
Call :py:func:`cutcutcodec.core.analysis.video.complexity.rms_sobel` on every frame.
rms_time_diff : boolean, default=False
Trigger the temporal root mean square time difference complexity.
Call :py:func:`cutcutcodec.core.analysis.video.complexity.rms_time_diff` on every frame.
spatial_dct : boolean, default=False
Call :py:func:`cutcutcodec.core.analysis.video.complexity.dct.spatial_dct` on every frame.
ssim : boolean, default=False
Trigger the spatial comparative quality SSIM metric.
Call :py:func:`cutcutcodec.core.analysis.video.quality.ssim` on every frame.
temporal_dct : boolean, default=False
Call :py:func:`cutcutcodec.core.analysis.video.complexity.dct.temporal_dct` on every frame.
uvq : boolean, default=False
Trigger the spatial and temporal no-reference quality UVQ metric.
Call :py:func:`cutcutcodec.core.analysis.video.quality.uvq`.
vif : boolean, default=False
Trigger the spatial comparative quality VIF metric.
Call :py:func:`cutcutcodec.core.analysis.video.quality.vif` on every frame.
vmaf : boolean, default=False
Trigger the spatial comparative quality VMAF metric.
Call :py:func:`cutcutcodec.core.analysis.video.quality.vmaf.vmaf` on every frame.
Returns
-------
metrics : dict[str, list[float]]
Associate the corresponding scalar values with each metric.
Examples
--------
>>> import pprint
>>> from cutcutcodec.core.analysis.video.metrics import video_metrics
>>> from cutcutcodec.utils import get_project_root
>>> video = get_project_root() / "media" / "video" / "intro.webm"
>>> res = video_metrics(video, video, psnr=True, ssim=True, rms_sobel=True)
>>> pprint.pprint(res) # doctest: +ELLIPSIS
{'psnr': [100.0,
100.0,
...,
100.0,
100.0],
'rms_sobel': [0.0036468505859375,
0.0036468505859375,
...,
0.0033111572265625,
0.000560760498046875],
'ssim': [1.0,
1.0,
...,
1.0,
1.0]}
>>>
SeeAlso
-------
* `MSU Video Quality Measurement Tool <https://pypi.org/project/msu-vqmt/>`_.
* `SITI <https://github.com/Telecommunication-Telemedia-Assessment/SITI>`_.
* `VCA Video Complexity Analyser <https://github.com/cd-athena/VCA>`_.
* `VSHIP Fast Metric Computation <https://github.com/Line-fr/Vship>`_.
"""
dis = pathlib.Path(dis).expanduser()
assert dis.exists(), f"the path {dis} doese not exists, it has to"
if ref is not None:
ref = pathlib.Path(ref).expanduser()
assert ref.exists(), f"the path {ref} doese not exists, it has to"
# predict video duration for progress bar
duration = _FuncEvalThread(func=get_duration_video, arg=(dis,), daemon=True)
# assess needs
needs: dict[str, list] = {} # metric: (func, comparative, space, shape, nbr, rate)
for metric, value in metrics.items():
assert isinstance(value, bool), f"the {metric} arg must be a boolean, not {value}"
if not value:
continue
if (need := _add_metric(metric, ref)) is not None:
needs[metric] = need
# calculate metrics in parallel on each batch
metrics: dict[str, list[float]] = {}
with tqdm.tqdm(
desc="Video metrics",
dynamic_ncols=True,
leave=False,
smoothing=1e-6,
unit="sec_video",
) as progress_bar:
for metric, timestamp, values in starmap(
lambda metric, timestamp, *batches: (
metric, round(float(timestamp), 2), needs[metric][0](*batches),
),
_yield_frames_batch(dis, ref, needs),
maxsize=os.cpu_count(),
):
metrics[metric] = metrics.get(metric, [])
metrics[metric].extend(map(mround, values.ravel().tolist()))
# progress bar
progress_bar.total = max(
progress_bar.total or round(float(duration.get()), 2), timestamp,
)
progress_bar.update(timestamp - progress_bar.n)
progress_bar.update(progress_bar.total or round(float(duration.get()), 2) - progress_bar.n)
return metrics