Source code for cutcutcodec.core.signal.psd

#!/usr/bin/env python3

"""Tools for the Power Spectral Density (PSD) estimation."""

from fractions import Fraction
import numbers

import torch

from .window import dpss


[docs] def welch(signal: torch.Tensor, freq_resol: numbers.Real = None) -> torch.Tensor: """Estimate the power spectral density (PSD) with the Welch method. Parameters ---------- signal : torch.Tensor The stationary signal on witch we evaluate the PSD. The tensor can be batched, so the shape is (..., n). freq_resol : float, default=10 The norlised frequency resolution in Hz, assuming a sample rate of 1. It is the lowest denoised frequency as well. Higher it is, better is the resolution but noiser it is. Returns ------- psd : torch.Tensor An estimation of the power spectral density, of shape (..., m). Examples -------- >>> import torch >>> from cutcutcodec.core.signal.psd import welch >>> signal = torch.randn((32, 2, 768000)) >>> psd = welch(signal) >>> >>> # freq = torch.fft.rfftfreq(2*psd.shape[-1]-1, 1/48000) >>> # import matplotlib.pyplot as plt >>> # _ = plt.plot(freq, psd[0].T) >>> # plt.show() >>> """ assert isinstance(signal, torch.Tensor), signal.__class__.__name__ freq_resol_min = Fraction(1, signal.shape[-1]) # theorical minimum frequency resolution if freq_resol is not None: assert isinstance(freq_resol, numbers.Real), freq_resol.__class__.__name__ assert freq_resol >= freq_resol_min, f"you must provide freq_resol >= {freq_resol_min}" else: freq_resol = 2 * freq_resol_min win_len = min(2048, 2*round(0.5 / freq_resol)) win = dpss(win_len, 5.0, dtype=signal.dtype) shift = win_len // 4 # should be choosen cleverer # if (exedent := (signal.shape[-1] - win_len) % shift) != 0: # signal = signal[..., :-exedent] signal = signal.contiguous() # assert signal.is_contiguous(), "signal has to be C contiguous" psd = signal.as_strided( # shape (..., o, m) ( *signal.shape[:-1], (signal.shape[-1] - win_len) // shift + 1, # number of slices win_len, ), (*signal.stride()[:-1], shift, 1), ) psd = psd * win # not inplace because blocs was not contiguous psd = torch.fft.rfft(psd, norm="ortho", dim=-1) psd = psd.real**2 + psd.imag**2 psd = torch.mean(psd, dim=-2) # shape (..., m) psd /= torch.sum(win) return psd