Source code for mosaic.formats.parser

"""
IO methods to parse a variety of file formats.

Copyright (c) 2024 European Molecular Biology Laboratory

Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
"""

import warnings
from string import ascii_lowercase

from typing import List, Dict
from dataclasses import dataclass
import xml.etree.ElementTree as ET

import numpy as np
import open3d as o3d
from scipy.spatial.transform import Rotation
from tme import Density, Structure, Orientations

from ..meshing.utils import to_open3d
from ..utils import volume_to_points, compute_bounding_box, NORMAL_REFERENCE


def _drop_prefix(iterable, target_length: int):
    if len(iterable) == target_length:
        iterable.pop(0)
    return iterable


def _parse_data_array(data_array: ET.Element, dtype: type = float) -> np.ndarray:
    """Parse a DataArray element into a numpy array."""
    rows = [row.strip() for row in data_array.text.strip().split("\n") if row.strip()]
    parsed_rows = [[dtype(x) for x in row.split()] for row in rows]
    data = np.array(parsed_rows)
    return np.squeeze(data)


def _parse_dtype(xml_element) -> object:
    data_type = float if xml_element.get("type", "").startswith("Float") else int
    return data_type


[docs] @dataclass class GeometryData: vertices: np.ndarray = None normals: np.ndarray = None faces: np.ndarray = None quaternions: np.ndarray = None shape: List[int] = None sampling: List[float] = (1, 1, 1)
[docs] @dataclass class GeometryDataContainer: vertices: List[np.ndarray] = None normals: List[np.ndarray] = None faces: List[np.ndarray] = None quaternions: List[np.ndarray] = None shape: List[int] = None sampling: List[float] = (1, 1, 1) def __post_init__(self): dtype_map = { "vertices": np.float32, "normals": np.float32, "faces": int, "quaternions": np.float32, } for attr_name, dtype in dtype_map.items(): attr = getattr(self, attr_name) setattr(self, attr_name, self._to_dtype(attr, dtype)) if self.normals is None: self.normals = [ np.full_like(x, fill_value=NORMAL_REFERENCE) for x in self.vertices ] for i in range(len(self.normals)): norm = np.linalg.norm(self.normals[i], axis=1) mask = norm < 1e-12 norm[mask] = 1 self.normals[i][mask] = NORMAL_REFERENCE self.normals[i] = self.normals[i] / norm[:, None] if self.shape is None: self.shape, _ = compute_bounding_box(self.vertices) if len(self.vertices) != len(self.normals): raise ValueError("Normals need to be specified for each vertex set.") if self.faces is not None: if len(self.vertices) != len(self.faces): raise ValueError("Faces need to be specified for each vertex set.") if self.quaternions is None: self.quaternions = [None for x in self.vertices] def __len__(self): return len(self.vertices) def __iter__(self): yield from [self[i] for i in range(len(self))] def __getitem__(self, index: int) -> GeometryData: return GeometryData( vertices=self.vertices[index], normals=self.normals[index], shape=self.shape, sampling=self.sampling, faces=self.faces[index] if self.faces is not None else None, quaternions=self.quaternions[index], ) @staticmethod def _to_dtype(data: List[np.ndarray], dtype=np.float32): if data is not None: return [x.astype(dtype) for x in data] return data
def _read_orientations(filename): data = Orientations.from_file(filename) angles = Rotation.from_euler("zyz", data.rotations, degrees=True) normals = angles.apply(NORMAL_REFERENCE) return { "vertices": [data.translations], "normals": [normals], "quaternions": [angles.as_quat(scalar_first=True)], } def read_star(filename): return GeometryDataContainer(**_read_orientations(filename)) def read_txt(filename: str): ret = [] delimiter = None if filename.endswith(("xyz", "csv")): delimiter = "," elif filename.endswith("txt"): delimiter = "\t" with open(filename, mode="r") as ifile: data = ifile.read().split() data = [x.strip().split(delimiter) for x in data if x.strip()] header = ("x", "y", "z", *ascii_lowercase)[: len(data[0])] if "x" in data[0]: header = data.pop(0) required_columns = ("x", "y", "z") for rc in required_columns: if rc in header: continue raise ValueError(f"Colums {required_columns} are required.") data = {c: np.asarray(d) for c, d in zip(header, zip(*data))} if "id" in data: ret = [] for cluster in np.unique(data["id"]): ret.append({c: d[data["id"] == cluster] for c, d in data.items()}) data = ret else: data = [data] vertices, normals, quaternions = [], [], [] for cluster in data: cols = ("x", "y", "z") vertices.append((np.hstack([cluster[k][:, None] for k in cols]))) try: cols = ("nx", "ny", "nz") normals.append((np.hstack([cluster[k][:, None] for k in cols]))) except Exception as e: continue if len(normals) == 0: normals = None return GeometryDataContainer(vertices=vertices, normals=normals) def read_tsv(filename: str) -> GeometryDataContainer: with open(filename, mode="r") as infile: header = infile.readline() if "euler" not in header: return read_txt(filename) return GeometryDataContainer(**_read_orientations(filename)) def read_tsi(filename: str) -> GeometryDataContainer: data = _read_tsi_file(filename) return _return_mesh(to_open3d(data["vertices"][:, 1:4], data["faces"][:, 1:4])) def read_vtu(filename: str) -> GeometryDataContainer: data = _read_vtu_file(filename) return _return_mesh(to_open3d(data["points"], data["connectivity"])) def read_mesh(filename: str) -> GeometryDataContainer: return _return_mesh(o3d.io.read_triangle_mesh(filename)) def _return_mesh(mesh: o3d.geometry.TriangleMesh) -> GeometryDataContainer: mesh.compute_vertex_normals() vertices = np.asarray(mesh.vertices) faces = np.asarray(mesh.triangles) normals = np.asarray(mesh.vertex_normals) return GeometryDataContainer(vertices=[vertices], faces=[faces], normals=[normals]) def read_structure(filename: str) -> GeometryDataContainer: data = Structure.from_file(filename) return GeometryDataContainer(vertices=[data.atom_coordinate]) def read_volume(filename: str): volume = load_density(filename) ret = volume_to_points(volume.data, volume.sampling_rate, reverse_order=True) shape = np.multiply(volume.shape, volume.sampling_rate) return GeometryDataContainer( vertices=ret, shape=shape, sampling=volume.sampling_rate ) def _read_tsi_file(file_path: str) -> Dict: """ Reads a topology file [1]_. Parameters ---------- file_path : str The path to the topology file to be parsed. Returns ------- Dict Topology file content. References ---------- .. [1] https://github.com/weria-pezeshkian/FreeDTS/wiki/Manual-for-version-1 """ _keys = ("version", "box", "n_vertices", "vertices", "n_faces", "faces") ret = {k: None for k in _keys} with open(file_path, mode="r", encoding="utf-8") as infile: data = [x.strip() for x in infile.read().split("\n") if len(x.strip())] # Version prefix if "version" in data[0]: ret["version"] = data.pop(0).split()[1] # Box prefix box = _drop_prefix(data.pop(0).split(), 4) ret["box"] = tuple(float(x) for x in box) # Vertex prefix n_vertices = _drop_prefix(data.pop(0).split(), 2) n_vertices = int(n_vertices[0]) vertices, data = data[:n_vertices], data[n_vertices:] ret["n_vertices"] = n_vertices ret["vertices"] = np.array([x.split() for x in vertices], dtype=np.float64) # Face prefix n_faces = _drop_prefix(data.pop(0).split(), 2) n_faces = int(n_faces[0]) faces, data = data[:n_faces], data[n_faces:] ret["n_faces"] = n_faces ret["faces"] = np.array([x.split() for x in faces], dtype=np.float64) while len(data): if not data[0].startswith("inclusion"): data.pop(0) break if len(data) == 0: return ret n_inclusions = _drop_prefix(data.pop(0).split(), 2) n_inclusions = int(n_inclusions[0]) incl, data = data[:n_inclusions], data[n_inclusions:] ret["n_inclusions"] = n_inclusions ret["inclusions"] = np.array([x.split() for x in incl], dtype=np.float64) return ret def _read_vtu_file(file_path: str) -> Dict: """ Parse a VTK XML file into a dictionary of numpy arrays. Parameters ---------- file_path : str The path to the topology file to be parsed. Returns ------- Dict Topology file content. """ with open(file_path, mode="r") as ifile: data = ifile.read() root = ET.fromstring(data) piece = root.find(".//Piece") result = { "num_points": int(piece.get("NumberOfPoints")), "num_cells": int(piece.get("NumberOfCells")), "point_data": {}, "points": None, "connectivity": None, "offsets": None, "types": None, } # Parse point data arrays if (point_data := piece.find("PointData")) is not None: for array in point_data.findall("DataArray"): data_type = _parse_dtype(array) result["point_data"][array.get("Name")] = _parse_data_array( array, data_type ) if (points_array := piece.find(".//Points/DataArray")) is not None: data_type = _parse_dtype(array) result["points"] = _parse_data_array(points_array, data_type) if (cells := piece.find("Cells")) is not None: for array in cells.findall("DataArray"): data_type = _parse_dtype(array) result[array.get("Name")] = _parse_data_array(array, float) return result
[docs] def load_density(filename: str, **kwargs) -> Density: volume = Density.from_file(filename, **kwargs) if np.allclose(volume.sampling_rate, 0): warnings.warn( "All sampling rates are 0 - Setting them to 1 for now. Some functions might" "not behave properly. Make sure to define sampling rates if you forgot." ) volume.sampling_rate = 1 return volume