Source code for cutcutcodec.core.analysis.video.quality.ssim_torch

"""Compute a differenciable batched torch ssim."""

import numbers
import typing

import torch

from cutcutcodec.core.opti.cache.basic import basic_cache
from cutcutcodec.core.opti.parallel.threading import TorchThreads
from cutcutcodec.core.signal.gauss import gauss2d as _gauss, gauss2d_fft as _gauss_fft


[docs] def ssim_conv_torch( im1: torch.Tensor, im2: torch.Tensor, data_range: numbers.Real = 1.0, weights: typing.Iterable[float] = None, sigma: numbers.Real = 1.5, **kwargs, ) -> float: """Pure torch implementation of :py:func:`cutcutcodec.core.analysis.video.quality.ssim`. It is based on a native torch convolution. Examples -------- >>> import torch >>> from cutcutcodec.core.analysis.video.quality.ssim_torch import ssim_conv_torch >>> _ = torch.manual_seed(0) >>> im1 = torch.rand(2, 4, 720, 1080, 3) >>> im2 = 0.8 * im1 + 0.2 * torch.rand(2, 4, 720, 1080, 3) >>> ssim_conv_torch(im1[0, 0], im1[0, 0]) tensor(1.) >>> ssim_conv_torch(im1, im2) tensor([[0.9511, 0.9512, 0.9511, 0.9511], [0.9512, 0.9512, 0.9511, 0.9512]]) >>> """ assert isinstance(im1, torch.Tensor), im1.__class__.__name__ assert isinstance(im2, torch.Tensor), im2.__class__.__name__ assert im1.ndim == im2.ndim >= 3, (im1.shape, im2.shape) assert im1.shape == im2.shape, (im1.shape, im2.shape) assert isinstance(data_range, numbers.Real), data_range.__class__.__name__ data_range = float(data_range) assert data_range > 0, data_range assert isinstance(sigma, numbers.Real), sigma.__class__.__name__ sigma = float(sigma) assert sigma > 0, sigma radius = int(3.5 * sigma + 0.5) # same as skimage.metrics.structural_similarity assert 2*radius + 1 <= im1.shape[-3] and 2*radius + 1 <= im1.shape[-2], \ "sigma is to big for the image size" # cast and normalise weights if weights is None: weights = [1.0 for _ in range(im1.shape[-1])] else: weights = [float(w) for w in weights] assert len(weights) == im1.shape[-1], (len(weights), im1.shape) weights = torch.asarray(weights, dtype=im1.dtype, device=im1.device) weights /= weights.sum() with TorchThreads(kwargs.get("threads", 0)): # convolution kernel gauss = _gauss(sigma, im1.dtype).to(im1.device) # compute statistics for all patches gauss = gauss[None, None, :, :] # (1, 1, hk, wk) shape = im1.shape # (..., h, w, n) im1 = im1.movedim(-1, -3).reshape((-1, 1, shape[-3], shape[-2])) # (... * n, 1, h, w) im2 = im2.movedim(-1, -3).reshape((-1, 1, shape[-3], shape[-2])) # (... * n, 1, h, w) stride = kwargs.get("stride", 1) stats = { "mu1": torch.nn.functional.conv2d(im1, gauss, stride=stride), "mu2": torch.nn.functional.conv2d(im2, gauss, stride=stride), } stats |= { "mu11": stats["mu1"] * stats["mu1"], "mu22": stats["mu2"] * stats["mu2"], "mu12": stats["mu1"] * stats["mu2"], } del stats["mu1"], stats["mu2"] stats |= { "s11": torch.nn.functional.conv2d(im1 * im1, gauss, stride=stride) - stats["mu11"], "s22": torch.nn.functional.conv2d(im2 * im2, gauss, stride=stride) - stats["mu22"], "s12": torch.nn.functional.conv2d(im1 * im2, gauss, stride=stride) - stats["mu12"], } # ssim formula cst = [(0.01 * data_range)**2, (0.03 * data_range)**2] ssim = ( (2.0*stats["mu12"] + cst[0]) * (2.0*stats["s12"] + cst[1]) ) / ( (stats["mu11"] + stats["mu22"] + cst[0]) * (stats["s11"] + stats["s22"] + cst[1]) ) ssim = ssim.mean(dim=(-1, -2)) # mean of each layers # average ssim = ssim.reshape((*shape[:-3], shape[-1])) # (..., n) ssim = (ssim * weights).sum(dim=-1) return ssim
[docs] def ssim_fft_torch( im1: torch.Tensor, im2: torch.Tensor, data_range: numbers.Real = 1.0, weights: typing.Iterable[float] = None, sigma: numbers.Real = 1.5, **kwargs, ) -> float: """Pure torch implementation of :py:func:`cutcutcodec.core.analysis.video.quality.ssim`. It is based on fast fft convolution. Examples -------- >>> import torch >>> from cutcutcodec.core.analysis.video.quality.ssim_torch import ssim_fft_torch >>> _ = torch.manual_seed(0) >>> im1 = torch.rand(2, 4, 720, 1080, 3) >>> im2 = 0.8 * im1 + 0.2 * torch.rand(2, 4, 720, 1080, 3) >>> ssim_fft_torch(im1[0, 0], im1[0, 0]) tensor(1.) >>> ssim_fft_torch(im1, im2) tensor([[0.9511, 0.9512, 0.9511, 0.9511], [0.9512, 0.9512, 0.9511, 0.9512]]) >>> """ assert isinstance(im1, torch.Tensor), im1.__class__.__name__ assert isinstance(im2, torch.Tensor), im2.__class__.__name__ assert im1.ndim == im2.ndim >= 3, (im1.shape, im2.shape) assert im1.shape == im2.shape, (im1.shape, im2.shape) assert isinstance(data_range, numbers.Real), data_range.__class__.__name__ data_range = float(data_range) assert data_range > 0, data_range assert isinstance(sigma, numbers.Real), sigma.__class__.__name__ sigma = float(sigma) assert sigma > 0, sigma radius = int(3.5 * sigma + 0.5) # same as skimage.metrics.structural_similarity assert 2*radius + 1 <= im1.shape[-3] and 2*radius + 1 <= im1.shape[-2], \ "sigma is to big for the image size" # cast and normalise weights if weights is None: weights = [1.0 for _ in range(im1.shape[-1])] else: weights = [float(w) for w in weights] assert len(weights) == im1.shape[-1], (len(weights), im1.shape) weights = torch.asarray(weights, dtype=im1.dtype, device=im1.device) weights /= weights.sum() # get gaussian kernel g_fft = _gauss_fft(sigma, im1.shape[-3], im1.shape[-2], im1.dtype).to(im1.device) with TorchThreads(kwargs.get("threads", 0)): # statistic convolutions and crop patches stats = { "mu1": torch.fft.irfft2( g_fft * torch.fft.rfft2(im1, dim=(-2, -3)), dim=(-2, -3), )[..., radius:-radius, radius:-radius, :], # crop patches "mu2": torch.fft.irfft2( g_fft * torch.fft.rfft2(im2, dim=(-2, -3)), dim=(-2, -3), )[..., radius:-radius, radius:-radius, :], # crop patches } stats |= { "mu11": stats["mu1"] * stats["mu1"], "mu22": stats["mu2"] * stats["mu2"], "mu12": stats["mu1"] * stats["mu2"], } del stats["mu1"], stats["mu2"] stats |= { "s11": torch.fft.irfft2( g_fft * torch.fft.rfft2(im1*im1, dim=(-2, -3)), dim=(-2, -3), )[..., radius:-radius, radius:-radius, :] - stats["mu11"], "s22": torch.fft.irfft2( g_fft * torch.fft.rfft2(im2*im2, dim=(-2, -3)), dim=(-2, -3), )[..., radius:-radius, radius:-radius, :] - stats["mu22"], "s12": torch.fft.irfft2( g_fft * torch.fft.rfft2(im1*im2, dim=(-2, -3)), dim=(-2, -3), )[..., radius:-radius, radius:-radius, :] - stats["mu12"], } # ssim formula cst = [(0.01 * data_range)**2, (0.03 * data_range)**2] ssim = ( (2.0*stats["mu12"] + cst[0]) * (2.0*stats["s12"] + cst[1]) ) / ( (stats["mu11"] + stats["mu22"] + cst[0]) * (stats["s11"] + stats["s22"] + cst[1]) ) ssim = ssim.mean(dim=(-2, -3)) # mean of each layers # average ssim = (ssim * weights).sum(dim=-1) return ssim