Source code for tme.filters.compose

"""
Combine filters using an interface analogous to pytorch's Compose.

Copyright (c) 2024 European Molecular Biology Laboratory

Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

from typing import Tuple, Dict
from abc import ABC, abstractmethod

from ._utils import crop_real_fourier
from ..backends import backend as be

__all__ = ["Compose", "ComposableFilter"]


[docs] class ComposableFilter(ABC): """ Base class for composable filters. This class provides a standard interface for filters used in template matching and reconstruction. It automatically handles: - Parameter merging between instance attributes and runtime arguments - Fourier space shifting when needed - Real Fourier transform cropping for efficiency - Standardized result dictionary formatting Subclasses need to implement :py:meth:`ComposableFilter._evaluate` which contains the core filter computation logic. By default, all filters are assumed to be multiplicative in Fourier space, which covers the vast majority of use cases (bandpass, CTF, wedge, whitening, etc.). Only explicitly specify non-multiplicative behavior when needed. """ @abstractmethod def _evaluate(self, **kwargs) -> Dict: """ Compute the actual filter given a set of keyword parameters. Parameters ---------- **kwargs : dict Merged parameters from instance attributes and runtime arguments passed to :py:meth:`__call__`. This includes both the filter's configuration parameters and any runtime overrides. Returns ------- Dict Dictionary containing the filter result and metadata. Required keys: - data : BackendArray or array-like The computed filter data - shape : tuple of int Input shape the filter was built for. Optional keys: - is_multiplicative_filter : bool Whether the filter is multiplicative in Fourier space (default True) """
[docs] def __call__(self, return_real_fourier: bool = False, **kwargs) -> Dict: """ This method provides the standard interface for creating of composable filter masks. It merges instance attributes with runtime parameters, and ensures Fourier conventions are consistent across filters. Parameters ---------- return_real_fourier : bool, optional Whether to crop the filter mask for compatibility with real input FFTs (i.e., :py:func:`numpy.fft.rfft`). When True, only the positive frequency components are returned, reducing memory usage and computation time for real-valued inputs. Default is False. **kwargs : dict Additional keyword arguments passed to :py:meth:`_evaluate`. These will override any matching instance attributes during parameter merging. Returns ------- Dict - data : BackendArray The processed filter data, converted to the appropriate backend array type and with fourier operations applied as needed - shape : tuple of int or None Shape for which the filter was created - return_real_fourier : bool The value of the return_real_fourier parameter - is_multiplicative_filter : bool Whether the filter is multiplicative in Fourier space - Additional metadata from the filter implementation """ ret = self._evaluate(**(vars(self) | kwargs)) # This parameter is only here to allow for using Composable filters outside # the context of a Compose operation. Internally, we require return_real_fourier # to be False, e.g., for filters that require reconstruction. if return_real_fourier: ret["data"] = crop_real_fourier(ret["data"]) ret["data"] = be.to_backend_array(ret["data"]) ret["return_real_fourier"] = return_real_fourier return ret
[docs] class Compose: """ Compose a series of filters. Parameters ---------- transforms : tuple of :py:class:`ComposableFilter`. Tuple of filter instances. """ def __init__(self, transforms: Tuple[ComposableFilter, ...]): for transform in transforms: if not isinstance(transform, ComposableFilter): raise ValueError(f"{transform} is not a child of {ComposableFilter}.") self.transforms = transforms
[docs] def __call__(self, return_real_fourier: bool = False, **kwargs) -> Dict: """ Apply the sequence of filters in order, chaining their outputs. Parameters ---------- return_real_fourier : bool, optional Whether to crop the filter mask for compatibility with real input FFTs (i.e., :py:func:`numpy.fft.rfft`). When True, only the positive frequency components are returned, reducing memory usage and computation time for real-valued inputs. Default is False. **kwargs : dict Keyword arguments passed to the first filter and propagated through the pipeline. Returns ------- Dict Result dictionary from the final filter in the composition, containing: - data : BackendArray The final composite filter data. For multiplicative filters, this is the element-wise product of all individual filter outputs. - shape : tuple of int Shape of the filter data - return_real_fourier : bool Whether the output is compatible with real FFTs - Additional metadata from the filter pipeline """ meta = {} if not len(self.transforms): return meta meta = self.transforms[0](**kwargs) for transform in self.transforms[1:]: kwargs.update(meta) ret = transform(**kwargs) if "data" not in ret: continue if ret.get("is_multiplicative_filter", True): prev_data = meta.pop("data") ret["data"] = be.multiply(ret["data"], prev_data) ret["merge"], prev_data = None, None meta = ret if return_real_fourier: meta["data"] = crop_real_fourier(meta["data"]) return meta