Source code for cutcutcodec.core.signal.psd

#!/usr/bin/env python3

"""Tools for the Power Spectral Density (PSD) estimation."""

import numbers
import typing
import math

import torch

from .window import alpha_to_band, att_to_alpha, band_to_alpha, dpss


def _compute_psd(
    signal: torch.Tensor, win: torch.Tensor, shift: int, return_std: bool
) -> torch.Tensor:
    """Help welch."""
    assert signal.ndim >= 1, signal.shape
    assert signal.shape[-1] >= len(win), "signal to short or window to tall"
    assert win.ndim == 1, win.shape
    signal = signal.contiguous()
    psd = signal.as_strided(  # shape (..., o, m)
        (
            *signal.shape[:-1],
            (signal.shape[-1] - len(win)) // shift + 1,  # number of slices
            len(win),
        ),
        (*signal.stride()[:-1], shift, 1),
    )
    psd = psd * win  # not inplace because blocs was not contiguous
    psd = torch.fft.rfft(psd, norm="ortho", dim=-1)  # norm ortho for perceval theorem
    psd = psd.real**2 + psd.imag**2
    inv_win_mass = len(win) / torch.sum(win)
    if return_std:
        std, psd = torch.std_mean(psd, dim=-2)  # shape (..., m)
        std *= inv_win_mass
        psd *= inv_win_mass
        return std, psd
    psd = torch.mean(psd, dim=-2)  # shape (..., m)
    psd *= inv_win_mass
    return psd


def _find_len(s_x: float, sigma_max: float, psd_max: float, freq_res: float) -> float:
    s_w_min, s_w_max = 3.0, 65536.0
    for _ in range(16):  # dichotomy, resol = 2**-n
        s_w = 0.5*(s_w_min + s_w_max)
        gamma = sigma_max * math.sqrt(s_w/s_x) / psd_max
        att = max(20.0, -20.0*math.log10(gamma))  # 20 is a minimal acceptable attenuation
        value = s_w*freq_res - 2.0 * alpha_to_band(att_to_alpha(att)) - 1.0
        if value <= 0:
            s_w_min = s_w
        else:
            s_w_max = s_w
    return s_w


[docs] def welch(signal: torch.Tensor, freq_res: typing.Optional[numbers.Real] = None) -> torch.Tensor: r"""Estimate the power spectral density (PSD) with the Welch method. Terminology ----------- * :math:`s_x` is the number of samples in the signal. * :math:`s_w` is the number of samples in the dpss window. * :math:`r` is the frequency resolution in Hz, for a sample rate of 1. * :math:`\sigma(f)` is an estimation of the std of the psd. * :math:`psd(f)` is the power spectral density. * :math:`n_{psd}(f)` is the psd noise, ignoring gibbs effects. * :math:`n_{win}` is the additive noise created by the gibbs effect of the window. * :math:`\gamma` is the maximum amplitude of the largest of the window's secondary lobs. * :math:`\alpha` is the window theorical standardized half bandwidth. * :math:`\beta` is the window experimental standardized half bandwidth. Equations --------- There, the equation to find the optimal windows size: .. math:: \begin{cases} n_{psd}(f) = \sigma(f) . \sqrt{\frac{s_w}{s_x}} & \text{std of mean estimator} \\ n_{win} = \max\left(psd(f)\right) . \gamma & \text{because convolution}\\ \gamma = g(\alpha) \\ \beta = \alpha + \epsilon = h(\alpha) \\ r = \frac{1}{s_w} + 2 . \frac{\beta}{s_w} & \text{convolution main lob} \\ \end{cases} To avoid having a too big gibbs noise, we want :math:`n_{win} < n_{psd}(f)`. .. math:: \Leftrightarrow \begin{cases} psd_{max} . \gamma < \sigma_{max} . \sqrt{\frac{s_w}{s_x}} \\ r . s_w = 1 + 2.h(g^{-1}(\gamma)) \\ \end{cases} \Leftrightarrow s_w . r - 2 . h\left( g^{-1}\left(\frac{\sigma_{max} . \sqrt{\frac{s_w}{s_x}}}{psd_{max}}\right) \right) - 1 > 0 Parameters ---------- signal : torch.Tensor The stationary signal on witch we evaluate the PSD. The tensor can be batched, so the shape is (..., n). freq_res : float The normlised frequency resolution in Hz (:math:`r \in \left]0, \frac{1}{2}\right[`), for a sample rate of 1. Higher it is, better is the frequency resolution but noiser it is. Returns ------- psd : torch.Tensor An estimation of the power spectral density, of shape (..., m). Examples -------- >>> import torch >>> from cutcutcodec.core.signal.psd import welch >>> signal = torch.randn((16, 2, 768000)) # 16 stereos signals >>> psd = welch(signal, 1/4800) >>> >>> # freq = torch.fft.rfftfreq(2*psd.shape[-1]-1, 1/48000) >>> # import matplotlib.pyplot as plt >>> # _ = plt.plot(freq, psd[0].T) >>> # plt.show() >>> """ assert isinstance(signal, torch.Tensor), signal.__class__.__name__ assert signal.shape[-1] >= 32, "the signal is to short, please provide more samples" # The first step consists in having a fast and inaccurate estimation of the psd. win_len = min(4096, signal.shape[-1]//3) win = torch.hann_window(win_len, periodic=False, dtype=signal.dtype, device=signal.device) std, psd = _compute_psd(signal, win, win_len//2, return_std=True) # Get a frequency resolution. if freq_res is None: freq_res = max(1.0 / win_len, 10.0 / 48000) else: assert isinstance(freq_res, numbers.Real), freq_res.__class__.__name__ assert 0 < freq_res < 0.5, freq_res freq_res = float(freq_res) # The second step constists in finding the best windows size. # According to the inequality specified in the docstring. win_len = math.ceil(_find_len(signal.shape[-1], float(std.max()), float(psd.max()), freq_res)) win_len = min(win_len, signal.shape[-1]) # The fird step consists in deducing alpha # according the frequency accuracy expression in the doctring. alpha = band_to_alpha(0.5 * (freq_res * win_len - 1)) # compute the psd win = dpss(win_len, alpha, dtype=signal.dtype) psd = _compute_psd(signal, win, win_len//4, return_std=False) # //4 is arbitrary overlapp return psd