Source code for cutcutcodec.core.signal.psd

#!/usr/bin/env python3

"""Tools for the Power Spectral Density (PSD) estimation."""

import numbers
import typing
import math

import torch

from .window import alpha_to_band, att_to_alpha, band_to_alpha, dpss


def _compute_psd(
    signal_1: torch.Tensor, signal_2: torch.Tensor, win: torch.Tensor, shift: int, return_std: bool
) -> torch.Tensor:
    """Help welch."""
    assert win.ndim == 1, win.shape
    assert signal_1.shape[-1] >= len(win), "signal to short or window to tall"
    is_autocorr = signal_1 is signal_2  # optimisation to avoid redondant operations
    signal_1 = signal_1.contiguous()
    signal_1 = signal_1.as_strided(  # shape (..., o, m), big ram usage!
        (
            *signal_1.shape[:-1],
            (signal_1.shape[-1] - len(win)) // shift + 1,  # number of slices
            len(win),
        ),
        (*signal_1.stride()[:-1], shift, 1),
    )
    signal_1 = signal_1 * win  # not inplace because blocs was not contiguous
    signal_1 = torch.fft.rfft(signal_1, norm="ortho", dim=-1)  # norm ortho for perceval theorem
    if not is_autocorr:
        signal_2 = signal_2.contiguous()
        signal_2 = signal_2.as_strided(  # shape (..., o, m), big ram usage!
            (
                *signal_2.shape[:-1],
                (signal_2.shape[-1] - len(win)) // shift + 1,  # number of slices
                len(win),
            ),
            (*signal_2.stride()[:-1], shift, 1),
        )
        signal_2 = signal_2 * win
        signal_2 = torch.fft.rfft(signal_2, norm="ortho", dim=-1)
        psd = signal_1 * signal_2.conj()
    else:
        psd = signal_1.real*signal_1.real + signal_1.imag*signal_1.imag
    win_power = (win**2).mean()
    if return_std:
        std, psd = torch.std_mean(psd, dim=-2)  # shape (..., m)
        std /= win_power
        psd /= win_power
        return abs(std), psd
    psd = torch.mean(psd, dim=-2)  # shape (..., m)
    psd /= win_power
    return psd


def _find_len(s_x: float, sigma_max: float, psd_max: float, freq_res: float) -> float:
    s_w_min, s_w_max = 3.0, 65536.0
    for _ in range(16):  # dichotomy, resol = 2**-n
        s_w = 0.5*(s_w_min + s_w_max)
        eta = sigma_max * math.sqrt(s_w/s_x) / psd_max
        att = max(20.0, -20.0*math.log10(eta))  # 20 is a minimal acceptable attenuation
        value = s_w*freq_res - 2.0 * alpha_to_band(att_to_alpha(att)) - 1.0
        if value <= 0:
            s_w_min = s_w
        else:
            s_w_max = s_w
    return s_w


[docs] def welch( signal_1: torch.Tensor, signal_2: typing.Optional[torch.Tensor] = None, freq_res: typing.Optional[numbers.Real] = None, ) -> torch.Tensor: r"""Estimate the power spectral density (PSD) ie intercorrelation with the Welch method. Terminology ----------- * :math:`x_1` and :math:`x_2` are the signal whose we want to estimate the intercorrelation. * :math:`s_x` is the number of samples in the signal. * :math:`s_w` is the number of samples in the dpss window. * :math:`r` is the normalize frequency resolution in Hz, for a sample rate of 1. * :math:`\Sigma(f)` is an estimation of the standard deviation of the psd. * :math:`\Gamma_{x_1,x_2}(f) = psd(f)` is the power spectral density or the intercorrelation. * :math:`n_{psd}(f)` is the psd noise, ignoring gibbs effects. * :math:`n_{win}` is the additive noise created by the gibbs effect of the window. * :math:`\eta` is the maximum amplitude of the largest of the window's secondary lobs. * :math:`\alpha` is the window theorical standardized half bandwidth. * :math:`\beta` is the window experimental standardized half bandwidth. Equations --------- There, the equation to find the optimal windows size: .. math:: \begin{cases} n_{psd}(f) = \Sigma(f) . \sqrt{\frac{s_w}{s_x}} & \text{std of mean estimator} \\ n_{win} = \max\left(\Gamma_{x_1,x_2}(f)\right) . \eta & \text{because convolution}\\ \eta = g(\alpha) \\ \beta = \alpha + \epsilon = h(\alpha) \\ r = \frac{1}{s_w} + 2 . \frac{\beta}{s_w} & \text{convolution main lob} \\ \end{cases} To avoid having a too big gibbs noise, we want :math:`n_{win} < n_{psd}(f)`. .. math:: \Leftrightarrow \begin{cases} psd_{max} . \eta < \Sigma_{max} . \sqrt{\frac{s_w}{s_x}} \\ r . s_w = 1 + 2.h(g^{-1}(\eta)) \\ \end{cases} \Leftrightarrow s_w . r - 2 . h\left( g^{-1}\left(\frac{\Sigma_{max} . \sqrt{\frac{s_w}{s_x}}}{psd_{max}}\right) \right) - 1 > 0 Parameters ---------- signal_1 : torch.Tensor The stationary signal $x_1$ on witch we evaluate the PSD. The tensor can be batched, so the shape is (..., n). signal_2 : torch.Tensor, default=signal_1 If you prefer to compute an intercorrelation rather an autocorrelation, please provide it (:math:`x_2`), otherwise, :math:`x_2 = x_1`. freq_res : float The normlised frequency resolution in Hz :math:`\left(r \in \left]0, \frac{1}{2}\right[\right)`, for a sample rate of 1. Higher it is, better is the frequency resolution but noiser it is. Returns ------- psd : torch.Tensor An estimation of the power spectral density :math:`\Gamma_{x_1,x_2}(f)`, of shape (..., m). Ie an estimation of :math:`\mathbb{E}\left[X_1(f)X_2^x(f)\right]`. In the case of autocorrelation, psd is returned as float type, overwise (in case of intercorrelation), it is returned as complex. Notes ----- The complexity is :math:`O\left(\frac{s_x}{r}\right)`. Examples -------- >>> import torch >>> from cutcutcodec.core.signal.psd import welch >>> sr, delta_t = 8000, 60 # sample rate in (Hz) and duration in (s) >>> t = torch.arange(0, delta_t, 1/sr) # timae smaple in (s) >>> signal = torch.randn((16, 2, len(t))) + torch.sin(2*torch.pi*440*t) # 16 stereo signals >>> psd = welch(signal, freq_res=20.0/sr) # band of 20 Hz >>> >>> # we should observe a background at 1 and a pic a 440 Hz. >>> # import matplotlib.pyplot as plt >>> # freq = torch.fft.rfftfreq(2*psd.shape[-1]-1, 1/sr) >>> # _ = plt.plot(freq, psd[0].T) >>> # _ = plt.xlabel("freq (Hz)") >>> # _ = plt.ylabel("psd") >>> # plt.show() >>> """ assert isinstance(signal_1, torch.Tensor), signal_1.__class__.__name__ assert signal_1.ndim >= 1 assert signal_1.shape[-1] >= 32, "the signal is to short, please provide more samples" if signal_2 is not None: assert isinstance(signal_2, torch.Tensor), signal_2.__class__.__name__ assert signal_2.ndim >= 1 assert signal_2.shape[-1] == signal_1.shape[-1], (signal_2.shape, signal_1.shape) signal_2 = signal_2.to(dtype=signal_1.dtype, device=signal_1.device) else: signal_2 = signal_1 # The first step consists in having a fast and inaccurate estimation of the psd. win_len = min(4096, signal_1.shape[-1]//3) win = torch.hann_window(win_len, periodic=False, dtype=signal_1.dtype, device=signal_1.device) std, psd = _compute_psd(signal_1, signal_2, win, win_len//2, return_std=True) # Get a frequency resolution. if freq_res is None: freq_res = max(1.0/win_len, 10.0/48000) # simple heuristic else: assert isinstance(freq_res, numbers.Real), freq_res.__class__.__name__ assert 0 < freq_res < 0.5, freq_res freq_res = float(freq_res) # The second step constists in finding the best windows size. # According to the inequality specified in the docstring. win_len = math.ceil( _find_len(signal_1.shape[-1], float(std.max()), float(abs(psd).max()), freq_res) ) win_len = min(win_len, signal_1.shape[-1]) # The fird step consists in deducing alpha # according the frequency accuracy expression in the doctring. alpha = band_to_alpha(0.5 * (freq_res * win_len - 1)) # compute the psd win = dpss(win_len, alpha, dtype=signal_1.dtype) psd = _compute_psd(signal_1, signal_2, win, win_len//4, return_std=False) # //4 for overlapp return psd