Source code for tme.filters.compose
"""
Combine filters using an interface analogous to pytorch's Compose.
Copyright (c) 2024 European Molecular Biology Laboratory
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""
from typing import Tuple, Dict
from abc import ABC, abstractmethod
from ._utils import crop_real_fourier
from ..backends import backend as be
__all__ = ["Compose", "ComposableFilter"]
[docs]
class ComposableFilter(ABC):
    """
    Base class for composable filters.
    This class provides a standard interface for filters used in template matching
    and reconstruction. It automatically handles:
    - Parameter merging between instance attributes and runtime arguments
    - Fourier space shifting when needed
    - Real Fourier transform cropping for efficiency
    - Standardized result dictionary formatting
    Subclasses need to implement :py:meth:`ComposableFilter._evaluate` which
    contains the core filter computation logic.
    By default, all filters are assumed to be multiplicative in Fourier space,
    which covers the vast majority of use cases (bandpass, CTF, wedge, whitening, etc.).
    Only explicitly specify non-multiplicative behavior when needed.
    """
    @abstractmethod
    def _evaluate(self, **kwargs) -> Dict:
        """
        Compute the actual filter given a set of keyword parameters.
        Parameters
        ----------
        **kwargs : dict
            Merged parameters from instance attributes and runtime arguments
            passed to :py:meth:`__call__`. This includes both the filter's
            configuration parameters and any runtime overrides.
        Returns
        -------
        Dict
            Dictionary containing the filter result and metadata. Required keys:
            - data : BackendArray or array-like
                The computed filter data
            - shape : tuple of int
                Input shape the filter was built for.
            Optional keys:
            - is_multiplicative_filter : bool
                Whether the filter is multiplicative in Fourier space (default True)
        """
[docs]
    def __call__(self, return_real_fourier: bool = False, **kwargs) -> Dict:
        """
        This method provides the standard interface for creating of composable
        filter masks. It merges instance attributes with runtime parameters,
        and ensures Fourier conventions are consistent across filters.
        Parameters
        ----------
        return_real_fourier : bool, optional
            Whether to crop the filter mask for compatibility with real input
            FFTs (i.e., :py:func:`numpy.fft.rfft`). When True, only the
            positive frequency components are returned, reducing memory usage
            and computation time for real-valued inputs. Default is False.
        **kwargs : dict
            Additional keyword arguments passed to :py:meth:`_evaluate`.
            These will override any matching instance attributes during
            parameter merging.
        Returns
        -------
        Dict
            - data : BackendArray
                The processed filter data, converted to the appropriate backend
                array type and with fourier operations applied as needed
            - shape : tuple of int or None
                Shape for which the filter was created
            - return_real_fourier : bool
                The value of the return_real_fourier parameter
            - is_multiplicative_filter : bool
                Whether the filter is multiplicative in Fourier space
            - Additional metadata from the filter implementation
        """
        ret = self._evaluate(**(vars(self) | kwargs))
        # This parameter is only here to allow for using Composable filters outside
        # the context of a Compose operation. Internally, we require return_real_fourier
        # to be False, e.g., for filters that require reconstruction.
        if return_real_fourier:
            ret["data"] = crop_real_fourier(ret["data"])
        ret["data"] = be.to_backend_array(ret["data"])
        ret["return_real_fourier"] = return_real_fourier
        return ret 
 
[docs]
class Compose:
    """
    Compose a series of filters.
    Parameters
    ----------
    transforms : tuple of :py:class:`ComposableFilter`.
        Tuple of filter instances.
    """
    def __init__(self, transforms: Tuple[ComposableFilter, ...]):
        for transform in transforms:
            if not isinstance(transform, ComposableFilter):
                raise ValueError(f"{transform} is not a child of {ComposableFilter}.")
        self.transforms = transforms
[docs]
    def __call__(self, return_real_fourier: bool = False, **kwargs) -> Dict:
        """
        Apply the sequence of filters in order, chaining their outputs.
        Parameters
        ----------
        return_real_fourier : bool, optional
            Whether to crop the filter mask for compatibility with real input
            FFTs (i.e., :py:func:`numpy.fft.rfft`). When True, only the
            positive frequency components are returned, reducing memory usage
            and computation time for real-valued inputs. Default is False.
        **kwargs : dict
            Keyword arguments passed to the first filter and propagated through
            the pipeline.
        Returns
        -------
        Dict
            Result dictionary from the final filter in the composition, containing:
            - data : BackendArray
                The final composite filter data. For multiplicative filters, this is
                the element-wise product of all individual filter outputs.
            - shape : tuple of int
                Shape of the filter data
            - return_real_fourier : bool
                Whether the output is compatible with real FFTs
            - Additional metadata from the filter pipeline
        """
        meta = {}
        if not len(self.transforms):
            return meta
        meta = self.transforms[0](**kwargs)
        for transform in self.transforms[1:]:
            kwargs.update(meta)
            ret = transform(**kwargs)
            if "data" not in ret:
                continue
            if ret.get("is_multiplicative_filter", True):
                prev_data = meta.pop("data")
                ret["data"] = be.multiply(ret["data"], prev_data)
                ret["merge"], prev_data = None, None
            meta = ret
        if return_real_fourier:
            meta["data"] = crop_real_fourier(meta["data"])
        return meta