"""Compute a differential batched torch VIF (Visual Information Fidelity).
Sources
=======
The original paper is: https://ieeexplore.ieee.org/document/1576816
Codes:
* https://github.com/photosynthesis-team/piq/blob/master/piq/vif.py
* https://github.com/alvitrioliks/VMAF-torch/blob/master/vmaf_torch/vif.py
* https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/functional/image/vif.py
* https://github.com/andrewekhalel/sewar/blob/ac76e7bc75732fde40bb0d3908f4b6863400cc27/sewar/full_ref.py
"""
import torch
from cutcutcodec.core.signal.gauss import gauss2d
[docs]
def vif_conv_torch(dis: torch.Tensor, ref: torch.Tensor) -> float:
"""Pure torch implementation of :py:func:`cutcutcodec.core.analysis.video.quality.vif`.
It is based on a native torch convolution.
Parameters
----------
dis, ref : torch.Tensor
The distorded and reference images of shape (batch, height, width).
Assumed to be only the luminosity in range [0, 1].
Returns
-------
vif : torch.Tensor
The VIF Index of similarity between two images. Usually in [0, 1], of shape (batch,).
Can be bigger than 1 for predicted :math:`x` images with higher contrast than original one.
Examples
--------
>>> import torch
>>> from cutcutcodec.core.analysis.video.quality.vif_torch import vif_conv_torch
>>> _ = torch.manual_seed(0)
>>> ref = torch.rand(4, 720, 1080)
>>> dis = 0.8 * ref + 0.2 * torch.rand(4, 720, 1080)
>>> vif_conv_torch(dis, ref)
tensor([0.6440, 0.6444, 0.6441, 0.6436])
>>>
>>> import torchmetrics
>>> torchmetrics.functional.image.visual_information_fidelity(
... dis[:, None, :, :], ref[:, None, :, :], reduction="none",
... )
tensor([0.6440, 0.6444, 0.6441, 0.6437])
>>>
"""
assert isinstance(dis, torch.Tensor), dis.__class__.__name__
assert isinstance(ref, torch.Tensor), ref.__class__.__name__
assert dis.ndim == ref.ndim == 3, (dis.shape, ref.shape)
assert dis.shape == ref.shape, (dis.shape, ref.shape)
# constants
EPS = 1e-10 # for numerical stability
SIGMA_N_SQ = 2.0 # HVS model parameter (variance of the visual noise)
# cast dimension
dis, ref = dis[:, None, :, :] * 1.0, ref[:, None, :, :] * 1.0 # (n, 1, h, w)
# dis, ref = dis - 127.5, ref - 127.5 # mathematicaly useless but better float accuracy
# Progressively downsample images and compute VIF on different scales
dis_vif, ref_vif = 0, 0
for scale in range(4):
sigma = (2**(4-scale) + 1) / 5.0
gauss = gauss2d(sigma, ref.dtype).to(ref.device)
# compute statistics for all patches
gauss = gauss[None, None, :, :] # (1, 1, hk, wk)
# convolve and downsample
if scale > 0:
dis = torch.nn.functional.conv2d(dis, gauss, stride=2)
ref = torch.nn.functional.conv2d(ref, gauss, stride=2)
# compute statistics for all patches
stats = {
"mud": torch.nn.functional.conv2d(dis, gauss),
"mur": torch.nn.functional.conv2d(ref, gauss),
}
stats |= {
"mudd": stats["mud"] * stats["mud"],
"mudr": stats["mud"] * stats["mur"],
"murr": stats["mur"] * stats["mur"],
}
del stats["mud"], stats["mur"]
stats |= {
"sdd": torch.nn.functional.conv2d(dis * dis, gauss) - stats["mudd"],
"sdr": torch.nn.functional.conv2d(dis * ref, gauss) - stats["mudr"],
"srr": torch.nn.functional.conv2d(ref * ref, gauss) - stats["murr"],
}
# compute vif
stats["sdd"], stats["srr"] = torch.relu(stats["sdd"]), torch.relu(stats["srr"])
g = stats["sdr"] / (stats["srr"] + EPS) # (n, 1, h, w)
sigma_v_sq = stats["sdd"] - g * stats["sdr"]
mask = stats["srr"] < EPS
g[mask] = 0.0
sigma_v_sq[mask] = stats["sdd"][mask]
stats["srr"][mask] = 0.0
mask = stats["sdd"] < EPS
g[mask] = 0.0
sigma_v_sq[mask] = stats["sdd"][mask]
sigma_v_sq[g < 0] = stats["sdd"][g < 0]
g = torch.relu(g)
sigma_v_sq[sigma_v_sq < EPS] = EPS
dis_vif_scale = torch.log10(1.0 + (g ** 2.) * stats["srr"] / (sigma_v_sq + SIGMA_N_SQ))
dis_vif = dis_vif + torch.sum(dis_vif_scale, dim=[1, 2, 3])
ref_vif = ref_vif + torch.sum(torch.log10(1.0 + stats["srr"] / SIGMA_N_SQ), dim=[1, 2, 3])
score: torch.Tensor = (dis_vif + EPS) / (ref_vif + EPS)
return score