"""
Implements classes to analyze outputs from exhaustive template matching.
Copyright (c) 2023 European Molecular Biology Laboratory
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""
from typing import Tuple, List, Dict
import numpy as np
from .base import AbstractAnalyzer
from ..types import BackendArray
from ._utils import cart_to_score
from ..backends import backend as be
from ..matching_utils import (
create_mask,
array_to_memmap,
apply_convolution_mode,
generate_tempfile_name,
)
__all__ = [
"MaxScoreOverRotations",
"MaxScoreOverRotationsConstrained",
"MaxScoreOverTranslations",
]
[docs]
class MaxScoreOverRotations(AbstractAnalyzer):
"""
Determine the rotation maximizing the score over all possible translations.
Parameters
----------
shape : tuple of int
Shape of array passed to :py:meth:`MaxScoreOverRotations.__call__`.
offset : BackendArray, optional
Coordinate origin considered during merging, zero by default.
score_threshold : float, optional
Minimum score to be considered, zero by default.
shm_handler : :class:`multiprocessing.managers.SharedMemoryManager`, optional
Shared memory manager, defaults to memory not being shared.
use_memmap : bool, optional
Memmap internal arrays, False by default.
thread_safe: bool, optional
Allow class to be modified by multiple processes, True by default.
inversion_mapping : bool, optional
Do not use rotation matrix bytestrings for intermediate data handling.
This is useful for GPU backend where analyzers are not shared across
devices and every rotation is only observed once. It is generally
safe to deactivate inversion mapping, but at a cost of performance.
Examples
--------
The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
instance
>>> from tme.analyzer import MaxScoreOverRotations
>>> analyzer = MaxScoreOverRotations(shape = (50, 50))
The following simulates a template matching run by creating random data for a range
of rotations and sending it to ``analyzer`` via its __call__ method
>> state = analyzer.init_state()
>>> for rotation_number in range(10):
>>> scores = np.random.rand(50,50)
>>> rotation = np.random.rand(scores.ndim, scores.ndim)
>>> state, analyzer(state, scores = scores, rotation_matrix = rotation)
The aggregated scores can be extracted by invoking the result method of
``analyzer``
>>> results = analyzer.result()
The ``results`` tuple contains (1) the maximum scores for each translation,
(2) an offset which is relevant when merging results from split template matching
using :py:meth:`MaxScoreOverRotations.merge`, (3) the rotation used to obtain a
score for a given translation, (4) a dictionary mapping indices used in (2) to
rotation matrices (2).
We can extract the ``optimal_score``, ``optimal_translation`` and ``optimal_rotation``
as follows
>>> optimal_score = results[0].max()
>>> optimal_translation = np.where(results[0] == results[0].max())
>>> optimal_rotation = results[2][optimal_translation]
The outlined procedure is a trivial method to identify high scoring peaks.
Alternatively, :py:class:`PeakCaller` offers a range of more elaborate approaches
that can be used.
"""
def __init__(
self,
shape: Tuple[int],
offset: BackendArray = None,
score_threshold: float = 0,
shm_handler: object = None,
use_memmap: bool = False,
inversion_mapping: bool = False,
**kwargs,
):
self._use_memmap = use_memmap
self._score_threshold = score_threshold
self._shape = tuple(int(x) for x in shape)
self._inversion_mapping = inversion_mapping
if offset is None:
offset = be.zeros(len(self._shape), be._int_dtype)
self._offset = be.astype(be.to_backend_array(offset), int)
@property
def shareable(self):
return True
[docs]
def init_state(self):
"""
Initialize the analysis state.
Returns
-------
tuple
Initial state tuple containing (scores, rotations, rotation_mapping) where:
- scores : BackendArray of shape `self._shape` filled with `score_threshold`.
- rotations : BackendArray of shape `self._shape` filled with -1.
- rotation_mapping : dict, empty mapping from rotation bytes to indices.
"""
scores = be.full(
shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
)
rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
return scores, rotations, {}
[docs]
def __call__(
self,
state: Tuple,
scores: BackendArray,
rotation_matrix: BackendArray,
) -> Tuple:
"""
Update the parameter store.
Parameters
----------
state : tuple
Current state tuple (scores, rotations, rotation_mapping) where:
- scores : BackendArray, current maximum scores.
- rotations : BackendArray, current rotation indices.
- rotation_mapping : dict, mapping from rotation bytes to indices.
scores : BackendArray
Array of new scores to update analyzer with.
rotation_matrix : BackendArray
Square matrix used to obtain the current rotation.
Returns
-------
tuple
Updated state tuple (scores, rotations, rotation_mapping).
"""
# be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
# If the analyzer is not shared and each rotation is unique, we can
# use index to rotation mapping and invert prior to merging.
prev_scores, rotations, rotation_mapping = state
rotation_index = len(rotation_mapping)
rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
if self._inversion_mapping:
rotation_mapping[rotation_index] = rotation_matrix
else:
rotation = be.tobytes(rotation_matrix)
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
scores, rotations = be.max_score_over_rotations(
scores=scores,
max_scores=prev_scores,
rotations=rotations,
rotation_index=rotation_index,
)
return scores, rotations, rotation_mapping
@staticmethod
def _invert_rmap(rotation_mapping: dict) -> dict:
"""
Invert dictionary from rotation matrix bytestrings mapping to rotation
indices ro rotation indices mapping to rotation matrices.
"""
new_map, ndim = {}, None
for k, v in rotation_mapping.items():
nbytes = be.datatype_bytes(be._float_dtype)
dtype = np.float32 if nbytes == 4 else np.float16
rmat = np.frombuffer(k, dtype=dtype)
if ndim is None:
ndim = int(np.sqrt(rmat.size))
new_map[v] = rmat.reshape(ndim, ndim)
return new_map
[docs]
def result(
self,
state,
targetshape: Tuple[int] = None,
templateshape: Tuple[int] = None,
convolution_shape: Tuple[int] = None,
fourier_shift: Tuple[int] = None,
convolution_mode: str = None,
**kwargs,
) -> Tuple:
"""
Finalize the analysis result with optional postprocessing.
Parameters
----------
state : tuple
Current state tuple (scores, rotations, rotation_mapping) where:
- scores : BackendArray, current maximum scores.
- rotations : BackendArray, current rotation indices.
- rotation_mapping : dict, mapping from rotation indices to matrices.
targetshape : Tuple[int], optional
Shape of the target for convolution mode correction.
templateshape : Tuple[int], optional
Shape of the template for convolution mode correction.
convolution_shape : Tuple[int], optional
Shape used for convolution.
fourier_shift : Tuple[int], optional.
Shift to apply for Fourier correction.
convolution_mode : str, optional
Convolution mode for padding correction.
**kwargs
Additional keyword arguments.
Returns
-------
tuple
Final result tuple (scores, offset, rotations, rotation_mapping).
"""
scores, rotations, rotation_mapping = state
# Apply postprocessing if parameters are provided
if fourier_shift is not None:
axis = tuple(i for i in range(len(fourier_shift)))
scores = be.roll(scores, shift=fourier_shift, axis=axis)
rotations = be.roll(rotations, shift=fourier_shift, axis=axis)
if convolution_mode is not None:
convargs = {
"s1": targetshape,
"s2": templateshape,
"convolution_mode": convolution_mode,
"convolution_shape": convolution_shape,
}
scores = apply_convolution_mode(scores, **convargs)
rotations = apply_convolution_mode(rotations, **convargs)
scores = be.to_numpy_array(scores)
rotations = be.to_numpy_array(rotations)
if self._use_memmap:
scores = array_to_memmap(scores)
rotations = array_to_memmap(rotations)
if self._inversion_mapping:
rotation_mapping = {be.tobytes(v): k for k, v in rotation_mapping.items()}
return (
scores,
be.to_numpy_array(self._offset),
rotations,
self._invert_rmap(rotation_mapping),
)
def _harmonize_states(states: List[Tuple]):
"""
Create consistent reference frame for merging different analyzer
instances, w.r.t. to rotations and output shape from different
splits of the target.
"""
new_rotation_mapping, out_shape = {}, None
for i in range(len(states)):
if states[i] is None:
continue
scores, offset, rotations, rotation_mapping = states[i]
if out_shape is None:
out_shape = np.zeros(scores.ndim, int)
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
new_param = {}
for key, value in rotation_mapping.items():
rotation_bytes = be.tobytes(value)
new_param[rotation_bytes] = key
if rotation_bytes not in new_rotation_mapping:
new_rotation_mapping[rotation_bytes] = len(new_rotation_mapping)
states[i] = (scores, offset, rotations, new_param)
out_shape = tuple(int(x) for x in out_shape)
return new_rotation_mapping, out_shape, states
[docs]
@classmethod
def merge(cls, results: List[Tuple], **kwargs) -> Tuple:
"""
Merge multiple instances of the current class.
Parameters
----------
results : list of tuple
List of instance's internal state created by applying `result`.
**kwargs : dict, optional
Optional keyword arguments.
Returns
-------
NDArray
Maximum score of each translation over all observed rotations.
NDArray
Translation offset, zero by default.
NDArray
Mapping between translations and rotation indices.
Dict
Mapping between rotations and rotation indices.
"""
use_memmap = kwargs.get("use_memmap", False)
if len(results) == 1:
ret = results[0]
if use_memmap:
scores, offset, rotations, rotation_mapping = ret
scores = array_to_memmap(scores)
rotations = array_to_memmap(rotations)
ret = (scores, offset, rotations, rotation_mapping)
return ret
# Determine output array shape and create consistent rotation map
master_rotation_mapping, out_shape, results = cls._harmonize_states(results)
if out_shape is None:
return None
scores_dtype = results[0][0].dtype
rotations_dtype = results[0][2].dtype
if use_memmap:
scores_out_filename = generate_tempfile_name()
rotations_out_filename = generate_tempfile_name()
scores_out = np.memmap(
scores_out_filename, mode="w+", shape=out_shape, dtype=scores_dtype
)
scores_out.fill(kwargs.get("score_threshold", 0))
scores_out.flush()
rotations_out = np.memmap(
rotations_out_filename,
mode="w+",
shape=out_shape,
dtype=rotations_dtype,
)
rotations_out.fill(-1)
rotations_out.flush()
else:
scores_out = np.full(
out_shape,
fill_value=kwargs.get("score_threshold", 0),
dtype=scores_dtype,
)
rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
for i in range(len(results)):
if results[i] is None:
continue
if use_memmap:
scores_out = np.memmap(
scores_out_filename,
mode="r+",
shape=out_shape,
dtype=scores_dtype,
)
rotations_out = np.memmap(
rotations_out_filename,
mode="r+",
shape=out_shape,
dtype=rotations_dtype,
)
scores, offset, rotations, rotation_mapping = results[i]
stops = np.add(offset, scores.shape).astype(int)
indices = tuple(slice(*pos) for pos in zip(offset, stops))
indices_update = scores > scores_out[indices]
scores_out[indices][indices_update] = scores[indices_update]
lookup_table = np.arange(
len(rotation_mapping) + 1, dtype=rotations_out.dtype
)
for key, value in rotation_mapping.items():
lookup_table[value] = master_rotation_mapping[key]
updated_rotations = rotations[indices_update]
if len(updated_rotations):
rotations_out[indices][indices_update] = lookup_table[updated_rotations]
if use_memmap:
scores._mmap.close()
rotations._mmap.close()
scores_out.flush()
rotations_out.flush()
scores_out, rotations_out = None, None
results[i] = None
scores, rotations = None, None
if use_memmap:
scores_out = np.memmap(
scores_out_filename, mode="r", shape=out_shape, dtype=scores_dtype
)
rotations_out = np.memmap(
rotations_out_filename,
mode="r",
shape=out_shape,
dtype=rotations_dtype,
)
return (
scores_out,
np.zeros(scores_out.ndim, dtype=int),
rotations_out,
cls._invert_rmap(master_rotation_mapping),
)
[docs]
class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
"""
Implements constrained template matching using rejection sampling.
Parameters
----------
cone_angle : float
Maximum accepted rotational deviation in degrees.
positions : BackendArray
Array of shape (n, d) with n seed point translations.
positions : BackendArray
Array of shape (n, d, d) with n seed point rotation matrices.
reference : BackendArray
Reference orientation of the template, wlog defaults to (0,0,1).
acceptance_radius : int or tuple of ints
Translational acceptance radius around seed point in voxels.
**kwargs : dict, optional
Keyword aguments passed to the constructor of :py:class:`MaxScoreOverRotations`.
"""
def __init__(
self,
cone_angle: float,
positions: BackendArray,
rotations: BackendArray,
reference: BackendArray = (0, 0, 1),
acceptance_radius: int = 10,
**kwargs,
):
MaxScoreOverRotations.__init__(self, **kwargs)
if not isinstance(acceptance_radius, (int, Tuple)):
raise ValueError("acceptance_radius needs to be of type int or tuple.")
if isinstance(acceptance_radius, int):
acceptance_radius = (
acceptance_radius,
acceptance_radius,
acceptance_radius,
)
acceptance_radius = tuple(int(x) for x in acceptance_radius)
self._cone_angle = float(np.radians(cone_angle))
self._cone_cutoff = float(np.tan(self._cone_angle))
self._reference = be.astype(
be.reshape(be.to_backend_array(reference), (-1,)), be._float_dtype
)
positions = be.astype(be.to_backend_array(positions), be._int_dtype)
ndim = positions.shape[1]
rotate_mask = len(set(acceptance_radius)) != 1
extend = max(acceptance_radius)
mask = create_mask(
mask_type="ellipse",
radius=acceptance_radius,
shape=tuple(2 * extend + 1 for _ in range(ndim)),
center=tuple(extend for _ in range(ndim)),
)
self._score_mask = be.astype(be.to_backend_array(mask), be._float_dtype)
# Map position from real space to shifted score space
lower_limit = be.to_backend_array(self._offset)
positions = be.subtract(positions, lower_limit)
positions, valid_positions = cart_to_score(
positions=positions,
fast_shape=kwargs.get("fast_shape", None),
targetshape=kwargs.get("targetshape", None),
templateshape=kwargs.get("templateshape", None),
fourier_shift=kwargs.get("fourier_shift", None),
convolution_mode=kwargs.get("convolution_mode", None),
convolution_shape=kwargs.get("convolution_shape", None),
)
self._positions = positions[valid_positions]
rotations = be.to_backend_array(rotations)[valid_positions]
ex = be.astype(be.to_backend_array((1, 0, 0)), be._float_dtype)
ey = be.astype(be.to_backend_array((0, 1, 0)), be._float_dtype)
ez = be.astype(be.to_backend_array((0, 0, 1)), be._float_dtype)
self._normals_x = (rotations @ ex[..., None])[..., 0]
self._normals_y = (rotations @ ey[..., None])[..., 0]
self._normals_z = (rotations @ ez[..., None])[..., 0]
# Periodic wrapping could be avoided by padding the target
shape = be.to_backend_array(self._shape)
starts = be.subtract(self._positions, extend)
ret, (n, d), mshape = [], self._positions.shape, self._score_mask.shape
if starts.shape[0] > 0:
for i in range(d):
indices = starts[:, slice(i, i + 1)] + be.arange(mshape[i])[None]
indices = be.mod(indices, shape[i], out=indices)
indices_shape = (n, *tuple(1 if k != i else -1 for k in range(d)))
ret.append(be.reshape(indices, indices_shape))
self._index_grid = tuple(ret)
self._mask_shape = tuple(1 if i != 0 else -1 for i in range(1 + ndim))
if rotate_mask:
self._score_mask = be.zeros(
(rotations.shape[0], *self._score_mask.shape), dtype=be._float_dtype
)
for i in range(rotations.shape[0]):
mask = create_mask(
mask_type="ellipse",
radius=acceptance_radius,
shape=tuple(2 * extend + 1 for _ in range(ndim)),
center=tuple(extend for _ in range(ndim)),
orientation=be.to_numpy_array(rotations[i]),
)
self._score_mask[i] = be.astype(
be.to_backend_array(mask), be._float_dtype
)
[docs]
def __call__(
self, state: Tuple, scores: BackendArray, rotation_matrix: BackendArray
) -> Tuple:
mask = self._get_constraint(rotation_matrix)
mask = self._get_score_mask(mask=mask, scores=scores)
scores = be.multiply(scores, mask, out=scores)
return super().__call__(state, scores=scores, rotation_matrix=rotation_matrix)
def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
"""
Determine whether the angle between projection of reference w.r.t to
a given rotation matrix and a set of rotations fall within the set
cone_angle cutoff.
Parameters
----------
rotation_matrix : BackendArray
Rotation matrix with shape (d,d).
Returns
-------
BackerndArray
Boolean mask of shape (n, )
"""
template_rot = rotation_matrix @ self._reference
x = be.sum(be.multiply(self._normals_x, template_rot), axis=1)
y = be.sum(be.multiply(self._normals_y, template_rot), axis=1)
z = be.sum(be.multiply(self._normals_z, template_rot), axis=1)
return be.sqrt(x**2 + y**2) <= (z * self._cone_cutoff)
def _get_score_mask(self, mask: BackendArray, scores: BackendArray, **kwargs):
score_mask = be.zeros(scores.shape, scores.dtype)
if be.sum(mask) == 0:
return score_mask
mask = be.reshape(mask, self._mask_shape)
score_mask = be.addat(score_mask, self._index_grid, self._score_mask * mask)
return score_mask > 0
[docs]
class MaxScoreOverTranslations(MaxScoreOverRotations):
"""
Determine the translation maximizing the score over all possible rotations.
Parameters
----------
shape : tuple of int
Shape of array passed to :py:meth:`MaxScoreOverTranslations.__call__`.
n_rotations : int
Number of rotations to aggregate over.
aggregate_axis : tuple of int, optional
Array axis to aggregate over, None by default.
shm_handler : :class:`multiprocessing.managers.SharedMemoryManager`, optional
Shared memory manager, defaults to memory not being shared.
**kwargs: dict, optional
Keyword arguments passed to the constructor of the parent class.
"""
def __init__(
self,
shape: Tuple[int],
n_rotations: int,
aggregate_axis: Tuple[int] = None,
shm_handler: object = None,
offset: Tuple[int] = None,
**kwargs: Dict,
):
shape_reduced = [x for i, x in enumerate(shape) if i not in aggregate_axis]
shape_reduced.insert(0, n_rotations)
if offset is None:
offset = be.zeros(len(shape), be._int_dtype)
offset = [x for i, x in enumerate(offset) if i not in aggregate_axis]
offset.insert(0, 0)
super().__init__(
shape=shape_reduced, shm_handler=shm_handler, offset=offset, **kwargs
)
self._aggregate_axis = aggregate_axis
[docs]
def init_state(self):
scores = be.full(
shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
)
rotations = be.full(1, dtype=be._int_dtype, fill_value=-1)
return scores, rotations, {}
[docs]
def __call__(
self, state, scores: BackendArray, rotation_matrix: BackendArray
) -> Tuple:
prev_scores, rotations, rotation_mapping = state
rotation_index = len(rotation_mapping)
if self._inversion_mapping:
rotation_mapping[rotation_index] = rotation_matrix
else:
rotation = be.tobytes(rotation_matrix)
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
max_score = be.max(scores, axis=self._aggregate_axis)
update = prev_scores[rotation_index]
update = be.maximum(max_score, update, out=update)
return prev_scores, rotations, rotation_mapping
[docs]
@classmethod
def merge(cls, states: List[Tuple], **kwargs) -> Tuple:
"""
Merge multiple instances of the current class.
Parameters
----------
states : list of tuple
List of instance's internal state created by applying `tuple(instance)`.
**kwargs : dict, optional
Optional keyword arguments.
Returns
-------
NDArray
Maximum score of each rotation over all observed translations.
NDArray
Translation offset, zero by default.
NDArray
Mapping between translations and rotation indices.
Dict
Mapping between rotations and rotation indices.
"""
if len(states) == 1:
return states[0]
# Determine output array shape and create consistent rotation map
states, master_rotation_mapping, out_shape = cls._harmonize_states(states)
if out_shape is None:
return None
out_shape[0] = len(master_rotation_mapping)
out_shape = tuple(int(x) for x in out_shape)
scores_dtype = states[0][0].dtype
use_memmap = kwargs.get("use_memmap", False)
if use_memmap:
scores_out_filename = generate_tempfile_name()
scores_out = np.memmap(
scores_out_filename, mode="w+", shape=out_shape, dtype=scores_dtype
)
scores_out.fill(kwargs.get("score_threshold", 0))
scores_out.flush()
else:
scores_out = np.full(
out_shape,
fill_value=kwargs.get("score_threshold", 0),
dtype=scores_dtype,
)
for i in range(len(states)):
if states[i] is None:
continue
if use_memmap:
scores_out = np.memmap(
scores_out_filename,
mode="r+",
shape=out_shape,
dtype=scores_dtype,
)
scores, offset, rotations, rotation_mapping = states[i]
outer_table = np.arange(len(rotation_mapping), dtype=int)
lookup_table = np.array(
[master_rotation_mapping[key] for key in rotation_mapping.keys()],
dtype=int,
)
stops = np.add(offset, scores.shape).astype(int)
indices = [slice(*pos) for pos in zip(offset[1:], stops[1:])]
indices.insert(0, lookup_table)
indices = tuple(indices)
scores_out[indices] = np.maximum(scores_out[indices], scores[outer_table])
if use_memmap:
scores._mmap.close()
scores_out.flush()
scores_out = None
states[i], scores = None, None
if use_memmap:
scores_out = np.memmap(
scores_out_filename, mode="r", shape=out_shape, dtype=scores_dtype
)
return (
scores_out,
np.zeros(scores_out.ndim, dtype=int),
states[2],
cls._invert_rmap(master_rotation_mapping),
)
def _postprocess(self, **kwargs):
return self