Source code for tme.mask

"""
Utility functions for generating template matching masks.

Copyright (c) 2023 European Molecular Biology Laboratory

Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

import numpy as np
from typing import Tuple, Optional

from .types import NDArray
from scipy.ndimage import gaussian_filter
from .matching_utils import rigid_transform

__all__ = ["elliptical_mask", "tube_mask", "box_mask", "membrane_mask"]


[docs] def elliptical_mask( shape: Tuple[int], radius: Tuple[float], center: Optional[Tuple[float]] = None, orientation: Optional[NDArray] = None, sigma_decay: float = 0.0, cutoff_sigma: float = 3, **kwargs, ) -> NDArray: """ Creates an ellipsoidal mask. Parameters ---------- shape : tuple of ints Shape of the mask to be created. radius : tuple of floats Radius of the mask. center : tuple of floats, optional Center of the mask, default to shape // 2. orientation : NDArray, optional. Orientation of the mask as rotation matrix with shape (d,d). Returns ------- NDArray The created ellipsoidal mask. Raises ------ ValueError If the length of center and radius is not one or the same as shape. Examples -------- >>> from tme.matching_utils import elliptical_mask >>> mask = elliptical_mask(shape=(20,20), radius=(5,5), center=(10,10)) """ shape, radius = np.asarray(shape), np.asarray(radius) shape = shape.astype(int) if center is None: center = np.divide(shape, 2).astype(int) center = np.asarray(center, dtype=np.float32) radius = np.repeat(radius, shape.size // radius.size) center = np.repeat(center, shape.size // center.size) if radius.size != shape.size: raise ValueError("Length of radius has to be either one or match shape.") if center.size != shape.size: raise ValueError("Length of center has to be either one or match shape.") n = shape.size center = center.reshape((-1,) + (1,) * n) radius = radius.reshape((-1,) + (1,) * n) indices = np.indices(shape, dtype=np.float32) - center if orientation is not None: return_shape = indices.shape indices = indices.reshape(n, -1) rigid_transform( coordinates=indices, rotation_matrix=np.asarray(orientation), out=indices, translation=np.zeros(n), use_geometric_center=False, ) indices = indices.reshape(*return_shape) dist = np.linalg.norm(indices / radius, axis=0) if sigma_decay > 0: sigma_decay = 2 * (sigma_decay / np.mean(radius)) ** 2 mask = np.maximum(0, dist - 1) mask = np.exp(-(mask**2) / sigma_decay) mask *= mask > np.exp(-(cutoff_sigma**2) / 2) else: mask = (dist <= 1).astype(int) return mask
[docs] def box_mask( shape: Tuple[int], center: Tuple[int], size: Tuple[int], sigma_decay: float = 0.0, cutoff_sigma: float = 3.0, **kwargs, ) -> np.ndarray: """ Creates a box mask centered around the provided center point. Parameters ---------- shape : tuple of ints Shape of the output array. center : tuple of ints Center point coordinates of the box. size : tuple of ints Side length of the box along each axis. Returns ------- NDArray The created box mask. Raises ------ ValueError If ``shape`` and ``center`` do not have the same length. If ``center`` and ``height`` do not have the same length. """ if len(shape) != len(center) or len(center) != len(size): raise ValueError("The length of shape, center, and height must be consistent.") shape = tuple(int(x) for x in shape) center, size = np.array(center, dtype=int), np.array(size, dtype=int) half_heights = size // 2 starts = np.maximum(center - half_heights, 0) stops = np.minimum(center + half_heights + np.mod(size, 2) + 1, shape) slice_indices = tuple(slice(*coord) for coord in zip(starts, stops)) out = np.zeros(shape) out[slice_indices] = 1 if sigma_decay > 0: mask_filter = gaussian_filter( out.astype(np.float32), sigma=sigma_decay, truncate=cutoff_sigma ) out = np.add(out, (1 - out) * mask_filter) out *= out > np.exp(-(cutoff_sigma**2) / 2) return out
[docs] def tube_mask( shape: Tuple[int], symmetry_axis: int, center: Tuple[int], inner_radius: float, outer_radius: float, height: int, sigma_decay: float = 0.0, cutoff_sigma: float = 3.0, **kwargs, ) -> NDArray: """ Creates a tube mask. Parameters ---------- shape : tuple Shape of the mask to be created. symmetry_axis : int The axis of symmetry for the tube. base_center : tuple Center of the tube. inner_radius : float Inner radius of the tube. outer_radius : float Outer radius of the tube. height : int Height of the tube. Returns ------- NDArray The created tube mask. Raises ------ ValueError If ``inner_radius`` is larger than ``outer_radius``. If ``height`` is larger than the symmetry axis. If ``base_center`` and ``shape`` do not have the same length. """ if inner_radius > outer_radius: raise ValueError("inner_radius should be smaller than outer_radius.") if height > shape[symmetry_axis]: raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.") if symmetry_axis > len(shape): raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.") if len(center) != len(shape): raise ValueError("shape and base_center need to have the same length.") shape = tuple(int(x) for x in shape) circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis) circle_center = tuple(b for ix, b in enumerate(center) if ix != symmetry_axis) inner_circle = np.zeros(circle_shape) outer_circle = np.zeros_like(inner_circle) if inner_radius > 0: inner_circle = elliptical_mask( shape=circle_shape, radius=inner_radius, center=circle_center, sigma_decay=sigma_decay, cutoff_sigma=cutoff_sigma, ) if outer_radius > 0: outer_circle = elliptical_mask( shape=circle_shape, radius=outer_radius, center=circle_center, sigma_decay=sigma_decay, cutoff_sigma=cutoff_sigma, ) circle = outer_circle - inner_circle circle = np.expand_dims(circle, axis=symmetry_axis) center = center[symmetry_axis] start_idx = int(center - height // 2) stop_idx = int(center + height // 2 + height % 2) start_idx, stop_idx = max(start_idx, 0), min(stop_idx, shape[symmetry_axis]) slice_indices = tuple( slice(None) if i != symmetry_axis else slice(start_idx, stop_idx) for i in range(len(shape)) ) tube = np.zeros(shape) tube[slice_indices] = circle return tube
[docs] def membrane_mask( shape: Tuple[int], radius: float, thickness: float, separation: float, symmetry_axis: int = 2, center: Optional[Tuple[float]] = None, sigma_decay: float = 0.5, cutoff_sigma: float = 3, **kwargs, ) -> NDArray: """ Creates a membrane mask consisting of two parallel disks with Gaussian intensity profile. Uses efficient broadcasting approach: flat disk mask × height profile. Parameters ---------- shape : tuple of ints Shape of the mask to be created. radius : float Radius of the membrane disks. thickness : float Thickness of each disk in the membrane. separation : float Distance between the centers of the two disks. symmetry_axis : int, optional The axis perpendicular to the membrane disks, defaults to 2. center : tuple of floats, optional Center of the membrane (midpoint between the two disks), defaults to shape // 2. sigma_decay : float, optional Controls edge sharpness relative to radius, defaults to 0.5. cutoff_sigma : float, optional Cutoff for height profile in standard deviations, defaults to 3. Returns ------- NDArray The created membrane mask with Gaussian intensity profile. Raises ------ ValueError If ``thickness`` is negative. If ``separation`` is negative. If ``center`` and ``shape`` do not have the same length. If ``symmetry_axis`` is out of bounds. Examples -------- >>> from tme.matching_utils import membrane_mask >>> mask = membrane_mask(shape=(50,50,50), radius=10, thickness=2, separation=15) """ shape = np.asarray(shape, dtype=int) if center is None: center = np.divide(shape, 2).astype(float) center = np.asarray(center, dtype=np.float32) center = np.repeat(center, shape.size // center.size) if thickness < 0: raise ValueError("thickness must be non-negative.") if separation < 0: raise ValueError("separation must be non-negative.") if symmetry_axis >= len(shape): raise ValueError(f"symmetry_axis must be less than {len(shape)}.") if center.size != shape.size: raise ValueError("Length of center has to be either one or match shape.") disk_mask = elliptical_mask( shape=[x for i, x in enumerate(shape) if i != symmetry_axis], radius=radius, sigma_decay=sigma_decay, cutoff_sigma=cutoff_sigma, ) axial_coord = np.arange(shape[symmetry_axis]) - center[symmetry_axis] height_profile = np.zeros((shape[symmetry_axis],), dtype=np.float32) for leaflet_pos in [-separation / 2, separation / 2]: leaflet_profile = np.exp( -((axial_coord - leaflet_pos) ** 2) / (2 * (thickness / 3) ** 2) ) cutoff_threshold = np.exp(-(cutoff_sigma**2) / 2) leaflet_profile *= leaflet_profile > cutoff_threshold height_profile = np.maximum(height_profile, leaflet_profile) disk_mask = disk_mask.reshape( [x if i != symmetry_axis else 1 for i, x in enumerate(shape)] ) height_profile = height_profile.reshape( [1 if i != symmetry_axis else x for i, x in enumerate(shape)] ) return disk_mask * height_profile