"""Compute a differenciable batched torch ssim."""
import numbers
import typing
import torch
from cutcutcodec.core.opti.parallel.threading import TorchThreads
from cutcutcodec.core.signal.gauss import gauss2d as _gauss
from cutcutcodec.core.signal.gauss import gauss2d_fft as _gauss_fft
[docs]
def ssim_conv_torch(
im1: torch.Tensor,
im2: torch.Tensor,
data_range: numbers.Real = 1.0,
weights: typing.Iterable[float] = None,
sigma: numbers.Real = 1.5,
**kwargs,
) -> float:
"""Pure torch implementation of :py:func:`cutcutcodec.core.analysis.video.quality.ssim`.
It is based on a native torch convolution.
Examples
--------
>>> import torch
>>> from cutcutcodec.core.analysis.video.quality.ssim_torch import ssim_conv_torch
>>> _ = torch.manual_seed(0)
>>> im1 = torch.rand(2, 4, 720, 1080, 3)
>>> im2 = 0.8 * im1 + 0.2 * torch.rand(2, 4, 720, 1080, 3)
>>> ssim_conv_torch(im1[0, 0], im1[0, 0])
tensor(1.)
>>> ssim_conv_torch(im1, im2)
tensor([[0.9511, 0.9512, 0.9511, 0.9511],
[0.9512, 0.9512, 0.9511, 0.9512]])
>>>
"""
assert isinstance(im1, torch.Tensor), im1.__class__.__name__
assert isinstance(im2, torch.Tensor), im2.__class__.__name__
assert im1.ndim == im2.ndim >= 3, (im1.shape, im2.shape)
assert im1.shape == im2.shape, (im1.shape, im2.shape)
assert isinstance(data_range, numbers.Real), data_range.__class__.__name__
data_range = float(data_range)
assert data_range > 0, data_range
assert isinstance(sigma, numbers.Real), sigma.__class__.__name__
sigma = float(sigma)
assert sigma > 0, sigma
radius = int(3.5 * sigma + 0.5) # same as skimage.metrics.structural_similarity
assert 2*radius + 1 <= im1.shape[-3] and 2*radius + 1 <= im1.shape[-2], \
"sigma is to big for the image size"
# cast and normalise weights
if weights is None:
weights = [1.0 for _ in range(im1.shape[-1])]
else:
weights = [float(w) for w in weights]
assert len(weights) == im1.shape[-1], (len(weights), im1.shape)
weights = torch.asarray(weights, dtype=im1.dtype, device=im1.device)
weights /= weights.sum()
with TorchThreads(kwargs.get("threads", 0)):
# convolution kernel
gauss = _gauss(sigma, im1.dtype).to(im1.device)
# compute statistics for all patches
gauss = gauss[None, None, :, :] # (1, 1, hk, wk)
shape = im1.shape # (..., h, w, n)
im1 = im1.movedim(-1, -3).reshape((-1, 1, shape[-3], shape[-2])) # (... * n, 1, h, w)
im2 = im2.movedim(-1, -3).reshape((-1, 1, shape[-3], shape[-2])) # (... * n, 1, h, w)
stride = kwargs.get("stride", 1)
stats = {
"mu1": torch.nn.functional.conv2d(im1, gauss, stride=stride),
"mu2": torch.nn.functional.conv2d(im2, gauss, stride=stride),
}
stats |= {
"mu11": stats["mu1"] * stats["mu1"],
"mu22": stats["mu2"] * stats["mu2"],
"mu12": stats["mu1"] * stats["mu2"],
}
del stats["mu1"], stats["mu2"]
stats |= {
"s11": torch.nn.functional.conv2d(im1 * im1, gauss, stride=stride) - stats["mu11"],
"s22": torch.nn.functional.conv2d(im2 * im2, gauss, stride=stride) - stats["mu22"],
"s12": torch.nn.functional.conv2d(im1 * im2, gauss, stride=stride) - stats["mu12"],
}
# ssim formula
cst = [(0.01 * data_range)**2, (0.03 * data_range)**2]
ssim = (
(2.0*stats["mu12"] + cst[0]) * (2.0*stats["s12"] + cst[1])
) / (
(stats["mu11"] + stats["mu22"] + cst[0]) * (stats["s11"] + stats["s22"] + cst[1])
)
ssim = ssim.mean(dim=(-1, -2)) # mean of each layers
# average
ssim = ssim.reshape((*shape[:-3], shape[-1])) # (..., n)
ssim = (ssim * weights).sum(dim=-1)
return ssim
[docs]
def ssim_fft_torch(
im1: torch.Tensor,
im2: torch.Tensor,
data_range: numbers.Real = 1.0,
weights: typing.Iterable[float] = None,
sigma: numbers.Real = 1.5,
**kwargs,
) -> float:
"""Pure torch implementation of :py:func:`cutcutcodec.core.analysis.video.quality.ssim`.
It is based on fast fft convolution.
Examples
--------
>>> import torch
>>> from cutcutcodec.core.analysis.video.quality.ssim_torch import ssim_fft_torch
>>> _ = torch.manual_seed(0)
>>> im1 = torch.rand(2, 4, 720, 1080, 3)
>>> im2 = 0.8 * im1 + 0.2 * torch.rand(2, 4, 720, 1080, 3)
>>> ssim_fft_torch(im1[0, 0], im1[0, 0])
tensor(1.)
>>> ssim_fft_torch(im1, im2)
tensor([[0.9511, 0.9512, 0.9511, 0.9511],
[0.9512, 0.9512, 0.9511, 0.9512]])
>>>
"""
assert isinstance(im1, torch.Tensor), im1.__class__.__name__
assert isinstance(im2, torch.Tensor), im2.__class__.__name__
assert im1.ndim == im2.ndim >= 3, (im1.shape, im2.shape)
assert im1.shape == im2.shape, (im1.shape, im2.shape)
assert isinstance(data_range, numbers.Real), data_range.__class__.__name__
data_range = float(data_range)
assert data_range > 0, data_range
assert isinstance(sigma, numbers.Real), sigma.__class__.__name__
sigma = float(sigma)
assert sigma > 0, sigma
radius = int(3.5 * sigma + 0.5) # same as skimage.metrics.structural_similarity
assert 2*radius + 1 <= im1.shape[-3] and 2*radius + 1 <= im1.shape[-2], \
"sigma is to big for the image size"
# cast and normalise weights
if weights is None:
weights = [1.0 for _ in range(im1.shape[-1])]
else:
weights = [float(w) for w in weights]
assert len(weights) == im1.shape[-1], (len(weights), im1.shape)
weights = torch.asarray(weights, dtype=im1.dtype, device=im1.device)
weights /= weights.sum()
# get gaussian kernel
g_fft = _gauss_fft(sigma, im1.shape[-3], im1.shape[-2], im1.dtype).to(im1.device)
with TorchThreads(kwargs.get("threads", 0)):
# statistic convolutions and crop patches
stats = {
"mu1": torch.fft.irfft2(
g_fft * torch.fft.rfft2(im1, dim=(-2, -3)), dim=(-2, -3),
)[..., radius:-radius, radius:-radius, :], # crop patches
"mu2": torch.fft.irfft2(
g_fft * torch.fft.rfft2(im2, dim=(-2, -3)), dim=(-2, -3),
)[..., radius:-radius, radius:-radius, :], # crop patches
}
stats |= {
"mu11": stats["mu1"] * stats["mu1"],
"mu22": stats["mu2"] * stats["mu2"],
"mu12": stats["mu1"] * stats["mu2"],
}
del stats["mu1"], stats["mu2"]
stats |= {
"s11": torch.fft.irfft2(
g_fft * torch.fft.rfft2(im1*im1, dim=(-2, -3)), dim=(-2, -3),
)[..., radius:-radius, radius:-radius, :] - stats["mu11"],
"s22": torch.fft.irfft2(
g_fft * torch.fft.rfft2(im2*im2, dim=(-2, -3)), dim=(-2, -3),
)[..., radius:-radius, radius:-radius, :] - stats["mu22"],
"s12": torch.fft.irfft2(
g_fft * torch.fft.rfft2(im1*im2, dim=(-2, -3)), dim=(-2, -3),
)[..., radius:-radius, radius:-radius, :] - stats["mu12"],
}
# ssim formula
cst = [(0.01 * data_range)**2, (0.03 * data_range)**2]
ssim = (
(2.0*stats["mu12"] + cst[0]) * (2.0*stats["s12"] + cst[1])
) / (
(stats["mu11"] + stats["mu22"] + cst[0]) * (stats["s11"] + stats["s22"] + cst[1])
)
ssim = ssim.mean(dim=(-2, -3)) # mean of each layers
# average
ssim = (ssim * weights).sum(dim=-1)
return ssim