Source code for tme.filters.ctf

""" Implements class CTF to create Fourier filter representations.

    Copyright (c) 2024 European Molecular Biology Laboratory

    Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

import re
import warnings
from typing import Tuple, Dict
from dataclasses import dataclass

import numpy as np

from ..types import NDArray
from ..parser import StarParser
from ..backends import backend as be
from .compose import ComposableFilter
from ._utils import (
    frequency_grid_at_angle,
    compute_tilt_shape,
    crop_real_fourier,
    fftfreqn,
    shift_fourier,
)

__all__ = ["CTF"]


[docs] @dataclass class CTF(ComposableFilter): """ Generate a contrast transfer function mask. References ---------- .. [1] CTFFIND4: Fast and accurate defocus estimation from electron micrographs. Alexis Rohou and Nikolaus Grigorieff. Journal of Structural Biology 2015. """ #: The shape of the to-be reconstructed volume. shape: Tuple[int] = None #: The defocus value in x direction. defocus_x: float = None #: The tilt angles. angles: Tuple[float] = None #: The axis around which the wedge is opened, defaults to None. opening_axis: int = None #: The axis along which the tilt is applied, defaults to None. tilt_axis: int = None #: Whether to correct defocus gradient, defaults to False. correct_defocus_gradient: bool = False #: The sampling rate, defaults to 1 Angstrom / Voxel. sampling_rate: Tuple[float] = 1 #: The acceleration voltage in Volts, defaults to 300e3. acceleration_voltage: float = 300e3 #: The spherical aberration coefficient, defaults to 2.7e7. spherical_aberration: float = 2.7e7 #: The amplitude contrast, defaults to 0.07. amplitude_contrast: float = 0.07 #: The phase shift, defaults to 0. phase_shift: float = 0 #: The defocus angle, defaults to 0. defocus_angle: float = 0 #: The defocus value in y direction, defaults to None. defocus_y: float = None #: Whether the returned CTF should be phase-flipped. flip_phase: bool = True #: Whether to return a format compliant with rfft. Only relevant for single angles. return_real_fourier: bool = False #: Whether the output should not be used for n+1 dimensional reconstruction no_reconstruction: bool = True
[docs] @classmethod def from_file(cls, filename: str) -> "CTF": """ Initialize :py:class:`CTF` from file. Parameters ---------- filename : str The path to a file with ctf parameters. Supports the following formats: - CTFFIND4 """ if filename.lower().endswith("star"): data = cls._from_gctf(filename=filename) else: data = cls._from_ctffind(filename=filename) return cls( shape=None, angles=None, defocus_x=data["defocus_1"], sampling_rate=data["pixel_size"], acceleration_voltage=data["acceleration_voltage"], spherical_aberration=data["spherical_aberration"], amplitude_contrast=data["amplitude_contrast"], phase_shift=data["additional_phase_shift"], defocus_angle=np.degrees(data["azimuth_astigmatism"]), defocus_y=data["defocus_2"], )
@staticmethod def _from_ctffind(filename: str): parameter_regex = { "pixel_size": r"Pixel size: ([0-9.]+) Angstroms", "acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV", "spherical_aberration": r"spherical aberration: ([0-9.]+) mm", "amplitude_contrast": r"amplitude contrast: ([0-9.]+)", } with open(filename, mode="r", encoding="utf-8") as infile: lines = [x.strip() for x in infile.read().split("\n")] lines = [x for x in lines if len(x)] def _screen_params(line, params, output): for parameter, regex_pattern in parameter_regex.items(): match = re.search(regex_pattern, line) if match: output[parameter] = float(match.group(1)) columns = { "micrograph_number": 0, "defocus_1": 1, "defocus_2": 2, "azimuth_astigmatism": 3, "additional_phase_shift": 4, "cross_correlation": 5, "spacing": 6, } output = {k: [] for k in columns.keys()} for line in lines: if line.startswith("#"): _screen_params(line, params=parameter_regex, output=output) continue values = line.split() for key, value in columns.items(): output[key].append(float(values[value])) for key in columns: output[key] = np.array(output[key]) return output @staticmethod def _from_gctf(filename: str): parser = StarParser(filename) ctf_data = parser["data_"] mapping = { "defocus_1": ("_rlnDefocusU", float), "defocus_2": ("_rlnDefocusV", float), "pixel_size": ("_rlnDetectorPixelSize", float), "acceleration_voltage": ("_rlnVoltage", float), "spherical_aberration": ("_rlnSphericalAberration", float), "amplitude_contrast": ("_rlnAmplitudeContrast", float), "additional_phase_shift": (None, float), "azimuth_astigmatism": ("_rlnDefocusAngle", float), } output = {} for out_key, (key, key_dtype) in mapping.items(): if key not in ctf_data and key is not None: warnings.warn(f"ctf_data is missing key {key}.") key_value = ctf_data.get(key, [0]) output[out_key] = [key_dtype(x) for x in key_value] longest_key = max(map(len, output.values())) output = {k: v * longest_key if len(v) == 1 else v for k, v in output.items()} return output def __post_init__(self): self.defocus_angle = np.radians(self.defocus_angle) def _compute_electron_wavelength(self, acceleration_voltage: int = None): """Computes the wavelength of an electron in angstrom.""" if acceleration_voltage is None: acceleration_voltage = self.acceleration_voltage # Physical constants expressed in SI units planck_constant = 6.62606896e-34 electron_charge = 1.60217646e-19 electron_mass = 9.10938215e-31 light_velocity = 299792458 energy = electron_charge * acceleration_voltage denominator = energy**2 denominator += 2 * energy * electron_mass * light_velocity**2 electron_wavelength = np.divide( planck_constant * light_velocity, np.sqrt(denominator) ) # Convert to Ångstrom electron_wavelength *= 1e10 return electron_wavelength
[docs] def __call__(self, **kwargs) -> NDArray: func_args = vars(self).copy() func_args.update(kwargs) if len(func_args["angles"]) != len(func_args["defocus_x"]): func_args["angles"] = self.angles func_args["return_real_fourier"] = False func_args["tilt_axis"] = None func_args["opening_axis"] = None ret = self.weight(**func_args) ret = be.astype(be.to_backend_array(ret), be._float_dtype) return { "data": ret, "angles": func_args["angles"], "tilt_axis": func_args["tilt_axis"], "opening_axis": func_args["opening_axis"], "is_multiplicative_filter": True, }
@staticmethod def _pad_to_length(arr, length: int): ret = np.atleast_1d(arr) return np.repeat(ret, length // ret.size)
[docs] def weight( self, shape: Tuple[int], defocus_x: Tuple[float], angles: Tuple[float], opening_axis: int = None, tilt_axis: int = None, amplitude_contrast: float = 0.07, phase_shift: Tuple[float] = 0, defocus_angle: Tuple[float] = 0, defocus_y: Tuple[float] = None, correct_defocus_gradient: bool = False, sampling_rate: Tuple[float] = 1, acceleration_voltage: float = 300e3, spherical_aberration: float = 2.7e3, flip_phase: bool = True, return_real_fourier: bool = False, no_reconstruction: bool = True, cutoff_frequency: float = 0.5, **kwargs: Dict, ) -> NDArray: """ Compute the CTF weight tilt stack. Parameters ---------- shape : tuple of int The shape of the CTF. defocus_x : tuple of float The defocus value in x direction. angles : tuple of float The tilt angles. opening_axis : int, optional The axis around which the wedge is opened, defaults to None. tilt_axis : int, optional The axis along which the tilt is applied, defaults to None. amplitude_contrast : float, optional The amplitude contrast, defaults to 0.07. phase_shift : tuple of float, optional The phase shift, defaults to 0. defocus_angle : tuple of float, optional The defocus angle, defaults to 0. defocus_y : tuple of float, optional The defocus value in y direction, defaults to None. correct_defocus_gradient : bool, optional Whether to correct defocus gradient, defaults to False. sampling_rate : tuple of float, optional The sampling rate, defaults to 1. acceleration_voltage : float, optional The acceleration voltage in electron microscopy, defaults to 300e3. spherical_aberration : float, optional The spherical aberration coefficient, defaults to 2.7e3. flip_phase : bool, optional Whether the returned CTF should be phase-flipped. **kwargs : Dict Additional keyword arguments. Returns ------- NDArray A stack containing the CTF weight. """ angles = np.atleast_1d(angles) defoci_x = self._pad_to_length(defocus_x, angles.size) defoci_y = self._pad_to_length(defocus_y, angles.size) phase_shift = self._pad_to_length(phase_shift, angles.size) defocus_angle = self._pad_to_length(defocus_angle, angles.size) spherical_aberration = self._pad_to_length(spherical_aberration, angles.size) amplitude_contrast = self._pad_to_length(amplitude_contrast, angles.size) sampling_rate = np.max(sampling_rate) tilt_shape = compute_tilt_shape( shape=shape, opening_axis=opening_axis, reduce_dim=True ) stack = np.zeros((len(angles), *tilt_shape)) correct_defocus_gradient &= len(shape) == 3 correct_defocus_gradient &= tilt_axis is not None correct_defocus_gradient &= opening_axis is not None spherical_aberration /= sampling_rate electron_wavelength = self._compute_electron_wavelength() / sampling_rate electron_aberration = spherical_aberration * electron_wavelength**2 for index, angle in enumerate(angles): defocus_x, defocus_y = defoci_x[index], defoci_y[index] defocus_x = defocus_x / sampling_rate if defocus_x is not None else None defocus_y = defocus_y / sampling_rate if defocus_y is not None else None if correct_defocus_gradient or defocus_y is not None: grid_shape = shape sampling = be.divide(sampling_rate, be.to_backend_array(shape)) sampling = tuple(float(x) for x in sampling) if not no_reconstruction: grid_shape = tilt_shape sampling = tuple( x for i, x in enumerate(sampling) if i != opening_axis ) grid = fftfreqn( shape=grid_shape, sampling_rate=sampling, return_sparse_grid=True, ) # This should be done after defocus_x computation if correct_defocus_gradient: angle_rad = np.radians(angle) defocus_gradient = np.multiply(grid[1], np.sin(angle_rad)) remaining_axis = tuple( i for i in range(len(shape)) if i not in (opening_axis, tilt_axis) )[0] if tilt_axis > remaining_axis: defocus_x = np.add(defocus_x, defocus_gradient) elif tilt_axis < remaining_axis and defocus_y is not None: defocus_y = np.add(defocus_y, defocus_gradient.T) # 0.5 * (dx + dy) + cos(2 * (azimuth - astigmatism) * (dx - dy)) if defocus_y is not None: defocus_sum = np.add(defocus_x, defocus_y) defocus_difference = np.subtract(defocus_x, defocus_y) angular_grid = np.arctan2(grid[1], grid[0]) defocus_difference = np.multiply( defocus_difference, np.cos(2 * (angular_grid - defocus_angle[index])), ) defocus_x = np.add(defocus_sum, defocus_difference) defocus_x *= 0.5 frequency_grid = frequency_grid_at_angle( shape=shape, opening_axis=opening_axis, tilt_axis=tilt_axis, angle=angle, sampling_rate=1, ) frequency_mask = frequency_grid < cutoff_frequency # k^2*π*λ(dx - 0.5 * sph_abb * λ^2 * k^2) + phase_shift + ampl_contrast_term) np.square(frequency_grid, out=frequency_grid) chi = defocus_x - 0.5 * electron_aberration[index] * frequency_grid np.multiply(chi, np.pi * electron_wavelength, out=chi) np.multiply(chi, frequency_grid, out=chi) chi += phase_shift[index] chi += np.arctan( np.divide( amplitude_contrast[index], np.sqrt(1 - np.square(amplitude_contrast[index])), ) ) np.sin(-chi, out=chi) np.multiply(chi, frequency_mask, out=chi) if no_reconstruction: chi = shift_fourier(data=chi, shape_is_real_fourier=False) stack[index] = chi # Avoid contrast inversion np.negative(stack, out=stack) if flip_phase: np.abs(stack, out=stack) stack = be.to_backend_array(np.squeeze(stack)) if no_reconstruction and return_real_fourier: stack = crop_real_fourier(stack) return stack