"""
Params files parsing functions
"""
import logging
import os
from typing import Dict
logger = logging.getLogger(__name__)
###### These are params loading functions ######
[docs]
def load_params(file_path: str, validate: bool = True):
# Check if the file exists
if not os.path.exists(file_path):
raise FileNotFoundError(f"The specified file '{file_path}' does not exist. Please check your file path and working directory.")
logger.info("### Loading params file ###")
logger.info(f"params_path = {file_path}")
param_path, param_type = os.path.splitext(file_path)
if param_type in (".yml", ".yaml"):
params_dict = load_yml_params(file_path)
elif param_type == ".toml":
params_dict = load_toml_params(file_path)
elif param_type == ".json":
params_dict = load_json_params(file_path)
elif param_type == ".py":
params_dict = load_py_params(param_path)
else:
raise ValueError("param_type needs to be either 'yml', 'json', or 'py'")
# Additional correction for constraint_params (temporarily added for smooth transition to v0.1.0b11)
if params_dict.get('constraint_params') is not None:
params_dict['constraint_params'] = normalize_constraint_params(params_dict['constraint_params'])
# Additional correction for the probe aberrations in init_params (temporatily added for smooth transition to v0.1.0b13)
if params_dict.get('init_params') is not None:
params_dict['init_params'] = normalize_probe_params(params_dict['init_params'])
# Additional correction for model_params.obj_preblur_std (temporatily added for smooth transition to v0.1.0b13)
if params_dict.get('model_params') is not None:
if 'obj_preblur_std' in params_dict.get('model_params'):
logger.warning(
"WARNING: The 'obj_preblur_std' parameter is deprecated since v0.1.0v13. "
"This flag will be ignored.")
params_dict['model_params'].pop('obj_preblur_std')
# Additional correction for recon_params.if_quiet (temporatily added for smooth transition to v0.1.0b13)
if params_dict.get('recon_params') is not None:
if 'if_quiet' in params_dict.get('recon_params'):
logger.warning(
"WARNING: The 'if_quiet' parameter is deprecated since v0.1.0v13. PtyRAD now uses a central LoggingManager. "
"This flag will be ignored.")
params_dict['recon_params'].pop('if_quiet')
# Pass into PtyRADParams (pydantic model) for default filling and validation
if validate:
from .ptyrad_params import PtyRADParams
logger.info("validate = True: Filling defaults and validating the params file...")
params_dict = PtyRADParams(**params_dict).model_dump()
logger.info("Success! Params file validated and defaults applied.")
else:
logger.warning("WARNING: validate = False: Skipping validation and default filling.")
logger.warning(" Ensure your params file is complete and consistent.")
logger.warning(" If you encounter issues, consider enabling validation or report the bug.")
# Add the file path to the params_dict while we save the params file to output folder
params_dict['params_path'] = file_path
logger.info(" ")
return params_dict
[docs]
def load_json_params(file_path):
import json
with open(file_path, "r", encoding='utf-8') as file:
params_dict = json.load(file)
return params_dict
[docs]
def load_toml_params(file_path):
"""
Load parameters from a TOML file.
Parameters:
file_path (str): The path to the TOML file to be loaded.
Returns:
dict: A dictionary containing the parameters loaded from the TOML file.
Raises:
FileNotFoundError: If the specified file does not exist.
ImportError: If the tomli package is not installed for Python < 3.11.
"""
try:
# Read the file with utf-8
# Note that "A TOML file must be a valid UTF-8 encoded Unicode document." per documentation.
# Therefore, the toml file is read in binary mode ("rb") and the encoding is handled internally.
# But I've observed some encoding mismatch when people run the script with terminal that has different default encoding.
# Therefore, it is safer to read it with utf-8 encoding first and pass it to tomllib.
with open(file_path, "r", encoding='utf-8') as file:
content = file.read()
try:
# For Python 3.11+
import tomllib
params_dict = tomllib.loads(content)
except ImportError:
# For Python < 3.11
import tomli # type: ignore
params_dict = tomli.loads(content)
except ImportError:
raise ImportError("TOML support requires 'tomli' package for Python < 3.11 or built-in 'tomllib' for Python 3.11+. ")
return params_dict
[docs]
def load_yml_params(file_path):
import yaml
with open(file_path, "r", encoding='utf-8') as file:
params_dict = yaml.safe_load(file)
return params_dict
[docs]
def load_py_params(file_path):
import importlib
params_module = importlib.import_module(file_path)
params_dict = {
name: getattr(params_module, name)
for name in dir(params_module)
if not name.startswith("__")
}
return params_dict
###### These are sanitization functions for backward compatibility #####
[docs]
def normalize_probe_params(init_params: Dict) -> Dict:
""" Normalize probe params in `init_params`
This includes:
- Migrate legacy keys (pre v0.1.0b13) like `probe_defocus`, `probe_c3`, `probe_c5` into `probe_aberrations`.
- Canonicalizes `probe_aberrations` into standard Krivanek polar format {'Cnm': XX, 'phinm': XX}.
Note that the init_params will be normalized before optionally passing into pydantic
"""
from ptyrad.optics.aberrations import Aberrations
# --- STEP 1: Legacy Migration (The "Move" Phase) ---
# Explictly initialize `probe_aberrations` as {} if it's missing or is set to None
if init_params.get('probe_aberrations') is None:
init_params['probe_aberrations'] = {}
aberrations = init_params['probe_aberrations']
migrated_keys = []
# Define Legacy Mappings AND their Blocking Aliases
# Format: 'legacy_key': ('canonical_modern_key', [list_of_aliases_to_check])
legacy_map = {
'probe_defocus': ('defocus', ['defocus', 'C1', 'C10', (1,0), '(1,0)']),
'probe_c3': ('C30', ['C30', 'C3', 'Cs', (3,0), '(3,0)']),
'probe_c5': ('C50', ['C50', 'C5', (5,0), '(5,0)']),
}
# Merge Logic with Precedence
# Only migrate if the modern key is NOT already present in aberrations.
# This ensures explicit modern config wins over legacy config.
for legacy_key, (modern_key, blocking_aliases) in legacy_map.items():
if legacy_key in init_params:
legacy_val = init_params[legacy_key]
# Check if ANY of the blocking aliases are already in the new dict.
conflict_found = any(alias in aberrations for alias in blocking_aliases)
if not conflict_found:
aberrations[modern_key] = legacy_val
migrated_keys.append(legacy_key)
else:
logger.warning(f"WARNING: Ignoring '{legacy_key}' because it is already defined in 'probe_aberrations' as one of {legacy_map[legacy_key][-1]}")
pass
# Old keys are deleted regardless
del init_params[legacy_key]
if migrated_keys:
logger.warning(f"WARNING: Probe aberrations '{migrated_keys}' in 'init_params' are depracated since PtyRAD v0.1.0b13 and are automatically converted to 'probe_aberrations' dict.")
# --- STEP 2: Canonicalization (The "Clean" Phase) ---
if aberrations:
init_params['probe_aberrations'] = Aberrations(aberrations).export(notation='krivanek', style='polar')
return init_params
[docs]
def normalize_constraint_params(constraint_params):
"""Convert old constraint param format {freq} (pre v0.1.0b11) to {start_iter, step, end_iter}."""
# Note that the constraint_params will be normalized before optionally passing into pydantic
# so it may contain either {freq}, or {start_iter, step, end_iter}
normalized_params = {}
print_freq_warning = False
for name, params in constraint_params.items():
# Extract legacy and new parameters
freq = params.get("freq", None) # Legacy constraint param before PtyRAD v0.1.0b11
start_iter = params.get("start_iter", 1 if freq is not None else None)
step = params.get("step", freq if freq is not None else 1)
end_iter = params.get("end_iter", None)
if freq is not None:
print_freq_warning = True
# Create normalized parameters
normalized_params[name] = {
"start_iter": start_iter,
"step": step,
"end_iter": end_iter,
**{k: v for k, v in params.items() if k not in ("freq", "step", "start_iter", "end_iter")}, # Copy other keys
}
if print_freq_warning:
logger.warning("WARNING: For constraint_params, 'freq' is depracated since PtyRAD v0.1.0b11 and is automatically converted to 'step'.")
return normalized_params
###### Params exporting / copying #####
[docs]
def copy_params_to_dir(params_path, output_dir, params=None):
"""
Copies the params file to the output directory if it exists. If the params file does not exist,
it dumps the provided params dictionary to a YAML file in the output directory.
Args:
params_path (str): Path to the params file (can be None if params are programmatically generated).
output_dir (str): Directory where the params file or YAML dump will be saved.
params (dict, optional): The programmatically generated params dictionary to save if no file exists.
"""
import os
import shutil
import yaml
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)
if params_path and os.path.isfile(params_path):
# If the params file exists, copy it to the output directory
file_name = os.path.basename(params_path)
output_path = os.path.join(output_dir, file_name)
shutil.copy2(params_path, output_path)
elif params is not None:
# If no file exists, dump the params dictionary to a YAML file
output_path = os.path.join(output_dir, "params_dumped.yml")
with open(output_path, "w") as f:
yaml.safe_dump(params, f, sort_keys=False)
else:
# If neither a file nor params are provided, skip with a warning
return
[docs]
def yaml2json(input_filepath, output_filepath):
import json
import yaml
with open(input_filepath, 'r') as file:
try:
# Load as YAML
data = yaml.safe_load(file)
# Save to JSON
with open(output_filepath, 'w') as json_file:
json.dump(data, json_file, indent=4)
print(f"YAML {input_filepath} has been successfully converted and saved to JSON {output_filepath}")
except yaml.YAMLError as e:
print("Error parsing YAML file:", e)