Source code for tme.filters.whitening

"""
Implements class LinearWhiteningFilter

Copyright (c) 2024 European Molecular Biology Laboratory

Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

from typing import Tuple, Dict
from dataclasses import dataclass

import numpy as np
from scipy.ndimage import mean as ndimean
from scipy.ndimage import map_coordinates

from ._utils import fftfreqn
from ..types import BackendArray
from ..analyzer.peaks import batchify
from ..backends import backend as be
from .compose import ComposableFilter


__all__ = ["LinearWhiteningFilter"]


[docs] @dataclass class LinearWhiteningFilter(ComposableFilter): """ Generate Fourier whitening filters. References ---------- .. [1] de Teresa-Trueba, I.; Goetz, S. K.; Mattausch, A.; Stojanovska, F.; Zimmerli, C. E.; Toro-Nahuelpan, M.; Cheng, D. W. C.; Tollervey, F.; Pape, C.; Beck, M.; Diz-Munoz, A.; Kreshuk, A.; Mahamid, J.; Zaugg, J. B. Nat. Methods 2023, 20, 284–294. .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet, R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24, 13375 (2023) """ @staticmethod def _compute_spectrum( data_rfft: BackendArray, n_bins: int = None ) -> Tuple[BackendArray, BackendArray]: """ Compute the power spectrum of the input data. Parameters ---------- data_rfft : BackendArray The Fourier transform of the input data. n_bins : int, optional The number of bins for computing the spectrum, defaults to None. Returns ------- bins : BackendArray Array containing the bin indices for the spectrum. radial_averages : BackendArray Array containing the radial averages of the spectrum. """ shape = data_rfft.shape max_bins = max(max(shape[:-1]) // 2 + 1, shape[-1]) n_bins = max_bins if n_bins is None else n_bins n_bins = int(min(n_bins, max_bins)) bins = fftfreqn( shape=shape, sampling_rate=0.5, shape_is_real_fourier=True, compute_euclidean_norm=True, fftshift=False, ) bins = be.to_numpy_array(bins) bins = np.floor(bins * (n_bins - 1) + 0.5).astype(int) fourier_spectrum = np.abs(data_rfft) fourier_spectrum = np.square(fourier_spectrum, out=fourier_spectrum) radial_averages = ndimean( fourier_spectrum, labels=bins, index=np.arange(n_bins) ) radial_averages = np.sqrt(radial_averages, out=radial_averages) radial_averages = np.where(radial_averages != 0, 1 / radial_averages, 0) norm_factor = radial_averages.max() if norm_factor != 0: radial_averages = np.divide(radial_averages, norm_factor) return bins, radial_averages @staticmethod def _interpolate_spectrum( spectrum: BackendArray, shape: Tuple[int], shape_is_real_fourier: bool = True, order: int = 1, ) -> BackendArray: grid = fftfreqn( shape=shape, sampling_rate=0.5, shape_is_real_fourier=shape_is_real_fourier, compute_euclidean_norm=True, fftshift=False, ) grid = be.to_numpy_array(grid) grid = np.floor(np.multiply(grid, spectrum.shape[0] - 1) + 0.5) spectrum = map_coordinates(spectrum, grid.reshape(1, -1), order=order) return spectrum.reshape(grid.shape) def _evaluate( self, shape: Tuple[int, ...], data_rfft: BackendArray, axes: Tuple[int] = (), order: int = 1, **kwargs: Dict, ) -> Dict: """ Apply linear whitening to the data and return the result. Parameters ---------- shape : tuple of ints Shape of the returned whitening filter. data_rfft : BackendArray, optional The Fourier transform of the input data, defaults to None. axes : tuple of ints, optional Axes to compute spectrum for independently. **kwargs : Dict Additional keyword arguments. """ if isinstance(axes, int): axes = (axes,) stack = [] data_rfft = be.to_numpy_array(data_rfft) for subset, _ in batchify(data_rfft.shape, axes): _, radial_avg = self._compute_spectrum(np.squeeze(data_rfft[subset])) ret = self._interpolate_spectrum( spectrum=radial_avg, shape=shape, shape_is_real_fourier=False, order=order, ) stack.append(ret) ret = np.array(stack) if not len(axes): ret = np.squeeze(ret) return {"data": be.to_backend_array(ret), "shape": shape}