""" Implements class Wedge and WedgeReconstructed to create Fourier
filter representations.
Copyright (c) 2024 European Molecular Biology Laboratory
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""
from typing import Tuple, Dict
import numpy as np
from ..types import NDArray
from ..backends import backend as be
from .compose import ComposableFilter
from ..matching_utils import centered
from ..rotations import euler_to_rotationmatrix
from ._utils import (
centered_grid,
frequency_grid_at_angle,
compute_tilt_shape,
crop_real_fourier,
fftfreqn,
shift_fourier,
create_reconstruction_filter,
)
__all__ = ["Wedge", "WedgeReconstructed"]
[docs]
class Wedge(ComposableFilter):
"""
Generate wedge mask for tomographic data.
Parameters
----------
shape : tuple of int
The shape of the reconstruction volume.
tilt_axis : int
Axis the plane is tilted over.
opening_axis : int
The axis around which the volume is opened.
angles : tuple of float
The tilt angles.
weights : tuple of float
The weights corresponding to each tilt angle.
weight_type : str, optional
The type of weighting to apply, defaults to None.
frequency_cutoff : float, optional
Frequency cutoff for created mask. Nyquist 0.5 by default.
Returns
-------
Dict
A dictionary containing weighted wedges and related information.
"""
def __init__(
self,
shape: Tuple[int],
tilt_axis: int,
opening_axis: int,
angles: Tuple[float],
weights: Tuple[float],
weight_type: str = None,
frequency_cutoff: float = 0.5,
):
self.shape = shape
self.tilt_axis = tilt_axis
self.opening_axis = opening_axis
self.angles = angles
self.weights = weights
self.frequency_cutoff = frequency_cutoff
[docs]
@classmethod
def from_file(cls, filename: str) -> "Wedge":
"""
Generate a :py:class:`Wedge` instance by reading tilt angles and weights
from a tab-separated text file.
Parameters
----------
filename : str
The path to the file containing tilt angles and weights.
Returns
-------
:py:class:`Wedge`
Class instance instance initialized with angles and weights from the file.
"""
data = cls._from_text(filename)
angles, weights = data.get("angles", None), data.get("weights", None)
if angles is None:
raise ValueError(f"Could not find colum angles in {filename}")
if weights is None:
weights = [1] * len(angles)
if len(weights) != len(angles):
raise ValueError("Length of weights and angles differ.")
return cls(
shape=None,
tilt_axis=0,
opening_axis=2,
angles=np.array(angles, dtype=np.float32),
weights=np.array(weights, dtype=np.float32),
)
@staticmethod
def _from_text(filename: str, delimiter="\t") -> Dict:
"""
Read column data from a text file.
Parameters
----------
filename : str
The path to the text file.
delimiter : str, optional
The delimiter used in the file, defaults to '\t'.
Returns
-------
Dict
A dictionary with one key for each column.
"""
with open(filename, mode="r", encoding="utf-8") as infile:
data = [x.strip() for x in infile.read().split("\n")]
data = [x.split("\t") for x in data if len(x)]
headers = data.pop(0)
ret = {header: list(column) for header, column in zip(headers, zip(*data))}
return ret
[docs]
def __call__(self, **kwargs: Dict) -> NDArray:
func_args = vars(self).copy()
func_args.update(kwargs)
weight_types = {
None: self.weight_angle,
"angle": self.weight_angle,
"relion": self.weight_relion,
"grigorieff": self.weight_grigorieff,
}
weight_type = func_args.get("weight_type", None)
if weight_type not in weight_types:
raise ValueError(
f"Supported weight_types are {','.join(list(weight_types.keys()))}"
)
if weight_type == "angle":
func_args["weights"] = np.cos(np.radians(self.angles))
ret = weight_types[weight_type](**func_args)
frequency_cutoff = func_args.get("frequency_cutoff", None)
if frequency_cutoff is not None:
for index, angle in enumerate(func_args["angles"]):
frequency_grid = frequency_grid_at_angle(
shape=func_args["shape"],
opening_axis=self.opening_axis,
tilt_axis=self.tilt_axis,
angle=angle,
sampling_rate=1,
)
ret[index] = np.multiply(ret[index], frequency_grid <= frequency_cutoff)
ret = be.astype(be.to_backend_array(ret), be._float_dtype)
return {
"data": ret,
"angles": func_args["angles"],
"tilt_axis": func_args["tilt_axis"],
"opening_axis": func_args["opening_axis"],
"sampling_rate": func_args.get("sampling_rate", 1),
"is_multiplicative_filter": True,
}
[docs]
@staticmethod
def weight_angle(
shape: Tuple[int],
weights: Tuple[float],
angles: Tuple[float],
opening_axis: int,
tilt_axis: int,
**kwargs,
) -> NDArray:
"""
Generate weighted wedges based on the cosine of the current angle.
"""
tilt_shape = compute_tilt_shape(
shape=shape, opening_axis=opening_axis, reduce_dim=True
)
wedge, wedges = np.ones(tilt_shape), np.zeros((len(angles), *tilt_shape))
for index, angle in enumerate(angles):
wedge.fill(weights[index])
wedges[index] = wedge
return wedges
[docs]
def weight_relion(
self, shape: Tuple[int], opening_axis: int, tilt_axis: int, **kwargs
) -> NDArray:
"""
Generate weighted wedges based on the RELION 1.4 formalism, weighting each
angle using the cosine of the angle and a Gaussian lowpass filter computed
with respect to the exposure per angstrom.
Returns
-------
NDArray
Weighted wedges.
"""
tilt_shape = compute_tilt_shape(
shape=shape, opening_axis=opening_axis, reduce_dim=True
)
wedges = np.zeros((len(self.angles), *tilt_shape))
for index, angle in enumerate(self.angles):
frequency_grid = frequency_grid_at_angle(
shape=shape,
opening_axis=opening_axis,
tilt_axis=tilt_axis,
angle=angle,
sampling_rate=1,
)
sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
sigma = -2 * np.pi**2 * sigma**2
np.square(frequency_grid, out=frequency_grid)
np.multiply(sigma, frequency_grid, out=frequency_grid)
np.exp(frequency_grid, out=frequency_grid)
np.multiply(frequency_grid, np.cos(np.radians(angle)), out=frequency_grid)
wedges[index] = frequency_grid
return wedges
[docs]
def weight_grigorieff(
self,
shape: Tuple[int],
opening_axis: int,
tilt_axis: int,
amplitude: float = 0.245,
power: float = -1.665,
offset: float = 2.81,
**kwargs,
) -> NDArray:
"""
Generate weighted wedges based on the formalism introduced in [1]_.
Returns
-------
NDArray
Weighted wedges.
References
----------
.. [1] Timothy Grant, Nikolaus Grigorieff (2015), eLife 4:e06980.
"""
tilt_shape = compute_tilt_shape(
shape=shape, opening_axis=opening_axis, reduce_dim=True
)
wedges = np.zeros((len(self.angles), *tilt_shape), dtype=be._float_dtype)
for index, angle in enumerate(self.angles):
frequency_grid = frequency_grid_at_angle(
shape=shape,
opening_axis=opening_axis,
tilt_axis=tilt_axis,
angle=angle,
sampling_rate=1,
)
with np.errstate(divide="ignore"):
np.power(frequency_grid, power, out=frequency_grid)
np.multiply(amplitude, frequency_grid, out=frequency_grid)
np.add(frequency_grid, offset, out=frequency_grid)
np.multiply(-2, frequency_grid, out=frequency_grid)
np.divide(
self.weights[index],
frequency_grid,
out=frequency_grid,
)
wedges[index] = np.exp(frequency_grid)
return wedges
[docs]
class WedgeReconstructed:
"""
Initialize :py:class:`WedgeReconstructed`.
Parameters
----------
angles :tuple of float, optional
The tilt angles, defaults to None.
opening_axis : int, optional
The axis around which the wedge is opened.
tilt_axis : int, optional
The axis along which the tilt is applied.
weights : tuple of float, optional
Weights to assign to individual wedge components.
weight_wedge : bool, optional
Whether individual wedge components should be weighted. If True and weights
is None, uses the cosine of the angle otherwise weights.
create_continuous_wedge: bool, optional
Whether to create a continous wedge or a per-component wedge. Weights are only
considered for non-continuous wedges.
frequency_cutoff : float, optional
Filter window applied during reconstruction.
**kwargs : Dict
Additional keyword arguments.
"""
def __init__(
self,
opening_axis: int,
tilt_axis: int,
angles: Tuple[float] = None,
weights: Tuple[float] = None,
weight_wedge: bool = False,
create_continuous_wedge: bool = False,
frequency_cutoff: float = 0.5,
reconstruction_filter: str = None,
**kwargs: Dict,
):
self.angles = angles
self.opening_axis = opening_axis
self.tilt_axis = tilt_axis
self.weights = weights
self.weight_wedge = weight_wedge
self.reconstruction_filter = reconstruction_filter
self.create_continuous_wedge = create_continuous_wedge
self.frequency_cutoff = frequency_cutoff
[docs]
def __call__(self, shape: Tuple[int], **kwargs: Dict) -> Dict:
"""
Generate the reconstructed wedge.
Parameters
----------
shape : tuple of int
The shape of the reconstruction volume.
**kwargs : Dict
Additional keyword arguments.
Returns
-------
Dict
A dictionary containing the reconstructed wedge and related information.
"""
func_args = vars(self).copy()
func_args.update(kwargs)
if kwargs.get("is_fourier_shape", False):
print("Cannot create continuous wedge mask based on real fourier shape.")
func = self.step_wedge
if func_args.get("create_continuous_wedge", False):
func = self.continuous_wedge
weight_wedge = func_args.get("weight_wedge", False)
if func_args.get("wedge_weights") is None and weight_wedge:
func_args["weights"] = np.cos(
np.radians(be.to_numpy_array(func_args.get("angles", (0,))))
)
ret = func(shape=shape, **func_args)
frequency_cutoff = func_args.get("frequency_cutoff", None)
if frequency_cutoff is not None:
frequency_mask = fftfreqn(
shape=shape,
sampling_rate=1,
compute_euclidean_norm=True,
shape_is_real_fourier=False,
)
ret = np.multiply(ret, frequency_mask <= frequency_cutoff, out=ret)
if not weight_wedge:
ret = (ret > 0) * 1.0
ret = be.astype(be.to_backend_array(ret), be._float_dtype)
ret = shift_fourier(data=ret, shape_is_real_fourier=False)
if func_args.get("return_real_fourier", False):
ret = crop_real_fourier(ret)
return {
"data": ret,
"shape_is_real_fourier": func_args["return_real_fourier"],
"shape": ret.shape,
"tilt_axis": func_args["tilt_axis"],
"opening_axis": func_args["opening_axis"],
"is_multiplicative_filter": True,
"angles": func_args["angles"],
}
[docs]
@staticmethod
def continuous_wedge(
shape: Tuple[int],
angles: Tuple[float],
opening_axis: int,
tilt_axis: int,
**kwargs: Dict,
) -> NDArray:
"""
Generate a continous wedge mask with DC component at the center.
Parameters
----------
shape : tuple of int
The shape of the reconstruction volume.
angles : tuple of float
Start and stop tilt angle.
opening_axis : int
The axis around which the wedge is opened.
tilt_axis : int
The axis along which the tilt is applied.
Returns
-------
NDArray
Wedge mask.
"""
aspect_ratio = shape[opening_axis] / shape[tilt_axis]
angles = np.degrees(np.arctan(np.tan(np.radians(angles)) * aspect_ratio))
start_radians = np.tan(np.radians(90 - angles[0]))
stop_radians = np.tan(np.radians(-1 * (90 - angles[1])))
grid = centered_grid(shape)
with np.errstate(divide="ignore", invalid="ignore"):
ratios = np.where(
grid[opening_axis] == 0,
np.tan(np.radians(90)) + 1,
grid[tilt_axis] / grid[opening_axis],
)
wedge = np.logical_or(start_radians <= ratios, stop_radians >= ratios).astype(
np.float32
)
return wedge
[docs]
@staticmethod
def step_wedge(
shape: Tuple[int],
angles: Tuple[float],
opening_axis: int,
tilt_axis: int,
weights: Tuple[float] = None,
reconstruction_filter: str = None,
**kwargs: Dict,
) -> NDArray:
"""
Generate a per-angle wedge shape with DC component at the center.
Parameters
----------
shape : tuple of int
The shape of the reconstruction volume.
angles : tuple of float
The tilt angles.
opening_axis : int
The axis around which the wedge is opened.
tilt_axis : int
The axis along which the tilt is applied.
reconstruction_filter : str
Filter used during reconstruction.
weights : tuple of float, optional
Weights to assign to individual tilts. Defaults to 1.
Returns
-------
NDArray
Wege mask.
"""
from ..backends import NumpyFFTWBackend
angles = np.asarray(be.to_numpy_array(angles))
if weights is None:
weights = np.ones(angles.size)
weights = np.asarray(weights)
shape = tuple(int(x) for x in shape)
opening_axis, tilt_axis = int(opening_axis), int(tilt_axis)
weights = np.repeat(weights, angles.size // weights.size)
plane = np.zeros(
(shape[opening_axis], shape[tilt_axis] + (1 - shape[tilt_axis] % 2)),
dtype=np.float32,
)
aspect_ratio = plane.shape[0] / plane.shape[1]
angles = np.degrees(np.arctan(np.tan(np.radians(angles)) * aspect_ratio))
rec_filter = 1
if reconstruction_filter is not None:
rec_filter = create_reconstruction_filter(
plane.shape[::-1], filter_type=reconstruction_filter, tilt_angles=angles
).T
subset = tuple(
slice(None) if i != 0 else slice(x // 2, x // 2 + 1)
for i, x in enumerate(plane.shape)
)
plane_rotated, wedge_volume = np.zeros_like(plane), np.zeros_like(plane)
for index in range(angles.shape[0]):
plane_rotated.fill(0)
plane[subset] = 1
rotation_matrix = euler_to_rotationmatrix((angles[index], 0))
rotation_matrix = rotation_matrix[np.ix_((0, 1), (0, 1))]
NumpyFFTWBackend().rigid_transform(
arr=plane * rec_filter,
rotation_matrix=rotation_matrix,
out=plane_rotated,
use_geometric_center=True,
order=1,
)
wedge_volume += plane_rotated * weights[index]
wedge_volume = centered(wedge_volume, (shape[opening_axis], shape[tilt_axis]))
np.fmin(wedge_volume, np.max(weights), wedge_volume)
if opening_axis > tilt_axis:
wedge_volume = np.moveaxis(wedge_volume, 1, 0)
reshape_dimensions = tuple(
x if i in (opening_axis, tilt_axis) else 1 for i, x in enumerate(shape)
)
wedge_volume = wedge_volume.reshape(reshape_dimensions)
tile_dimensions = np.divide(shape, reshape_dimensions).astype(int)
wedge_volume = np.tile(wedge_volume, tile_dimensions)
return wedge_volume