Source code for mosaic.formats.parser

"""
IO methods to parse a variety of file formats.

Copyright (c) 2024 European Molecular Biology Laboratory

Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

import warnings
from string import ascii_lowercase
from dataclasses import dataclass
import xml.etree.ElementTree as ET
from typing import List, Dict, Any, Optional

import numpy as np
from scipy.spatial.transform import Rotation

from .. import meshing
from ..utils import volume_to_points, compute_bounding_box, NORMAL_REFERENCE


def _parse_data_array(data_array: ET.Element, dtype: type = float) -> np.ndarray:
    """
    Parse a DataArray element into a numpy array.

    Parameters
    ----------
    data_array : ET.Element
        XML element containing array data.
    dtype : type, optional
        Data type for parsing, by default float.

    Returns
    -------
    np.ndarray
        Parsed numpy array.
    """
    rows = [row.strip() for row in data_array.text.strip().split("\n") if row.strip()]
    parsed_rows = [[dtype(x) for x in row.split()] for row in rows]
    data = np.array(parsed_rows)
    return np.squeeze(data)


def _parse_dtype(xml_element) -> object:
    """
    Determine data type from XML element type attribute.

    Parameters
    ----------
    xml_element : ET.Element
        XML element to parse type from.

    Returns
    -------
    object
        Data type (float or int).
    """
    return float if xml_element.get("type", "").startswith("Float") else int


[docs] @dataclass class GeometryData: """ Container for single geometry entity data. Parameters ---------- vertices : np.ndarray, optional 3D vertex coordinates. normals : np.ndarray, optional Normal vectors at each vertex. faces : np.ndarray, optional Face connectivity indices. quaternions : np.ndarray, optional Orientation quaternions for each vertex. vertex_properties : VertexPropertyContainer, optional Additional vertex properties. shape : List[int], optional Bounding box dimensions. sampling : List[float], optional Sampling rates along each axis, by default (1, 1, 1). """ vertices: np.ndarray = None normals: np.ndarray = None faces: np.ndarray = None quaternions: np.ndarray = None vertex_properties: "VertexPropertyContainer" = None shape: List[int] = None sampling: List[float] = (1, 1, 1)
[docs] @dataclass class GeometryDataContainer: """ Container for multiple geometry entities with automatic validation. Parameters ---------- vertices : List[np.ndarray], optional List of vertex arrays for each geometry entity. normals : List[np.ndarray], optional List of normal arrays for each geometry entity. faces : List[np.ndarray], optional List of face arrays for each geometry entity. quaternions : List[np.ndarray], optional List of quaternion arrays for each geometry entity. vertex_properties : List[VertexPropertyContainer], optional List of vertex property containers for each geometry entity. shape : List[int], optional Bounding box dimensions. sampling : List[float], optional Sampling rates along each axis, by default (1, 1, 1). """ vertices: List[np.ndarray] = None normals: List[np.ndarray] = None faces: List[np.ndarray] = None quaternions: List[np.ndarray] = None vertex_properties: List["VertexPropertyContainer"] = None shape: List[int] = None sampling: List[float] = (1, 1, 1) def __post_init__(self): dtype_map = { "vertices": np.float32, "normals": np.float32, "faces": int, "quaternions": np.float32, } for attr_name, dtype in dtype_map.items(): attr = getattr(self, attr_name) setattr(self, attr_name, self._to_dtype(attr, dtype)) if self.normals is None: self.normals = [ np.full_like(x, fill_value=NORMAL_REFERENCE) for x in self.vertices ] if self.vertex_properties is None: self.vertex_properties = [VertexPropertyContainer() for _ in self.vertices] for i in range(len(self.normals)): norm = np.linalg.norm(self.normals[i], axis=1) mask = norm < 1e-12 norm[mask] = 1 self.normals[i][mask] = NORMAL_REFERENCE self.normals[i] = self.normals[i] / norm[:, None] if self.shape is None: self.shape, _ = compute_bounding_box(self.vertices) if len(self.vertices) != len(self.normals): raise ValueError("Normals need to be specified for each vertex set.") if self.faces is not None: if len(self.vertices) != len(self.faces): raise ValueError("Faces need to be specified for each vertex set.") if self.quaternions is None: self.quaternions = [None for x in self.vertices] def __len__(self): return len(self.vertices) def __iter__(self): yield from [self[i] for i in range(len(self))] def __getitem__(self, index: int) -> GeometryData: return GeometryData( vertices=self.vertices[index], normals=self.normals[index], shape=self.shape, sampling=self.sampling, faces=self.faces[index] if self.faces is not None else None, quaternions=self.quaternions[index], vertex_properties=self.vertex_properties[index], ) @staticmethod def _to_dtype(data: List[np.ndarray], dtype=np.float32): if data is not None: return [x.astype(dtype) for x in data] return data
class VertexPropertyContainer: """ Container for managing custom vertex properties with automatic synchronization. Parameters ---------- properties : dict of str -> np.ndarray, optional Dictionary mapping property names to vertex data arrays. """ def __init__(self, properties: Optional[Dict[str, np.ndarray]] = None): """ Initialize vertex property container. Parameters ---------- properties : dict of str -> np.ndarray, optional Dictionary mapping property names to vertex data arrays """ properties = {} if properties is None else properties properties = {name: np.asarray(data) for name, data in properties.items()} # We use len instead of size for future vector field support self._n_vertices = max((*(len(x) for x in properties.values()), 0)) for name, data in properties.items(): if len(data) == self._n_vertices: continue raise ValueError( f"Property '{name}' has {len(data)} values, " f"but expected {self._n_vertices} to match vertex count" ) self._properties = properties def __getitem__(self, idx: str) -> "VertexPropertyContainer": """Array-like indexing using int/bool numpy arrays, slices or ellipses.""" if not self._properties: return VertexPropertyContainer() if isinstance(idx, (int, np.integer)): idx = [idx] elif isinstance(idx, slice) or idx is ...: idx = np.arange(self._n_vertices)[idx] idx = np.asarray(idx) if idx.dtype == bool: idx = np.where(idx)[0] return VertexPropertyContainer( {k: v[idx].copy() for k, v in self._properties.items()} ) @property def properties(self): """List available vertex properties.""" return list(self._properties.keys()) def get_property(self, name: str, default: Any = None) -> Optional[np.ndarray]: """Get property data by name.""" return self._properties.get(name, default) def remove_property(self, name: str) -> None: _ = self._properties.pop(name, None) def copy(self) -> "VertexPropertyContainer": """Create a deep copy of the container.""" return self[...] @classmethod def merge( cls, containers: List["VertexPropertyContainer"] ) -> "VertexPropertyContainer": """ Merge multiple property containers. Parameters ---------- containers : list of VertexPropertyContainer Containers to merge Returns ------- VertexPropertyContainer New container with merged properties """ containers = [c for c in containers if c._properties] if not containers: return cls() all_props = set(containers[0].properties) common_props = set(containers[0].properties) for container in containers[1:]: container_props = set(container.properties) common_props &= container_props all_props |= container_props if not common_props: warnings.warn("No common properties found across containers to merge") return cls() dropped_props = all_props - common_props if dropped_props: warnings.warn( f"Properties {sorted(dropped_props)} were not common across all " f"containers and were dropped during merge" ) merged_props = {} for prop_name in common_props: merged_props[prop_name] = np.concatenate( [container.get_property(prop_name) for container in containers], axis=0 ) return cls(merged_props) def _read_orientations(filename: str): """ Read orientation data from file and convert to geometry format. Parameters ---------- filename : str Path to orientation file. Returns ------- dict Dictionary containing vertices, normals, and quaternions. """ from tme import Orientations data = Orientations.from_file(filename) # Remap as active (push) rotation angles = Rotation.from_euler(seq="ZYZ", angles=data.rotations, degrees=True).inv() normals = angles.apply(NORMAL_REFERENCE) quaternions = angles.as_quat(scalar_first=True) cluster = data.details.astype(int) indices = [np.where(cluster == x) for x in np.unique(cluster)] try: vertex_properties = [ VertexPropertyContainer({"pytme_score": data.scores[x]}) for x in indices ] except Exception: vertex_properties = None return { "vertices": [data.translations[x] for x in indices], "normals": [normals[x] for x in indices], "quaternions": [quaternions[x] for x in indices], "vertex_properties": vertex_properties, } def read_star(filename: str): """ Read RELION star file format. Parameters ---------- filename : str Path to star file. Returns ------- GeometryDataContainer Parsed geometry data container. """ return GeometryDataContainer(**_read_orientations(filename)) def read_txt(filename: str): """ Read text-based point cloud files. Parameters ---------- filename : str Path to text file (txt, csv, xyz). Returns ------- GeometryDataContainer Parsed geometry data container. """ ret = [] delimiter = None if filename.endswith(("csv", "xyz")): delimiter = "," elif filename.endswith(("txt", "tsv")): delimiter = "\t" with open(filename, mode="r") as ifile: data = ifile.read().split("\n") data = [x.strip().split(delimiter) for x in data if x.strip()] header = ("x", "y", "z", *ascii_lowercase)[: len(data[0])] if "x" in data[0]: header = data.pop(0) required_columns = ("x", "y", "z") for rc in required_columns: if rc in header: continue raise ValueError(f"Colums {required_columns} are required.") data = {c: np.asarray(d) for c, d in zip(header, zip(*data))} if "id" in data: ret = [] for cluster in np.unique(data["id"]): ret.append({c: d[data["id"] == cluster] for c, d in data.items()}) data = ret else: data = [data] vertices, normals, quaternions = [], [], [] for cluster in data: cols = ("x", "y", "z") vertices.append((np.hstack([cluster[k][:, None] for k in cols]))) try: cols = ("nx", "ny", "nz") normals.append((np.hstack([cluster[k][:, None] for k in cols]))) except Exception as e: continue if len(normals) == 0: normals = None return GeometryDataContainer(vertices=vertices, normals=normals) def read_tsv(filename: str) -> GeometryDataContainer: """ Read tab-separated values file with orientation data. Parameters ---------- filename : str Path to tsv file. Returns ------- GeometryDataContainer Parsed geometry data container. """ with open(filename, mode="r") as infile: header = infile.readline() if "euler" not in header: return read_txt(filename) return GeometryDataContainer(**_read_orientations(filename)) def read_tsi(filename: str) -> GeometryDataContainer: """ Read topology surface information file format. Parameters ---------- filename : str Path to tsi file. Returns ------- GeometryDataContainer Parsed geometry data container. """ data = _read_tsi_file(filename) mesh = meshing.utils.to_open3d(data["vertices"][:, 1:4], data["faces"][:, 1:4]) vertex_properties = {} try: if "inclusions" in data: inclusions = np.zeros((len(data["vertices"]))) inclusion_type = data["inclusions"][:, 1] inclusion_vert = data["inclusions"][:, 2].astype(int) inclusions[inclusion_vert] = inclusion_type vertex_properties = {"inclusion": inclusions} except Exception: pass return _return_mesh(mesh, vertex_properties=vertex_properties) def read_vtu(filename: str) -> GeometryDataContainer: """ Read VTK unstructured grid XML file format. Parameters ---------- filename : str Path to vtu file. Returns ------- GeometryDataContainer Parsed geometry data container. """ data = _read_vtu_file(filename) mesh = meshing.utils.to_open3d(data["points"], data["connectivity"]) return _return_mesh(mesh, vertex_properties=data.get("point_data", {})) def read_mesh(filename: str) -> GeometryDataContainer: """ Read 3D mesh files using Open3D. Parameters ---------- filename : str Path to mesh file. Returns ------- GeometryDataContainer Parsed geometry data container. """ import open3d as o3d return _return_mesh(o3d.io.read_triangle_mesh(filename)) def _return_mesh(mesh, vertex_properties: dict = None) -> GeometryDataContainer: """ Convert Open3D mesh to GeometryDataContainer. Parameters ---------- mesh : o3d.geometry.TriangleMesh Open3D triangle mesh object. vertex_properties : dict, optional Vertex property data. Returns ------- GeometryDataContainer Converted geometry data container. """ mesh.compute_vertex_normals() vertices = np.asarray(mesh.vertices) faces = np.asarray(mesh.triangles) normals = np.asarray(mesh.vertex_normals) return GeometryDataContainer( vertices=[vertices], faces=[faces], normals=[normals], vertex_properties=[VertexPropertyContainer(vertex_properties)], ) def read_structure(filename: str) -> GeometryDataContainer: """ Read molecular structure files. Parameters ---------- filename : str Path to structure file (pdb, cif, gro). Returns ------- GeometryDataContainer Parsed geometry data container. """ from tme import Structure data = Structure.from_file(filename) return GeometryDataContainer(vertices=[data.atom_coordinate]) def read_volume(filename: str): """ Read 3D volume data and convert to point cloud. Parameters ---------- filename : str Path to volume file. Returns ------- GeometryDataContainer Parsed geometry data container. """ volume = load_density(filename) ret = volume_to_points( volume.data, volume.sampling_rate, reverse_order=True, max_cluster=10000 ) shape = np.multiply(volume.shape, volume.sampling_rate) return GeometryDataContainer( vertices=ret, shape=shape, sampling=volume.sampling_rate ) def _read_tsi_file(file_path: str) -> Dict: """ Reads a topology file [1]_. Parameters ---------- file_path : str The path to the topology file to be parsed. Returns ------- Dict Topology file content. References ---------- .. [1] https://github.com/weria-pezeshkian/FreeDTS/wiki/Manual-for-version-1 """ from ._utils import _drop_prefix _keys = ("version", "box", "n_vertices", "vertices", "n_faces", "faces") ret = {k: None for k in _keys} with open(file_path, mode="r", encoding="utf-8") as infile: data = [x.strip() for x in infile.read().split("\n") if len(x.strip())] # Version prefix if "version" in data[0]: ret["version"] = data.pop(0).split()[1] # Box prefix box = _drop_prefix(data.pop(0).split(), 4) ret["box"] = tuple(float(x) for x in box) # Vertex prefix n_vertices = _drop_prefix(data.pop(0).split(), 2) n_vertices = int(n_vertices[0]) vertices, data = data[:n_vertices], data[n_vertices:] ret["n_vertices"] = n_vertices ret["vertices"] = np.array([x.split() for x in vertices], dtype=np.float64) # Face prefix n_faces = _drop_prefix(data.pop(0).split(), 2) n_faces = int(n_faces[0]) faces, data = data[:n_faces], data[n_faces:] ret["n_faces"] = n_faces ret["faces"] = np.array([x.split() for x in faces], dtype=np.float64) while len(data): if not data[0].startswith("inclusion"): data.pop(0) break if len(data) == 0: return ret n_inclusions = _drop_prefix(data.pop(0).split(), 2) n_inclusions = int(n_inclusions[0]) incl, data = data[:n_inclusions], data[n_inclusions:] ret["n_inclusions"] = n_inclusions ret["inclusions"] = np.array([x.split() for x in incl], dtype=np.float64) return ret def _read_vtu_file(file_path: str) -> Dict: """ Parse a VTK XML file into a dictionary of numpy arrays. Parameters ---------- file_path : str The path to the topology file to be parsed. Returns ------- Dict Topology file content. """ with open(file_path, mode="r") as ifile: data = ifile.read() root = ET.fromstring(data) piece = root.find(".//Piece") result = { "num_points": int(piece.get("NumberOfPoints")), "num_cells": int(piece.get("NumberOfCells")), "point_data": {}, "points": None, "connectivity": None, "offsets": None, "types": None, } # Parse point data arrays if (point_data := piece.find("PointData")) is not None: for array in point_data.findall("DataArray"): data_type = _parse_dtype(array) result["point_data"][array.get("Name")] = _parse_data_array( array, data_type ) if (points_array := piece.find(".//Points/DataArray")) is not None: data_type = _parse_dtype(array) result["points"] = _parse_data_array(points_array, data_type) if (cells := piece.find("Cells")) is not None: for array in cells.findall("DataArray"): data_type = _parse_dtype(array) result[array.get("Name")] = _parse_data_array(array, float) return result
[docs] def load_density(filename: str, **kwargs) -> "Density": """ Load 3D density data from file. Parameters ---------- filename : str Path to density file. **kwargs Additional keyword arguments passed to Density.from_file. Returns ------- Density Loaded density object. """ from tme import Density volume = Density.from_file(filename, **kwargs) if np.allclose(volume.sampling_rate, 0): warnings.warn( "All sampling rates are 0 - Setting them to 1 for now. Some functions might" "not behave properly. Make sure to define sampling rates if you forgot." ) volume.sampling_rate = 1 return volume