Source code for tme.filters.ctf

"""
Implements class CTF and CTFReconstruced.

Copyright (c) 2024 European Molecular Biology Laboratory

Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

import re
from typing import Tuple, Dict
from dataclasses import dataclass

import numpy as np

from ..types import NDArray
from ..backends import backend as be
from .compose import ComposableFilter
from ..parser import StarParser, XMLParser, MDOCParser
from ._utils import (
    frequency_grid_at_angle,
    compute_tilt_shape,
    fftfreqn,
    pad_to_length,
)

__all__ = ["CTF", "CTFReconstructed", "create_ctf"]


[docs] @dataclass class CTF(ComposableFilter): """ Generate per-tilt contrast transfer function filter. """ #: The defocus in x direction (in units of sampling rate). defocus_x: Tuple[float] = None #: The tilt angles in degrees. angles: Tuple[float] = None #: The microscope projection axis, defaults to 2 (z). opening_axis: int = 2 #: The axis along which the tilt is applied, defaults to 0 (x). tilt_axis: int = 0 #: The sampling rate, defaults to 1 Ångstrom / voxel. sampling_rate: Tuple[float] = 1 #: The acceleration voltage in Volts, defaults to 300e3. acceleration_voltage: Tuple[float] = 300e3 #: The spherical aberration, defaults to 2.7e7 (in units of sampling rate). spherical_aberration: Tuple[float] = 2.7e7 #: The amplitude contrast, defaults to 0.07. amplitude_contrast: Tuple[float] = 0.07 #: The phase shift in radians, defaults to 0. phase_shift: Tuple[float] = 0 #: The defocus angle in radians, defaults to 0. defocus_angle: Tuple[float] = 0 #: The defocus value in y direction, defaults to None (in units of sampling rate). defocus_y: Tuple[float] = None #: Whether the returned CTF should be phase-flipped, defaults to True. flip_phase: bool = True
[docs] @classmethod def from_file(cls, filename: str, **kwargs) -> "CTF": """ Initialize :py:class:`CTF` from file. Parameters ---------- filename : str The path to a file with ctf parameters. Supports extensions are: +-------+---------------------------------------------------------+ | .star | GCTF file | +-------+---------------------------------------------------------+ | .xml | WARP/M XML file | +-------+---------------------------------------------------------+ | .mdoc | SerialEM file | +-------+---------------------------------------------------------+ | .* | CTFFIND4 file | +-------+---------------------------------------------------------+ **kwargs : optional Overwrite fields that cannot be extracted from input file. """ func = _from_ctffind if filename.lower().endswith("star"): func = _from_star elif filename.lower().endswith("xml"): func = _from_xml elif filename.lower().endswith("mdoc"): func = _from_mdoc data = func(filename=filename) # Pixel size needs to be overwritten by pixel size the ctf is generated for init_kwargs = { "angles": data.get("angles", None), "defocus_x": data["defocus_1"], "sampling_rate": data["pixel_size"], "acceleration_voltage": np.multiply(data["acceleration_voltage"], 1e3), "spherical_aberration": data.get("spherical_aberration"), "amplitude_contrast": data.get("amplitude_contrast"), "phase_shift": data.get("additional_phase_shift"), "defocus_angle": data.get("azimuth_astigmatism"), "defocus_y": data["defocus_2"], } for k, v in kwargs.items(): if k in init_kwargs and init_kwargs.get(k) is None: init_kwargs[k] = v init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None} # Moved format conversion from __post__init if "phase_shift" in init_kwargs: init_kwargs["phase_shift"] = np.radians(init_kwargs["phase_shift"]) if "defocus_angle" in init_kwargs: init_kwargs["defocus_angle"] = np.radians(init_kwargs["defocus_angle"]) return cls(**init_kwargs)
def _evaluate( self, shape: Tuple[int, ...], defocus_x: Tuple[float], angles: Tuple[float], opening_axis: int = 2, tilt_axis: int = 0, amplitude_contrast: Tuple[float] = 0.07, phase_shift: Tuple[float] = 0, defocus_angle: Tuple[float] = 0, defocus_y: Tuple[float] = None, sampling_rate: Tuple[float] = 1, acceleration_voltage: float = 300e3, spherical_aberration: float = 2.7e7, flip_phase: bool = True, cutoff_frequency: float = 0.5, **kwargs: Dict, ) -> Dict: """ Compute the CTF weight tilt stack. Parameters ---------- shape : tuple of int The shape of the CTF. defocus_x : tuple of float Defocus along the first principal axis in spatial units of sampling rate. angles : tuple of float The tilt angles in degrees. opening_axis : int, optional The axis around which the wedge is opened, defaults to 2. tilt_axis : int, optional The axis along which the tilt is applied, defaults to 0. amplitude_contrast : tuple of float, optional Amplitude contrast of microscope, defaults to 0.07. phase_shift : tuple of float, optional CTF phase shift in radians, defaults to 0. defocus_angle : tuple of float, optional Astigmatism angle in radians, defaults to 0. defocus_y : tuple of float, optional Defocus along the second principal axis in spatial units of sampling rate. sampling_rate : tuple of float, optional The sampling rate, defaults to 1. acceleration_voltage : float, optional The acceleration voltage in electron microscopy, defaults to 300e3. spherical_aberration : float, optional Spherical aberration of microscope in units of sampling rate. flip_phase : bool, optional Whether the returned CTF should be phase-flipped, defaults to True. **kwargs : Dict Additional keyword arguments. """ angles = np.atleast_1d(angles) defoci_x = pad_to_length(defocus_x, angles.size) defoci_y = pad_to_length(defocus_y, angles.size) phase_shift = pad_to_length(phase_shift, angles.size) defocus_angle = pad_to_length(defocus_angle, angles.size) spherical_aberration = pad_to_length(spherical_aberration, angles.size) amplitude_contrast = pad_to_length(amplitude_contrast, angles.size) acceleration_voltage = pad_to_length(acceleration_voltage, angles.size) sampling_rate = np.max(sampling_rate) ctf_shape = compute_tilt_shape( shape=shape, opening_axis=opening_axis, reduce_dim=True ) stack = np.zeros((len(angles), *ctf_shape)) # Shift tilt axis forward corrected_tilt_axis = tilt_axis if opening_axis and tilt_axis is not None: if opening_axis < tilt_axis: corrected_tilt_axis -= 1 for index, angle in enumerate(angles): chi = create_ctf( angle=angle, shape=ctf_shape, defocus_x=defoci_x[index], defocus_y=defoci_y[index], sampling_rate=sampling_rate, acceleration_voltage=acceleration_voltage[index], spherical_aberration=spherical_aberration[index], cutoff_frequency=cutoff_frequency, phase_shift=phase_shift[index], defocus_angle=defocus_angle[index], amplitude_contrast=amplitude_contrast[index], tilt_axis=corrected_tilt_axis, opening_axis=opening_axis, full_shape=shape, ) stack[index] = chi # Avoid contrast inversion stack = np.negative(stack, out=stack) if flip_phase: stack = np.abs(stack, out=stack) return {"data": be.to_backend_array(stack), "shape": shape}
[docs] @dataclass class CTFReconstructed(CTF): """ Generate CTF filter for reconstructions. """ def _evaluate( self, shape: Tuple[int], defocus_x: Tuple[float], amplitude_contrast: float = 0.07, phase_shift: Tuple[float] = 0, defocus_angle: Tuple[float] = 0, defocus_y: Tuple[float] = None, sampling_rate: Tuple[float] = 1, acceleration_voltage: float = 300e3, spherical_aberration: float = 2.7e3, flip_phase: bool = True, cutoff_frequency: float = 0.5, **kwargs: Dict, ) -> Dict: """ Compute the CTF weight tilt stack. Parameters ---------- shape : tuple of int The shape of the CTF. defocus_x : tuple of float Defocus along the first principal axis in spatial units of sampling rate. opening_axis : int, optional The axis around which the wedge is opened, defaults to 2. amplitude_contrast : float, optional The amplitude contrast, defaults to 0.07. phase_shift : tuple of float, optional CTF phase shift in radians, defaults to 0. defocus_angle : tuple of float, optional The defocus angle in radians, defaults to 0. defocus_y : tuple of float, optional Defocus along the second principal axis in spatial units of sampling rate. sampling_rate : tuple of float, optional The sampling rate, defaults to 1. acceleration_voltage : float, optional The acceleration voltage in electron microscopy, defaults to 300e3. spherical_aberration : float, optional The spherical aberration coefficient, defaults to 2.7e3. flip_phase : bool, optional Whether the returned CTF should be phase-flipped. **kwargs : Dict Additional keyword arguments. Returns ------- NDArray A stack containing the CTF weight. """ stack = create_ctf( shape=shape, defocus_x=defocus_x, defocus_y=defocus_y, sampling_rate=np.max(sampling_rate), acceleration_voltage=self.acceleration_voltage, spherical_aberration=spherical_aberration, cutoff_frequency=cutoff_frequency, phase_shift=phase_shift, defocus_angle=defocus_angle, amplitude_contrast=amplitude_contrast, ) # Avoid contrast inversion stack = np.negative(stack, out=stack) if flip_phase: stack = np.abs(stack, out=stack) return {"data": be.to_backend_array(stack), "shape": shape}
def _from_xml(filename: str) -> Dict: data = XMLParser(filename) params = { "PhaseShift": None, "Amplitude": None, "Defocus": None, "Voltage": None, "Cs": None, "DefocusAngle": None, "PixelSize": None, "Angles": data["Angles"], } ctf_options = data["CTF"]["Param"] for option in ctf_options: option = option["@attributes"] name = option["Name"] if name in params: params[name] = option["Value"] if "GridCTF" in data: ctf = data["GridCTF"]["Node"] params["Defocus"] = [ctf[i]["@attributes"]["Value"] for i in range(len(ctf))] ctf_phase = data["GridCTFPhase"]["Node"] params["PhaseShift"] = [ ctf_phase[i]["@attributes"]["Value"] for i in range(len(ctf_phase)) ] params["PhaseShift"] = np.degrees(params["PhaseShift"]) ctf_ast = data["GridCTFDefocusAngle"]["Node"] params["DefocusAngle"] = [ ctf_ast[i]["@attributes"]["Value"] for i in range(len(ctf_ast)) ] missing = [k for k, v in params.items() if v is None] if len(missing): raise ValueError(f"Could not find {missing} in {filename}.") params = { k: np.array(v) if hasattr(v, "__len__") else float(v) for k, v in params.items() } # Convert units to sampling rate (we assume it is Angstrom) params["Cs"] = float(params["Cs"] * 1e7) params["Defocus"] = params["Defocus"] * 1e4 mapping = { "angles": "Angles", "defocus_1": "Defocus", "defocus_2": "Defocus", "azimuth_astigmatism": "DefocusAngle", "additional_phase_shift": "PhaseShift", "acceleration_voltage": "Voltage", "spherical_aberration": "Cs", "amplitude_contrast": "Amplitude", "pixel_size": "PixelSize", } return {k: params[v] for k, v in mapping.items()} def _from_ctffind(filename: str) -> Dict: parameter_regex = { "pixel_size": r"Pixel size: ([0-9.]+) Angstroms", "acceleration_voltage": r"acceleration voltage: ([0-9.]+) keV", "spherical_aberration": r"spherical aberration: ([0-9.]+) mm", "amplitude_contrast": r"amplitude contrast: ([0-9.]+)", } with open(filename, mode="r", encoding="utf-8") as infile: lines = [x.strip() for x in infile.read().split("\n")] lines = [x for x in lines if len(x)] def _screen_params(line, params, output): for parameter, regex_pattern in parameter_regex.items(): match = re.search(regex_pattern, line) if match: output[parameter] = float(match.group(1)) columns = { "micrograph_number": 0, "defocus_1": 1, "defocus_2": 2, "azimuth_astigmatism": 3, "additional_phase_shift": 4, "cross_correlation": 5, "spacing": 6, } output = {k: [] for k in columns.keys()} for line in lines: if line.startswith("#"): _screen_params(line, params=parameter_regex, output=output) continue values = line.split() for key, value in columns.items(): output[key].append(float(values[value])) for key in columns: output[key] = np.array(output[key]) output["additional_phase_shift"] = np.degrees(output["additional_phase_shift"]) cs = output.get("spherical_aberration") if cs is not None: output["spherical_aberration"] = float(cs) * 1e7 return output def _from_star(filename: str) -> Dict: parser = StarParser(filename) if "data_stopgap_wedgelist" in parser: key = "data_stopgap_wedgelist" mapping = { "angles": ("_tilt_angle", float, 1), "defocus_1": ("_defocus", float, 1e4), "defocus_2": (None, float, 1e4), "pixel_size": ("_pixelsize", float, 1), "acceleration_voltage": ("_voltage", float, 1), "spherical_aberration": ("_cs", float, 1e7), "amplitude_contrast": ("_amp_contrast", float, 1), "additional_phase_shift": (None, float, 1), "azimuth_astigmatism": (None, float, 1), } else: key = "data_" mapping = { "defocus_1": ("_rlnDefocusU", float, 1), "defocus_2": ("_rlnDefocusV", float, 1), "pixel_size": ("_rlnDetectorPixelSize", float, 1), "acceleration_voltage": ("_rlnVoltage", float, 1), "spherical_aberration": ("_rlnSphericalAberration", float, 1), "amplitude_contrast": ("_rlnAmplitudeContrast", float, 1), "additional_phase_shift": (None, float, 1), "azimuth_astigmatism": ("_rlnDefocusAngle", float, 1), } output = {} ctf_data = parser[key] for out_key, (key, key_dtype, scale) in mapping.items(): key_value = ctf_data.get(key) if key_value is not None: try: key_value = [key_dtype(x) * scale for x in key_value] except Exception: pass output[out_key] = key_value return output def _from_mdoc(filename: str) -> Dict: parser = MDOCParser(filename) mapping = { "angles": ("TiltAngle", float), "defocus_1": ("Defocus", float), "acceleration_voltage": ("Voltage", float), # These will be None, but on purpose "pixel_size": ("_rlnDetectorPixelSize", float), "defocus_2": ("Defocus2", float), "spherical_aberration": ("_rlnSphericalAberration", float), "amplitude_contrast": ("_rlnAmplitudeContrast", float), "additional_phase_shift": (None, float), "azimuth_astigmatism": ("_rlnDefocusAngle", float), } output = {} for out_key, (key, key_dtype) in mapping.items(): output[out_key] = parser.get(key, None) # Adjust convention and convert to Angstrom output["defocus_1"] = np.multiply(output["defocus_1"], -1e4) return output def _compute_electron_wavelength(acceleration_voltage: int = 300e3): """Computes the wavelength of an electron in angstrom.""" # Physical constants expressed in SI units planck_constant = 6.62606896e-34 electron_charge = 1.60217646e-19 electron_mass = 9.10938215e-31 light_velocity = 299792458 energy = electron_charge * acceleration_voltage denominator = energy**2 denominator += 2 * energy * electron_mass * light_velocity**2 electron_wavelength = np.divide( planck_constant * light_velocity, np.sqrt(denominator) ) # Convert to Ångstrom electron_wavelength *= 1e10 return electron_wavelength def create_ctf( shape: Tuple[int], defocus_x: float, acceleration_voltage: float = 300e3, defocus_angle: float = 0, phase_shift: float = 0, defocus_y: float = None, sampling_rate: float = 1, spherical_aberration: float = 2.7e7, amplitude_contrast: float = 0.07, cutoff_frequency: float = 0.5, angle: float = None, tilt_axis: int = 0, opening_axis: int = None, full_shape: Tuple[int] = None, ) -> NDArray: """ Create CTF representation using the definition from [1]_. Parameters ---------- shape : Tuple[int] Shape of the returned CTF mask. defocus_x : float Defocus along the first principal axis in spatial units of sampling rate, e.g. 30000 Angstrom. acceleration_voltage : float, optional Acceleration voltage in keV, defaults to 300e3. defocus_angle : float, optional Astigmatism angle in radians, defaults to 0. phase_shift : float, optional CTF phase shift in radians, defaults to 0. defocus_y : float, optional Defocus along the second principal axis in spatial units of sampling rate. tilt_axis : int, optional Axes the specimen was tilted over, defaults to 0 (x-axis). sampling_rate : float or tuple of floats Sampling rate throughout shape, e.g., 4 Angstrom per voxel. amplitude_contrast : float, optional Amplitude contrast of microscope, defaults to 0.07. spherical_aberration : float, optional Spherical aberration of microscope in units of sampling rate. angle : float, optional Assume the created CTF is a projection observed at angle degrees. opening_axis : int, optional Projection axis, only relevant if angle is given. full_shape : tuple of ints Shape of the entire volume we are observing a projection of. This is required to compute aspect ratios for correct scaling. For instance, the 2D CTF slice could be (50,50), while the final 3D CTF volume is (50,50,25) with the opening_axis being 2, i.e., the z-axis. Returns ------- NDArray CTF mask. References ---------- .. [1] CTFFIND4: Fast and accurate defocus estimation from electron micrographs. Alexis Rohou and Nikolaus Grigorieff. Journal of Structural Biology 2015. """ electron_wavelength = _compute_electron_wavelength(acceleration_voltage) electron_wavelength /= sampling_rate aberration = (spherical_aberration / sampling_rate) * electron_wavelength**2 defocus_x = defocus_x / sampling_rate if defocus_x is not None else None defocus_y = defocus_y / sampling_rate if defocus_y is not None else None if defocus_y is not None: if len(shape) < 2: raise ValueError(f"Length of shape needs to be at least 2, got {shape}") # Axial distance from grid center in voxels grid = fftfreqn( shape=shape, sampling_rate=None, return_sparse_grid=True, fftshift=False, ) # 0.5 * (dx + dy) + cos(2 * (azimuth - astigmatism) * (dx - dy)) if defocus_y is not None: defocus_sum = np.add(defocus_x, defocus_y) defocus_difference = np.subtract(defocus_x, defocus_y) # Reusing grid, but in principle pure frequencies would suffice angular_grid = np.arctan2(grid[1], grid[0]) defocus_difference = np.multiply( defocus_difference, np.cos(2 * (angular_grid - defocus_angle)), ) defocus_x = np.add(defocus_sum, defocus_difference) defocus_x *= 0.5 frequency_grid = fftfreqn( shape, sampling_rate=1, compute_euclidean_norm=True, fftshift=False ) if angle is not None and opening_axis is not None and full_shape is not None: frequency_grid = frequency_grid_at_angle( shape=full_shape, tilt_axis=tilt_axis, opening_axis=opening_axis, angle=angle, sampling_rate=1, fftshift=False, ) frequency_mask = frequency_grid <= cutoff_frequency # k^2*π*λ(dx - 0.5 * sph_abb * λ^2 * k^2) + phase_shift + ampl_contrast_term) frequency_grid = np.square(frequency_grid, out=frequency_grid) chi = defocus_x - 0.5 * aberration * frequency_grid chi = np.multiply(chi, np.pi * electron_wavelength, out=chi) chi = np.multiply(chi, frequency_grid, out=chi) chi += phase_shift chi += np.arctan( np.divide( amplitude_contrast, np.sqrt(1 - np.square(amplitude_contrast)), ) ) chi = np.sin(-chi, out=chi) return np.multiply(chi, frequency_mask, out=chi)