""" Compute memory consumption of template matching components.
Copyright (c) 2023 European Molecular Biology Laboratory
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""
from abc import ABC, abstractmethod
from typing import Tuple
import numpy as np
from pyfftw import next_fast_len
[docs]
class MatchingMemoryUsage(ABC):
"""
Class specification for estimating the memory requirements of template matching.
Parameters
----------
fast_shape : tuple of int
Shape of the real array.
ft_shape : tuple of int
Shape of the complex array.
float_nbytes : int
Number of bytes of the used float, e.g. 4 for float32.
complex_nbytes : int
Number of bytes of the used complex, e.g. 8 for complex64.
integer_nbytes : int
Number of bytes of the used integer, e.g. 4 for int32.
Attributes
----------
real_array_size : int
Number of elements in real array.
complex_array_size : int
Number of elements in complex array.
float_nbytes : int
Number of bytes of the used float, e.g. 4 for float32.
complex_nbytes : int
Number of bytes of the used complex, e.g. 8 for complex64.
integer_nbytes : int
Number of bytes of the used integer, e.g. 4 for int32.
Methods
-------
base_usage():
Returns the base memory usage in bytes.
per_fork():
Returns the memory usage in bytes per fork.
"""
def __init__(
self,
fast_shape: Tuple[int],
ft_shape: Tuple[int],
float_nbytes: int,
complex_nbytes: int,
integer_nbytes: int,
):
self.real_array_size = np.prod(fast_shape)
self.complex_array_size = np.prod(ft_shape)
self.float_nbytes = float_nbytes
self.complex_nbytes = complex_nbytes
self.integer_nbytes = integer_nbytes
[docs]
@abstractmethod
def base_usage(self) -> int:
"""Return the base memory usage in bytes."""
[docs]
@abstractmethod
def per_fork(self) -> int:
"""Return the memory usage per fork in bytes."""
[docs]
class CCMemoryUsage(MatchingMemoryUsage):
"""
Memory usage estimation for CC scoring.
See Also
--------
:py:meth:`tme.matching_exhaustive.cc_setup`.
"""
[docs]
def base_usage(self) -> int:
float_arrays = self.real_array_size * self.float_nbytes
complex_arrays = self.complex_array_size * self.complex_nbytes
return float_arrays + complex_arrays
[docs]
def per_fork(self) -> int:
float_arrays = self.real_array_size * self.float_nbytes
complex_arrays = self.complex_array_size * self.complex_nbytes
return float_arrays + complex_arrays
[docs]
class LCCMemoryUsage(CCMemoryUsage):
"""
Memory usage estimation for LCC scoring.
See Also
--------
:py:meth:`tme.matching_exhaustive.lcc_setup`.
"""
[docs]
class CORRMemoryUsage(MatchingMemoryUsage):
"""
Memory usage estimation for CORR scoring.
See Also
--------
:py:meth:`tme.matching_exhaustive.corr_setup`.
"""
[docs]
def base_usage(self) -> int:
float_arrays = self.real_array_size * self.float_nbytes * 4
complex_arrays = self.complex_array_size * self.complex_nbytes
return float_arrays + complex_arrays
[docs]
def per_fork(self) -> int:
float_arrays = self.real_array_size * self.float_nbytes
complex_arrays = self.complex_array_size * self.complex_nbytes
return float_arrays + complex_arrays
[docs]
class CAMMemoryUsage(CORRMemoryUsage):
"""
Memory usage estimation for CAM scoring.
See Also
--------
:py:meth:`tme.matching_exhaustive.cam_setup`.
"""
[docs]
class FLCSphericalMaskMemoryUsage(CORRMemoryUsage):
"""
Memory usage estimation for FLCMSphericalMask scoring.
See Also
--------
:py:meth:`tme.matching_exhaustive.flcSphericalMask_setup`.
"""
[docs]
class FLCMemoryUsage(MatchingMemoryUsage):
"""
Memory usage estimation for FLC scoring.
See Also
--------
:py:meth:`tme.matching_exhaustive.flc_setup`.
"""
[docs]
def base_usage(self) -> int:
float_arrays = self.real_array_size * self.float_nbytes * 2
complex_arrays = self.complex_array_size * self.complex_nbytes * 2
return float_arrays + complex_arrays
[docs]
def per_fork(self) -> int:
float_arrays = self.real_array_size * self.float_nbytes * 3
complex_arrays = self.complex_array_size * self.complex_nbytes * 2
return float_arrays + complex_arrays
[docs]
class MCCMemoryUsage(MatchingMemoryUsage):
"""
Memory usage estimation for MCC scoring.
See Also
--------
:py:meth:`tme.matching_exhaustive.mcc_setup`.
"""
[docs]
def base_usage(self) -> int:
float_arrays = self.real_array_size * self.float_nbytes * 2
complex_arrays = self.complex_array_size * self.complex_nbytes * 3
return float_arrays + complex_arrays
[docs]
def per_fork(self) -> int:
float_arrays = self.real_array_size * self.float_nbytes * 6
complex_arrays = self.complex_array_size * self.complex_nbytes
return float_arrays + complex_arrays
[docs]
class MaxScoreOverRotationsMemoryUsage(MatchingMemoryUsage):
"""
Memory usage estimation MaxScoreOverRotations Analyzer.
See Also
--------
:py:class:`tme.analyzer.MaxScoreOverRotations`.
"""
[docs]
def base_usage(self) -> int:
float_arrays = self.real_array_size * self.float_nbytes * 2
return float_arrays
[docs]
def per_fork(self) -> int:
return 0
[docs]
class PeakCallerMaximumFilterMemoryUsage(MatchingMemoryUsage):
"""
Memory usage estimation MaxScoreOverRotations Analyzer.
See Also
--------
:py:class:`tme.analyzer.PeakCallerMaximumFilter`.
"""
[docs]
def base_usage(self) -> int:
float_arrays = self.real_array_size * self.float_nbytes
return float_arrays
[docs]
def per_fork(self) -> int:
float_arrays = self.real_array_size * self.float_nbytes
return float_arrays
[docs]
class CupyBackendMemoryUsage(MatchingMemoryUsage):
"""
Memory usage estimation for CupyBackend.
See Also
--------
:py:class:`tme.backends.CupyBackend`.
"""
[docs]
def base_usage(self) -> int:
# FFT plans, overhead from assigning FFT result, rotation interpolation
complex_arrays = self.real_array_size * self.complex_nbytes * 3
float_arrays = self.complex_array_size * self.float_nbytes * 2
return float_arrays + complex_arrays
[docs]
def per_fork(self) -> int:
return 0
def _compute_convolution_shapes(
arr1_shape: Tuple[int], arr2_shape: Tuple[int]
) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]:
"""
Computes regular, optimized and fourier convolution shape.
Parameters
----------
arr1_shape : tuple
Tuple of integers corresponding to array1 shape.
arr2_shape : tuple
Tuple of integers corresponding to array2 shape.
Returns
-------
tuple
Tuple with regular convolution shape, convolution shape optimized for faster
fourier transform, shape of the forward fourier transform
(see :py:meth:`build_fft`).
"""
convolution_shape = np.add(arr1_shape, arr2_shape) - 1
fast_shape = [next_fast_len(x) for x in convolution_shape]
fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
return convolution_shape, fast_shape, fast_ft_shape
MATCHING_MEMORY_REGISTRY = {
"CC": CCMemoryUsage,
"LCC": LCCMemoryUsage,
"CORR": CORRMemoryUsage,
"CAM": CAMMemoryUsage,
"MCC": MCCMemoryUsage,
"FLCSphericalMask": FLCSphericalMaskMemoryUsage,
"FLC": FLCMemoryUsage,
"MaxScoreOverRotations": MaxScoreOverRotationsMemoryUsage,
"PeakCallerMaximumFilter": PeakCallerMaximumFilterMemoryUsage,
"cupy": CupyBackendMemoryUsage,
"pytorch": CupyBackendMemoryUsage,
}
[docs]
def estimate_ram_usage(
shape1: Tuple[int],
shape2: Tuple[int],
matching_method: str,
ncores: int,
analyzer_method: str = None,
backend: str = None,
float_nbytes: int = 4,
complex_nbytes: int = 8,
integer_nbytes: int = 4,
) -> int:
"""
Estimate the RAM usage for a given convolution operation based on input shapes,
matching_method, and number of cores.
Parameters
----------
shape1 : tuple
The shape of the input target.
shape2 : tuple
The shape of the input template.
matching_method : str
The method used for the operation.
is_gpu : bool, optional
Whether the computation is performed on GPU. This factors in FFT
plan caching.
analyzer_method : str, optional
The method used for score analysis.
backend : str, optional
Backend used for computation.
ncores : int
The number of CPU cores used for the operation.
float_nbytes : int
Number of bytes of the used float, e.g. 4 for float32.
complex_nbytes : int
Number of bytes of the used complex, e.g. 8 for complex64.
integer_nbytes : int
Number of bytes of the used integer, e.g. 4 for int32.
Returns
-------
int
The estimated RAM usage for the operation in bytes.
Notes
-----
Residual memory from other objects that may remain allocated during
template matching, e.g. the full sized target when using splitting,
are not considered by this function.
Raises
------
ValueError
If an unsupported matching_methode is provided.
"""
if matching_method not in MATCHING_MEMORY_REGISTRY:
raise ValueError(
f"Supported options are {','.join(MATCHING_MEMORY_REGISTRY.keys())}"
)
convolution_shape, fast_shape, ft_shape = _compute_convolution_shapes(
shape1, shape2
)
memory_instance = MATCHING_MEMORY_REGISTRY[matching_method](
fast_shape=fast_shape,
ft_shape=ft_shape,
float_nbytes=float_nbytes,
complex_nbytes=complex_nbytes,
integer_nbytes=integer_nbytes,
)
nbytes = memory_instance.base_usage() + memory_instance.per_fork() * ncores
analyzer_instance = MATCHING_MEMORY_REGISTRY.get(analyzer_method, None)
if analyzer_instance is not None:
analyzer_instance = analyzer_instance(
fast_shape=fast_shape,
ft_shape=ft_shape,
float_nbytes=float_nbytes,
complex_nbytes=complex_nbytes,
integer_nbytes=integer_nbytes,
)
nbytes += analyzer_instance.base_usage() + analyzer_instance.per_fork() * ncores
backend_instance = MATCHING_MEMORY_REGISTRY.get(backend, None)
if backend_instance is not None:
backend_instance = backend_instance(
fast_shape=fast_shape,
ft_shape=ft_shape,
float_nbytes=float_nbytes,
complex_nbytes=complex_nbytes,
integer_nbytes=integer_nbytes,
)
nbytes += backend_instance.base_usage() + backend_instance.per_fork() * ncores
return nbytes