"""
Backend using pytorch and optionally GPU acceleration for
template matching.
Copyright (c) 2023 European Molecular Biology Laboratory
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""
from typing import Tuple
from contextlib import contextmanager
from multiprocessing import shared_memory
from multiprocessing.managers import SharedMemoryManager
import numpy as np
from .npfftw_backend import NumpyFFTWBackend
from ..types import NDArray, TorchTensor, shm_type
[docs]
class PytorchBackend(NumpyFFTWBackend):
"""
A pytorch-based matching backend.
"""
def __init__(
self,
device="cuda",
float_dtype=None,
complex_dtype=None,
int_dtype=None,
overflow_safe_dtype=None,
**kwargs,
):
import torch
import torch.nn.functional as F
float_dtype = torch.float32 if float_dtype is None else float_dtype
complex_dtype = torch.complex64 if complex_dtype is None else complex_dtype
int_dtype = torch.int32 if int_dtype is None else int_dtype
if overflow_safe_dtype is None:
overflow_safe_dtype = torch.float32
super().__init__(
array_backend=torch,
float_dtype=float_dtype,
complex_dtype=complex_dtype,
int_dtype=int_dtype,
overflow_safe_dtype=overflow_safe_dtype,
)
self.device = device
self.F = F
[docs]
def to_backend_array(self, arr: NDArray, check_device: bool = True) -> TorchTensor:
if isinstance(arr, self._array_backend.Tensor):
if arr.device == self.device or not check_device:
return arr
return arr.to(self.device)
return self.tensor(arr, device=self.device)
[docs]
def to_numpy_array(self, arr: TorchTensor) -> NDArray:
if isinstance(arr, np.ndarray):
return arr
elif isinstance(arr, self._array_backend.Tensor):
return arr.cpu().numpy()
return np.array(arr)
[docs]
def to_cpu_array(self, arr: TorchTensor) -> NDArray:
return arr.cpu()
[docs]
def get_fundamental_dtype(self, arr):
if self._array_backend.is_floating_point(arr):
return float
elif self._array_backend.is_complex(arr):
return complex
return int
[docs]
def free_cache(self):
self._array_backend.cuda.empty_cache()
[docs]
def mod(self, x1, x2, *args, **kwargs):
return self._array_backend.remainder(x1, x2, *args, **kwargs)
[docs]
def max(self, *args, **kwargs) -> NDArray:
ret = self._array_backend.amax(*args, **kwargs)
if isinstance(ret, self._array_backend.Tensor):
return ret
return ret[0]
[docs]
def min(self, *args, **kwargs) -> NDArray:
ret = self._array_backend.amin(*args, **kwargs)
if isinstance(ret, self._array_backend.Tensor):
return ret
return ret[0]
[docs]
def maximum(self, x1, x2, *args, **kwargs) -> NDArray:
x1 = self.to_backend_array(x1, check_device=False)
x2 = self.to_backend_array(x2, check_device=False).to(x1.device)
return self._array_backend.maximum(input=x1, other=x2, *args, **kwargs)
[docs]
def minimum(self, x1, x2, *args, **kwargs) -> NDArray:
x1 = self.to_backend_array(x1)
x2 = self.to_backend_array(x2)
return self._array_backend.minimum(input=x1, other=x2, *args, **kwargs)
[docs]
def tobytes(self, arr):
return arr.cpu().numpy().tobytes()
[docs]
def size(self, arr):
return arr.numel()
[docs]
def zeros(self, shape, dtype=None):
return self._array_backend.zeros(shape, dtype=dtype, device=self.device)
[docs]
def copy(self, arr: TorchTensor) -> TorchTensor:
return self._array_backend.clone(arr)
[docs]
def full(self, shape, fill_value, dtype=None):
if isinstance(shape, int):
shape = (shape,)
return self._array_backend.full(
size=shape, dtype=dtype, fill_value=fill_value, device=self.device
)
[docs]
def arange(self, *args, **kwargs):
return self._array_backend.arange(*args, **kwargs, device=self.device)
[docs]
def datatype_bytes(self, dtype: type) -> int:
temp = self.zeros(1, dtype=dtype)
return temp.element_size()
[docs]
def fill(self, arr: TorchTensor, value: float) -> TorchTensor:
arr.fill_(value)
return arr
[docs]
def astype(self, arr: TorchTensor, dtype: type) -> TorchTensor:
return arr.to(dtype)
[docs]
@staticmethod
def at(arr, idx, value) -> NDArray:
arr[idx] = value
return arr
[docs]
@staticmethod
def addat(arr, indices, *args, **kwargs) -> NDArray:
return arr.index_put_(indices, *args, accumulate=True, **kwargs)
[docs]
def flip(self, a, axis, **kwargs):
return self._array_backend.flip(input=a, dims=axis, **kwargs)
[docs]
def topk_indices(self, arr, k):
temp = arr.reshape(-1)
values, indices = self._array_backend.topk(temp, k)
indices = self.unravel_index(indices=indices, shape=arr.shape)
return indices
[docs]
def indices(self, shape: Tuple[int], dtype: type = int) -> TorchTensor:
grids = [self.arange(x, dtype=dtype) for x in shape]
mesh = self._array_backend.meshgrid(*grids, indexing="ij")
return self._array_backend.stack(mesh)
[docs]
def unravel_index(self, indices, shape):
indices = self.to_backend_array(indices)
shape = self.to_backend_array(shape)
strides = self._array_backend.cumprod(shape.flip(0), dim=0).flip(0)
strides = self._array_backend.cat(
(strides[1:], self.to_backend_array([1])),
)
unraveled_coords = (indices.view(-1, 1) // strides.view(1, -1)) % shape.view(
1, -1
)
if unraveled_coords.size(0) == 1:
return (unraveled_coords[0, :],)
else:
return tuple(unraveled_coords.T)
[docs]
def roll(self, a, shift, axis, **kwargs):
shift = tuple(shift)
return self._array_backend.roll(input=a, shifts=shift, dims=axis, **kwargs)
[docs]
def unique(
self,
ar,
return_index: bool = False,
return_inverse: bool = False,
return_counts: bool = False,
axis: int = None,
sorted: bool = True,
):
# https://github.com/pytorch/pytorch/issues/36748#issuecomment-1478913448
unique, inverse, counts = self._array_backend.unique(
ar, return_inverse=True, return_counts=True, dim=axis, sorted=sorted
)
inverse = inverse.reshape(-1)
if return_index:
inv_sorted = inverse.argsort(stable=True)
tot_counts = self._array_backend.cat(
(counts.new_zeros(1), counts.cumsum(dim=0))
)[:-1]
index = inv_sorted[tot_counts]
ret = unique
if return_index or return_inverse or return_counts:
ret = [unique]
if return_index:
ret.append(index)
if return_inverse:
ret.append(inverse)
if return_counts:
ret.append(counts)
return ret
[docs]
def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
if score_space.ndim == 3:
func = self._array_backend.nn.MaxPool3d
elif score_space.ndim == 2:
func = self._array_backend.nn.MaxPool2d
else:
raise NotImplementedError("Operation only implemented for 2 and 3D inputs.")
pool = func(
kernel_size=min_distance, padding=min_distance // 2, return_indices=True
)
_, indices = pool(score_space.reshape(1, 1, *score_space.shape))
coordinates = self.unravel_index(indices.reshape(-1), score_space.shape)
coordinates = self.transpose(self.stack(coordinates))
return coordinates
[docs]
def repeat(self, *args, **kwargs):
return self._array_backend.repeat_interleave(*args, **kwargs)
[docs]
def from_sharedarr(self, args) -> TorchTensor:
if self.device == "cuda":
return args
shm, shape, dtype = args
required_size = int(self._array_backend.prod(self.to_backend_array(shape)))
ret = self._array_backend.frombuffer(shm.buf, dtype=dtype)[
:required_size
].reshape(shape)
return ret
[docs]
def to_sharedarr(
self, arr: TorchTensor, shared_memory_handler: type = None
) -> shm_type:
if self.device == "cuda":
return arr
nbytes = arr.numel() * arr.element_size()
if isinstance(shared_memory_handler, SharedMemoryManager):
shm = shared_memory_handler.SharedMemory(size=nbytes)
else:
shm = shared_memory.SharedMemory(create=True, size=nbytes)
shm.buf[:nbytes] = arr.numpy().tobytes()
return shm, arr.shape, arr.dtype
[docs]
def transpose(self, arr, axes=None):
if axes is None:
axes = tuple(range(arr.ndim - 1, -1, -1))
return arr.permute(axes)
[docs]
def power(self, *args, **kwargs):
return self._array_backend.pow(*args, **kwargs)
[docs]
def eye(self, *args, **kwargs):
if "device" not in kwargs:
kwargs["device"] = self.device
return self._array_backend.eye(*args, **kwargs)
[docs]
def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
kwargs["dim"] = kwargs.pop("axes", None)
return self._array_backend.fft.rfftn(arr, **kwargs)
[docs]
def irfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
kwargs["dim"] = kwargs.pop("axes", None)
return self._array_backend.fft.irfftn(arr, **kwargs)
def _build_transform_matrix(
self,
shape: Tuple[int],
rotation_matrix: TorchTensor,
translation: TorchTensor = None,
center: TorchTensor = None,
**kwargs,
) -> TorchTensor:
"""
Express the transform matrix in normalized coordinates.
"""
shape = self.to_backend_array(shape) - 1
scale_factors = 2.0 / shape
if center is not None:
center = center - shape / 2
center = center * scale_factors
if translation is not None:
translation = translation * scale_factors
return super()._build_transform_matrix(
rotation_matrix=self.flip(rotation_matrix, [0, 1]),
translation=translation,
center=center,
)
def _rigid_transform(
self,
arr: TorchTensor,
matrix: TorchTensor,
arr_mask: TorchTensor = None,
out: TorchTensor = None,
out_mask: TorchTensor = None,
order: int = 1,
batched: bool = False,
**kwargs,
) -> Tuple[TorchTensor, TorchTensor]:
"""Apply rigid transformation using homogeneous transformation matrix."""
_mode_mapping = {0: "nearest", 1: "bilinear", 3: "bicubic"}
mode = _mode_mapping.get(order, None)
if mode is None:
modes = ", ".join([str(x) for x in _mode_mapping.keys()])
raise ValueError(
f"Got {order} but supported interpolation orders are: {modes}."
)
batch_size, spatial_dims = 1, arr.shape
out_slice = tuple(slice(0, x) for x in arr.shape)
if batched:
matrix = matrix[1:, 1:]
batch_size, *spatial_dims = arr.shape
# Remove homogeneous row and expand for batch processing
matrix = matrix[:-1, :].to(arr.dtype)
matrix = matrix.unsqueeze(0).expand(batch_size, -1, -1)
grid = self.F.affine_grid(
theta=matrix.to(arr.dtype),
size=self.Size([batch_size, 1, *spatial_dims]),
align_corners=False,
)
arr = arr.unsqueeze(0) if not batched else arr
ret = self.F.grid_sample(
input=arr.unsqueeze(1),
grid=grid,
mode=mode,
align_corners=False,
).squeeze(1)
ret_mask = None
if arr_mask is not None:
arr_mask = arr_mask.unsqueeze(0) if not batched else arr_mask
ret_mask = self.F.grid_sample(
input=arr_mask.unsqueeze(1),
grid=grid,
mode=mode,
align_corners=False,
).squeeze(1)
if not batched:
ret = ret.squeeze(0)
ret_mask = ret_mask.squeeze(0) if arr_mask is not None else None
if out is not None:
out[out_slice] = ret
else:
out = ret
if out_mask is not None:
out_mask[out_slice] = ret_mask
else:
out_mask = ret_mask
return out, out_mask
[docs]
def get_available_memory(self) -> int:
if self.device == "cpu":
return super().get_available_memory()
return self._array_backend.cuda.mem_get_info()[0]
[docs]
@contextmanager
def set_device(self, device_index: int):
if self.device == "cuda":
with self._array_backend.cuda.device(device_index):
yield
else:
yield None
[docs]
def device_count(self) -> int:
if self.device == "cpu":
return 1
return self._array_backend.cuda.device_count()
[docs]
def reverse(self, arr: TorchTensor, axis: Tuple[int] = None) -> TorchTensor:
if axis is None:
axis = tuple(range(arr.ndim))
return self._array_backend.flip(arr, [i for i in range(arr.ndim) if i in axis])
[docs]
def triu_indices(self, n: int, k: int = 0, m: int = None) -> TorchTensor:
if m is None:
m = n
return self._array_backend.triu_indices(n, m, k)