Source code for tme.matching_exhaustive

""" Implements cross-correlation based template matching using different metrics.

    Copyright (c) 2023 European Molecular Biology Laboratory

    Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""
import sys
import warnings
import traceback
from functools import wraps
from itertools import product
from typing import Callable, Tuple, Dict, Optional

from joblib import Parallel, delayed
from multiprocessing.managers import SharedMemoryManager

from .backends import backend as be
from .preprocessing import Compose
from .matching_utils import split_shape
from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
from .types import CallbackClass, MatchingData
from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY


def _handle_traceback(last_type, last_value, last_traceback):
    """
    Handle sys.exc_info().

    Parameters
    ----------
    last_type : type
        The type of the last exception.
    last_value :
        The value of the last exception.
    last_traceback : traceback
        Traceback call stack at the point where the Exception occured.

    Raises
    ------
    Exception
        Re-raises the last exception.
    """
    if last_type is None:
        return None
    traceback.print_tb(last_traceback)
    raise Exception(last_value)


def _wrap_backend(func):
    @wraps(func)
    def wrapper(*args, backend_name: str, backend_args: Dict, **kwargs):
        from tme.backends import backend as be

        be.change_backend(backend_name, **backend_args)
        return func(*args, **kwargs)

    return wrapper


def _setup_template_filter_apply_target_filter(
    matching_data: MatchingData,
    rfftn: Callable,
    irfftn: Callable,
    fast_shape: Tuple[int],
    fast_ft_shape: Tuple[int],
    pad_template_filter: bool = True,
):
    filter_template = isinstance(matching_data.template_filter, Compose)
    filter_target = isinstance(matching_data.target_filter, Compose)

    template_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)

    if not filter_template and not filter_target:
        return template_filter

    inv_mask = be.subtract(1, be.to_backend_array(matching_data._batch_mask))
    filter_shape = be.multiply(be.to_backend_array(fast_ft_shape), inv_mask)
    filter_shape = tuple(int(x) if x != 0 else 1 for x in filter_shape)
    fast_shape = be.multiply(be.to_backend_array(fast_shape), inv_mask)
    fast_shape = tuple(int(x) for x in fast_shape if x != 0)

    fastt_shape, fastt_ft_shape = fast_shape, filter_shape
    if filter_template and not pad_template_filter:
        # FFT shape acrobatics for faster filter application
        # _, fastt_shape, _, _ = matching_data._fourier_padding(
        #     target_shape=be.to_numpy_array(matching_data._template.shape),
        #     template_shape=be.to_numpy_array(
        #         [1 for _ in matching_data._template.shape]
        #     ),
        #     pad_fourier=False,
        # )
        fastt_shape = matching_data._template.shape
        # matching_data.template = be.reverse(
        #     be.topleft_pad(matching_data.template, fastt_shape)
        # )
        # matching_data.template_mask = be.reverse(
        #     be.topleft_pad(matching_data.template_mask, fastt_shape)
        # )
        matching_data._set_matching_dimension(
            target_dims=matching_data._target_dims,
            template_dims=matching_data._template_dims,
        )
        fastt_ft_shape = [int(x) for x in matching_data._output_template_shape]
        fastt_ft_shape[-1] = fastt_ft_shape[-1] // 2 + 1

    target_temp = be.topleft_pad(matching_data.target, fast_shape)
    target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
    target_temp_ft = rfftn(target_temp, target_temp_ft)
    if filter_template:
        template_filter = matching_data.template_filter(
            shape=fastt_shape,
            return_real_fourier=True,
            shape_is_real_fourier=False,
            data_rfft=target_temp_ft,
            batch_dimension=matching_data._target_dims,
        )["data"]
        template_filter = be.reshape(template_filter, fastt_ft_shape)

    if filter_target:
        target_filter = matching_data.target_filter(
            shape=fast_shape,
            return_real_fourier=True,
            shape_is_real_fourier=False,
            data_rfft=target_temp_ft,
            weight_type=None,
            batch_dimension=matching_data._target_dims,
        )["data"]
        target_filter = be.reshape(target_filter, filter_shape)
        target_temp_ft = be.multiply(target_temp_ft, target_filter, out=target_temp_ft)

        target_temp = irfftn(target_temp_ft, target_temp)
        matching_data._target = be.topleft_pad(target_temp, matching_data.target.shape)

    return template_filter


def device_memory_handler(func: Callable):
    """Decorator function providing SharedMemory Handler."""

    @wraps(func)
    def inner_function(*args, **kwargs):
        return_value = None
        last_type, last_value, last_traceback = sys.exc_info()
        try:
            with SharedMemoryManager() as smh:
                gpu_index = kwargs.pop("gpu_index") if "gpu_index" in kwargs else 0
                with be.set_device(gpu_index):
                    return_value = func(shared_memory_handler=smh, *args, **kwargs)
        except Exception as e:
            print(e)
            last_type, last_value, last_traceback = sys.exc_info()
        finally:
            _handle_traceback(last_type, last_value, last_traceback)
        return return_value

    return inner_function


[docs] @device_memory_handler def scan( matching_data: MatchingData, matching_setup: Callable, matching_score: Callable, n_jobs: int = 4, callback_class: CallbackClass = None, callback_class_args: Dict = {}, pad_fourier: bool = True, pad_template_filter: bool = True, interpolation_order: int = 3, jobs_per_callback_class: int = 8, shared_memory_handler=None, ) -> Optional[Tuple]: """ Run template matching. .. warning:: ``matching_data`` might be altered or destroyed during computation. Parameters ---------- matching_data : :py:class:`tme.matching_data.MatchingData` Template matching data. matching_setup : Callable Function pointer to setup function. matching_score : Callable Function pointer to scoring function. n_jobs : int, optional Number of parallel jobs. Default is 4. callback_class : type, optional Analyzer class pointer to operate on computed scores. callback_class_args : dict, optional Arguments passed to the callback_class. Default is an empty dictionary. pad_fourier: bool, optional Whether to pad target and template to the full convolution shape. pad_template_filter: bool, optional Whether to pad potential template filters to the full convolution shape. interpolation_order : int, optional Order of spline interpolation for rotations. jobs_per_callback_class : int, optional How many jobs should be processed by a single callback_class instance, if one is provided. shared_memory_handler : type, optional Manager for shared memory objects, None by default. Returns ------- Optional[Tuple] The merged results from callback_class if provided otherwise None. Examples -------- Schematically, :py:meth:`scan` is identical to :py:meth:`scan_subsets`, with the distinction that the objects contained in ``matching_data`` are not split and the search is only parallelized over angles. Assuming you have followed the example in :py:meth:`scan_subsets`, :py:meth:`scan` can be invoked like so >>> from tme.matching_exhaustive import scan >>> results = scan( >>> matching_data = matching_data, >>> matching_score = matching_score, >>> matching_setup = matching_setup, >>> callback_class = callback_class, >>> callback_class_args = callback_class_args, >>> ) """ matching_data.to_backend() ( conv_shape, fast_shape, fast_ft_shape, fourier_shift, ) = matching_data.fourier_padding(pad_fourier=pad_fourier) template_shape = matching_data.template.shape rfftn, irfftn = be.build_fft( fast_shape=fast_shape, fast_ft_shape=fast_ft_shape, real_dtype=be._float_dtype, complex_dtype=be._complex_dtype, ) template_filter = _setup_template_filter_apply_target_filter( matching_data=matching_data, rfftn=rfftn, irfftn=irfftn, fast_shape=fast_shape, fast_ft_shape=fast_ft_shape, pad_template_filter=pad_template_filter, ) template_filter = be.astype(be.to_backend_array(template_filter), be._float_dtype) setup = matching_setup( rfftn=rfftn, irfftn=irfftn, template=matching_data.template, template_filter=template_filter, template_mask=matching_data.template_mask, target=matching_data.target, target_mask=matching_data.target_mask, fast_shape=fast_shape, fast_ft_shape=fast_ft_shape, shared_memory_handler=shared_memory_handler, ) rfftn, irfftn = None, None setup["interpolation_order"] = interpolation_order setup["template_filter"] = be.to_sharedarr(template_filter, shared_memory_handler) offset = be.to_backend_array(matching_data._translation_offset) convmode = "valid" if getattr(matching_data, "_is_padded", False) else "same" default_callback_args = { "offset": be.astype(offset, be._int_dtype), "thread_safe": n_jobs > 1, "fourier_shift": fourier_shift, "convolution_mode": convmode, "targetshape": matching_data.target.shape, "templateshape": template_shape, "convolution_shape": conv_shape, "fast_shape": fast_shape, "indices": getattr(matching_data, "indices", None), "shared_memory_handler": shared_memory_handler, "only_unique_rotations": True, } default_callback_args.update(callback_class_args) matching_data._free_data() be.free_cache() # For highly parallel jobs, blocking in certain analyzers becomes a bottleneck if getattr(callback_class, "shared", True): jobs_per_callback_class = 1 n_callback_classes = max(n_jobs // jobs_per_callback_class, 1) callback_classes = [ callback_class( shape=fast_shape, **default_callback_args, ) if callback_class is not None else None for _ in range(n_callback_classes) ] callbacks = Parallel(n_jobs=n_jobs)( delayed(_wrap_backend(matching_score))( backend_name=be._backend_name, backend_args=be._backend_args, rotations=rotation, callback=callback_classes[index % n_callback_classes], **setup, ) for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs)) ) callbacks = [ tuple(callback._postprocess(**default_callback_args)) for callback in callbacks if callback is not None ] be.free_cache() if callback_class is not None: return callback_class.merge(callbacks, **default_callback_args) return None
[docs] def scan_subsets( matching_data: MatchingData, matching_score: Callable, matching_setup: Callable, callback_class: CallbackClass = None, callback_class_args: Dict = {}, job_schedule: Tuple[int] = (1, 1), target_splits: Dict = {}, template_splits: Dict = {}, pad_target_edges: bool = False, pad_fourier: bool = True, pad_template_filter: bool = True, interpolation_order: int = 3, jobs_per_callback_class: int = 8, backend_name: str = None, backend_args: Dict = {}, ) -> Optional[Tuple]: """ Wrapper around :py:meth:`scan` that supports matching on splits of ``matching_data``. Parameters ---------- matching_data : :py:class:`tme.matching_data.MatchingData` MatchingData instance containing relevant data. matching_setup : type Function pointer to setup function. matching_score : type Function pointer to scoring function. callback_class : type, optional Analyzer class pointer to operate on computed scores. callback_class_args : dict, optional Arguments passed to the callback_class. Default is an empty dictionary. job_schedule : tuple of int, optional Job scheduling scheme, default is (1, 1). First value corresponds to the number of splits that are processed in parallel, the second to the number of angles evaluated in parallel on each split. target_splits : dict, optional Splits for target. Default is an empty dictionary, i.e. no splits. See :py:meth:`tme.matching_utils.compute_parallelization_schedule`. template_splits : dict, optional Splits for template. Default is an empty dictionary, i.e. no splits. See :py:meth:`tme.matching_utils.compute_parallelization_schedule`. pad_target_edges : bool, optional Whether to pad the target boundaries by half the template shape along each axis. pad_fourier: bool, optional Whether to pad target and template to the full convolution shape. pad_template_filter: bool, optional Whether to pad potential template filters to the full convolution shape. interpolation_order : int, optional Order of spline interpolation for rotations. jobs_per_callback_class : int, optional How many jobs should be processed by a single callback_class instance, if ones is provided. Returns ------- Optional[Tuple] The merged results from callback_class if provided otherwise None. Examples -------- All data relevant to template matching will be contained in ``matching_data``, which is a :py:class:`tme.matching_data.MatchingData` instance and can be created like so >>> import numpy as np >>> from tme.matching_data import MatchingData >>> from tme.matching_utils import get_rotation_matrices >>> target = np.random.rand(50,40,60) >>> template = target[15:25, 10:20, 30:40] >>> matching_data = MatchingData(target, template) >>> matching_data.rotations = get_rotation_matrices( >>> angular_sampling = 60, dim = target.ndim >>> ) The template matching procedure is determined by ``matching_setup`` and ``matching_score``, which are unique to each score. In the following, we will be using the `FLCSphericalMask` score, which is composed of :py:meth:`tme.matching_scores.flcSphericalMask_setup` and :py:meth:`tme.matching_scores.corr_scoring` >>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER >>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask") >>> matching_setup, matching_score = funcs Computed scores are flexibly analyzed by being passed through an analyzer. In the following, we will use :py:class:`tme.analyzer.MaxScoreOverRotations` to aggregate sores over rotations >>> from tme.analyzer import MaxScoreOverRotations >>> callback_class = MaxScoreOverRotations >>> callback_class_args = {"score_threshold" : 0} In case the entire template matching problem does not fit into memory, we can determine the splitting procedure. In this case, we halv the first axis of the target once. Splitting and ``job_schedule`` is typically computed using :py:meth:`tme.matching_utils.compute_parallelization_schedule`. >>> target_splits = {0 : 1} Finally, we can perform template matching. Note that the data contained in ``matching_data`` will be destroyed when running the following >>> from tme.matching_exhaustive import scan_subsets >>> results = scan_subsets( >>> matching_data = matching_data, >>> matching_score = matching_score, >>> matching_setup = matching_setup, >>> callback_class = callback_class, >>> callback_class_args = callback_class_args, >>> target_splits = target_splits, >>> ) The ``results`` tuple contains the output of the chosen analyzer. See Also -------- :py:meth:`tme.matching_utils.compute_parallelization_schedule` """ template_splits = split_shape(matching_data._template.shape, splits=template_splits) target_splits = split_shape(matching_data._target.shape, splits=target_splits) if (len(target_splits) > 1) and not pad_target_edges: warnings.warn( "Target splitting without padding target edges leads to unreliable " "similarity estimates around the split border." ) splits = tuple(product(target_splits, template_splits)) outer_jobs, inner_jobs = job_schedule target_pad = matching_data.target_padding(pad_target=pad_target_edges) if hasattr(be, "scan"): corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1] results = be.scan( matching_data=matching_data, splits=splits, n_jobs=outer_jobs, rotate_mask=matching_score != corr_scoring, callback_class=callback_class, ) else: results = Parallel(n_jobs=outer_jobs)( delayed(_wrap_backend(scan))( backend_name=be._backend_name, backend_args=be._backend_args, matching_data=matching_data.subset_by_slice( target_slice=target_split, target_pad=target_pad, template_slice=template_split, ), matching_score=matching_score, matching_setup=matching_setup, n_jobs=inner_jobs, callback_class=callback_class, callback_class_args=callback_class_args, interpolation_order=interpolation_order, pad_fourier=pad_fourier, gpu_index=index % outer_jobs, pad_template_filter=pad_template_filter, ) for index, (target_split, template_split) in enumerate(splits) ) matching_data._free_data() if callback_class is not None: return callback_class.merge(results, **callback_class_args) return None
[docs] def register_matching_exhaustive( matching: str, matching_setup: Callable, matching_scoring: Callable, memory_class: MatchingMemoryUsage, ) -> None: """ Registers a new matching scheme. Parameters ---------- matching : str Name of the matching method. matching_setup : Callable Corresponding setup function. matching_scoring : Callable Corresponing scoring function. memory_class : MatchingMemoryUsage Child of :py:class:`tme.memory.MatchingMemoryUsage`. Raises ------ ValueError If a function with the name ``matching`` already exists in the registry, or if ``memory_class`` is no child of :py:class:`tme.memory.MatchingMemoryUsage`. """ if matching in MATCHING_EXHAUSTIVE_REGISTER: raise ValueError(f"A method with name '{matching}' is already registered.") if not issubclass(memory_class, MatchingMemoryUsage): raise ValueError(f"{memory_class} is not a subclass of {MatchingMemoryUsage}.") MATCHING_EXHAUSTIVE_REGISTER[matching] = (matching_setup, matching_scoring) MATCHING_MEMORY_REGISTRY[matching] = memory_class