Source code for tme.backends.jax_backend

"""
Backend using jax for template matching.

Copyright (c) 2023-2024 European Molecular Biology Laboratory

Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

from functools import wraps
from typing import Tuple, List, Dict, Any

import numpy as np

from ..types import JaxArray
from .npfftw_backend import NumpyFFTWBackend, shm_type


def emulate_out(func):
    """
    Adds an out argument to write output of ``func`` to.
    """

    @wraps(func)
    def inner(*args, out=None, **kwargs):
        ret = func(*args, **kwargs)
        if out is not None:
            out = out.at[:].set(ret)
            return out
        return ret

    return inner


[docs] class JaxBackend(NumpyFFTWBackend): """ A jax-based matching backend. """ def __init__(self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs): import jax.scipy as jsp import jax.numpy as jnp float_dtype = jnp.float32 if float_dtype is None else float_dtype complex_dtype = jnp.complex64 if complex_dtype is None else complex_dtype int_dtype = jnp.int32 if int_dtype is None else int_dtype super().__init__( array_backend=jnp, float_dtype=float_dtype, complex_dtype=complex_dtype, int_dtype=int_dtype, overflow_safe_dtype=float_dtype, ) self.scipy = jsp self._create_ufuncs()
[docs] def from_sharedarr(self, arr: JaxArray) -> JaxArray: return arr
[docs] @staticmethod def to_sharedarr(arr: JaxArray, shared_memory_handler: type = None) -> shm_type: return arr
[docs] @staticmethod def at(arr, idx, value) -> JaxArray: return arr.at[idx].set(value)
[docs] def addat(self, arr, indices, values): return arr.at[indices].add(values)
[docs] def topleft_pad( self, arr: JaxArray, shape: Tuple[int], padval: int = 0 ) -> JaxArray: b = self.full(shape=shape, dtype=arr.dtype, fill_value=padval) aind = [slice(None, None)] * arr.ndim bind = [slice(None, None)] * arr.ndim for i in range(arr.ndim): if arr.shape[i] > shape[i]: aind[i] = slice(0, shape[i]) elif arr.shape[i] < shape[i]: bind[i] = slice(0, arr.shape[i]) b = b.at[tuple(bind)].set(arr[tuple(aind)]) return b
def _create_ufuncs(self): ufuncs = [ "add", "subtract", "multiply", "divide", "square", "sqrt", "maximum", "exp", "mod", "dot", ] for ufunc in ufuncs: backend_method = emulate_out(getattr(self._array_backend, ufunc)) setattr(self, ufunc, staticmethod(backend_method)) ufuncs = ["zeros", "full"] for ufunc in ufuncs: backend_method = getattr(self._array_backend, ufunc) setattr(self, ufunc, staticmethod(backend_method))
[docs] def fill(self, arr: JaxArray, value: float) -> JaxArray: return self._array_backend.full( shape=arr.shape, dtype=arr.dtype, fill_value=value )
[docs] def rfftn(self, arr: JaxArray, *args, **kwargs) -> JaxArray: return self._array_backend.fft.rfftn(arr, **kwargs)
[docs] def irfftn(self, arr: JaxArray, *args, **kwargs) -> JaxArray: return self._array_backend.fft.irfftn(arr, **kwargs)
def _interpolate(self, arr, indices, order: int = 1): ret = self.scipy.ndimage.map_coordinates(arr, indices, order=order) return ret.reshape(arr.shape) def _index_grid(self, shape: Tuple[int]) -> JaxArray: """ Create homogeneous coordinate grid. Parameters ---------- shape : tuple of int Shape to create the grid for Returns ------- JaxArray Coordinate grid of shape (ndim + int(homogeneous), n_points) """ indices = self._array_backend.indices(shape, dtype=self._float_dtype) indices = indices.reshape((len(shape), -1)) ones = self._array_backend.ones((1, indices.shape[1]), dtype=indices.dtype) return self._array_backend.concatenate([indices, ones], axis=0) def _transform_indices(self, indices: JaxArray, matrix: JaxArray) -> JaxArray: return self._array_backend.matmul(matrix[:-1], indices) def _rigid_transform( self, arr: JaxArray, matrix: JaxArray, out: JaxArray = None, out_mask: JaxArray = None, arr_mask: JaxArray = None, order: int = 1, **kwargs, ) -> Tuple[JaxArray, JaxArray]: indices = self._index_grid(arr.shape) indices = self._transform_indices(indices, matrix) arr = self._interpolate(arr, indices, order) if arr_mask is not None: arr_mask = self._interpolate(out_mask, indices, order) return arr, arr_mask
[docs] def max_score_over_rotations( self, scores: JaxArray, max_scores: JaxArray, rotations: JaxArray, rotation_index: int, ) -> Tuple[JaxArray, JaxArray]: update = self.greater(max_scores, scores) max_scores = max_scores.at[:].set(self.where(update, max_scores, scores)) rotations = rotations.at[:].set(self.where(update, rotations, rotation_index)) return max_scores, rotations
[docs] def compute_convolution_shapes( self, arr1_shape: Tuple[int], arr2_shape: Tuple[int] ) -> Tuple[List[int], List[int], List[int]]: from scipy.fft import next_fast_len convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)] fast_shape = [next_fast_len(x, real=True) for x in convolution_shape] fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1] return convolution_shape, fast_shape, fast_ft_shape
def _to_hashable(self, obj: Any) -> Tuple[str, Tuple]: if isinstance(obj, np.ndarray): return ("array", (tuple(obj.flatten().tolist()), obj.shape)) return ("other", obj) def _from_hashable(self, type_info: str, data: Any) -> Any: if type_info == "array": data, shape = data return self.array(data).reshape(shape) return data def _dict_to_tuple(self, data: Dict) -> Tuple: return tuple((k, self._to_hashable(v)) for k, v in data.items()) def _tuple_to_dict(self, data: Tuple) -> Dict: return {x[0]: self._from_hashable(*x[1]) for x in data} def _unbatch(self, data, target_ndim, index): if not isinstance(data, type(self.zeros(1))): return data elif data.ndim <= target_ndim: return data return data[index]
[docs] def scan( self, matching_data: type, splits: Tuple[Tuple[slice, slice]], n_jobs: int, callback_class: object, callback_class_args: Dict, rotate_mask: bool = False, background_correction: str = None, match_projection: bool = False, **kwargs, ) -> List: """ Emulates output of :py:meth:`tme.matching_exhaustive._match_exhaustive`. """ from ._jax_utils import setup_scan from ..matching_utils import setup_filter from ..analyzer import MaxScoreOverRotations pad_target = True if len(splits) > 1 else False target_pad = matching_data.target_padding(pad_target=pad_target) template_shape = matching_data._batch_shape( matching_data.template.shape, matching_data._target_batch ) score_mask = 1 target_shape = tuple( (x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad) ) conv_shape, fast_shape, fast_ft_shape, shift = matching_data.fourier_padding( target_shape=target_shape ) analyzer_args = { "shape": fast_shape, "fourier_shift": shift, "fast_shape": fast_shape, "templateshape": template_shape, "convolution_shape": conv_shape, "convolution_mode": "valid" if pad_target else "same", "thread_safe": False, "aggregate_axis": matching_data._batch_axis(matching_data._batch_mask), "n_rotations": matching_data.rotations.shape[0], "jax_mode": True, } analyzer_args.update(callback_class_args) create_target_filter = matching_data.target_filter is not None create_template_filter = matching_data.template_filter is not None create_filter = create_target_filter or create_template_filter bg_tmpl = 1 if background_correction == "phase-scrambling": bg_tmpl = matching_data.transform_template( "phase_randomization", reverse=True ) bg_tmpl = self.astype(bg_tmpl, self._float_dtype) rotations = self.astype(matching_data.rotations, self._float_dtype) ret, template_filter, target_filter = [], 1, 1 rotation_mapping = { self.tobytes(rotations[i]): i for i in range(rotations.shape[0]) } for split_start in range(0, len(splits), n_jobs): analyzer_kwargs = [] split_subset = splits[split_start : (split_start + n_jobs)] if not len(split_subset): continue targets = [] for target_split, template_split in split_subset: base, translation_offset = matching_data.subset_by_slice( target_slice=target_split, target_pad=target_pad, template_slice=template_split, ) cur_args = analyzer_args.copy() cur_args["offset"] = translation_offset cur_args["targetshape"] = base._output_shape analyzer_kwargs.append(cur_args) if pad_target: score_mask = base._score_mask(fast_shape, shift) # We prepad outside of jit to guarantee the stack operation works targets.append(self.topleft_pad(base._target, fast_shape)) if create_filter: # This is technically inaccurate for whitening filters template_filter, target_filter = setup_filter( matching_data=base, fast_shape=fast_shape, fast_ft_shape=fast_ft_shape, pad_template_filter=False, apply_target_filter=False, ) # For projection matching we allow broadcasting the first dimension # This becomes problematic when applying per-tilt filters to the target # as the number of tilts does not necessarily coincide with the ideal # fourier shape. Hence we pad the target_filter with zeros here if target_filter.shape != (1,): target_filter = self.topleft_pad(target_filter, fast_ft_shape) base, targets = None, self._array_backend.stack(targets) scan_inner = setup_scan( analyzer_kwargs=analyzer_kwargs, analyzer=callback_class, fast_shape=fast_shape, rotate_mask=rotate_mask, match_projection=match_projection, ) states = scan_inner( self.astype(targets, self._float_dtype), self.astype(matching_data.template, self._float_dtype), self.astype(matching_data.template_mask, self._float_dtype), rotations, template_filter, target_filter, score_mask, bg_tmpl, ) ndim = targets.ndim - 1 for index in range(targets.shape[0]): kwargs = analyzer_kwargs[index] analyzer = callback_class(**kwargs) state = [self._unbatch(x, ndim, index) for x in states] if isinstance(analyzer, MaxScoreOverRotations): state[2] = rotation_mapping ret.append(analyzer.result(state, **kwargs)) return ret
[docs] def get_available_memory(self) -> int: import jax _memory = {"cpu": 0, "gpu": 0} for device in jax.devices(): if device.platform == "cpu": _memory["cpu"] = super().get_available_memory() else: mem_stats = device.memory_stats() _memory["gpu"] += mem_stats.get("bytes_limit", 0) if _memory["gpu"] > 0: return _memory["gpu"] return _memory["cpu"]