Source code for tme.preprocessing.frequency_filters

""" Defines Fourier frequency filters.

    Copyright (c) 2024 European Molecular Biology Laboratory

    Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""
from math import log, sqrt
from typing import Tuple, Dict

import numpy as np
from scipy.ndimage import mean as ndimean
from scipy.ndimage import map_coordinates

from ..types import BackendArray
from ..backends import backend as be
from ._utils import fftfreqn, crop_real_fourier, shift_fourier, compute_fourier_shape


[docs] class BandPassFilter: """ Generate bandpass filters in Fourier space. Parameters ---------- lowpass : float, optional The lowpass cutoff, defaults to None. highpass : float, optional The highpass cutoff, defaults to None. sampling_rate : Tuple[float], optional The sampling r_position_to_molmapate in Fourier space, defaults to 1. use_gaussian : bool, optional Whether to use Gaussian bandpass filter, defaults to True. return_real_fourier : bool, optional Whether to return only the real Fourier space, defaults to False. shape_is_real_fourier : bool, optional Whether the shape represents the real Fourier space, defaults to False. """ def __init__( self, lowpass: float = None, highpass: float = None, sampling_rate: Tuple[float] = 1, use_gaussian: bool = True, return_real_fourier: bool = False, shape_is_real_fourier: bool = False, ): self.lowpass = lowpass self.highpass = highpass self.use_gaussian = use_gaussian self.return_real_fourier = return_real_fourier self.shape_is_real_fourier = shape_is_real_fourier self.sampling_rate = sampling_rate
[docs] @staticmethod def discrete_bandpass( shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: Tuple[float], return_real_fourier: bool = False, shape_is_real_fourier: bool = False, **kwargs, ) -> BackendArray: """ Generate a bandpass filter using discrete frequency cutoffs. Parameters ---------- shape : tuple of int The shape of the bandpass filter. lowpass : float The lowpass cutoff in units of sampling rate. highpass : float The highpass cutoff in units of sampling rate. return_real_fourier : bool, optional Whether to return only the real Fourier space, defaults to False. sampling_rate : float The sampling rate in Fourier space. shape_is_real_fourier : bool, optional Whether the shape represents the real Fourier space, defaults to False. **kwargs : dict Additional keyword arguments. Returns ------- BackendArray The bandpass filter in Fourier space. """ if shape_is_real_fourier: return_real_fourier = False grid = fftfreqn( shape=shape, sampling_rate=0.5, shape_is_real_fourier=shape_is_real_fourier, compute_euclidean_norm=True, ) grid = be.to_backend_array(grid) sampling_rate = be.to_backend_array(sampling_rate) highcut = grid.max() if lowpass is not None: highcut = be.max(2 * sampling_rate / lowpass) lowcut = 0 if highpass is not None: lowcut = be.max(2 * sampling_rate / highpass) bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0 bandpass_filter = shift_fourier( data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier ) if return_real_fourier: bandpass_filter = crop_real_fourier(bandpass_filter) return bandpass_filter
[docs] @staticmethod def gaussian_bandpass( shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float, return_real_fourier: bool = False, shape_is_real_fourier: bool = False, **kwargs, ) -> BackendArray: """ Generate a bandpass filter using Gaussians. Parameters ---------- shape : tuple of int The shape of the bandpass filter. lowpass : float The lowpass cutoff in units of sampling rate. highpass : float The highpass cutoff in units of sampling rate. sampling_rate : float The sampling rate in Fourier space. return_real_fourier : bool, optional Whether to return only the real Fourier space, defaults to False. shape_is_real_fourier : bool, optional Whether the shape represents the real Fourier space, defaults to False. **kwargs : dict Additional keyword arguments. Returns ------- BackendArray The bandpass filter in Fourier space. """ if shape_is_real_fourier: return_real_fourier = False grid = fftfreqn( shape=shape, sampling_rate=0.5, shape_is_real_fourier=shape_is_real_fourier, compute_euclidean_norm=True, ) grid = be.to_backend_array(grid) grid = -be.square(grid) lowpass_filter, highpass_filter = 1, 1 norm = float(sqrt(2 * log(2))) upper_sampling = float( be.max(be.multiply(2, be.to_backend_array(sampling_rate))) ) if lowpass is not None: lowpass = float(lowpass) lowpass = be.maximum(lowpass, be.eps(be._float_dtype)) if highpass is not None: highpass = float(highpass) highpass = be.maximum(highpass, be.eps(be._float_dtype)) if lowpass is not None: lowpass = upper_sampling / (lowpass * norm) lowpass = be.multiply(2, be.square(lowpass)) lowpass_filter = be.exp(be.divide(grid, lowpass)) if highpass is not None: highpass = upper_sampling / (highpass * norm) highpass = be.multiply(2, be.square(highpass)) highpass_filter = 1 - be.exp(be.divide(grid, highpass)) bandpass_filter = be.multiply(lowpass_filter, highpass_filter) bandpass_filter = shift_fourier( data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier ) if return_real_fourier: bandpass_filter = crop_real_fourier(bandpass_filter) return bandpass_filter
def __call__(self, **kwargs): func_args = vars(self) func_args.update(kwargs) func = self.discrete_bandpass if func_args.get("use_gaussian"): func = self.gaussian_bandpass mask = func(**func_args) return { "data": be.to_backend_array(mask), "sampling_rate": func_args.get("sampling_rate", 1), "is_multiplicative_filter": True, }
[docs] class LinearWhiteningFilter: """ Compute Fourier power spectrums and perform whitening. Parameters ---------- **kwargs : Dict, optional Additional keyword arguments. References ---------- .. [1] de Teresa-Trueba, I.; Goetz, S. K.; Mattausch, A.; Stojanovska, F.; Zimmerli, C. E.; Toro-Nahuelpan, M.; Cheng, D. W. C.; Tollervey, F.; Pape, C.; Beck, M.; Diz-Munoz, A.; Kreshuk, A.; Mahamid, J.; Zaugg, J. B. Nat. Methods 2023, 20, 284–294. .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet, R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24, 13375 (2023) """ def __init__(self, **kwargs): pass @staticmethod def _compute_spectrum( data_rfft: BackendArray, n_bins: int = None, batch_dimension: int = None ) -> Tuple[BackendArray, BackendArray]: """ Compute the power spectrum of the input data. Parameters ---------- data_rfft : BackendArray The Fourier transform of the input data. n_bins : int, optional The number of bins for computing the spectrum, defaults to None. batch_dimension : int, optional Batch dimension to average over. Returns ------- bins : BackendArray Array containing the bin indices for the spectrum. radial_averages : BackendArray Array containing the radial averages of the spectrum. """ shape = tuple(x for i, x in enumerate(data_rfft.shape) if i != batch_dimension) max_bins = max(max(shape[:-1]) // 2 + 1, shape[-1]) n_bins = max_bins if n_bins is None else n_bins n_bins = int(min(n_bins, max_bins)) bins = fftfreqn( shape=shape, sampling_rate=0.5, shape_is_real_fourier=True, compute_euclidean_norm=True, ) bins = be.to_numpy_array(bins) # Implicit lowpass to nyquist bins = np.floor(bins * (n_bins - 1) + 0.5).astype(int) fft_shift_axes = tuple( i for i in range(data_rfft.ndim - 1) if i != batch_dimension ) fourier_spectrum = np.fft.fftshift(data_rfft, axes=fft_shift_axes) fourier_spectrum = np.abs(fourier_spectrum) np.square(fourier_spectrum, out=fourier_spectrum) radial_averages = ndimean( fourier_spectrum, labels=bins, index=np.arange(n_bins) ) np.sqrt(radial_averages, out=radial_averages) np.reciprocal(radial_averages, out=radial_averages) np.divide(radial_averages, radial_averages.max(), out=radial_averages) return bins, radial_averages @staticmethod def _interpolate_spectrum( spectrum: BackendArray, shape: Tuple[int], shape_is_real_fourier: bool = True, order: int = 1, ) -> BackendArray: """ References ---------- .. [1] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet, R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24, 13375 (2023) """ grid = fftfreqn( shape=shape, sampling_rate=0.5, shape_is_real_fourier=shape_is_real_fourier, compute_euclidean_norm=True, ) grid = be.to_numpy_array(grid) np.multiply(grid, (spectrum.shape[0] - 1), out=grid) + 0.5 spectrum = map_coordinates(spectrum, grid.reshape(1, -1), order=order) return spectrum.reshape(grid.shape) def __call__( self, data: BackendArray = None, data_rfft: BackendArray = None, n_bins: int = None, batch_dimension: int = None, order: int = 1, **kwargs: Dict, ) -> Dict: """ Apply linear whitening to the data and return the result. Parameters ---------- data : BackendArray, optional The input data, defaults to None. data_rfft : BackendArray, optional The Fourier transform of the input data, defaults to None. n_bins : int, optional The number of bins for computing the spectrum, defaults to None. batch_dimension : int, optional Batch dimension to average over. order : int, optional Interpolation order to use. **kwargs : Dict Additional keyword arguments. Returns ------- Dict Filter data and associated parameters. """ if data_rfft is None: data_rfft = np.fft.rfftn(be.to_numpy_array(data)) data_rfft = be.to_numpy_array(data_rfft) bins, radial_averages = self._compute_spectrum( data_rfft, n_bins, batch_dimension ) if order is None: cutoff = bins < radial_averages.size filter_mask = np.zeros(bins.shape, radial_averages.dtype) filter_mask[cutoff] = radial_averages[bins[cutoff]] else: shape = bins.shape if kwargs.get("shape", False): shape = compute_fourier_shape( shape=kwargs.get("shape"), shape_is_real_fourier=kwargs.get("shape_is_real_fourier", False), ) filter_mask = self._interpolate_spectrum( spectrum=radial_averages, shape=shape, shape_is_real_fourier=True, ) filter_mask = np.fft.ifftshift( filter_mask, axes=tuple(i for i in range(data_rfft.ndim - 1) if i != batch_dimension), ) return { "data": be.to_backend_array(filter_mask), "is_multiplicative_filter": True, }