Source code for tme.matching_scores

""" Implements a range of cross-correlation coefficients.

    Copyright (c) 2023-2024 European Molecular Biology Laboratory

    Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

import warnings
from typing import Callable, Tuple, Dict, Optional

import numpy as np
from scipy.ndimage import laplace

from .backends import backend as be
from .types import CallbackClass, BackendArray, shm_type
from .matching_utils import (
    conditional_execute,
    identity,
    normalize_template,
    _normalize_template_overflow_safe,
)


def _shape_match(shape1: Tuple[int], shape2: Tuple[int]) -> bool:
    """
    Determine whether ``shape1`` is equal to ``shape2``.

    Parameters
    ----------
    shape1, shape2 : tuple of ints
        Shapes to compare.

    Returns
    -------
    Bool
        ``shape1`` is equal to ``shape2``.
    """
    if len(shape1) != len(shape2):
        return False
    return shape1 == shape2


def _setup_template_filtering(
    forward_ft_shape: Tuple[int],
    inverse_ft_shape: Tuple[int],
    template_shape: Tuple[int],
    template_filter: BackendArray,
    rfftn: Callable = None,
    irfftn: Callable = None,
) -> Callable:
    """
    Configure template filtering function for Fourier transforms.

    Parameters
    ----------
    forward_ft_shape : tuple of ints
        Shape for the forward Fourier transform.
    inverse_ft_shape : tuple of ints
        Shape for the inverse Fourier transform.
    template_shape : tuple of ints
        Shape of the template to be filtered.
    template_filter : BackendArray
        Precomputed filter to apply in the frequency domain.
    rfftn : Callable, optional
        Real-to-complex FFT function.
    irfftn : Callable, optional
        Complex-to-real inverse FFT function.

    Returns
    -------
    Callable
        Filter function with parameters template, ft_temp and template_filter.

    Notes
    -----
        If the shape of template_filter does not match inverse_ft_shape
        the template is assumed to be padded and cropped back to template_shape
        prior to filter application.
    """
    if be.size(template_filter) == 1:
        return conditional_execute(identity, identity, False)

    shape_mismatch = False
    if not _shape_match(template_filter.shape, inverse_ft_shape):
        shape_mismatch = True
        forward_ft_shape = template_shape
        inverse_ft_shape = template_filter.shape

    if (rfftn is not None and irfftn is not None) or shape_mismatch:
        rfftn, irfftn = be.build_fft(
            fast_shape=forward_ft_shape,
            fast_ft_shape=inverse_ft_shape,
            real_dtype=be._float_dtype,
            complex_dtype=be._complex_dtype,
            inverse_fast_shape=forward_ft_shape,
        )

    # Default case, all shapes are correctly matched
    def _apply_template_filter(template, ft_temp, template_filter):
        ft_temp = rfftn(template, ft_temp)
        ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
        return irfftn(ft_temp, template)

    # Template is padded, filter is not. Crop and assign for continuous arrays
    if shape_mismatch:
        real_subset = tuple(slice(0, x) for x in forward_ft_shape)
        _template = be.zeros(forward_ft_shape, be._float_dtype)
        _ft_temp = be.zeros(inverse_ft_shape, be._complex_dtype)

        def _apply_filter_shape_mismatch(template, ft_temp, template_filter):
            _template[:] = template[real_subset]
            template[real_subset] = _apply_template_filter(
                _template, _ft_temp, template_filter
            )
            return template

        return _apply_filter_shape_mismatch

    return _apply_template_filter


[docs] def cc_setup( rfftn: Callable, irfftn: Callable, template: BackendArray, target: BackendArray, fast_shape: Tuple[int], fast_ft_shape: Tuple[int], shared_memory_handler: type, **kwargs, ) -> Dict: """ Setup function for comuting a unnormalized cross-correlation between ``target`` (f) and ``template`` (g) .. math:: \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*). Notes ----- To be used with :py:meth:`corr_scoring`. """ target_pad_ft = be.zeros(fast_ft_shape, be._complex_dtype) target_pad_ft = rfftn(be.topleft_pad(target, fast_shape), target_pad_ft) numerator = be.zeros(1, be._float_dtype) inv_denominator = be.zeros(1, be._float_dtype) + 1 ret = { "fast_shape": fast_shape, "fast_ft_shape": fast_ft_shape, "template": be.to_sharedarr(template, shared_memory_handler), "ft_target": be.to_sharedarr(target_pad_ft, shared_memory_handler), "inv_denominator": be.to_sharedarr(inv_denominator, shared_memory_handler), "numerator": be.to_sharedarr(numerator, shared_memory_handler), } return ret
[docs] def lcc_setup(target: BackendArray, template: BackendArray, **kwargs) -> Dict: """ Setup function for computing a laplace cross-correlation between ``target`` (f) and ``template`` (g) .. math:: \\mathcal{F}^{-1}(\\mathcal{F}(\\nabla^{2}f) \\cdot \\mathcal{F}(\\nabla^{2} g)^*) Notes ----- To be used with :py:meth:`corr_scoring`. """ target, template = be.to_numpy_array(target), be.to_numpy_array(template) kwargs["target"] = be.to_backend_array(laplace(target, mode="wrap")) kwargs["template"] = be.to_backend_array(laplace(template, mode="wrap")) return cc_setup(**kwargs)
[docs] def corr_setup( rfftn: Callable, irfftn: Callable, template: BackendArray, template_mask: BackendArray, template_filter: BackendArray, target: BackendArray, fast_shape: Tuple[int], fast_ft_shape: Tuple[int], shared_memory_handler: type, **kwargs, ) -> Dict: """ Setup for computing a normalized cross-correlation between a ``target`` (f), a ``template`` (g) given ``template_mask`` (m) .. math:: \\frac{CC(f,g) - \\overline{g} \\cdot CC(f, m)} {(CC(f^2, m) - \\frac{CC(f, m)^2}{N_g}) \\cdot \\sigma_{g}}, where .. math:: CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*). Notes ----- To be used with :py:meth:`corr_scoring`. References ---------- .. [1] Lewis P. J. Fast Normalized Cross-Correlation, Industrial Light and Magic. """ target_pad = be.topleft_pad(target, fast_shape) # The exact composition of the denominator is debatable # scikit-image match_template multiplies the running sum of the target # with a scaling factor derived from the template. This is probably appropriate # in pattern matching situations where the template exists in the target ft_window = be.zeros(fast_ft_shape, be._complex_dtype) ft_window = rfftn(be.topleft_pad(template_mask, fast_shape), ft_window) ft_target = be.zeros(fast_ft_shape, be._complex_dtype) ft_target2 = be.zeros(fast_ft_shape, be._complex_dtype) denominator = be.zeros(fast_shape, be._float_dtype) window_sum = be.zeros(fast_shape, be._float_dtype) ft_target = rfftn(target_pad, ft_target) ft_target2 = rfftn(be.square(target_pad), ft_target2) ft_target2 = be.multiply(ft_target2, ft_window, out=ft_target2) denominator = irfftn(ft_target2, denominator) ft_window = be.multiply(ft_target, ft_window, out=ft_window) window_sum = irfftn(ft_window, window_sum) target_pad, ft_target2, ft_window = None, None, None # TODO: Factor in template_filter here if be.size(template_filter) != 1: warnings.warn( "CORR scores obtained with template_filter are not correctly scaled. " "Please use a different score or consider only relative peak heights." ) n_obs, norm_func = be.sum(template_mask), normalize_template if be.datatype_bytes(template_mask.dtype) == 2: norm_func = _normalize_template_overflow_safe n_obs = be.sum(be.astype(template_mask, be._overflow_safe_dtype)) template = norm_func(template, template_mask, n_obs) template_mean = be.sum(be.multiply(template, template_mask)) template_mean = be.divide(template_mean, n_obs) template_ssd = be.sum(be.square(template - template_mean) * template_mask) template_volume = np.prod(tuple(int(x) for x in template.shape)) template = be.multiply(template, template_mask, out=template) numerator = be.multiply(window_sum, template_mean) window_sum = be.square(window_sum, out=window_sum) window_sum = be.divide(window_sum, template_volume, out=window_sum) denominator = be.subtract(denominator, window_sum, out=denominator) denominator = be.multiply(denominator, template_ssd, out=denominator) denominator = be.maximum(denominator, 0, out=denominator) denominator = be.sqrt(denominator, out=denominator) mask = denominator > be.eps(be._float_dtype) denominator = be.multiply(denominator, mask, out=denominator) denominator = be.add(denominator, ~mask, out=denominator) denominator = be.divide(1, denominator, out=denominator) denominator = be.multiply(denominator, mask, out=denominator) ret = { "fast_shape": fast_shape, "fast_ft_shape": fast_ft_shape, "template": be.to_sharedarr(template, shared_memory_handler), "ft_target": be.to_sharedarr(ft_target, shared_memory_handler), "inv_denominator": be.to_sharedarr(denominator, shared_memory_handler), "numerator": be.to_sharedarr(numerator, shared_memory_handler), } return ret
[docs] def cam_setup(template: BackendArray, target: BackendArray, **kwargs) -> Dict: """ Like :py:meth:`corr_setup` but with standardized ``target``, ``template`` .. math:: f' = \\frac{f - \\overline{f}}{\\sigma_f}. Notes ----- To be used with :py:meth:`corr_scoring`. """ template = (template - be.mean(template)) / be.std(template) target = (target - be.mean(target)) / be.std(target) return corr_setup(template=template, target=target, **kwargs)
[docs] def flc_setup( rfftn: Callable, irfftn: Callable, template: BackendArray, template_mask: BackendArray, target: BackendArray, fast_shape: Tuple[int], fast_ft_shape: Tuple[int], shared_memory_handler: type, **kwargs, ) -> Dict: """ Setup function for :py:meth:`flc_scoring`. """ target_pad = be.topleft_pad(target, fast_shape) ft_target = be.zeros(fast_ft_shape, be._complex_dtype) ft_target2 = be.zeros(fast_ft_shape, be._complex_dtype) ft_target = rfftn(target_pad, ft_target) target_pad = be.square(target_pad, out=target_pad) ft_target2 = rfftn(target_pad, ft_target2) template = normalize_template(template, template_mask, be.sum(template_mask)) ret = { "fast_shape": fast_shape, "fast_ft_shape": fast_ft_shape, "template": be.to_sharedarr(template, shared_memory_handler), "template_mask": be.to_sharedarr(template_mask, shared_memory_handler), "ft_target": be.to_sharedarr(ft_target, shared_memory_handler), "ft_target2": be.to_sharedarr(ft_target2, shared_memory_handler), } return ret
[docs] def flcSphericalMask_setup( rfftn: Callable, irfftn: Callable, template: BackendArray, template_mask: BackendArray, target: BackendArray, fast_shape: Tuple[int], fast_ft_shape: Tuple[int], shared_memory_handler: type, **kwargs, ) -> Dict: """ Setup for :py:meth:`corr_scoring`, like :py:meth:`flc_setup` but for rotation invariant masks. """ n_obs, norm_func = be.sum(template_mask), normalize_template if be.datatype_bytes(template_mask.dtype) == 2: norm_func = _normalize_template_overflow_safe n_obs = be.sum(be.astype(template_mask, be._overflow_safe_dtype)) target_pad = be.topleft_pad(target, fast_shape) temp = be.zeros(fast_shape, be._float_dtype) temp2 = be.zeros(fast_shape, be._float_dtype) numerator = be.zeros(1, be._float_dtype) ft_target = be.zeros(fast_ft_shape, be._complex_dtype) ft_template_mask = be.zeros(fast_ft_shape, be._complex_dtype) ft_temp = be.zeros(fast_ft_shape, be._complex_dtype) template = norm_func(template, template_mask, n_obs) ft_template_mask = rfftn( be.topleft_pad(template_mask, fast_shape), ft_template_mask ) # E(X^2) - E(X)^2 ft_target = rfftn(be.square(target_pad), ft_target) ft_temp = be.multiply(ft_target, ft_template_mask, out=ft_temp) temp2 = irfftn(ft_temp, temp2) temp2 = be.divide(temp2, n_obs, out=temp2) ft_target = rfftn(target_pad, ft_target) ft_temp = be.multiply(ft_target, ft_template_mask, out=ft_temp) temp = irfftn(ft_temp, temp) temp = be.divide(temp, n_obs, out=temp) temp = be.square(temp, out=temp) temp = be.subtract(temp2, temp, out=temp) temp = be.maximum(temp, 0.0, out=temp) temp = be.sqrt(temp, out=temp) # Avoide divide by zero warnings mask = temp > be.eps(be._float_dtype) temp = be.multiply(temp, mask * n_obs, out=temp) temp = be.add(temp, ~mask, out=temp) temp2 = be.divide(1, temp, out=temp) temp2 = be.multiply(temp2, mask, out=temp2) ret = { "fast_shape": fast_shape, "fast_ft_shape": fast_ft_shape, "template": be.to_sharedarr(template, shared_memory_handler), "template_mask": be.to_sharedarr(template_mask, shared_memory_handler), "ft_target": be.to_sharedarr(ft_target, shared_memory_handler), "inv_denominator": be.to_sharedarr(temp2, shared_memory_handler), "numerator": be.to_sharedarr(numerator, shared_memory_handler), } return ret
[docs] def mcc_setup( rfftn: Callable, irfftn: Callable, template: BackendArray, template_mask: BackendArray, target: BackendArray, target_mask: BackendArray, fast_shape: Tuple[int], fast_ft_shape: Tuple[int], shared_memory_handler: Callable, **kwargs, ) -> Dict: """ Setup function for :py:meth:`mcc_scoring`. """ target = be.multiply(target, target_mask > 0, out=target) target_pad = be.topleft_pad(target, fast_shape) ft_target = be.zeros(fast_ft_shape, be._complex_dtype) ft_target2 = be.zeros(fast_ft_shape, be._complex_dtype) target_mask_ft = be.zeros(fast_ft_shape, be._complex_dtype) ft_target = rfftn(target_pad, ft_target) ft_target2 = rfftn(be.square(target_pad), ft_target2) target_mask_ft = rfftn(be.topleft_pad(target_mask, fast_shape), target_mask_ft) ret = { "fast_shape": fast_shape, "fast_ft_shape": fast_ft_shape, "template": be.to_sharedarr(template, shared_memory_handler), "template_mask": be.to_sharedarr(template_mask, shared_memory_handler), "ft_target": be.to_sharedarr(ft_target, shared_memory_handler), "ft_target2": be.to_sharedarr(ft_target2, shared_memory_handler), "ft_target_mask": be.to_sharedarr(target_mask_ft, shared_memory_handler), } return ret
[docs] def corr_scoring( template: shm_type, template_filter: shm_type, ft_target: shm_type, inv_denominator: shm_type, numerator: shm_type, fast_shape: Tuple[int], fast_ft_shape: Tuple[int], rotations: BackendArray, callback: CallbackClass, interpolation_order: int, template_mask: shm_type = None, ) -> Optional[CallbackClass]: """ Calculates a normalized cross-correlation between a target f and a template g. .. math:: (CC(f,g) - \\text{numerator}) \\cdot \\text{inv_denominator}, where .. math:: CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*). Parameters ---------- template : Union[Tuple[type, tuple of ints, type], BackendArray] Template data buffer, its shape and datatype. template_filter : Union[Tuple[type, tuple of ints, type], BackendArray] Template filter data buffer, its shape and datatype. ft_target : Union[Tuple[type, tuple of ints, type], BackendArray] Fourier transformed target data buffer, its shape and datatype. inv_denominator : Union[Tuple[type, tuple of ints, type], BackendArray] Inverse denominator data buffer, its shape and datatype. numerator : Union[Tuple[type, tuple of ints, type], BackendArray] Numerator data buffer, its shape, and its datatype. fast_shape: tuple of ints Data shape for the forward Fourier transform. fast_ft_shape: tuple of ints Data shape for the inverse Fourier transform. rotations : BackendArray Rotation matrices to be sampled (n, d, d). callback : CallbackClass A callable for processing the result of each rotation. interpolation_order : int Spline order for template rotations. template_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional Template mask data buffer, its shape and datatype, None by default. Returns ------- Optional[CallbackClass] ``callback`` if provided otherwise None. """ template = be.from_sharedarr(template) ft_target = be.from_sharedarr(ft_target) inv_denominator = be.from_sharedarr(inv_denominator) numerator = be.from_sharedarr(numerator) template_filter = be.from_sharedarr(template_filter) norm_func, norm_template, mask_sum = normalize_template, False, 1 if template_mask is not None: template_mask = be.from_sharedarr(template_mask) norm_template, mask_sum = True, be.sum(template_mask) if be.datatype_bytes(template_mask.dtype) == 2: norm_func = _normalize_template_overflow_safe mask_sum = be.sum(be.astype(template_mask, be._overflow_safe_dtype)) callback_func = conditional_execute(callback, callback is not None) norm_template = conditional_execute(norm_func, norm_template) norm_numerator = conditional_execute( be.subtract, identity, _shape_match(numerator.shape, fast_shape) ) norm_denominator = conditional_execute( be.multiply, identity, _shape_match(inv_denominator.shape, fast_shape) ) arr = be.zeros(fast_shape, be._float_dtype) ft_temp = be.zeros(fast_ft_shape, be._complex_dtype) rfftn, irfftn = be.build_fft( fast_shape=fast_shape, fast_ft_shape=fast_ft_shape, real_dtype=be._float_dtype, complex_dtype=be._complex_dtype, temp_real=arr, temp_fft=ft_temp, ) template_filter_func = _setup_template_filtering( forward_ft_shape=fast_shape, inverse_ft_shape=fast_ft_shape, template_shape=template.shape, template_filter=template_filter, rfftn=rfftn, irfftn=irfftn, ) unpadded_slice = tuple(slice(0, stop) for stop in template.shape) for index in range(rotations.shape[0]): rotation = rotations[index] arr = be.fill(arr, 0) arr, _ = be.rigid_transform( arr=template, rotation_matrix=rotation, out=arr, use_geometric_center=True, order=interpolation_order, cache=True, ) arr = template_filter_func(arr, ft_temp, template_filter) norm_template(arr[unpadded_slice], template_mask, mask_sum) ft_temp = rfftn(arr, ft_temp) ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp) arr = irfftn(ft_temp, arr) arr = norm_numerator(arr, numerator, out=arr) arr = norm_denominator(arr, inv_denominator, out=arr) callback_func(arr, rotation_matrix=rotation) return callback
[docs] def flc_scoring( template: shm_type, template_mask: shm_type, ft_target: shm_type, ft_target2: shm_type, template_filter: shm_type, fast_shape: Tuple[int], fast_ft_shape: Tuple[int], rotations: BackendArray, callback: CallbackClass, interpolation_order: int, ) -> Optional[CallbackClass]: """ Computes a normalized cross-correlation between ``target`` (f), ``template`` (g), and ``template_mask`` (m) .. math:: \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})} {N_m * \\sqrt{ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2} }, where .. math:: CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*) and Nm is the sum of g. Parameters ---------- template : Union[Tuple[type, tuple of ints, type], BackendArray] Template data buffer, its shape and datatype. template_mask : Union[Tuple[type, tuple of ints, type], BackendArray] Template mask data buffer, its shape and datatype. template_filter : Union[Tuple[type, tuple of ints, type], BackendArray] Template filter data buffer, its shape and datatype. ft_target : Union[Tuple[type, tuple of ints, type], BackendArray] Fourier transformed target data buffer, its shape and datatype. ft_target2 : Union[Tuple[type, tuple of ints, type], BackendArray] Fourier transformed squared target data buffer, its shape and datatype. fast_shape : tuple of ints Data shape for the forward Fourier transform. fast_ft_shape : tuple of ints Data shape for the inverse Fourier transform. rotations : BackendArray Rotation matrices to be sampled (n, d, d). callback : CallbackClass A callable for processing the result of each rotation. callback_class_args : Dict Dictionary of arguments to be passed to ``callback``. interpolation_order : int Spline order for template rotations. Returns ------- Optional[CallbackClass] ``callback`` if provided otherwise None. References ---------- .. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012). """ float_dtype, complex_dtype = be._float_dtype, be._complex_dtype template = be.from_sharedarr(template) template_mask = be.from_sharedarr(template_mask) ft_target = be.from_sharedarr(ft_target) ft_target2 = be.from_sharedarr(ft_target2) template_filter = be.from_sharedarr(template_filter) arr = be.zeros(fast_shape, float_dtype) temp = be.zeros(fast_shape, float_dtype) temp2 = be.zeros(fast_shape, float_dtype) ft_temp = be.zeros(fast_ft_shape, complex_dtype) ft_denom = be.zeros(fast_ft_shape, complex_dtype) rfftn, irfftn = be.build_fft( fast_shape=fast_shape, fast_ft_shape=fast_ft_shape, real_dtype=float_dtype, complex_dtype=complex_dtype, temp_real=arr, temp_fft=ft_temp, ) template_filter_func = _setup_template_filtering( forward_ft_shape=fast_shape, inverse_ft_shape=fast_ft_shape, template_shape=template.shape, template_filter=template_filter, rfftn=rfftn, irfftn=irfftn, ) eps = be.eps(float_dtype) callback_func = conditional_execute(callback, callback is not None) for index in range(rotations.shape[0]): rotation = rotations[index] arr = be.fill(arr, 0) temp = be.fill(temp, 0) arr, temp = be.rigid_transform( arr=template, arr_mask=template_mask, rotation_matrix=rotations[index], out=arr, out_mask=temp, use_geometric_center=True, order=interpolation_order, cache=True, ) n_obs = be.sum(temp) arr = template_filter_func(arr, ft_temp, template_filter) arr = normalize_template(arr, temp, n_obs) ft_temp = rfftn(temp, ft_temp) ft_denom = be.multiply(ft_target, ft_temp, out=ft_denom) temp = irfftn(ft_denom, temp) ft_denom = be.multiply(ft_target2, ft_temp, out=ft_denom) temp2 = irfftn(ft_denom, temp2) ft_temp = rfftn(arr, ft_temp) ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp) arr = irfftn(ft_temp, arr) arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr) callback_func(arr, rotation_matrix=rotation) return callback
[docs] def mcc_scoring( template: shm_type, template_mask: shm_type, template_filter: shm_type, ft_target: shm_type, ft_target2: shm_type, ft_target_mask: shm_type, fast_shape: Tuple[int], fast_ft_shape: Tuple[int], rotations: BackendArray, callback: CallbackClass, interpolation_order: int, overlap_ratio: float = 0.3, ) -> CallbackClass: """ Computes a normalized cross-correlation score between ``target`` (f), ``template`` (g), ``template_mask`` (m) and ``target_mask`` (t) .. math:: \\frac{ CC(f, g) - \\frac{CC(f, m) \\cdot CC(t, g)}{CC(t, m)} }{ \\sqrt{ (CC(f ^ 2, m) - \\frac{CC(f, m) ^ 2}{CC(t, m)}) \\cdot (CC(t, g^2) - \\frac{CC(t, g) ^ 2}{CC(t, m)}) } }, where .. math:: CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*). Parameters ---------- template : Union[Tuple[type, tuple of ints, type], BackendArray] Template data buffer, its shape and datatype. template_mask : Union[Tuple[type, tuple of ints, type], BackendArray] Template mask data buffer, its shape and datatype. template_filter : Union[Tuple[type, tuple of ints, type], BackendArray] Template filter data buffer, its shape and datatype. ft_target : Union[Tuple[type, tuple of ints, type], BackendArray] Fourier transformed target data buffer, its shape and datatype. ft_target2 : Union[Tuple[type, tuple of ints, type], BackendArray] Fourier transformed squared target data buffer, its shape and datatype. ft_target_mask : Union[Tuple[type, tuple of ints, type], BackendArray] Fourier transformed target mask data buffer, its shape and datatype. fast_shape: tuple of ints Data shape for the forward Fourier transform. fast_ft_shape: tuple of ints Data shape for the inverse Fourier transform. rotations : BackendArray Rotation matrices to be sampled (n, d, d). callback : CallbackClass A callable for processing the result of each rotation. interpolation_order : int Spline order for template rotations. overlap_ratio : float, optional Required fractional mask overlap, 0.3 by default. References ---------- .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference .. [2] https://scikit-image.org/docs/stable/api/skimage.registration.html """ float_dtype, complex_dtype = be._float_dtype, be._complex_dtype template = be.from_sharedarr(template) target_ft = be.from_sharedarr(ft_target) target_ft2 = be.from_sharedarr(ft_target2) template_mask = be.from_sharedarr(template_mask) target_mask_ft = be.from_sharedarr(ft_target_mask) template_filter = be.from_sharedarr(template_filter) axes = tuple(range(template.ndim)) eps = be.eps(float_dtype) # Allocate score and process specific arrays template_rot = be.zeros(fast_shape, float_dtype) mask_overlap = be.zeros(fast_shape, float_dtype) numerator = be.zeros(fast_shape, float_dtype) temp = be.zeros(fast_shape, float_dtype) temp2 = be.zeros(fast_shape, float_dtype) temp3 = be.zeros(fast_shape, float_dtype) temp_ft = be.zeros(fast_ft_shape, complex_dtype) rfftn, irfftn = be.build_fft( fast_shape=fast_shape, fast_ft_shape=fast_ft_shape, real_dtype=float_dtype, complex_dtype=complex_dtype, temp_real=numerator, temp_fft=temp_ft, ) template_filter_func = _setup_template_filtering( forward_ft_shape=fast_shape, inverse_ft_shape=fast_ft_shape, template_shape=template.shape, template_filter=template_filter, rfftn=rfftn, irfftn=irfftn, ) callback_func = conditional_execute(callback, callback is not None) for index in range(rotations.shape[0]): rotation = rotations[index] template_rot = be.fill(template_rot, 0) temp = be.fill(temp, 0) be.rigid_transform( arr=template, arr_mask=template_mask, rotation_matrix=rotation, out=template_rot, out_mask=temp, use_geometric_center=True, order=interpolation_order, cache=True, ) template_filter_func(template_rot, temp_ft, template_filter) normalize_template(template_rot, temp, be.sum(temp)) temp_ft = rfftn(template_rot, temp_ft) temp2 = irfftn(target_mask_ft * temp_ft, temp2) numerator = irfftn(target_ft * temp_ft, numerator) # temp template_mask_rot | temp_ft template_mask_rot_ft # Calculate overlap of masks at every point in the convolution. # Locations with high overlap should not be taken into account. temp_ft = rfftn(temp, temp_ft) mask_overlap = irfftn(temp_ft * target_mask_ft, mask_overlap) be.maximum(mask_overlap, eps, out=mask_overlap) temp = irfftn(temp_ft * target_ft, temp) be.subtract( numerator, be.divide(be.multiply(temp, temp2), mask_overlap), out=numerator, ) # temp_3 = fixed_denom be.multiply(temp_ft, target_ft2, out=temp_ft) temp3 = irfftn(temp_ft, temp3) be.subtract(temp3, be.divide(be.square(temp), mask_overlap), out=temp3) be.maximum(temp3, 0.0, out=temp3) # temp = moving_denom temp_ft = rfftn(be.square(template_rot), temp_ft) be.multiply(target_mask_ft, temp_ft, out=temp_ft) temp = irfftn(temp_ft, temp) be.subtract(temp, be.divide(be.square(temp2), mask_overlap), out=temp) be.maximum(temp, 0.0, out=temp) # temp_2 = denom be.multiply(temp3, temp, out=temp) be.sqrt(temp, temp2) # Pixels where `denom` is very small will introduce large # numbers after division. To get around this problem, # we zero-out problematic pixels. tol = 1e3 * eps * be.max(be.abs(temp2), axis=axes, keepdims=True) temp2[temp2 < tol] = 1 temp = be.divide(numerator, temp2, out=temp) temp = be.clip(temp, a_min=-1, a_max=1, out=temp) # Apply overlap ratio threshold number_px_threshold = overlap_ratio * be.max( mask_overlap, axis=axes, keepdims=True ) temp[mask_overlap < number_px_threshold] = 0.0 callback_func(temp, rotation_matrix=rotation) return callback
MATCHING_EXHAUSTIVE_REGISTER = { "CC": (cc_setup, corr_scoring), "LCC": (lcc_setup, corr_scoring), "CORR": (corr_setup, corr_scoring), "CAM": (cam_setup, corr_scoring), "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring), "FLC": (flc_setup, flc_scoring), "MCC": (mcc_setup, mcc_scoring), }