Source code for tme.filters.bandpass

"""
Implements class BandPass and BandPassReconstructed.

Copyright (c) 2024 European Molecular Biology Laboratory

Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

from math import log, sqrt
from typing import Tuple, Union, Optional
from pydantic.dataclasses import dataclass

import numpy as np

from ..types import BackendArray
from ..backends import backend as be
from .compose import ComposableFilter
from ._utils import pad_to_length, frequency_grid_at_angle, fftfreqn

__all__ = ["BandPass", "BandPassReconstructed"]


[docs] @dataclass(config=dict(extra="allow")) class BandPass(ComposableFilter): """ Generate per-tilt Bandpass filter. Examples -------- The following creates an instance of :py:class:`BandPass` >>> from tme.filters import BandPass >>> bpf_instance = BandPass( >>> angles=(-70, 0, 30), >>> lowpass=30, >>> sampling_rate=10 >>> ) Differently from :py:class:`tme.filters.BandPassReconstructed`, the filter masks are intended to be used in subsequent reconstruction using :py:class:`tme.filters.ReconstructFromTilt`. The ``opening_axis``, ``tilt_axis`` and ``angles`` parameter are used to determine the correct frequencies for non-cubical input shapes. The ``shape`` argument contains the shape of the reconstruction. >>> ret = bpf_instance(shape=(50,50,25)) >>> mask = ret["data"] >>> mask.shape # 3, 50, 50 Note that different from its reconstructed counterpart, the DC component is at the center of the array. >>> import matplotlib.pyplot as plt >>> fix, ax = plt.subplots(nrows=1, ncols=3) >>> _ = [ax[i].imshow(mask[i]) for i in range(mask.shape[0])] >>> plt.show() """ #: The lowpass cutoff, defaults to None. lowpass: Optional[float] = None #: The highpass cutoff, defaults to None. highpass: Optional[float] = None #: The sampling rate, defaults to 1 Ångstrom / voxel. sampling_rate: Union[Tuple[float, ...], float] = 1 #: Whether to use Gaussian bandpass filter, defaults to True. use_gaussian: bool = True #: The tilt angles in degrees. angles: Tuple[float, ...] = None #: Axis the plane is tilted over, defaults to 0 (x). tilt_axis: int = 0 #: The projection axis, defaults to 2 (z). opening_axis: int = 2 def _evaluate(self, shape: Tuple[int, ...], **kwargs): """ Returns a Bandpass stack of chosen parameters with DC component in the center. """ func = discrete_bandpass if kwargs.get("use_gaussian"): func = gaussian_bandpass angles = np.atleast_1d(kwargs["angles"]) _lowpass = pad_to_length(kwargs["lowpass"], angles.size) _highpass = pad_to_length(kwargs["highpass"], angles.size) masks = [] for index, angle in enumerate(angles): frequency_grid = frequency_grid_at_angle( shape=shape, tilt_axis=kwargs["tilt_axis"], opening_axis=kwargs["opening_axis"], angle=angle, sampling_rate=1, ) kwargs["lowpass"] = _lowpass[index] kwargs["highpass"] = _highpass[index] mask = func(grid=frequency_grid, **kwargs) masks.append(be.to_backend_array(mask[None])) return {"data": be.concatenate(masks, axis=0), "shape": shape}
[docs] @dataclass(config=dict(extra="allow")) class BandPassReconstructed(ComposableFilter): """ Generate Bandpass filter for reconstructions. Examples -------- The following creates an instance of :py:class:`BandPassReconstructed` >>> from tme.filters import BandPassReconstructed >>> bpf_instance = BandPassReconstructed( >>> lowpass=30, >>> sampling_rate=10 >>> ) We can use its call method to create filters of given shape >>> import matplotlib.pyplot as plt >>> ret = bpf_instance(shape=(50,50)) The ``data`` key of the returned dictionary contains the corresponding Fourier filter mask. The DC component is located at the origin. >>> plt.imshow(ret["data"]) >>> plt.show() """ #: The lowpass cutoff, defaults to None. lowpass: Optional[float] = None #: The highpass cutoff, defaults to None. highpass: Optional[float] = None #: The sampling rate, defaults to 1 Ångstrom / voxel. sampling_rate: Union[Tuple[float, ...], float] = 1 #: Whether to use Gaussian bandpass filter, defaults to True. use_gaussian: bool = True def _evaluate(self, shape: Tuple[int, ...], **kwargs): func = discrete_bandpass if kwargs.get("use_gaussian"): func = gaussian_bandpass grid = fftfreqn( shape=shape, sampling_rate=0.5, shape_is_real_fourier=False, compute_euclidean_norm=True, fftshift=False, ) ret = be.to_backend_array(func(grid=grid, **kwargs)) return {"data": ret, "shape": shape}
def discrete_bandpass( grid: BackendArray, lowpass: float, highpass: float, sampling_rate: Tuple[float], **kwargs, ) -> BackendArray: """ Generate a bandpass filter using discrete frequency cutoffs. Parameters ---------- grid : BackendArray Frequencies in Fourier space. lowpass : float The lowpass cutoff in spatial units of sampling rate. highpass : float The highpass cutoff in spatial units of sampling rate. sampling_rate : float The sampling rate in Fourier space. **kwargs : dict Additional keyword arguments. Returns ------- BackendArray The bandpass filter in Fourier space. """ grid = be.astype(be.to_backend_array(grid), be._float_dtype) sampling_rate = be.to_backend_array(sampling_rate) highcut = grid.max() if lowpass is not None: highcut = be.max(2 * sampling_rate / lowpass) lowcut = 0 if highpass is not None: lowcut = be.max(2 * sampling_rate / highpass) bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0 return bandpass_filter def gaussian_bandpass( grid: BackendArray, lowpass: float = None, highpass: float = None, sampling_rate: float = 1, **kwargs, ) -> BackendArray: """ Generate a bandpass filter using Gaussians. Parameters ---------- grid : BackendArray Frequency grid in Fourier space. lowpass : float, optional The lowpass cutoff in units of sampling rate, defaults to None. highpass : float, optional The highpass cutoff in units of sampling rate, defaults to None. sampling_rate : float, optional The sampling rate in Fourier space, defaults to one. **kwargs : dict Additional keyword arguments. Returns ------- BackendArray The bandpass filter in Fourier space. """ grid = be.astype(be.to_backend_array(grid), be._float_dtype) grid = -be.square(grid, out=grid) has_lowpass, has_highpass = False, False norm = float(sqrt(2 * log(2))) upper_sampling = float(be.max(be.multiply(2, be.to_backend_array(sampling_rate)))) if lowpass is not None: lowpass, has_lowpass = float(lowpass), True lowpass = be.maximum(lowpass, be.eps(be._float_dtype)) if highpass is not None: highpass, has_highpass = float(highpass), True highpass = be.maximum(highpass, be.eps(be._float_dtype)) if has_lowpass: lowpass = upper_sampling / (lowpass * norm) lowpass = be.multiply(2, be.square(lowpass)) lowpass_filter = be.divide(grid, lowpass) lowpass_filter = be.exp(lowpass_filter, out=lowpass_filter) if has_highpass: highpass = upper_sampling / (highpass * norm) highpass = be.multiply(2, be.square(highpass)) highpass_filter = be.divide(grid, highpass) highpass_filter = be.exp(highpass_filter, out=highpass_filter) highpass_filter = be.subtract(1, highpass_filter, out=highpass_filter) if has_lowpass and not has_highpass: bandpass_filter = lowpass_filter elif not has_lowpass and has_highpass: bandpass_filter = highpass_filter elif has_lowpass and has_highpass: bandpass_filter = be.multiply( lowpass_filter, highpass_filter, out=lowpass_filter ) else: bandpass_filter = be.full(grid.shape, fill_value=1, dtype=be._float_dtype) return bandpass_filter