Source code for tme.memory

"""
Compute memory consumption of template matching components.

Copyright (c) 2023-2025 European Molecular Biology Laboratory

Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

from abc import ABC, abstractmethod
from typing import Tuple, Optional

import numpy as np

from .backends import backend as be

MATCHING_MEMORY_REGISTRY = {}


def register_memory(*names: str):
    """
    Decorator to auto-register memory estimators.

    Parameters
    ----------
    *names : str
        Names to register this memory estimator under.
    """

    def decorator(cls):
        for name in names:
            MATCHING_MEMORY_REGISTRY[name] = cls
        return cls

    return decorator


[docs] class MatchingMemoryUsage(ABC): """ Strategy class for estimating memory requirements. Parameters ---------- fast_shape : tuple of int Shape of the real array. ft_shape : tuple of int Shape of the complex array. float_nbytes : int Number of bytes of the used float, e.g. 4 for float32. complex_nbytes : int Number of bytes of the used complex, e.g. 8 for complex64. integer_nbytes : int Number of bytes of the used integer, e.g. 4 for int32. """ def __init__( self, fast_shape: Tuple[int, ...], ft_shape: Tuple[int, ...], float_nbytes: int, complex_nbytes: int, integer_nbytes: int, ): self.real_array_size = np.prod(fast_shape) self.complex_array_size = np.prod(ft_shape) self.float_nbytes = float_nbytes self.complex_nbytes = complex_nbytes self.integer_nbytes = integer_nbytes
[docs] @abstractmethod def base_usage(self) -> int: """Return the base memory usage in bytes."""
[docs] @abstractmethod def per_fork(self) -> int: """Return the memory usage per fork in bytes."""
[docs] class MemoryProfile(MatchingMemoryUsage): """Memory estimator for methods with uniform array requirements.""" #: Number of shared real arrays base_float: int = 0 #: Number of shared complex arrays base_complex: int = 0 #: Number of real arrays per fork fork_float: int = 0 #: Number of complex arrays per fork fork_complex: int = 0
[docs] def base_usage(self) -> int: return ( self.base_float * self.real_array_size * self.float_nbytes + self.base_complex * self.complex_array_size * self.complex_nbytes )
[docs] def per_fork(self) -> int: return ( self.fork_float * self.real_array_size * self.float_nbytes + self.fork_complex * self.complex_array_size * self.complex_nbytes )
[docs] @register_memory("CC", "LCC") class CCMemoryUsage(MemoryProfile): """:py:meth:`tme.matching_scores.cc_setup` memory estimator.""" base_float, base_complex = 1, 1 fork_float, fork_complex = 1, 1
[docs] @register_memory("CORR", "NCC", "CAM", "FLCSphericalMask", "batchFLCSphericalMask") class CORRMemoryUsage(MemoryProfile): """:py:meth:`tme.matching_scores.corr_setup` memory estimator.""" base_float, base_complex = 4, 1 fork_float, fork_complex = 1, 1
[docs] @register_memory("FLC", "batchFLC") class FLCMemoryUsage(MemoryProfile): """:py:meth:`tme.matching_scores.flc_setup` memory estimator.""" base_float, base_complex = 2, 2 fork_float, fork_complex = 3, 2
[docs] @register_memory("MCC") class MCCMemoryUsage(MemoryProfile): """:py:meth:`tme.matching_scores.mcc_setup` memory estimator.""" base_float, base_complex = 2, 3 fork_float, fork_complex = 6, 1
[docs] @register_memory("MaxScoreOverRotations") class MaxScoreOverRotationsMemoryUsage(MemoryProfile): """:py:class:`tme.analyzer.MaxScoreOverRotations` memory estimator.""" base_float = 2
[docs] @register_memory("MaxScoreOverRotationsConstrained") class MaxScoreOverRotationsConstrainedMemoryUsage(MemoryProfile): """:py:class:`tme.analyzer.MaxScoreOverRotationsConstrained` memory estimator.""" # This ultimately depends on the number of seed points and mask size. # Ideally we would use that in the memory estimation, but for now we # approximate by reqesting memory for another real array base_float = 3
[docs] @register_memory("PeakCallerMaximumFilter") class PeakCallerMaximumFilterMemoryUsage(MemoryProfile): """:py:class:`tme.analyzer.peaks.PeakCallerMaximumFilter` memory estimator.""" base_float, fork_float = 1, 1
[docs] @register_memory("cupy", "pytorch") class CupyBackendMemoryUsage(MemoryProfile): """:py:class:`tme.backends.CupyBackend` memory estimator.""" # FFT plans, overhead from assigning FFT result, rotation interpolation base_complex, base_float = 3, 2
[docs] def estimate_memory_usage( shape1: Tuple[int, ...], shape2: Tuple[int, ...], matching_method: str, ncores: int, analyzer_method: Optional[str] = None, backend: Optional[str] = None, float_nbytes: int = 4, complex_nbytes: int = 8, integer_nbytes: int = 4, ) -> int: """ Estimate the memory usage for a given template matching run. Parameters ---------- shape1 : tuple Shape of the target array. shape2 : tuple Shape of the template array. matching_method : str Matching method to estimate memory usage for. analyzer_method : str, optional The method used for score analysis. backend : str, optional Backend used for computation. ncores : int The number of CPU cores used for the operation. float_nbytes : int Number of bytes of the used float, defaults to 4 (float32). complex_nbytes : int Number of bytes of the used complex, defaults to 8 (complex64). integer_nbytes : int Number of bytes of the used integer, defaults to 4 (int32). Returns ------- int The estimated memory usage for the operation in bytes. Raises ------ ValueError If an unsupported matching_method is provided. """ if matching_method not in MATCHING_MEMORY_REGISTRY: raise ValueError( f"Supported options are {','.join(MATCHING_MEMORY_REGISTRY.keys())}" ) _, fast_shape, ft_shape = be.compute_convolution_shapes(shape1, shape2) memory_instance = MATCHING_MEMORY_REGISTRY[matching_method]( fast_shape=fast_shape, ft_shape=ft_shape, float_nbytes=float_nbytes, complex_nbytes=complex_nbytes, integer_nbytes=integer_nbytes, ) nbytes = memory_instance.base_usage() + memory_instance.per_fork() * ncores if analyzer_method in MATCHING_MEMORY_REGISTRY: analyzer_instance = MATCHING_MEMORY_REGISTRY[analyzer_method]( fast_shape=fast_shape, ft_shape=ft_shape, float_nbytes=float_nbytes, complex_nbytes=complex_nbytes, integer_nbytes=integer_nbytes, ) nbytes += analyzer_instance.base_usage() + analyzer_instance.per_fork() * ncores if backend in MATCHING_MEMORY_REGISTRY: backend_instance = MATCHING_MEMORY_REGISTRY[backend]( fast_shape=fast_shape, ft_shape=ft_shape, float_nbytes=float_nbytes, complex_nbytes=complex_nbytes, integer_nbytes=integer_nbytes, ) nbytes += backend_instance.base_usage() + backend_instance.per_fork() * ncores return nbytes