Source code for tme.filters.bandpass

"""
Implements class BandPassFilter to create Fourier filter representations.

Copyright (c) 2024 European Molecular Biology Laboratory

Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

from typing import Tuple
from math import log, sqrt
from dataclasses import dataclass

import numpy as np

from ..types import BackendArray
from ..backends import backend as be
from .compose import ComposableFilter
from ._utils import (
    crop_real_fourier,
    shift_fourier,
    pad_to_length,
    frequency_grid_at_angle,
    fftfreqn,
)

__all__ = ["BandPass", "BandPassReconstructed"]


[docs] @dataclass class BandPass(ComposableFilter): """ Generate per-slice Fourier Bandpass filter """ #: The tilt angles. angles: Tuple[float] #: The lowpass cutoffs. Either one or one per angle, defaults to None. lowpass: Tuple[float] = None #: The highpass cutoffs. Either one or one per angle, defaults to None. highpass: Tuple[float] = None #: The shape of the to-be created mask. shape: Tuple[int] = None #: Axis the plane is tilted over, defaults to 0 (x). tilt_axis: int = 0 #: The projection axis, defaults to 2 (z). opening_axis: int = 2 #: The sampling rate, defaults to 1 Ångstrom / voxel. sampling_rate: Tuple[float] = 1 #: Whether to use Gaussian bandpass filter, defaults to True. use_gaussian: bool = True #: Whether to return a mask for rfft return_real_fourier: bool = False
[docs] def __call__(self, **kwargs): """ Returns a Bandpass stack of chosen parameters with DC component in the center. """ func_args = vars(self).copy() func_args.update(kwargs) func = discrete_bandpass if func_args.get("use_gaussian"): func = gaussian_bandpass return_real_fourier = kwargs.get("return_real_fourier", True) shape_is_real_fourier = kwargs.get("shape_is_real_fourier", False) if shape_is_real_fourier: return_real_fourier = False angles = np.atleast_1d(func_args["angles"]) _lowpass = pad_to_length(func_args["lowpass"], angles.size) _highpass = pad_to_length(func_args["highpass"], angles.size) masks = [] for index, angle in enumerate(angles): frequency_grid = frequency_grid_at_angle( shape=func_args["shape"], tilt_axis=func_args["tilt_axis"], opening_axis=func_args["opening_axis"], angle=angle, sampling_rate=1, ) func_args["lowpass"] = _lowpass[index] func_args["highpass"] = _highpass[index] mask = func(grid=frequency_grid, **func_args) mask = shift_fourier(data=mask, shape_is_real_fourier=shape_is_real_fourier) if return_real_fourier: mask = crop_real_fourier(mask) masks.append(mask[None]) masks = be.concatenate(masks, axis=0) return { "data": be.to_backend_array(masks), "shape": func_args["shape"], "return_real_fourier": return_real_fourier, "is_multiplicative_filter": True, }
[docs] @dataclass class BandPassReconstructed(ComposableFilter): """ Generate reconstructed bandpass filters in Fourier space. """ #: The lowpass cutoff, defaults to None. lowpass: float = None #: The highpass cutoff, defaults to None. highpass: float = None #: The shape of the to-be created mask. shape: Tuple[int] = None #: Axis the plane is tilted over, defaults to 0 (x). tilt_axis: int = 0 #: The projection axis, defaults to 2 (z). opening_axis: int = 2 #: The sampling rate, defaults to 1 Ångstrom / voxel. sampling_rate: Tuple[float] = 1 #: Whether to use Gaussian bandpass filter, defaults to True. use_gaussian: bool = True #: Whether to return a mask for rfft return_real_fourier: bool = False
[docs] def __call__(self, **kwargs): func_args = vars(self).copy() func_args.update(kwargs) func = discrete_bandpass if func_args.get("use_gaussian"): func = gaussian_bandpass return_real_fourier = func_args.get("return_real_fourier", True) shape_is_real_fourier = func_args.get("shape_is_real_fourier", False) if shape_is_real_fourier: return_real_fourier = False grid = fftfreqn( shape=func_args["shape"], sampling_rate=0.5, shape_is_real_fourier=shape_is_real_fourier, compute_euclidean_norm=True, ) mask = func(grid=grid, **func_args) mask = shift_fourier(data=mask, shape_is_real_fourier=shape_is_real_fourier) if return_real_fourier: mask = crop_real_fourier(mask) return { "data": be.to_backend_array(mask), "shape": func_args["shape"], "return_real_fourier": return_real_fourier, "is_multiplicative_filter": True, }
def discrete_bandpass( grid: BackendArray, lowpass: float, highpass: float, sampling_rate: Tuple[float], **kwargs, ) -> BackendArray: """ Generate a bandpass filter using discrete frequency cutoffs. Parameters ---------- grid : BackendArray Frequencies in Fourier space. lowpass : float The lowpass cutoff in units of sampling rate. highpass : float The highpass cutoff in units of sampling rate. return_real_fourier : bool, optional Whether to return only the real Fourier space, defaults to False. sampling_rate : float The sampling rate in Fourier space. **kwargs : dict Additional keyword arguments. Returns ------- BackendArray The bandpass filter in Fourier space. """ grid = be.astype(be.to_backend_array(grid), be._float_dtype) sampling_rate = be.to_backend_array(sampling_rate) highcut = grid.max() if lowpass is not None: highcut = be.max(2 * sampling_rate / lowpass) lowcut = 0 if highpass is not None: lowcut = be.max(2 * sampling_rate / highpass) bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0 return bandpass_filter def gaussian_bandpass( grid: BackendArray, lowpass: float = None, highpass: float = None, sampling_rate: float = 1, **kwargs, ) -> BackendArray: """ Generate a bandpass filter using Gaussians. Parameters ---------- grid : BackendArray Frequency grid in Fourier space. lowpass : float, optional The lowpass cutoff in units of sampling rate, defaults to None. highpass : float, optional The highpass cutoff in units of sampling rate, defaults to None. sampling_rate : float, optional The sampling rate in Fourier space, defaults to one. **kwargs : dict Additional keyword arguments. Returns ------- BackendArray The bandpass filter in Fourier space. """ grid = be.astype(be.to_backend_array(grid), be._float_dtype) grid = -be.square(grid, out=grid) has_lowpass, has_highpass = False, False norm = float(sqrt(2 * log(2))) upper_sampling = float(be.max(be.multiply(2, be.to_backend_array(sampling_rate)))) if lowpass is not None: lowpass, has_lowpass = float(lowpass), True lowpass = be.maximum(lowpass, be.eps(be._float_dtype)) if highpass is not None: highpass, has_highpass = float(highpass), True highpass = be.maximum(highpass, be.eps(be._float_dtype)) if has_lowpass: lowpass = upper_sampling / (lowpass * norm) lowpass = be.multiply(2, be.square(lowpass)) if not has_highpass: lowpass_filter = be.divide(grid, lowpass, out=grid) else: lowpass_filter = be.divide(grid, lowpass) lowpass_filter = be.exp(lowpass_filter, out=lowpass_filter) if has_highpass: highpass = upper_sampling / (highpass * norm) highpass = be.multiply(2, be.square(highpass)) highpass_filter = be.divide(grid, highpass, out=grid) highpass_filter = be.exp(highpass_filter, out=highpass_filter) highpass_filter = be.subtract(1, highpass_filter, out=highpass_filter) if has_lowpass and not has_highpass: bandpass_filter = lowpass_filter elif not has_lowpass and has_highpass: bandpass_filter = highpass_filter elif has_lowpass and has_highpass: bandpass_filter = be.multiply( lowpass_filter, highpass_filter, out=lowpass_filter ) else: bandpass_filter = be.full(grid.shape, fill_value=1, dtype=be._float_dtype) return bandpass_filter