Source code for tme.memory

""" Compute memory consumption of template matching components.

    Copyright (c) 2023 European Molecular Biology Laboratory

    Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

from abc import ABC, abstractmethod
from typing import Tuple

import numpy as np

from .backends import backend as be


[docs] class MatchingMemoryUsage(ABC): """ Class specification for estimating the memory requirements of template matching. Parameters ---------- fast_shape : tuple of int Shape of the real array. ft_shape : tuple of int Shape of the complex array. float_nbytes : int Number of bytes of the used float, e.g. 4 for float32. complex_nbytes : int Number of bytes of the used complex, e.g. 8 for complex64. integer_nbytes : int Number of bytes of the used integer, e.g. 4 for int32. Attributes ---------- real_array_size : int Number of elements in real array. complex_array_size : int Number of elements in complex array. float_nbytes : int Number of bytes of the used float, e.g. 4 for float32. complex_nbytes : int Number of bytes of the used complex, e.g. 8 for complex64. integer_nbytes : int Number of bytes of the used integer, e.g. 4 for int32. """ def __init__( self, fast_shape: Tuple[int], ft_shape: Tuple[int], float_nbytes: int, complex_nbytes: int, integer_nbytes: int, ): self.real_array_size = np.prod(fast_shape) self.complex_array_size = np.prod(ft_shape) self.float_nbytes = float_nbytes self.complex_nbytes = complex_nbytes self.integer_nbytes = integer_nbytes
[docs] @abstractmethod def base_usage(self) -> int: """Return the base memory usage in bytes."""
[docs] @abstractmethod def per_fork(self) -> int: """Return the memory usage per fork in bytes."""
[docs] class CCMemoryUsage(MatchingMemoryUsage): """ Memory usage estimation for CC scoring. See Also -------- :py:meth:`tme.matching_scores.cc_setup`. """
[docs] def base_usage(self) -> int: float_arrays = self.real_array_size * self.float_nbytes complex_arrays = self.complex_array_size * self.complex_nbytes return float_arrays + complex_arrays
[docs] def per_fork(self) -> int: float_arrays = self.real_array_size * self.float_nbytes complex_arrays = self.complex_array_size * self.complex_nbytes return float_arrays + complex_arrays
[docs] class LCCMemoryUsage(CCMemoryUsage): """ Memory usage estimation for LCC scoring. See Also -------- :py:meth:`tme.matching_scores.lcc_setup`. """
[docs] class CORRMemoryUsage(MatchingMemoryUsage): """ Memory usage estimation for CORR scoring. See Also -------- :py:meth:`tme.matching_scores.corr_setup`. """
[docs] def base_usage(self) -> int: float_arrays = self.real_array_size * self.float_nbytes * 4 complex_arrays = self.complex_array_size * self.complex_nbytes return float_arrays + complex_arrays
[docs] def per_fork(self) -> int: float_arrays = self.real_array_size * self.float_nbytes complex_arrays = self.complex_array_size * self.complex_nbytes return float_arrays + complex_arrays
[docs] class CAMMemoryUsage(CORRMemoryUsage): """ Memory usage estimation for CAM scoring. See Also -------- :py:meth:`tme.matching_scores.cam_setup`. """
[docs] class FLCSphericalMaskMemoryUsage(CORRMemoryUsage): """ Memory usage estimation for FLCMSphericalMask scoring. See Also -------- :py:meth:`tme.matching_scores.flcSphericalMask_setup`. """
[docs] class FLCMemoryUsage(MatchingMemoryUsage): """ Memory usage estimation for FLC scoring. See Also -------- :py:meth:`tme.matching_scores.flc_setup`. """
[docs] def base_usage(self) -> int: float_arrays = self.real_array_size * self.float_nbytes * 2 complex_arrays = self.complex_array_size * self.complex_nbytes * 2 return float_arrays + complex_arrays
[docs] def per_fork(self) -> int: float_arrays = self.real_array_size * self.float_nbytes * 3 complex_arrays = self.complex_array_size * self.complex_nbytes * 2 return float_arrays + complex_arrays
[docs] class MCCMemoryUsage(MatchingMemoryUsage): """ Memory usage estimation for MCC scoring. See Also -------- :py:meth:`tme.matching_scores.mcc_setup`. """
[docs] def base_usage(self) -> int: float_arrays = self.real_array_size * self.float_nbytes * 2 complex_arrays = self.complex_array_size * self.complex_nbytes * 3 return float_arrays + complex_arrays
[docs] def per_fork(self) -> int: float_arrays = self.real_array_size * self.float_nbytes * 6 complex_arrays = self.complex_array_size * self.complex_nbytes return float_arrays + complex_arrays
[docs] class MaxScoreOverRotationsMemoryUsage(MatchingMemoryUsage): """ Memory usage estimation MaxScoreOverRotations Analyzer. See Also -------- :py:class:`tme.analyzer.MaxScoreOverRotations`. """
[docs] def base_usage(self) -> int: float_arrays = self.real_array_size * self.float_nbytes * 2 return float_arrays
[docs] def per_fork(self) -> int: return 0
[docs] class PeakCallerMaximumFilterMemoryUsage(MatchingMemoryUsage): """ Memory usage estimation MaxScoreOverRotations Analyzer. See Also -------- :py:class:`tme.analyzer.PeakCallerMaximumFilter`. """
[docs] def base_usage(self) -> int: float_arrays = self.real_array_size * self.float_nbytes return float_arrays
[docs] def per_fork(self) -> int: float_arrays = self.real_array_size * self.float_nbytes return float_arrays
[docs] class CupyBackendMemoryUsage(MatchingMemoryUsage): """ Memory usage estimation for CupyBackend. See Also -------- :py:class:`tme.backends.CupyBackend`. """
[docs] def base_usage(self) -> int: # FFT plans, overhead from assigning FFT result, rotation interpolation complex_arrays = self.real_array_size * self.complex_nbytes * 3 float_arrays = self.complex_array_size * self.float_nbytes * 2 return float_arrays + complex_arrays
[docs] def per_fork(self) -> int: return 0
MATCHING_MEMORY_REGISTRY = { "CC": CCMemoryUsage, "LCC": LCCMemoryUsage, "CORR": CORRMemoryUsage, "CAM": CAMMemoryUsage, "MCC": MCCMemoryUsage, "FLCSphericalMask": FLCSphericalMaskMemoryUsage, "FLC": FLCMemoryUsage, "MaxScoreOverRotations": MaxScoreOverRotationsMemoryUsage, "PeakCallerMaximumFilter": PeakCallerMaximumFilterMemoryUsage, "cupy": CupyBackendMemoryUsage, "pytorch": CupyBackendMemoryUsage, "batchFLCSpherical": FLCSphericalMaskMemoryUsage, "batchFLC": FLCMemoryUsage, }
[docs] def estimate_memory_usage( shape1: Tuple[int], shape2: Tuple[int], matching_method: str, ncores: int, analyzer_method: str = None, backend: str = None, float_nbytes: int = 4, complex_nbytes: int = 8, integer_nbytes: int = 4, ) -> int: """ Estimate the memory usage for a given template matching operation. Parameters ---------- shape1 : tuple Shape of the target array. shape2 : tuple Shape of the template array. matching_method : str Matching method to estimate memory usage for. analyzer_method : str, optional The method used for score analysis. backend : str, optional Backend used for computation. ncores : int The number of CPU cores used for the operation. float_nbytes : int Number of bytes of the used float, defaults to 4 (float32). complex_nbytes : int Number of bytes of the used complex, defaults to 8 (complex64). integer_nbytes : int Number of bytes of the used integer, defaults to 4 (int32). Returns ------- int The estimated memory usage for the operation in bytes. Raises ------ ValueError If an unsupported matching_method is provided. """ if matching_method not in MATCHING_MEMORY_REGISTRY: raise ValueError( f"Supported options are {','.join(MATCHING_MEMORY_REGISTRY.keys())}" ) convolution_shape, fast_shape, ft_shape = be.compute_convolution_shapes( shape1, shape2 ) memory_instance = MATCHING_MEMORY_REGISTRY[matching_method]( fast_shape=fast_shape, ft_shape=ft_shape, float_nbytes=float_nbytes, complex_nbytes=complex_nbytes, integer_nbytes=integer_nbytes, ) nbytes = memory_instance.base_usage() + memory_instance.per_fork() * ncores analyzer_instance = MATCHING_MEMORY_REGISTRY.get(analyzer_method, None) if analyzer_instance is not None: analyzer_instance = analyzer_instance( fast_shape=fast_shape, ft_shape=ft_shape, float_nbytes=float_nbytes, complex_nbytes=complex_nbytes, integer_nbytes=integer_nbytes, ) nbytes += analyzer_instance.base_usage() + analyzer_instance.per_fork() * ncores backend_instance = MATCHING_MEMORY_REGISTRY.get(backend, None) if backend_instance is not None: backend_instance = backend_instance( fast_shape=fast_shape, ft_shape=ft_shape, float_nbytes=float_nbytes, complex_nbytes=complex_nbytes, integer_nbytes=integer_nbytes, ) nbytes += backend_instance.base_usage() + backend_instance.per_fork() * ncores return nbytes