"""
Hierarchical file handling (load/save) for pt, mat, hdf5 formats
"""
import os
from typing import Any, List, Optional, Union
import logging
import h5py
import numpy as np
import scipy.io as sio
KeyType = Union[str, list[str], None]
logger = logging.getLogger(__name__)
[docs]
def load_pt(file_path, weights_only=False):
"""Loads data from a PyTorch .pt file.
Warning:
This function defaults to `weights_only=False` because PtyRAD .pt files
often contain complex objects and dictionaries, not just state dictionaries.
As of PyTorch 2.6, `torch.load` defaults to `weights_only=True` for security.
Loading with `weights_only=False` can execute arbitrary code if the file
contains malicious payloads. Only use this function to load trusted,
legacy PtyRAD-generated files.
Args:
file_path (str): The path to the PyTorch .pt file.
weights_only (bool, optional): If True, restricts the unpickler to load
only tensors, primitive types, and dictionaries. Defaults to False.
Returns:
Any: The deserialized Python object(s) stored in the file.
Raises:
FileNotFoundError: If the specified file does not exist.
"""
import torch
# Check if the file exists
if not os.path.exists(file_path):
raise FileNotFoundError(f"The specified file '{file_path}' does not exist. Please check your file path and working directory.")
data = torch.load(file_path, weights_only=weights_only)
# The default behavior of torch.load is `weights_only=True` since PyTorch 2.6 (2025.01.29)
# https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573
# Because PtyRAD .pt isn't a true PyTorch model, so `weights_only=True` would break this critical loading function.
# However, `weights_only=False` has potential risk if the .pt file contains malicious code, so please only use this `load_pt` for PtyRAD-generated .pt file.
return data
[docs]
def load_mat(
file_path: str, key: KeyType = None, delimiter: str = ".",
squeeze_me=True, simplify_cells=True
) -> Union[np.ndarray, dict[str, np.ndarray]]:
"""
Load dataset(s) from a MATLAB .mat file, handling both default and v7.3 (HDF5) formats.
The version is used to switch between scipy.io.loadmat or h5py.
Parameters:
file_path (str): Path to the .mat file.
key (str | list[str] | None): Name(s) of the dataset(s) to load.
- If None, '', or []: Load all datasets, preserving the original nested structure.
- If str: Load a single dataset or group. Supports hierarchical keys (e.g., 'group1.dataset1').
- If list[str]: Load multiple datasets. The returned dictionary will have a flattened structure.
delimiter (str): Delimiter for hierarchical keys (default: ".").
squeeze_me (bool): Whether to squeeze unit matrix dimensions (scipy.io.loadmat parameter).
simplify_cells (bool): Whether to simplify cell arrays (scipy.io.loadmat parameter).
Returns:
data (np.ndarray or dict): The loaded dataset(s) with the same structure as load_hdf5.
Raises:
FileNotFoundError: If the specified file does not exist.
KeyError: If provided key(s) are not found in the file.
TypeError: If the key is not None, a string, or a list of strings.
"""
# Check if file exists
if not os.path.exists(file_path):
raise FileNotFoundError(
f"The specified file '{file_path}' does not exist. Please check your file path or working directory."
)
# Check file version
from scipy.io.matlab import matfile_version as get_matfile_version
try:
mat_version = get_matfile_version(file_path)
except ValueError as e:
logger.warning(f"{e}. Switching to `load_hdf5` as it's probably not generated by MATLAB.")
mat_version = (2,0) # Since Scipy can't find the version, it's likely a fake mat file that's actually HDF5
is_hdf5_format = (mat_version[0] == 2)
# If v7.3 (HDF5), delegate to load_hdf5 directly
if is_hdf5_format:
logger.info("Detected .mat v7.3 (HDF5 format). Delegating to `load_hdf5`.")
return load_hdf5(file_path, key=key, delimiter=delimiter)
# Handle normal .mat formats
logger.info("Detected .mat version less than v7.3. Using `scipy.io.loadmat`.")
# Load the entire .mat file first
mat_contents = sio.loadmat(file_path, squeeze_me=squeeze_me, simplify_cells=simplify_cells) # mat_contents is already a nested dict
# Handle different key scenarios
if key in (None, "", []):
return mat_contents
elif isinstance(key, str):
data = get_nested(mat_contents, key=key, delimiter=delimiter)
return data
elif isinstance(key, list):
if not all(isinstance(k, str) for k in key):
raise TypeError(
f"All elements in 'key' list must be strings, got {[type(k).__name__ for k in key]}"
)
missing = []
datasets_dict = {}
for k in key:
try:
datasets_dict[k] = get_nested(mat_contents, key=k, delimiter=delimiter)
except KeyError:
missing.append(k)
if missing:
raise KeyError(
f"Key(s) = {missing} not found. "
f"Available key(s) in this mat file are {list_nested_keys(mat_contents)}. "
"Tip: If you don't know the correct key, try 'key=None' to load the entire file as a dict."
)
return datasets_dict
else:
raise TypeError(
f"`key` must be None, a string, or a list of strings but got key = '{key}'"
)
[docs]
def load_hdf5(
file_path: str, key: KeyType = None, delimiter: str = ".") -> Union[np.ndarray, dict[str, np.ndarray]]:
"""
Load dataset(s) from an HDF5 file, recursively if groups are encountered.
Parameters:
file_path (str): Path to the HDF5 file.
key (str | list[str] | None): Name(s) of the dataset(s) to load.
- If None, '', or []: Load all datasets recursively, preserving the original nested structure.
- If str: Load a single dataset or group. Supports hierarchical keys (e.g., 'group1.dataset1').
- If list[str]: Load multiple datasets. The returned dictionary will have a flattened structure with the hierarchical key strings as keys.
delimiter (str): Delimiter for hierarchical keys (default: ".").
Returns:
data (np.ndarray or dict): The loaded dataset(s).
- If `key` is a string, returns a single `np.ndarray` or a nested dictionary if the key points to a group.
- If `key` is a list of strings, returns a dictionary with the hierarchical key strings as keys and the corresponding datasets as values.
- If `key` is None, returns a nested dictionary preserving the original structure of the HDF5 file.
Raises:
FileNotFoundError: If the specified file does not exist.
KeyError: If provided key(s) are not found in the file.
TypeError: If the key is not None, a string, or a list of strings.
Notes:
- Hierarchical Keys:
- The function supports hierarchical keys (e.g., 'group1.dataset1') to directly access nested datasets or groups.
- When a list of hierarchical keys is provided, the returned dictionary will have a flattened structure with the hierarchical key strings as keys.
- Preserving Original Structure:
- If `key=None`, the function recursively loads all datasets and groups, preserving the original nested structure of the HDF5 file.
- Performance Considerations:
- Providing an exact key (e.g., `key="group1/dataset1"`) is significantly faster than recursively loading the entire file or traversing the hierarchy.
"""
def _recursively_load(hobj, key=None, delimiter="."):
"""Recursively load h5py Group or Dataset into dict or array."""
# Traverse hierarchically with a user-specified key
if key is not None:
parts = key.split(delimiter)
for part in parts:
if not isinstance(hobj, (h5py.Group, h5py.File)) or part not in hobj:
raise KeyError(
f"Key '{key}' not found. Failed at '{part}'. "
f"Available key(s) in this HDF5 file are {list_nested_keys(hf)}. "
"Tip: If you don't know the correct key, try 'key=None' to load the entire file as a dict."
)
hobj = hobj[part]
# Load the object without user-specified key
if isinstance(hobj, h5py.Dataset):
return handle_hdf5_types(hobj[()])
elif isinstance(hobj, h5py.Group):
return {k: _recursively_load(hobj[k]) for k in hobj}
else:
raise TypeError(f"Unsupported HDF5 object type: {type(hobj)}")
# Check if the file exists
if not os.path.exists(file_path):
raise FileNotFoundError(
f"The specified file '{file_path}' does not exist. Please check your file path or working directory."
)
with h5py.File(file_path, "r") as hf:
if key in (None, "", []):
file_dict = {k: _recursively_load(hf[k]) for k in hf.keys()}
return file_dict
elif isinstance(key, str):
data = _recursively_load(hf, key=key, delimiter=delimiter)
return data
elif isinstance(key, list):
if not all(isinstance(k, str) for k in key):
raise TypeError(
f"All elements in 'key' list must be strings, got {[type(k).__name__ for k in key]}"
)
datasets_dict = {}
missing = []
for k in key:
try:
datasets_dict[k] = _recursively_load(hf, key=k, delimiter=delimiter)
except KeyError:
missing.append(k)
if missing:
raise KeyError(
f"Key(s) = {missing} not found. Available key(s) in this HDF5 file are {list_nested_keys(hf)}. "
"Tip: If you don't know the correct key, try 'key=None' to load the entire file as a dict."
)
return datasets_dict
else:
raise TypeError(
f"`key` must be None, a string, or a list of strings but got key = '{key}'"
)
[docs]
def write_hdf5(file_path, data, dataset_name="meas", **kwargs):
"""
Save an array as an HDF5 file.
"""
with h5py.File(file_path, "w") as hf: # 'w' will override if the file already exists
hf.create_dataset(dataset_name, data=data, compression="gzip", **kwargs)
[docs]
def load_ND_with_key(
file_path: str,
key: Optional[str] = None,
ndims: Optional[List[int]] = None,
) -> np.ndarray:
"""
Load exactly one ND dataset from (possibly nested) files like .mat and .hdf5.
Args:
file_path (str): Path to the file.
key (str, optional): Key to specify the dataset. If not provided, will search for all valid ND datasets.
ndims (list): List of desired dimensions for filtering datasets.
Returns:
numpy.ndarray: The loaded dataset.
Raises:
ValueError: If the file type is unsupported, or the key is invalid, or multiple/zero valid datasets are found.
"""
if ndims is None:
ndims = [3, 4]
# Check if the file exists
if not os.path.exists(file_path):
raise FileNotFoundError(
f"The specified file '{file_path}' does not exist. Please check your file path and working directory."
)
# Infer file type from extension
_, ext = os.path.splitext(file_path)
ext = ext.lower()
# Select loader
if ext == ".mat":
load_func = load_mat
elif ext in [".h5", ".hdf5"]:
load_func = load_hdf5
else:
raise ValueError(
f"Unsupported file type: '{ext}'. Supported types are .mat, .h5, .hdf5."
)
# Load the data using the selected loader.
if key in (None, ""):
datasets_dict = load_func(file_path) # None key would return a dict of the file
valid_datasets = collect_ND_datasets(
datasets_dict, ndims=ndims
) # This will search recursively and return all valid ND datasets
if len(valid_datasets) == 1:
return next(iter(valid_datasets.values()))
elif len(valid_datasets) == 0:
raise ValueError(
f"No eligible datasets found in file with ndims = {ndims}. Please check the file and file path."
)
else:
raise ValueError(
f"Multiple eligible ND datasets found: {list(valid_datasets.keys())}. Please specify the dataset key explicitly."
)
elif isinstance(key, str):
data_or_dict = load_func(
file_path, key
) # String key would normally return ndarray, but incorrectly specified key may point to a group or anything else
if isinstance(data_or_dict, np.ndarray):
return data_or_dict
else:
raise ValueError(
f"The returned value at key '{key}' is not an ndarray dataset, got type = {type(data_or_dict).__name__}. "
"If you don't know the correct dataset key, try 'key=None' to search for eligible ND datasets from the entire file."
)
else:
raise TypeError(f"`key` must be None or a string, but got key = '{key}'")
[docs]
def collect_ND_datasets(
data_dict: dict[str, Any],
ndims: list[int] = None,
delimiter: str = ".",
_parent_key: Optional[str] = None,
) -> dict[str, np.ndarray]:
"""
Collect ND numpy arrays from a (possibly nested) dictionary that match desired dimensionalities.
Automatically traverses nested dictionaries and flattens keys with '//'.
Args:
data_dict (dict): Dictionary of datasets (flat or nested).
ndims (list of int): Desired dimensionalities to match (e.g., [3, 4]).
delimiter (str): String symbol used to seperate different levels of the full path to the dataset
_parent_key (str, optional): **Internal use only.** Tracks nested keys during recursion. Do not set manually.
Returns:
dict[str, np.ndarray]: Matching datasets with flattened hierarchical keys.
Raises:
ValueError: If input is not a dict or no datasets match.
"""
if not isinstance(data_dict, dict):
raise ValueError("Input must be a dictionary containing datasets.")
if ndims is None:
ndims = [3, 4]
results: dict[str, np.ndarray] = {}
for key, val in data_dict.items():
full_key = f"{_parent_key}{delimiter}{key}" if _parent_key else key
if isinstance(val, np.ndarray):
if val.ndim in ndims:
results[full_key] = val
elif isinstance(val, dict):
results.update(
collect_ND_datasets(
val, ndims=ndims, _parent_key=full_key
)
)
if results:
logger.info(f"Found the following ND datasets with ndim in {ndims}:")
for k, arr in results.items():
logger.info(f" Key: '{k}', Shape: {arr.shape}, Dtype: {arr.dtype}")
return results
[docs]
def handle_hdf5_types(x):
"""
Convert data to native Python or NumPy types. Especially when loaded by h5py.
Handles special cases like MATLAB v7.3 complex128 data types and ensures
that data is converted to a format compatible with native Python or NumPy.
Also handles sentinel string "__NONE__" as a substitute for None in HDF5.
Args:
x: The input data to be converted.
Returns:
The converted data into native Python or NumPy types.
"""
# Handle scalar Numpy types
if isinstance(x, np.generic):
x = x.item()
# Handle 0-dimensional Numpy arrays (convert to Python scalars) as they were probably forced by HDF5
if isinstance(x, np.ndarray) and x.ndim == 0:
x = x.item()
# Handle bytes (e.g., HDF5 strings or sentinel)
if isinstance(x, bytes):
try:
x = x.decode('utf-8')
except UnicodeDecodeError:
return x # Leave undecodable bytes unchanged
# Convert sentinel string to None — only safe for scalar strings
if isinstance(x, str) and x == "__NONE__":
return None
# Handle MATLAB-style complex128 compound dtype
if isinstance(x, np.ndarray) and x.dtype == [('real', '<f8'), ('imag', '<f8')]:
logger.info(f"Detected data.shape = {x.shape} with data.dtype = {x.dtype}. Casting back to 'complex128'.")
return x.view(np.complex128)
# Convert 1D array of strings (or object-dtype strings) to Python list of str
if isinstance(x, np.ndarray) and x.ndim == 1:
if np.issubdtype(x.dtype, np.str_) or np.issubdtype(x.dtype, np.object_):
try:
return [i.decode('utf-8') if isinstance(i, bytes) else str(i) for i in x]
except Exception:
pass # fallback to returning as-is
# Try parsing stringified literals
if isinstance(x, str):
import ast
try:
parsed = ast.literal_eval(x)
return parsed
except (ValueError, SyntaxError):
pass
return x
[docs]
def get_nested(d, key, delimiter='.', safe=False, default=None):
"""
Get a value from a nested dictionary either safely (return default if not found) or stricly to fail early.
Parameters:
- d (dict): The dictionary to traverse.
- key (str, or list or tuple of string): A sequence of keys to access nested values.
- delimiter (str): The string used to seperate different parts of the displayed key path
- safe (boolean): The flag to switch between safe/strict mode of getting values from a nested dict.
- default: The value to return if any key is missing or intermediate value is None.
Returns:
- The nested value if found, otherwise `default` in safe mode or error in strict mode.
"""
if not key:
raise ValueError("Please specify a non-empty 'key' to get the value from a nested dict.")
# Parse the input key (str with delimiter, or sequence of strings)
if isinstance(key, str):
parts = key.split(delimiter)
elif isinstance(key, (tuple, list)):
if not all(isinstance(k, str) for k in key):
raise TypeError(
f"All elements in 'key' must be strings, got {[type(k).__name__ for k in key]}"
)
parts = key
else:
raise TypeError(f"'key' must be a str, or a sequence (list, tuple) of strings, got {type(key).__name__}.")
# Getting value safely with a default return
if safe:
for k in parts:
if not isinstance(d, dict):
return default
d = d.get(k)
if d is None:
return default
return d
# Getting value strictly with raised error
else:
for k in parts:
if not isinstance(d, dict) or k not in d:
raise KeyError(
f"Key '{key}' not found. Failed at '{k}'. "
f"Available key(s) in this nested dict are {list_nested_keys(d)}. "
"Tip: If you don't know the correct key, use `print_nested_dict()` from `ptyrad.io.hdf5` to check your nested dict first."
)
d = d[k]
return d
[docs]
def list_nested_keys(hobj, delimiter=".", prefix=""):
"""
Recursively list all keys in an HDF5 file, HDF5 group, or dict, including hierarchical paths.
Args:
hobj (h5py.File, h5py.Group, or dict): The hierarchical object to traverse.
delimiter (str): The string used to seperate different parts of the displayed key path
prefix (str): The current hierarchical path (used for recursion).
Returns:
list[str]: A list of all keys with their hierarchical paths.
"""
# Check input type
if isinstance(hobj, (h5py.Group, h5py.File)):
compare_type = h5py.Group
elif isinstance(hobj, dict):
compare_type = dict
else:
raise ValueError(f"Expected hobj is an HDF5 file, HDF5 group, or a dict, got {type(hobj).__name__}.")
keys = []
for key in hobj.keys():
full_key = f"{prefix}{key}" if prefix == "" else f"{prefix}{delimiter}{key}"
if isinstance(hobj[key], compare_type):
# Recursively list keys in the group / dict
keys.extend(list_nested_keys(hobj[key], delimiter=delimiter, prefix=full_key))
else:
# Add dataset key
keys.append(full_key)
return keys
[docs]
def print_nested_dict(d, indent=0, leaf_inline_threshold=6):
"""Recursively logs a nested dictionary with structured formatting.
To improve log readability and save vertical space, small "leaf" dictionaries
(dictionaries containing no further nested dicts or lists) are printed inline
on a single line, provided their length does not exceed `leaf_inline_threshold`.
Flat lists are also printed inline.
Args:
d (dict): The dictionary to log.
indent (int, optional): The current indentation level (number of tabs).
Defaults to 0.
leaf_inline_threshold (int, optional): The maximum number of key-value
pairs a flat leaf dictionary can have to be formatted inline.
Defaults to 6.
"""
indent_str = " " * indent
for key, value in d.items():
if isinstance(value, dict):
# Check if this is a flat leaf dict
is_flat_leaf = all(not isinstance(v, (dict, list)) for v in value.values())
if is_flat_leaf and len(value) <= leaf_inline_threshold: # Determine whether to print inline or not
flat = ", ".join(f"{k}: {repr(v)}" for k, v in value.items())
logger.info(f"{indent_str}{key}: {{{flat}}}")
else:
logger.info(f"{indent_str}{key}:")
print_nested_dict(value, indent + 1)
elif isinstance(value, list) and all(not isinstance(i, (dict, list)) for i in value):
logger.info(f"{indent_str}{key}: {value}")
else:
logger.info(f"{indent_str}{key}: {repr(value)}")