Source code for cutcutcodec.core.analysis.video.quality.lpips_torch

"""Compute a differenciable batched torch lpips."""

import functools
import os
import threading

import torch

from cutcutcodec.core.opti.parallel.threading import TorchThreads

LOCK = threading.Lock()


@functools.cache
def _get_lpips_model(net: str):
    # pip install torchmetrics
    from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
    return LearnedPerceptualImagePatchSimilarity(net_type=net, reduction="none", normalize=True)


[docs] def lpips_torch( im1: torch.Tensor, im2: torch.Tensor, net: str | torch.nn.Module = "alex", threads: int = 0, ) -> torch.Tensor: """Pure torch implementation of :py:func:`cutcutcodec.core.analysis.video.quality.lpips`. Examples -------- >>> import torch >>> from cutcutcodec.core.analysis.video.quality.lpips_torch import lpips_torch >>> _ = torch.manual_seed(0) >>> im1 = torch.rand(2, 4, 720, 1080, 3) >>> im2 = 0.8 * im1 + 0.2 * torch.rand(2, 4, 720, 1080, 3) >>> lpips_torch(im1[0, 0], im1[0, 0]) tensor(0.) >>> lpips_torch(im1, im2) tensor([[0.0492, 0.0443, 0.0468, 0.0469], [0.0444, 0.0445, 0.0470, 0.0446]]) >>> """ assert isinstance(im1, torch.Tensor), im1.__class__.__name__ assert isinstance(im2, torch.Tensor), im2.__class__.__name__ assert im1.ndim == im2.ndim >= 3, (im1.shape, im2.shape) assert im1.shape == im2.shape, (im1.shape, im2.shape) assert im1.shape[-1] == 3, im1.shape assert isinstance(net, str | torch.nn.Module), net.__class__.__name__ if isinstance(net, str): model = _get_lpips_model(net) else: model = net *batch, height, width, _ = im1.shape im1, im2 = im1.reshape(-1, height, width, 3), im2.reshape(-1, height, width, 3) # (n, h, w, 3) im1, im2 = im1.movedim(3, 1), im2.movedim(3, 1) # (n, 3, h, w) if threads <= 0: threads = max(2, os.cpu_count()//2) # nested threads are created because they are locked with TorchThreads(threads), LOCK: # strict range in [0, 1] because normalize=True im1, im2 = im1.clamp(0.0, 1.0), im2.clamp(0.0, 1.0) if im1.requires_grad or im2.requires_grad: loss = model(im1, im2) else: with torch.no_grad(): # by default, a gradient is append loss = model(im1, im2) loss = loss.reshape(batch) return loss