Source code for tme.filters.reconstruction

""" Defines filters on tomographic tilt series.

    Copyright (c) 2024 European Molecular Biology Laboratory

    Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

from typing import Tuple
from dataclasses import dataclass

import numpy as np

from ..types import NDArray
from ..backends import backend as be

from .compose import ComposableFilter
from ..rotations import euler_to_rotationmatrix
from ._utils import (
    crop_real_fourier,
    shift_fourier,
    create_reconstruction_filter,
)

__all__ = ["ReconstructFromTilt"]


[docs] @dataclass class ReconstructFromTilt(ComposableFilter): """Reconstruct a volume from a tilt series.""" #: Shape of the reconstruction. shape: Tuple[int] = None #: Angle of each individual tilt. angles: Tuple[float] = None #: The axis around which the volume is opened. opening_axis: int = 0 #: Axis the plane is tilted over. tilt_axis: int = 2 #: Whether to return a share compliant with rfftn. return_real_fourier: bool = True #: Interpolation order used for rotation interpolation_order: int = 1 #: Filter window applied during reconstruction. reconstruction_filter: str = None
[docs] def __call__(self, **kwargs): func_args = vars(self).copy() func_args.update(kwargs) ret = self.reconstruct(**func_args) return { "data": ret, "shape": ret.shape, "shape_is_real_fourier": func_args["return_real_fourier"], "angles": func_args["angles"], "tilt_axis": func_args["tilt_axis"], "opening_axis": func_args["opening_axis"], "is_multiplicative_filter": False, }
[docs] @staticmethod def reconstruct( data: NDArray, shape: Tuple[int], angles: Tuple[float], opening_axis: int, tilt_axis: int, interpolation_order: int = 1, return_real_fourier: bool = True, reconstruction_filter: str = None, **kwargs, ): """ Reconstruct a volume from a tilt series. Parameters ---------- data : NDArray The tilt series data. shape : tuple of int Shape of the reconstruction. angles : tuple of float Angle of each individual tilt. opening_axis : int The axis around which the volume is opened. tilt_axis : int Axis the plane is tilted over. interpolation_order : int, optional Interpolation order used for rotation, defaults to 1. return_real_fourier : bool, optional Whether to return a shape compliant with rfftn, defaults to True. reconstruction_filter : bool, optional Filter window applied during reconstruction. See :py:meth:`create_reconstruction_filter` for available options. Returns ------- NDArray The reconstructed volume. """ if data.shape == shape: return data data = be.to_backend_array(data) volume_temp = be.zeros(shape, dtype=be._float_dtype) volume_temp_rotated = be.zeros(shape, dtype=be._float_dtype) volume = be.zeros(shape, dtype=be._float_dtype) slices = tuple(slice(a // 2, (a // 2) + 1) for a in shape) subset = tuple( slice(None) if i != opening_axis else x for i, x in enumerate(slices) ) angles_loop = be.zeros(len(shape)) wedge_dim = [x for x in data.shape] wedge_dim.insert(1 + opening_axis, 1) wedges = be.reshape(data, wedge_dim) rec_filter = 1 aspect_ratio = shape[opening_axis] / shape[tilt_axis] angles = np.degrees(np.arctan(np.tan(np.radians(angles)) * aspect_ratio)) if reconstruction_filter is not None: rec_filter = create_reconstruction_filter( filter_type=reconstruction_filter, filter_shape=(shape[tilt_axis],), tilt_angles=angles, ) rec_shape = tuple(1 if i != tilt_axis else x for i, x in enumerate(shape)) rec_filter = be.to_backend_array(rec_filter) rec_filter = be.reshape(rec_filter, rec_shape) angles = be.to_backend_array(angles) for index in range(len(angles)): angles_loop = be.fill(angles_loop, 0) volume_temp = be.fill(volume_temp, 0) volume_temp_rotated = be.fill(volume_temp_rotated, 0) # Jax compatibility volume_temp = be.at(volume_temp, subset, wedges[index] * rec_filter) angles_loop = be.at(angles_loop, tilt_axis, angles[index]) angles_loop = be.roll(angles_loop, (opening_axis - 1,), axis=0) rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles_loop)) rotation_matrix = be.to_backend_array(rotation_matrix) volume_temp_rotated, _ = be.rigid_transform( arr=volume_temp, rotation_matrix=rotation_matrix, out=volume_temp_rotated, use_geometric_center=True, order=interpolation_order, ) volume = be.add(volume, volume_temp_rotated, out=volume) volume = shift_fourier(data=volume, shape_is_real_fourier=False) if return_real_fourier: volume = crop_real_fourier(volume) return volume