"""Gaussian tools."""
import torch
from cutcutcodec.core.opti.cache.basic import basic_cache
[docs]
@basic_cache
def gauss2d(sigma: float, dtype: torch.dtype) -> torch.Tensor:
"""Compute a gaussian window.
Examples
--------
>>> import torch
>>> from cutcutcodec.core.signal.gauss import gauss2d
>>> gauss2d(1.5, dtype=torch.float32).round(decimals=3)
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0010, 0.0020, 0.0020, 0.0020, 0.0010, 0.0000,
0.0000, 0.0000],
[0.0000, 0.0000, 0.0010, 0.0040, 0.0080, 0.0100, 0.0080, 0.0040, 0.0010,
0.0000, 0.0000],
[0.0000, 0.0010, 0.0040, 0.0120, 0.0230, 0.0290, 0.0230, 0.0120, 0.0040,
0.0010, 0.0000],
[0.0000, 0.0020, 0.0080, 0.0230, 0.0450, 0.0570, 0.0450, 0.0230, 0.0080,
0.0020, 0.0000],
[0.0000, 0.0020, 0.0100, 0.0290, 0.0570, 0.0710, 0.0570, 0.0290, 0.0100,
0.0020, 0.0000],
[0.0000, 0.0020, 0.0080, 0.0230, 0.0450, 0.0570, 0.0450, 0.0230, 0.0080,
0.0020, 0.0000],
[0.0000, 0.0010, 0.0040, 0.0120, 0.0230, 0.0290, 0.0230, 0.0120, 0.0040,
0.0010, 0.0000],
[0.0000, 0.0000, 0.0010, 0.0040, 0.0080, 0.0100, 0.0080, 0.0040, 0.0010,
0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0010, 0.0020, 0.0020, 0.0020, 0.0010, 0.0000,
0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]])
>>>
"""
assert isinstance(sigma, float), sigma.__class__.__name__
assert isinstance(dtype, torch.dtype), dtype.__class__.__name__
assert dtype in {torch.float64, torch.float32, torch.float16}, dtype
radius = int(3.5 * sigma + 0.5) # same as skimage.metrics.structural_similarity
gauss = torch.arange(-radius, radius+1, dtype=dtype)
gauss = torch.exp(-gauss**2 / (2.0 * sigma**2))
gauss_i, gauss_j = torch.meshgrid(gauss, gauss, indexing="ij")
gauss = gauss_i * gauss_j
gauss /= gauss.sum()
return gauss
[docs]
@basic_cache
def gauss2d_fft(sigma: float, height: int, width: int, dtype: torch.dtype) -> torch.Tensor:
"""Compute the fourier transform of a 2d gaussian window."""
gauss = gauss2d(sigma, dtype)
pad_height, pad_width = height - gauss.shape[0], width - gauss.shape[1]
gauss = torch.nn.functional.pad(
gauss,
(
pad_width//2, pad_width-pad_width//2,
pad_height//2, pad_height-pad_height//2,
),
value=0.0,
)
gauss_fft = torch.fft.rfft2(gauss, dim=(1, 0))[:, :, None]
return gauss_fft