""" Implements class BandPassFilter to create Fourier filter representations.
Copyright (c) 2024 European Molecular Biology Laboratory
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""
from typing import Tuple
from math import log, sqrt
from ..types import BackendArray
from ..backends import backend as be
from .compose import ComposableFilter
from ._utils import fftfreqn, crop_real_fourier, shift_fourier
__all__ = ["BandPassFilter"]
[docs]
class BandPassFilter(ComposableFilter):
"""
Generate bandpass filters in Fourier space.
Parameters
----------
lowpass : float, optional
The lowpass cutoff, defaults to None.
highpass : float, optional
The highpass cutoff, defaults to None.
sampling_rate : Tuple[float], optional
The sampling r_position_to_molmapate in Fourier space, defaults to 1.
use_gaussian : bool, optional
Whether to use Gaussian bandpass filter, defaults to True.
return_real_fourier : bool, optional
Whether to return only the real Fourier space, defaults to False.
shape_is_real_fourier : bool, optional
Whether the shape represents the real Fourier space, defaults to False.
"""
def __init__(
self,
lowpass: float = None,
highpass: float = None,
sampling_rate: Tuple[float] = 1,
use_gaussian: bool = True,
return_real_fourier: bool = False,
shape_is_real_fourier: bool = False,
):
self.lowpass = lowpass
self.highpass = highpass
self.use_gaussian = use_gaussian
self.return_real_fourier = return_real_fourier
self.shape_is_real_fourier = shape_is_real_fourier
self.sampling_rate = sampling_rate
[docs]
@staticmethod
def discrete_bandpass(
shape: Tuple[int],
lowpass: float,
highpass: float,
sampling_rate: Tuple[float],
return_real_fourier: bool = False,
shape_is_real_fourier: bool = False,
**kwargs,
) -> BackendArray:
"""
Generate a bandpass filter using discrete frequency cutoffs.
Parameters
----------
shape : tuple of int
The shape of the bandpass filter.
lowpass : float
The lowpass cutoff in units of sampling rate.
highpass : float
The highpass cutoff in units of sampling rate.
return_real_fourier : bool, optional
Whether to return only the real Fourier space, defaults to False.
sampling_rate : float
The sampling rate in Fourier space.
shape_is_real_fourier : bool, optional
Whether the shape represents the real Fourier space, defaults to False.
**kwargs : dict
Additional keyword arguments.
Returns
-------
BackendArray
The bandpass filter in Fourier space.
"""
if shape_is_real_fourier:
return_real_fourier = False
grid = fftfreqn(
shape=shape,
sampling_rate=0.5,
shape_is_real_fourier=shape_is_real_fourier,
compute_euclidean_norm=True,
)
grid = be.astype(be.to_backend_array(grid), be._float_dtype)
sampling_rate = be.to_backend_array(sampling_rate)
highcut = grid.max()
if lowpass is not None:
highcut = be.max(2 * sampling_rate / lowpass)
lowcut = 0
if highpass is not None:
lowcut = be.max(2 * sampling_rate / highpass)
bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
bandpass_filter = shift_fourier(
data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
)
if return_real_fourier:
bandpass_filter = crop_real_fourier(bandpass_filter)
return bandpass_filter
[docs]
@staticmethod
def gaussian_bandpass(
shape: Tuple[int],
lowpass: float,
highpass: float,
sampling_rate: float,
return_real_fourier: bool = False,
shape_is_real_fourier: bool = False,
**kwargs,
) -> BackendArray:
"""
Generate a bandpass filter using Gaussians.
Parameters
----------
shape : tuple of int
The shape of the bandpass filter.
lowpass : float
The lowpass cutoff in units of sampling rate.
highpass : float
The highpass cutoff in units of sampling rate.
sampling_rate : float
The sampling rate in Fourier space.
return_real_fourier : bool, optional
Whether to return only the real Fourier space, defaults to False.
shape_is_real_fourier : bool, optional
Whether the shape represents the real Fourier space, defaults to False.
**kwargs : dict
Additional keyword arguments.
Returns
-------
BackendArray
The bandpass filter in Fourier space.
"""
if shape_is_real_fourier:
return_real_fourier = False
grid = fftfreqn(
shape=shape,
sampling_rate=0.5,
shape_is_real_fourier=shape_is_real_fourier,
compute_euclidean_norm=True,
)
grid = be.astype(be.to_backend_array(grid), be._float_dtype)
grid = -be.square(grid, out=grid)
has_lowpass, has_highpass = False, False
norm = float(sqrt(2 * log(2)))
upper_sampling = float(
be.max(be.multiply(2, be.to_backend_array(sampling_rate)))
)
if lowpass is not None:
lowpass, has_lowpass = float(lowpass), True
lowpass = be.maximum(lowpass, be.eps(be._float_dtype))
if highpass is not None:
highpass, has_highpass = float(highpass), True
highpass = be.maximum(highpass, be.eps(be._float_dtype))
if has_lowpass:
lowpass = upper_sampling / (lowpass * norm)
lowpass = be.multiply(2, be.square(lowpass))
if not has_highpass:
lowpass_filter = be.divide(grid, lowpass, out=grid)
else:
lowpass_filter = be.divide(grid, lowpass)
lowpass_filter = be.exp(lowpass_filter, out=lowpass_filter)
if has_highpass:
highpass = upper_sampling / (highpass * norm)
highpass = be.multiply(2, be.square(highpass))
highpass_filter = be.divide(grid, highpass, out=grid)
highpass_filter = be.exp(highpass_filter, out=highpass_filter)
highpass_filter = be.subtract(1, highpass_filter, out=highpass_filter)
if has_lowpass and not has_highpass:
bandpass_filter = lowpass_filter
elif not has_lowpass and has_highpass:
bandpass_filter = highpass_filter
elif has_lowpass and has_highpass:
bandpass_filter = be.multiply(
lowpass_filter, highpass_filter, out=lowpass_filter
)
else:
bandpass_filter = be.full(shape, fill_value=1, dtype=be._float_dtype)
bandpass_filter = shift_fourier(
data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
)
if return_real_fourier:
bandpass_filter = crop_real_fourier(bandpass_filter)
return bandpass_filter
[docs]
def __call__(self, **kwargs):
func_args = vars(self)
func_args.update(kwargs)
func = self.discrete_bandpass
if func_args.get("use_gaussian"):
func = self.gaussian_bandpass
mask = func(**func_args)
return {
"data": be.to_backend_array(mask),
"sampling_rate": func_args.get("sampling_rate", 1),
"is_multiplicative_filter": True,
}