""" Backend using Apple's MLX library for template matching.
Copyright (c) 2024 European Molecular Biology Laboratory
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""
from typing import Tuple, List, Callable
import numpy as np
from .npfftw_backend import NumpyFFTWBackend
from ..types import NDArray, MlxArray, Scalar, shm_type
[docs]
class MLXBackend(NumpyFFTWBackend):
"""
A mlx-based matching backend.
"""
def __init__(
self,
device="cpu",
float_dtype=None,
complex_dtype=None,
int_dtype=None,
overflow_safe_dtype=None,
**kwargs,
):
import mlx.core as mx
device = mx.cpu if device == "cpu" else mx.gpu
float_dtype = mx.float32 if float_dtype is None else float_dtype
complex_dtype = mx.complex64 if complex_dtype is None else complex_dtype
int_dtype = mx.int32 if int_dtype is None else int_dtype
if overflow_safe_dtype is None:
overflow_safe_dtype = mx.float32
super().__init__(
array_backend=mx,
float_dtype=float_dtype,
complex_dtype=complex_dtype,
int_dtype=int_dtype,
overflow_safe_dtype=overflow_safe_dtype,
)
self.device = device
[docs]
def to_backend_array(self, arr: NDArray) -> MlxArray:
return self._array_backend.array(arr)
[docs]
def to_numpy_array(self, arr: MlxArray) -> NDArray:
return np.array(arr)
[docs]
def to_cpu_array(self, arr: MlxArray) -> NDArray:
return arr
[docs]
def free_cache(self):
pass
[docs]
def mod(self, arr1: MlxArray, arr2: MlxArray, out: MlxArray = None) -> MlxArray:
if out is not None:
out[:] = arr1 % arr2
return None
return arr1 % arr2
[docs]
def add(self, x1, x2, out: MlxArray = None, **kwargs) -> MlxArray:
x1 = self.to_backend_array(x1)
x2 = self.to_backend_array(x2)
if out is not None:
out[:] = self._array_backend.add(x1, x2, **kwargs)
return None
return self._array_backend.add(x1, x2, **kwargs)
[docs]
def multiply(self, x1, x2, out: MlxArray = None, **kwargs) -> MlxArray:
x1 = self.to_backend_array(x1)
x2 = self.to_backend_array(x2)
if out is not None:
out[:] = self._array_backend.multiply(x1, x2, **kwargs)
return None
return self._array_backend.multiply(x1, x2, **kwargs)
[docs]
def std(self, arr: MlxArray, axis) -> Scalar:
return self._array_backend.sqrt(arr.var(axis=axis))
[docs]
def unique(self, *args, **kwargs):
ret = np.unique(*args, **kwargs)
if isinstance(ret, tuple):
ret = [self.to_backend_array(x) for x in ret]
return ret
[docs]
def tobytes(self, arr):
return self.to_numpy_array(arr).tobytes()
[docs]
def full(self, shape, fill_value, dtype=None):
return self._array_backend.full(shape=shape, dtype=dtype, vals=fill_value)
[docs]
def fill(self, arr: MlxArray, value: Scalar) -> MlxArray:
arr[:] = value
return arr
[docs]
def zeros(self, shape: Tuple[int], dtype: type = None) -> MlxArray:
return self._array_backend.zeros(shape=shape, dtype=dtype)
[docs]
def roll(self, a: MlxArray, shift, axis, **kwargs):
a = self.to_numpy_array(a)
ret = NumpyFFTWBackend().roll(
a,
shift=shift,
axis=axis,
**kwargs,
)
return self.to_backend_array(ret)
[docs]
def build_fft(
self, fast_shape: Tuple[int], fast_ft_shape: Tuple[int], **kwargs
) -> Tuple[Callable, Callable]:
"""
Build fft builder functions.
Parameters
----------
fast_shape : tuple
Tuple of integers corresponding to fast convolution shape
(see `compute_convolution_shapes`).
fast_ft_shape : tuple
Tuple of integers corresponding to the shape of the fourier
transform array (see `compute_convolution_shapes`).
**kwargs : dict, optional
Additional parameters that are not used for now.
Returns
-------
tuple
Tupple containing callable rfft and irfft object.
"""
# Runs on mlx.core.cpu until Metal support is available
def rfftn(arr: MlxArray, out: MlxArray, shape: Tuple[int] = fast_shape) -> None:
out[:] = self._array_backend.fft.rfftn(
arr, s=shape, stream=self._array_backend.cpu
)
def irfftn(
arr: MlxArray, out: MlxArray, shape: Tuple[int] = fast_shape
) -> None:
out[:] = self._array_backend.fft.irfftn(
arr, s=shape, stream=self._array_backend.cpu
)
return rfftn, irfftn
[docs]
def from_sharedarr(self, arr: MlxArray) -> MlxArray:
return arr
[docs]
@staticmethod
def to_sharedarr(arr: MlxArray, shared_memory_handler: type = None) -> shm_type:
return arr
[docs]
def topk_indices(self, arr: NDArray, k: int):
arr = self.to_numpy_array(arr)
ret = NumpyFFTWBackend().topk_indices(arr=arr, k=k)
ret = [self.to_backend_array(x) for x in ret]
return ret
[docs]
def indices(self, arr: List) -> MlxArray:
ret = NumpyFFTWBackend().indices(arr)
return self.to_backend_array(ret)