Source code for cutcutcodec.core.signal.gauss

"""Gaussian tools."""

import torch

from cutcutcodec.core.opti.cache.basic import basic_cache


[docs] @basic_cache def gauss2d(sigma: float, dtype: torch.dtype) -> torch.Tensor: """Compute a gaussian window. Examples -------- >>> import torch >>> from cutcutcodec.core.signal.gauss import gauss2d >>> gauss2d(1.5, dtype=torch.float32).round(decimals=3) tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0010, 0.0020, 0.0020, 0.0020, 0.0010, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0010, 0.0040, 0.0080, 0.0100, 0.0080, 0.0040, 0.0010, 0.0000, 0.0000], [0.0000, 0.0010, 0.0040, 0.0120, 0.0230, 0.0290, 0.0230, 0.0120, 0.0040, 0.0010, 0.0000], [0.0000, 0.0020, 0.0080, 0.0230, 0.0450, 0.0570, 0.0450, 0.0230, 0.0080, 0.0020, 0.0000], [0.0000, 0.0020, 0.0100, 0.0290, 0.0570, 0.0710, 0.0570, 0.0290, 0.0100, 0.0020, 0.0000], [0.0000, 0.0020, 0.0080, 0.0230, 0.0450, 0.0570, 0.0450, 0.0230, 0.0080, 0.0020, 0.0000], [0.0000, 0.0010, 0.0040, 0.0120, 0.0230, 0.0290, 0.0230, 0.0120, 0.0040, 0.0010, 0.0000], [0.0000, 0.0000, 0.0010, 0.0040, 0.0080, 0.0100, 0.0080, 0.0040, 0.0010, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0010, 0.0020, 0.0020, 0.0020, 0.0010, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]) >>> """ assert isinstance(sigma, float), sigma.__class__.__name__ assert isinstance(dtype, torch.dtype), dtype.__class__.__name__ assert dtype in {torch.float64, torch.float32, torch.float16}, dtype radius = int(3.5 * sigma + 0.5) # same as skimage.metrics.structural_similarity gauss = torch.arange(-radius, radius+1, dtype=dtype) gauss = torch.exp(-gauss**2 / (2.0 * sigma**2)) gauss_i, gauss_j = torch.meshgrid(gauss, gauss, indexing="ij") gauss = gauss_i * gauss_j gauss /= gauss.sum() return gauss
[docs] @basic_cache def gauss2d_fft(sigma: float, height: int, width: int, dtype: torch.dtype) -> torch.Tensor: """Compute the fourier transform of a 2d gaussian window.""" gauss = gauss2d(sigma, dtype) pad_height, pad_width = height - gauss.shape[0], width - gauss.shape[1] gauss = torch.nn.functional.pad( gauss, ( pad_width//2, pad_width-pad_width//2, pad_height//2, pad_height-pad_height//2, ), value=0.0, ) gauss_fft = torch.fft.rfft2(gauss, dim=(1, 0))[:, :, None] return gauss_fft