""" Backend using pytorch and optionally GPU acceleration for
template matching.
Copyright (c) 2023 European Molecular Biology Laboratory
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""
from typing import Tuple, Callable
from contextlib import contextmanager
from multiprocessing import shared_memory
from multiprocessing.managers import SharedMemoryManager
import numpy as np
from .npfftw_backend import NumpyFFTWBackend
from ..types import NDArray, TorchTensor, shm_type
[docs]
class PytorchBackend(NumpyFFTWBackend):
"""
A pytorch-based matching backend.
"""
def __init__(
self,
device="cuda",
float_dtype=None,
complex_dtype=None,
int_dtype=None,
overflow_safe_dtype=None,
**kwargs,
):
import torch
import torch.nn.functional as F
float_dtype = torch.float32 if float_dtype is None else float_dtype
complex_dtype = torch.complex64 if complex_dtype is None else complex_dtype
int_dtype = torch.int32 if int_dtype is None else int_dtype
if overflow_safe_dtype is None:
overflow_safe_dtype = torch.float32
super().__init__(
array_backend=torch,
float_dtype=float_dtype,
complex_dtype=complex_dtype,
int_dtype=int_dtype,
overflow_safe_dtype=overflow_safe_dtype,
)
self.device = device
self.F = F
[docs]
def to_backend_array(self, arr: NDArray) -> TorchTensor:
if isinstance(arr, self._array_backend.Tensor):
if arr.device == self.device:
return arr
return arr.to(self.device)
return self.tensor(arr, device=self.device)
[docs]
def to_numpy_array(self, arr: TorchTensor) -> NDArray:
if isinstance(arr, np.ndarray):
return arr
elif isinstance(arr, self._array_backend.Tensor):
return arr.cpu().numpy()
return np.array(arr)
[docs]
def to_cpu_array(self, arr: TorchTensor) -> NDArray:
return arr.cpu()
[docs]
def get_fundamental_dtype(self, arr):
if self._array_backend.is_floating_point(arr):
return float
elif self._array_backend.is_complex(arr):
return complex
return int
[docs]
def free_cache(self):
self._array_backend.cuda.empty_cache()
[docs]
def mod(self, x1, x2, *args, **kwargs):
return self._array_backend.remainder(x1, x2, *args, **kwargs)
[docs]
def max(self, *args, **kwargs) -> NDArray:
ret = self._array_backend.amax(*args, **kwargs)
if isinstance(ret, self._array_backend.Tensor):
return ret
return ret[0]
[docs]
def min(self, *args, **kwargs) -> NDArray:
ret = self._array_backend.amin(*args, **kwargs)
if isinstance(ret, self._array_backend.Tensor):
return ret
return ret[0]
[docs]
def maximum(self, x1, x2, *args, **kwargs) -> NDArray:
x1 = self.to_backend_array(x1)
x2 = self.to_backend_array(x2)
return self._array_backend.maximum(input=x1, other=x2, *args, **kwargs)
[docs]
def minimum(self, x1, x2, *args, **kwargs) -> NDArray:
x1 = self.to_backend_array(x1)
x2 = self.to_backend_array(x2)
return self._array_backend.minimum(input=x1, other=x2, *args, **kwargs)
[docs]
def tobytes(self, arr):
return arr.cpu().numpy().tobytes()
[docs]
def size(self, arr):
return arr.numel()
[docs]
def zeros(self, shape, dtype=None):
return self._array_backend.zeros(shape, dtype=dtype, device=self.device)
[docs]
def full(self, shape, fill_value, dtype=None):
return self._array_backend.full(
size=shape, dtype=dtype, fill_value=fill_value, device=self.device
)
[docs]
def arange(self, *args, **kwargs):
return self._array_backend.arange(*args, **kwargs, device=self.device)
[docs]
def datatype_bytes(self, dtype: type) -> int:
temp = self.zeros(1, dtype=dtype)
return temp.element_size()
[docs]
def fill(self, arr: TorchTensor, value: float) -> TorchTensor:
arr.fill_(value)
return arr
[docs]
def astype(self, arr: TorchTensor, dtype: type) -> TorchTensor:
return arr.to(dtype)
[docs]
def flip(self, a, axis, **kwargs):
return self._array_backend.flip(input=a, dims=axis, **kwargs)
[docs]
def topk_indices(self, arr, k):
temp = arr.reshape(-1)
values, indices = self._array_backend.topk(temp, k)
indices = self.unravel_index(indices=indices, shape=arr.shape)
return indices
[docs]
def indices(self, shape: Tuple[int]) -> TorchTensor:
grids = [self.arange(x) for x in shape]
mesh = self._array_backend.meshgrid(*grids, indexing="ij")
return self._array_backend.stack(mesh)
[docs]
def unravel_index(self, indices, shape):
indices = self.to_backend_array(indices)
shape = self.to_backend_array(shape)
strides = self._array_backend.cumprod(shape.flip(0), dim=0).flip(0)
strides = self._array_backend.cat(
(strides[1:], self.to_backend_array([1])),
)
unraveled_coords = (indices.view(-1, 1) // strides.view(1, -1)) % shape.view(
1, -1
)
if unraveled_coords.size(0) == 1:
return (unraveled_coords[0, :],)
else:
return tuple(unraveled_coords.T)
[docs]
def roll(self, a, shift, axis, **kwargs):
shift = tuple(shift)
return self._array_backend.roll(input=a, shifts=shift, dims=axis, **kwargs)
[docs]
def unique(
self,
ar,
return_index: bool = False,
return_inverse: bool = False,
return_counts: bool = False,
axis: int = None,
sorted: bool = True,
):
# https://github.com/pytorch/pytorch/issues/36748#issuecomment-1478913448
unique, inverse, counts = self._array_backend.unique(
ar, return_inverse=True, return_counts=True, dim=axis, sorted=sorted
)
inverse = inverse.reshape(-1)
if return_index:
inv_sorted = inverse.argsort(stable=True)
tot_counts = self._array_backend.cat(
(counts.new_zeros(1), counts.cumsum(dim=0))
)[:-1]
index = inv_sorted[tot_counts]
ret = unique
if return_index or return_inverse or return_counts:
ret = [unique]
if return_index:
ret.append(index)
if return_inverse:
ret.append(inverse)
if return_counts:
ret.append(counts)
return ret
[docs]
def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
if score_space.ndim == 3:
func = self._array_backend.nn.MaxPool3d
elif score_space.ndim == 2:
func = self._array_backend.nn.MaxPool2d
else:
raise NotImplementedError("Operation only implemented for 2 and 3D inputs.")
pool = func(
kernel_size=min_distance, padding=min_distance // 2, return_indices=True
)
_, indices = pool(score_space.reshape(1, 1, *score_space.shape))
coordinates = self.unravel_index(indices.reshape(-1), score_space.shape)
coordinates = self.transpose(self.stack(coordinates))
return coordinates
[docs]
def repeat(self, *args, **kwargs):
return self._array_backend.repeat_interleave(*args, **kwargs)
[docs]
def from_sharedarr(self, args) -> TorchTensor:
if self.device == "cuda":
return args
shm, shape, dtype = args
required_size = int(self._array_backend.prod(self.to_backend_array(shape)))
ret = self._array_backend.frombuffer(shm.buf, dtype=dtype)[
:required_size
].reshape(shape)
return ret
[docs]
def to_sharedarr(
self, arr: TorchTensor, shared_memory_handler: type = None
) -> shm_type:
if self.device == "cuda":
return arr
nbytes = arr.numel() * arr.element_size()
if isinstance(shared_memory_handler, SharedMemoryManager):
shm = shared_memory_handler.SharedMemory(size=nbytes)
else:
shm = shared_memory.SharedMemory(create=True, size=nbytes)
shm.buf[:nbytes] = arr.numpy().tobytes()
return shm, arr.shape, arr.dtype
[docs]
def transpose(self, arr):
return arr.permute(*self._array_backend.arange(arr.ndim - 1, -1, -1))
[docs]
def power(self, *args, **kwargs):
return self._array_backend.pow(*args, **kwargs)
[docs]
def build_fft(
self,
fast_shape: Tuple[int],
fast_ft_shape: Tuple[int],
inverse_fast_shape: Tuple[int] = None,
**kwargs,
) -> Tuple[Callable, Callable]:
if inverse_fast_shape is None:
inverse_fast_shape = fast_shape
def rfftn(
arr: TorchTensor, out: TorchTensor, shape: Tuple[int] = fast_shape
) -> TorchTensor:
return self._array_backend.fft.rfftn(arr, s=shape, out=out)
def irfftn(
arr: TorchTensor, out: TorchTensor, shape: Tuple[int] = inverse_fast_shape
) -> TorchTensor:
return self._array_backend.fft.irfftn(arr, s=shape, out=out)
return rfftn, irfftn
def _affine_transform(
self,
arr: TorchTensor,
rotation_matrix: TorchTensor,
translation: TorchTensor,
mode,
) -> TorchTensor:
transformation_matrix = self._array_backend.zeros(
arr.ndim, arr.ndim + 1, device=arr.device, dtype=arr.dtype
)
transformation_matrix[:, : arr.ndim] = rotation_matrix
transformation_matrix[:, arr.ndim] = translation
size = self.Size([1, 1, *arr.shape])
grid = self.F.affine_grid(
theta=transformation_matrix.unsqueeze(0), size=size, align_corners=False
)
output = self.F.grid_sample(
input=arr.unsqueeze(0).unsqueeze(0),
grid=grid,
mode=mode,
align_corners=False,
)
return output.squeeze()
[docs]
def get_available_memory(self) -> int:
if self.device == "cpu":
return super().get_available_memory()
return self._array_backend.cuda.mem_get_info()[0]
[docs]
@contextmanager
def set_device(self, device_index: int):
if self.device == "cuda":
with self._array_backend.cuda.device(device_index):
yield
else:
yield None
[docs]
def device_count(self) -> int:
if self.device == "cpu":
return 1
return self._array_backend.cuda.device_count()
[docs]
def reverse(self, arr: TorchTensor) -> TorchTensor:
return self._array_backend.flip(arr, [i for i in range(arr.ndim)])