""" Backend using jax for template matching.
Copyright (c) 2023-2024 European Molecular Biology Laboratory
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""
from functools import wraps
from typing import Tuple, List, Callable
from ..types import BackendArray
from .npfftw_backend import NumpyFFTWBackend, shm_type
def emulate_out(func):
"""
Adds an out argument to write output of ``func`` to.
"""
@wraps(func)
def inner(*args, out=None, **kwargs):
ret = func(*args, **kwargs)
if out is not None:
out = out.at[:].set(ret)
return out
return ret
return inner
[docs]
class JaxBackend(NumpyFFTWBackend):
"""
A jax-based matching backend.
"""
def __init__(self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs):
import jax.scipy as jsp
import jax.numpy as jnp
float_dtype = jnp.float32 if float_dtype is None else float_dtype
complex_dtype = jnp.complex64 if complex_dtype is None else complex_dtype
int_dtype = jnp.int32 if int_dtype is None else int_dtype
super().__init__(
array_backend=jnp,
float_dtype=float_dtype,
complex_dtype=complex_dtype,
int_dtype=int_dtype,
overflow_safe_dtype=float_dtype,
)
self.scipy = jsp
self._create_ufuncs()
try:
from ._jax_utils import scan as _
self.scan = self._scan
except Exception:
pass
[docs]
def from_sharedarr(self, arr: BackendArray) -> BackendArray:
return arr
[docs]
@staticmethod
def to_sharedarr(arr: BackendArray, shared_memory_handler: type = None) -> shm_type:
return arr
[docs]
def topleft_pad(
self, arr: BackendArray, shape: Tuple[int], padval: int = 0
) -> BackendArray:
b = self.full(shape=shape, dtype=arr.dtype, fill_value=padval)
aind = [slice(None, None)] * arr.ndim
bind = [slice(None, None)] * arr.ndim
for i in range(arr.ndim):
if arr.shape[i] > shape[i]:
aind[i] = slice(0, shape[i])
elif arr.shape[i] < shape[i]:
bind[i] = slice(0, arr.shape[i])
b = b.at[tuple(bind)].set(arr[tuple(aind)])
return b
def _create_ufuncs(self):
ufuncs = [
"add",
"subtract",
"multiply",
"divide",
"square",
"sqrt",
"maximum",
]
for ufunc in ufuncs:
backend_method = emulate_out(getattr(self._array_backend, ufunc))
setattr(self, ufunc, staticmethod(backend_method))
ufuncs = ["zeros", "full"]
for ufunc in ufuncs:
backend_method = getattr(self._array_backend, ufunc)
setattr(self, ufunc, staticmethod(backend_method))
[docs]
def fill(self, arr: BackendArray, value: float) -> BackendArray:
return self._array_backend.full(
shape=arr.shape, dtype=arr.dtype, fill_value=value
)
[docs]
def build_fft(
self,
fast_shape: Tuple[int],
fast_ft_shape: Tuple[int],
inverse_fast_shape: Tuple[int] = None,
**kwargs,
) -> Tuple[Callable, Callable]:
if inverse_fast_shape is None:
inverse_fast_shape = fast_shape
def rfftn(arr, out, shape=fast_shape):
return self._array_backend.fft.rfftn(arr, s=shape)
def irfftn(arr, out, shape=fast_shape):
return self._array_backend.fft.irfftn(arr, s=shape)
return rfftn, irfftn
[docs]
def max_score_over_rotations(
self,
scores: BackendArray,
max_scores: BackendArray,
rotations: BackendArray,
rotation_index: int,
) -> Tuple[BackendArray, BackendArray]:
update = self.greater(max_scores, scores)
max_scores = max_scores.at[:].set(self.where(update, max_scores, scores))
rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
return max_scores, rotations
def _scan(
self,
matching_data: type,
splits: Tuple[Tuple[slice, slice]],
n_jobs: int,
callback_class,
rotate_mask: bool = False,
**kwargs,
) -> List:
"""
Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
:py:class:`tme.analyzer.MaxScoreOverRotations`.
"""
from ._jax_utils import scan as scan_inner
pad_target = True if len(splits) > 1 else False
convolution_mode = "valid" if pad_target else "same"
target_pad = matching_data.target_padding(pad_target=pad_target)
target_shape = tuple(
(x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
)
conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
target_shape=self.to_numpy_array(target_shape),
template_shape=self.to_numpy_array(matching_data._template.shape),
pad_fourier=False,
)
analyzer_args = {
"convolution_mode": convolution_mode,
"fourier_shift": shift,
"targetshape": target_shape,
"templateshape": matching_data.template.shape,
"convolution_shape": conv_shape,
}
create_target_filter = matching_data.target_filter is not None
create_template_filter = matching_data.template_filter is not None
create_filter = create_target_filter or create_template_filter
# Applying the filter leads to more FFTs
fastt_shape = matching_data._template.shape
if create_template_filter:
fastt_shape = matching_data._template.shape
ret, template_filter, target_filter = [], 1, 1
rotation_mapping = {
self.tobytes(matching_data.rotations[i]): i
for i in range(matching_data.rotations.shape[0])
}
for split_start in range(0, len(splits), n_jobs):
split_subset = splits[split_start : (split_start + n_jobs)]
if not len(split_subset):
continue
targets, translation_offsets = [], []
for target_split, template_split in split_subset:
base = matching_data.subset_by_slice(
target_slice=target_split,
target_pad=target_pad,
template_slice=template_split,
)
translation_offsets.append(base._translation_offset)
targets.append(self.topleft_pad(base._target, fast_shape))
if create_filter:
filter_args = {
"data_rfft": self.fft.rfftn(targets[0]),
"return_real_fourier": True,
"shape_is_real_fourier": False,
}
if create_template_filter:
template_filter = matching_data.template_filter(
shape=fastt_shape, **filter_args
)["data"]
template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
if create_target_filter:
target_filter = matching_data.target_filter(
shape=fast_shape, **filter_args
)["data"]
target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
create_filter, create_template_filter, create_target_filter = (False,) * 3
base, targets = None, self._array_backend.stack(targets)
scores, rotations = scan_inner(
targets,
matching_data.template,
matching_data.template_mask,
matching_data.rotations,
template_filter,
target_filter,
fast_shape,
rotate_mask,
)
for index in range(scores.shape[0]):
temp = callback_class(
scores=scores[index],
rotations=rotations[index],
thread_safe=False,
offset=translation_offsets[index],
)
temp.rotation_mapping = rotation_mapping
ret.append(tuple(temp._postprocess(**analyzer_args)))
return ret
[docs]
def get_available_memory(self) -> int:
import jax
_memory = {"cpu": 0, "gpu": 0}
for device in jax.devices():
if device.platform == "cpu":
_memory["cpu"] = super().get_available_memory()
else:
mem_stats = device.memory_stats()
_memory["gpu"] += mem_stats.get("bytes_limit", 0)
if _memory["gpu"] > 0:
return _memory["gpu"]
return _memory["cpu"]