Source code for cutcutcodec.core.analysis.video.quality.vif_torch

"""Compute a differential batched torch VIF (Visual Information Fidelity).

Sources
=======
The original paper is: https://ieeexplore.ieee.org/document/1576816
Codes:
    * https://github.com/photosynthesis-team/piq/blob/master/piq/vif.py
    * https://github.com/alvitrioliks/VMAF-torch/blob/master/vmaf_torch/vif.py
    * https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/functional/image/vif.py
    * https://github.com/andrewekhalel/sewar/blob/ac76e7bc75732fde40bb0d3908f4b6863400cc27/sewar/full_ref.py
"""

import torch

from cutcutcodec.core.signal.gauss import gauss2d


[docs] def vif_conv_torch(dis: torch.Tensor, ref: torch.Tensor) -> float: """Pure torch implementation of :py:func:`cutcutcodec.core.analysis.video.quality.vif`. It is based on a native torch convolution. Parameters ---------- dis, ref : torch.Tensor The distorded and reference images of shape (batch, height, width). Assumed to be only the luminosity in range [0, 1]. Returns ------- vif : torch.Tensor The VIF Index of similarity between two images. Usually in [0, 1], of shape (batch,). Can be bigger than 1 for predicted :math:`x` images with higher contrast than original one. Examples -------- >>> import torch >>> from cutcutcodec.core.analysis.video.quality.vif_torch import vif_conv_torch >>> _ = torch.manual_seed(0) >>> ref = torch.rand(4, 720, 1080) >>> dis = 0.8 * ref + 0.2 * torch.rand(4, 720, 1080) >>> vif_conv_torch(dis, ref) tensor([0.6440, 0.6444, 0.6441, 0.6436]) >>> >>> import torchmetrics >>> torchmetrics.functional.image.visual_information_fidelity( ... dis[:, None, :, :], ref[:, None, :, :], reduction="none", ... ) tensor([0.6440, 0.6444, 0.6441, 0.6437]) >>> """ assert isinstance(dis, torch.Tensor), dis.__class__.__name__ assert isinstance(ref, torch.Tensor), ref.__class__.__name__ assert dis.ndim == ref.ndim == 3, (dis.shape, ref.shape) assert dis.shape == ref.shape, (dis.shape, ref.shape) # constants EPS = 1e-10 # for numerical stability SIGMA_N_SQ = 2.0 # HVS model parameter (variance of the visual noise) # cast dimension dis, ref = dis[:, None, :, :] * 1.0, ref[:, None, :, :] * 1.0 # (n, 1, h, w) # dis, ref = dis - 127.5, ref - 127.5 # mathematicaly useless but better float accuracy # Progressively downsample images and compute VIF on different scales dis_vif, ref_vif = 0, 0 for scale in range(4): sigma = (2**(4-scale) + 1) / 5.0 gauss = gauss2d(sigma, ref.dtype).to(ref.device) # compute statistics for all patches gauss = gauss[None, None, :, :] # (1, 1, hk, wk) # convolve and downsample if scale > 0: dis = torch.nn.functional.conv2d(dis, gauss, stride=2) ref = torch.nn.functional.conv2d(ref, gauss, stride=2) # compute statistics for all patches stats = { "mud": torch.nn.functional.conv2d(dis, gauss), "mur": torch.nn.functional.conv2d(ref, gauss), } stats |= { "mudd": stats["mud"] * stats["mud"], "mudr": stats["mud"] * stats["mur"], "murr": stats["mur"] * stats["mur"], } del stats["mud"], stats["mur"] stats |= { "sdd": torch.nn.functional.conv2d(dis * dis, gauss) - stats["mudd"], "sdr": torch.nn.functional.conv2d(dis * ref, gauss) - stats["mudr"], "srr": torch.nn.functional.conv2d(ref * ref, gauss) - stats["murr"], } # compute vif stats["sdd"], stats["srr"] = torch.relu(stats["sdd"]), torch.relu(stats["srr"]) g = stats["sdr"] / (stats["srr"] + EPS) # (n, 1, h, w) sigma_v_sq = stats["sdd"] - g * stats["sdr"] mask = stats["srr"] < EPS g[mask] = 0.0 sigma_v_sq[mask] = stats["sdd"][mask] stats["srr"][mask] = 0.0 mask = stats["sdd"] < EPS g[mask] = 0.0 sigma_v_sq[mask] = stats["sdd"][mask] sigma_v_sq[g < 0] = stats["sdd"][g < 0] g = torch.relu(g) sigma_v_sq[sigma_v_sq < EPS] = EPS dis_vif_scale = torch.log10(1.0 + (g ** 2.) * stats["srr"] / (sigma_v_sq + SIGMA_N_SQ)) dis_vif = dis_vif + torch.sum(dis_vif_scale, dim=[1, 2, 3]) ref_vif = ref_vif + torch.sum(torch.log10(1.0 + stats["srr"] / SIGMA_N_SQ), dim=[1, 2, 3]) score: torch.Tensor = (dis_vif + EPS) / (ref_vif + EPS) return score