Source code for tme.filters.reconstruction

"""
Defines filters on tomographic tilt series.

Copyright (c) 2024 European Molecular Biology Laboratory

Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

from typing import Tuple
from dataclasses import dataclass

import numpy as np

from ..backends import backend as be
from ..types import NDArray, BackendArray

from .compose import ComposableFilter
from ..rotations import euler_to_rotationmatrix
from ._utils import crop_real_fourier, shift_fourier, create_reconstruction_filter

__all__ = ["ReconstructFromTilt", "ShiftFourier"]


[docs] @dataclass class ReconstructFromTilt(ComposableFilter): """ Reconstruct a d+1 array from a d-dimensional input projection using weighted backprojection (WBP). """ #: Shape of the reconstruction. shape: Tuple[int] = None #: Angle of each individual tilt. angles: Tuple[float] = None #: Projection axis, defaults to 2 (z). opening_axis: int = 2 #: Tilt axis, defaults to 0 (x). tilt_axis: int = 0 #: Whether to return a share compliant with rfftn. return_real_fourier: bool = True #: Interpolation order used for rotation interpolation_order: int = 1 #: Filter window applied during reconstruction. reconstruction_filter: str = None
[docs] def __call__(self, return_real_fourier: bool = False, **kwargs): """ Reconstruct a d+1 array from a d-dimensional input using WBP. Parameters ---------- shape : tuple of int The shape of the reconstruction volume. data : BackendArray D-dimensional image stack with shape (n, ...). The data is assumed to be a Fourier transform of the stack you are trying to reconstruct with DC component in the center. angles : tuple of float Angle of each individual tilt. return_real_fourier : bool, optional Return a shape compliant return_real_fourier : tuple of int Return a shape compliant with rfft, i.e., omit the negative frequencies terms resulting in a return shape (*shape[:-1], shape[-1]//2+1). Defaults to False. reconstruction_filter : bool, optional Filter window applied during reconstruction. See :py:meth:`create_reconstruction_filter` for available options. tilt_axis : int Axis the plane is tilted over, defaults to 0 (x). opening_axis : int The projection axis, defaults to 2 (z). Returns ------- dict data: BackendArray The filter mask. shape: tuple of ints The requested filter shape return_real_fourier: bool Whether data is compliant with rfftn. is_multiplicative_filter: bool Whether the filter is multiplicative in Fourier space. """ func_args = vars(self).copy() func_args.update(kwargs) ret = self.reconstruct(**func_args) ret = shift_fourier(data=ret, shape_is_real_fourier=False) if return_real_fourier: ret = crop_real_fourier(ret) return { "data": ret, "shape": func_args["shape"], "return_real_fourier": return_real_fourier, "is_multiplicative_filter": False, }
[docs] @staticmethod def reconstruct( data: NDArray, shape: Tuple[int], angles: Tuple[float], opening_axis: int, tilt_axis: int, interpolation_order: int = 1, reconstruction_filter: str = None, **kwargs, ): """ Reconstruct a volume from a tilt series. Parameters ---------- data : NDArray The Fourier transform of tilt series data. shape : tuple of int Shape of the reconstruction. angles : tuple of float Angle of each individual tilt. opening_axis : int The axis around which the volume is opened. tilt_axis : int Axis the plane is tilted over. interpolation_order : int, optional Interpolation order used for rotation, defaults to 1. reconstruction_filter : bool, optional Filter window applied during reconstruction. See :py:meth:`create_reconstruction_filter` for available options. Returns ------- NDArray The reconstructed volume. """ if data.shape == shape: return data data = be.to_backend_array(data) volume_temp = be.zeros(shape, dtype=data.dtype) volume_temp_rotated = be.zeros(shape, dtype=data.dtype) volume = be.zeros(shape, dtype=data.dtype) slices = tuple(slice(a // 2, (a // 2) + 1) for a in shape) subset = tuple( slice(None) if i != opening_axis else x for i, x in enumerate(slices) ) angles_loop = be.zeros(len(shape)) wedge_dim = [x for x in data.shape] wedge_dim.insert(1 + opening_axis, 1) wedges = be.reshape(data, wedge_dim) rec_filter = 1 aspect_ratio = shape[opening_axis] / shape[tilt_axis] angles = np.degrees(np.arctan(np.tan(np.radians(angles)) * aspect_ratio)) if reconstruction_filter is not None: rec_filter = create_reconstruction_filter( filter_type=reconstruction_filter, filter_shape=(shape[tilt_axis],), tilt_angles=angles, ) rec_shape = tuple(1 if i != tilt_axis else x for i, x in enumerate(shape)) rec_filter = be.to_backend_array(rec_filter) rec_filter = be.reshape(rec_filter, rec_shape) angles = be.to_backend_array(angles) for index in range(len(angles)): angles_loop = be.fill(angles_loop, 0) volume_temp = be.fill(volume_temp, 0) volume_temp_rotated = be.fill(volume_temp_rotated, 0) # Jax compatibility volume_temp = be.at(volume_temp, subset, wedges[index] * rec_filter) angles_loop = be.at(angles_loop, tilt_axis, angles[index]) angles_loop = be.roll(angles_loop, (opening_axis - 1,), axis=0) rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles_loop)) rotation_matrix = be.to_backend_array(rotation_matrix) volume_temp_rotated, _ = be.rigid_transform( arr=volume_temp, rotation_matrix=rotation_matrix, out=volume_temp_rotated, use_geometric_center=True, order=interpolation_order, ) volume = be.add(volume, volume_temp_rotated, out=volume) return volume
class ShiftFourier(ComposableFilter): def __call__( self, data: BackendArray, shape_is_real_fourier: bool = False, return_real_fourier: bool = True, **kwargs, ): ret = [] for index in range(data.shape[0]): mask = be.to_numpy_array(data[index]) mask = shift_fourier(data=mask, shape_is_real_fourier=shape_is_real_fourier) if return_real_fourier: mask = crop_real_fourier(mask) ret.append(mask[None]) ret = np.concatenate(ret, axis=0) return { "data": ret, "shape": kwargs.get("shape"), "return_real_fourier": return_real_fourier, "is_multiplicative_filter": False, }