Source code for ptyrad.io.hierarchy

"""
Hierarchical file handling (load/save) for pt, mat, hdf5 formats

"""

import os
from typing import Any, List, Optional, Union
import logging

import h5py
import numpy as np
import scipy.io as sio

KeyType = Union[str, list[str], None]

logger = logging.getLogger(__name__)

[docs] def load_pt(file_path, weights_only=False): """Loads data from a PyTorch .pt file. Warning: This function defaults to `weights_only=False` because PtyRAD .pt files often contain complex objects and dictionaries, not just state dictionaries. As of PyTorch 2.6, `torch.load` defaults to `weights_only=True` for security. Loading with `weights_only=False` can execute arbitrary code if the file contains malicious payloads. Only use this function to load trusted, legacy PtyRAD-generated files. Args: file_path (str): The path to the PyTorch .pt file. weights_only (bool, optional): If True, restricts the unpickler to load only tensors, primitive types, and dictionaries. Defaults to False. Returns: Any: The deserialized Python object(s) stored in the file. Raises: FileNotFoundError: If the specified file does not exist. """ import torch # Check if the file exists if not os.path.exists(file_path): raise FileNotFoundError(f"The specified file '{file_path}' does not exist. Please check your file path and working directory.") data = torch.load(file_path, weights_only=weights_only) # The default behavior of torch.load is `weights_only=True` since PyTorch 2.6 (2025.01.29) # https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573 # Because PtyRAD .pt isn't a true PyTorch model, so `weights_only=True` would break this critical loading function. # However, `weights_only=False` has potential risk if the .pt file contains malicious code, so please only use this `load_pt` for PtyRAD-generated .pt file. return data
[docs] def load_mat( file_path: str, key: KeyType = None, delimiter: str = ".", squeeze_me=True, simplify_cells=True ) -> Union[np.ndarray, dict[str, np.ndarray]]: """ Load dataset(s) from a MATLAB .mat file, handling both default and v7.3 (HDF5) formats. The version is used to switch between scipy.io.loadmat or h5py. Parameters: file_path (str): Path to the .mat file. key (str | list[str] | None): Name(s) of the dataset(s) to load. - If None, '', or []: Load all datasets, preserving the original nested structure. - If str: Load a single dataset or group. Supports hierarchical keys (e.g., 'group1.dataset1'). - If list[str]: Load multiple datasets. The returned dictionary will have a flattened structure. delimiter (str): Delimiter for hierarchical keys (default: "."). squeeze_me (bool): Whether to squeeze unit matrix dimensions (scipy.io.loadmat parameter). simplify_cells (bool): Whether to simplify cell arrays (scipy.io.loadmat parameter). Returns: data (np.ndarray or dict): The loaded dataset(s) with the same structure as load_hdf5. Raises: FileNotFoundError: If the specified file does not exist. KeyError: If provided key(s) are not found in the file. TypeError: If the key is not None, a string, or a list of strings. """ # Check if file exists if not os.path.exists(file_path): raise FileNotFoundError( f"The specified file '{file_path}' does not exist. Please check your file path or working directory." ) # Check file version from scipy.io.matlab import matfile_version as get_matfile_version try: mat_version = get_matfile_version(file_path) except ValueError as e: logger.warning(f"{e}. Switching to `load_hdf5` as it's probably not generated by MATLAB.") mat_version = (2,0) # Since Scipy can't find the version, it's likely a fake mat file that's actually HDF5 is_hdf5_format = (mat_version[0] == 2) # If v7.3 (HDF5), delegate to load_hdf5 directly if is_hdf5_format: logger.info("Detected .mat v7.3 (HDF5 format). Delegating to `load_hdf5`.") return load_hdf5(file_path, key=key, delimiter=delimiter) # Handle normal .mat formats logger.info("Detected .mat version less than v7.3. Using `scipy.io.loadmat`.") # Load the entire .mat file first mat_contents = sio.loadmat(file_path, squeeze_me=squeeze_me, simplify_cells=simplify_cells) # mat_contents is already a nested dict # Handle different key scenarios if key in (None, "", []): return mat_contents elif isinstance(key, str): data = get_nested(mat_contents, key=key, delimiter=delimiter) return data elif isinstance(key, list): if not all(isinstance(k, str) for k in key): raise TypeError( f"All elements in 'key' list must be strings, got {[type(k).__name__ for k in key]}" ) missing = [] datasets_dict = {} for k in key: try: datasets_dict[k] = get_nested(mat_contents, key=k, delimiter=delimiter) except KeyError: missing.append(k) if missing: raise KeyError( f"Key(s) = {missing} not found. " f"Available key(s) in this mat file are {list_nested_keys(mat_contents)}. " "Tip: If you don't know the correct key, try 'key=None' to load the entire file as a dict." ) return datasets_dict else: raise TypeError( f"`key` must be None, a string, or a list of strings but got key = '{key}'" )
[docs] def load_hdf5( file_path: str, key: KeyType = None, delimiter: str = ".") -> Union[np.ndarray, dict[str, np.ndarray]]: """ Load dataset(s) from an HDF5 file, recursively if groups are encountered. Parameters: file_path (str): Path to the HDF5 file. key (str | list[str] | None): Name(s) of the dataset(s) to load. - If None, '', or []: Load all datasets recursively, preserving the original nested structure. - If str: Load a single dataset or group. Supports hierarchical keys (e.g., 'group1.dataset1'). - If list[str]: Load multiple datasets. The returned dictionary will have a flattened structure with the hierarchical key strings as keys. delimiter (str): Delimiter for hierarchical keys (default: "."). Returns: data (np.ndarray or dict): The loaded dataset(s). - If `key` is a string, returns a single `np.ndarray` or a nested dictionary if the key points to a group. - If `key` is a list of strings, returns a dictionary with the hierarchical key strings as keys and the corresponding datasets as values. - If `key` is None, returns a nested dictionary preserving the original structure of the HDF5 file. Raises: FileNotFoundError: If the specified file does not exist. KeyError: If provided key(s) are not found in the file. TypeError: If the key is not None, a string, or a list of strings. Notes: - Hierarchical Keys: - The function supports hierarchical keys (e.g., 'group1.dataset1') to directly access nested datasets or groups. - When a list of hierarchical keys is provided, the returned dictionary will have a flattened structure with the hierarchical key strings as keys. - Preserving Original Structure: - If `key=None`, the function recursively loads all datasets and groups, preserving the original nested structure of the HDF5 file. - Performance Considerations: - Providing an exact key (e.g., `key="group1/dataset1"`) is significantly faster than recursively loading the entire file or traversing the hierarchy. """ def _recursively_load(hobj, key=None, delimiter="."): """Recursively load h5py Group or Dataset into dict or array.""" # Traverse hierarchically with a user-specified key if key is not None: parts = key.split(delimiter) for part in parts: if not isinstance(hobj, (h5py.Group, h5py.File)) or part not in hobj: raise KeyError( f"Key '{key}' not found. Failed at '{part}'. " f"Available key(s) in this HDF5 file are {list_nested_keys(hf)}. " "Tip: If you don't know the correct key, try 'key=None' to load the entire file as a dict." ) hobj = hobj[part] # Load the object without user-specified key if isinstance(hobj, h5py.Dataset): return handle_hdf5_types(hobj[()]) elif isinstance(hobj, h5py.Group): return {k: _recursively_load(hobj[k]) for k in hobj} else: raise TypeError(f"Unsupported HDF5 object type: {type(hobj)}") # Check if the file exists if not os.path.exists(file_path): raise FileNotFoundError( f"The specified file '{file_path}' does not exist. Please check your file path or working directory." ) with h5py.File(file_path, "r") as hf: if key in (None, "", []): file_dict = {k: _recursively_load(hf[k]) for k in hf.keys()} return file_dict elif isinstance(key, str): data = _recursively_load(hf, key=key, delimiter=delimiter) return data elif isinstance(key, list): if not all(isinstance(k, str) for k in key): raise TypeError( f"All elements in 'key' list must be strings, got {[type(k).__name__ for k in key]}" ) datasets_dict = {} missing = [] for k in key: try: datasets_dict[k] = _recursively_load(hf, key=k, delimiter=delimiter) except KeyError: missing.append(k) if missing: raise KeyError( f"Key(s) = {missing} not found. Available key(s) in this HDF5 file are {list_nested_keys(hf)}. " "Tip: If you don't know the correct key, try 'key=None' to load the entire file as a dict." ) return datasets_dict else: raise TypeError( f"`key` must be None, a string, or a list of strings but got key = '{key}'" )
[docs] def write_hdf5(file_path, data, dataset_name="meas", **kwargs): """ Save an array as an HDF5 file. """ with h5py.File(file_path, "w") as hf: # 'w' will override if the file already exists hf.create_dataset(dataset_name, data=data, compression="gzip", **kwargs)
[docs] def load_ND_with_key( file_path: str, key: Optional[str] = None, ndims: Optional[List[int]] = None, ) -> np.ndarray: """ Load exactly one ND dataset from (possibly nested) files like .mat and .hdf5. Args: file_path (str): Path to the file. key (str, optional): Key to specify the dataset. If not provided, will search for all valid ND datasets. ndims (list): List of desired dimensions for filtering datasets. Returns: numpy.ndarray: The loaded dataset. Raises: ValueError: If the file type is unsupported, or the key is invalid, or multiple/zero valid datasets are found. """ if ndims is None: ndims = [3, 4] # Check if the file exists if not os.path.exists(file_path): raise FileNotFoundError( f"The specified file '{file_path}' does not exist. Please check your file path and working directory." ) # Infer file type from extension _, ext = os.path.splitext(file_path) ext = ext.lower() # Select loader if ext == ".mat": load_func = load_mat elif ext in [".h5", ".hdf5"]: load_func = load_hdf5 else: raise ValueError( f"Unsupported file type: '{ext}'. Supported types are .mat, .h5, .hdf5." ) # Load the data using the selected loader. if key in (None, ""): datasets_dict = load_func(file_path) # None key would return a dict of the file valid_datasets = collect_ND_datasets( datasets_dict, ndims=ndims ) # This will search recursively and return all valid ND datasets if len(valid_datasets) == 1: return next(iter(valid_datasets.values())) elif len(valid_datasets) == 0: raise ValueError( f"No eligible datasets found in file with ndims = {ndims}. Please check the file and file path." ) else: raise ValueError( f"Multiple eligible ND datasets found: {list(valid_datasets.keys())}. Please specify the dataset key explicitly." ) elif isinstance(key, str): data_or_dict = load_func( file_path, key ) # String key would normally return ndarray, but incorrectly specified key may point to a group or anything else if isinstance(data_or_dict, np.ndarray): return data_or_dict else: raise ValueError( f"The returned value at key '{key}' is not an ndarray dataset, got type = {type(data_or_dict).__name__}. " "If you don't know the correct dataset key, try 'key=None' to search for eligible ND datasets from the entire file." ) else: raise TypeError(f"`key` must be None or a string, but got key = '{key}'")
[docs] def collect_ND_datasets( data_dict: dict[str, Any], ndims: list[int] = None, delimiter: str = ".", _parent_key: Optional[str] = None, ) -> dict[str, np.ndarray]: """ Collect ND numpy arrays from a (possibly nested) dictionary that match desired dimensionalities. Automatically traverses nested dictionaries and flattens keys with '//'. Args: data_dict (dict): Dictionary of datasets (flat or nested). ndims (list of int): Desired dimensionalities to match (e.g., [3, 4]). delimiter (str): String symbol used to seperate different levels of the full path to the dataset _parent_key (str, optional): **Internal use only.** Tracks nested keys during recursion. Do not set manually. Returns: dict[str, np.ndarray]: Matching datasets with flattened hierarchical keys. Raises: ValueError: If input is not a dict or no datasets match. """ if not isinstance(data_dict, dict): raise ValueError("Input must be a dictionary containing datasets.") if ndims is None: ndims = [3, 4] results: dict[str, np.ndarray] = {} for key, val in data_dict.items(): full_key = f"{_parent_key}{delimiter}{key}" if _parent_key else key if isinstance(val, np.ndarray): if val.ndim in ndims: results[full_key] = val elif isinstance(val, dict): results.update( collect_ND_datasets( val, ndims=ndims, _parent_key=full_key ) ) if results: logger.info(f"Found the following ND datasets with ndim in {ndims}:") for k, arr in results.items(): logger.info(f" Key: '{k}', Shape: {arr.shape}, Dtype: {arr.dtype}") return results
[docs] def handle_hdf5_types(x): """ Convert data to native Python or NumPy types. Especially when loaded by h5py. Handles special cases like MATLAB v7.3 complex128 data types and ensures that data is converted to a format compatible with native Python or NumPy. Also handles sentinel string "__NONE__" as a substitute for None in HDF5. Args: x: The input data to be converted. Returns: The converted data into native Python or NumPy types. """ # Handle scalar Numpy types if isinstance(x, np.generic): x = x.item() # Handle 0-dimensional Numpy arrays (convert to Python scalars) as they were probably forced by HDF5 if isinstance(x, np.ndarray) and x.ndim == 0: x = x.item() # Handle bytes (e.g., HDF5 strings or sentinel) if isinstance(x, bytes): try: x = x.decode('utf-8') except UnicodeDecodeError: return x # Leave undecodable bytes unchanged # Convert sentinel string to None — only safe for scalar strings if isinstance(x, str) and x == "__NONE__": return None # Handle MATLAB-style complex128 compound dtype if isinstance(x, np.ndarray) and x.dtype == [('real', '<f8'), ('imag', '<f8')]: logger.info(f"Detected data.shape = {x.shape} with data.dtype = {x.dtype}. Casting back to 'complex128'.") return x.view(np.complex128) # Convert 1D array of strings (or object-dtype strings) to Python list of str if isinstance(x, np.ndarray) and x.ndim == 1: if np.issubdtype(x.dtype, np.str_) or np.issubdtype(x.dtype, np.object_): try: return [i.decode('utf-8') if isinstance(i, bytes) else str(i) for i in x] except Exception: pass # fallback to returning as-is # Try parsing stringified literals if isinstance(x, str): import ast try: parsed = ast.literal_eval(x) return parsed except (ValueError, SyntaxError): pass return x
[docs] def get_nested(d, key, delimiter='.', safe=False, default=None): """ Get a value from a nested dictionary either safely (return default if not found) or stricly to fail early. Parameters: - d (dict): The dictionary to traverse. - key (str, or list or tuple of string): A sequence of keys to access nested values. - delimiter (str): The string used to seperate different parts of the displayed key path - safe (boolean): The flag to switch between safe/strict mode of getting values from a nested dict. - default: The value to return if any key is missing or intermediate value is None. Returns: - The nested value if found, otherwise `default` in safe mode or error in strict mode. """ if not key: raise ValueError("Please specify a non-empty 'key' to get the value from a nested dict.") # Parse the input key (str with delimiter, or sequence of strings) if isinstance(key, str): parts = key.split(delimiter) elif isinstance(key, (tuple, list)): if not all(isinstance(k, str) for k in key): raise TypeError( f"All elements in 'key' must be strings, got {[type(k).__name__ for k in key]}" ) parts = key else: raise TypeError(f"'key' must be a str, or a sequence (list, tuple) of strings, got {type(key).__name__}.") # Getting value safely with a default return if safe: for k in parts: if not isinstance(d, dict): return default d = d.get(k) if d is None: return default return d # Getting value strictly with raised error else: for k in parts: if not isinstance(d, dict) or k not in d: raise KeyError( f"Key '{key}' not found. Failed at '{k}'. " f"Available key(s) in this nested dict are {list_nested_keys(d)}. " "Tip: If you don't know the correct key, use `print_nested_dict()` from `ptyrad.io.hdf5` to check your nested dict first." ) d = d[k] return d
[docs] def list_nested_keys(hobj, delimiter=".", prefix=""): """ Recursively list all keys in an HDF5 file, HDF5 group, or dict, including hierarchical paths. Args: hobj (h5py.File, h5py.Group, or dict): The hierarchical object to traverse. delimiter (str): The string used to seperate different parts of the displayed key path prefix (str): The current hierarchical path (used for recursion). Returns: list[str]: A list of all keys with their hierarchical paths. """ # Check input type if isinstance(hobj, (h5py.Group, h5py.File)): compare_type = h5py.Group elif isinstance(hobj, dict): compare_type = dict else: raise ValueError(f"Expected hobj is an HDF5 file, HDF5 group, or a dict, got {type(hobj).__name__}.") keys = [] for key in hobj.keys(): full_key = f"{prefix}{key}" if prefix == "" else f"{prefix}{delimiter}{key}" if isinstance(hobj[key], compare_type): # Recursively list keys in the group / dict keys.extend(list_nested_keys(hobj[key], delimiter=delimiter, prefix=full_key)) else: # Add dataset key keys.append(full_key) return keys