Source code for cutcutcodec.core.analysis.video.quality.uvq_google.aggregationnet

"""A modified version of the Google UVQ source file.

As the original file is under apache lisence,
I should mention that this is a modified version of the source file:

https://github.com/google/uvq/blob/main/uvq_pytorch/utils/aggregationnet.py
"""

import numpy as np
import torch

from cutcutcodec.core.nn.start import load

NUM_CHANNELS_PER_SUBNET = 100
NUM_FILTERS = 256
CONV2D_KERNEL_SIZE = (1, 1)
MAXPOOL2D_KERNEL_SIZE = (16, 16)

BN_DEFAULT_EPS = 0.001
BN_DEFAULT_MOMENTUM = 1
DROPOUT_RATE = 0.2


[docs] class AggregationNet(torch.nn.Module): """Basic class to average all scores.""" def __init__(self, subnets: list[str]): super().__init__() self.subnets = subnets self.conv1 = torch.nn.Conv2d( len(subnets) * NUM_CHANNELS_PER_SUBNET, NUM_FILTERS, kernel_size=CONV2D_KERNEL_SIZE, bias=True, ) self.bn1 = torch.nn.BatchNorm2d( NUM_FILTERS, eps=BN_DEFAULT_EPS, momentum=BN_DEFAULT_MOMENTUM, ) self.relu1 = torch.nn.ReLU() self.maxpool1 = torch.nn.MaxPool2d(kernel_size=MAXPOOL2D_KERNEL_SIZE) self.dropout1 = torch.nn.Dropout(p=DROPOUT_RATE) self.linear1 = torch.nn.Linear(bias=True, in_features=NUM_FILTERS, out_features=1)
[docs] def forward(self, features: dict[str, torch.Tensor]): """Mix all features.""" x = ( features[self.subnets[0]] if len(self.subnets) == 1 else torch.cat([features[i] for i in self.subnets], dim=1) ) x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.maxpool1(x) x = self.dropout1(x) x = torch.nn.Flatten()(x) x = self.linear1(x) return x
[docs] class AggregationNetInference(torch.nn.Module): """Average all the scores.""" def __init__(self, eval_mode=True, **kwargs): super().__init__() self.models = torch.nn.ModuleList( [ AggregationNet(["compression", "content", "distortion"]), AggregationNet(["compression", "content", "distortion"]), AggregationNet(["compression", "content", "distortion"]), AggregationNet(["compression", "content", "distortion"]), AggregationNet(["compression", "content", "distortion"]), ], ) if eval_mode: self.eval() load(self, kwargs.get("weights")) # f65da816bf4ca0bb91a24e3ba62fb02b
[docs] def forward( self, compression_features: np.ndarray, content_features: np.ndarray, distortion_features: np.ndarray, ) -> float: """Compute the final score.""" feature_results = [] with torch.no_grad(): for model in self.models: res = model( { "compression": torch.Tensor( compression_features.transpose(0, 3, 1, 2), ), "content": torch.Tensor(content_features.transpose(0, 3, 1, 2)), "distortion": torch.Tensor( distortion_features.transpose(0, 3, 1, 2), ), }, ) feature_results.append(res) return torch.cat(feature_results, dim=1).mean(dim=1)
[docs] def predict(self, *args, **kwargs) -> float: """Compute the final score.""" return self.forward(*args, **kwargs)