"""
Reconstruction and hypertune workflows for ptychographic reconstructions
"""
from copy import deepcopy
import logging
from random import shuffle
import warnings
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Dataset
from ptyrad.constraints import CombinedConstraint
from ptyrad.initialization import Initializer
from ptyrad.losses import CombinedLoss, get_objp_contrast
from ptyrad.models import PtychoAD
from ptyrad.save import copy_params_to_dir, make_output_folder, save_results
from ptyrad.utils import (
get_blob_size,
get_time,
ndarrays_to_tensors,
parse_hypertune_params_to_str,
parse_sec_to_time_str,
safe_filename,
set_random_seed,
time_sync,
vprint,
)
from ptyrad.visualization import plot_pos_grouping, plot_summary
# This suppresses the '..._inductor/compile_fx.py:236: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled.
# Consider setting `torch.set_float32_matmul_precision('high')` for better performance.'
# Although I didn't see much effect on performance because there's very little matrix multiplication in PtyRAD.
torch.set_float32_matmul_precision('high')
# The actual performance is significantly better than 'eager' so I supress this for clarity
warnings.filterwarnings(
"ignore",
message="Torchinductor does not support code generation for complex operators. Performance may be worse than eager."
)
# This will show up torch.compile but it's harmless
warnings.filterwarnings("ignore", message=".*Profiler function.*will be ignored.*")
# This will show up with DDP via accelerate but this doesn't affect multi GPU
warnings.filterwarnings("ignore", message=".*No device id is provided.*")
# This will show up when multiGPU + compile but has no affect
warnings.filterwarnings("ignore", message=".*Dynamo does not know how to trace.*")
[docs]
class PtyRADSolver(object):
"""
A wrapper class to perform ptychographic reconstruction or hyperparameter tuning.
The PtyRADSolver class initializes the necessary components for ptychographic
reconstruction and provides methods to execute the reconstruction or perform
hyperparameter tuning using Optuna.
Attributes:
params (dict): Dictionary containing all the parameters required for
initialization, loss functions, constraints, model, and optional
hyperparameter tuning.
if_hypertune (bool): A flag to indicate whether hyperparameter tuning should
be performed instead of regular reconstruction. Defaults to False.
verbose (bool): A flag to control the verbosity of the output. Defaults to True unless
if_quiet is set to True.
device (str): The device to run the computations on (e.g., 'cuda' for GPU, 'cpu' for CPU).
Defaults to None to let `accelerate` automatically decide.
"""
[docs]
def __init__(self, params, device=None, seed=None, acc=None, logger=None):
self.params = deepcopy(params)
self.if_hypertune = self.params.get('hypertune_params', {}).get('if_hypertune', False)
self.verbose = not self.params['recon_params']['if_quiet']
self.accelerator = acc
self.use_acc_device = device is None and acc is not None
self.device = self.accelerator.device if self.use_acc_device else device
self.random_seed = seed
self.logger = logger
# model and optimizer are instantiate inside reconstruct() and hypertune()
self.init_initializer()
self.init_loss()
self.init_constraint()
vprint("### Done initializing PtyRADSolver ###")
vprint(" ")
def init_initializer(self):
"""Initializes the variables and objects needed for the reconstruction process."""
# These components are organized into individual methods so we can re-initialize some of them if needed
vprint("### Initializing Initializer ###")
self.init = Initializer(self.params['init_params'], seed=self.random_seed).init_all()
vprint(" ")
def init_loss(self):
"""Initializes the loss function using the provided parameters."""
vprint("### Initializing loss function ###")
loss_params = self.params['loss_params']
# Print loss params
vprint("Active loss types:")
for key, value in loss_params.items():
if value.get('state', False):
vprint(f" {key.ljust(12)}: {value}")
self.loss_fn = CombinedLoss(loss_params, device=self.device)
vprint(" ")
def init_constraint(self):
"""Initializes the constraint function using the provided parameters."""
vprint("### Initializing constraint function ###")
constraint_params = self.params['constraint_params']
# Print constraint params
vprint("Active constraint types:")
for key, value in constraint_params.items():
if value.get('start_iter', None) is not None:
vprint(f" {key.ljust(14)}: {value}")
self.constraint_fn = CombinedConstraint(constraint_params, device=self.device, verbose=self.verbose)
vprint(" ")
def reconstruct(self):
"""Executes the ptychographic reconstruction process by creating the model,
optimizer, and running the reconstruction loop."""
params = self.params
device = self.device
logger = self.logger
# Create the model and optimizer, prepare indices, batches, and output_path
model = PtychoAD(self.init.init_variables, params['model_params'], device=device, verbose=self.verbose)
optimizer = create_optimizer(model.optimizer_params, model.optimizable_params)
if not self.use_acc_device:
indices, batches, output_path = prepare_recon(model, self.init, params)
else:
if params['model_params']['optimizer_params']['name'] == 'LBFGS' and self.accelerator.num_processes >1:
vprint(f"WARNING: Optimizer 'LBFGS' is not supported for multiGPU mode (accelerator.num_processes = {self.accelerator.num_processes}), switch to default optimizer 'Adam'")
params['model_params']['optimizer_params']['name'] = 'Adam'
model.optimizer_params['name'] = 'Adam'
optimizer = create_optimizer(model.optimizer_params, model.optimizable_params)
vprint(f"params['recon_params']['GROUP_MODE'] is set to 'random' because `use_acc_device` = {self.use_acc_device}", verbose=self.verbose)
params['recon_params']['GROUP_MODE'] = 'random'
# `batches` would be replaced by a random DataLoader if we use_acc_device because I haven't figured out how to do specified indices in DataLoader
# In other words, only `random` grouping is available for accelerate-powered multiGPU and mixed-precision
indices, batches, output_path = prepare_recon(model, self.init, params)
ds = IndicesDataset(indices)
dl = torch.utils.data.DataLoader(ds, batch_size = params['recon_params']['BATCH_SIZE']['size'], shuffle = True) # This will do the batching
batches = self.accelerator.prepare(dl) # Note that `batches` is replaced by a DataLoader (accelerate mode) that is also an iterable object
model, optimizer = self.accelerator.prepare(model, optimizer)
vprint(f"len(DataLoader) = num_batches = {len(dl)}, DataLoader.batch_size = {len(indices)//len(dl)}", verbose=self.verbose)
vprint("Note that the DataLoader will be duplicated for each process, while DataLoader.batch_size is the effective batch size (batch_size_per_process * num_process)", verbose=self.verbose)
vprint("The actual batch_size_per_process will be printed below for the reported batches from the main process", verbose=self.verbose)
vprint("For example, batch size = 512 with 2 GPUs (2 processes), the reported/observed batch size per GPU will be 512/2=256.", verbose=self.verbose)
if logger is not None and logger.flush_file:
logger.flush_to_file(log_dir=output_path) # Note that output_path can be None, and there's an internal flag of self.flush_file controls the actual file creation
recon_loop(model, self.init, params, optimizer, self.loss_fn, self.constraint_fn, indices, batches, output_path, acc=self.accelerator)
self.reconstruct_results = model
self.optimizer = optimizer
def hypertune(self):
"""Performs hyperparameter tuning using Optuna."""
import optuna
hypertune_params = self.params['hypertune_params']
params_path = self.params.get('params_path')
n_trials = hypertune_params.get('n_trials')
timeout = hypertune_params.get('timeout')
study_name = hypertune_params.get('study_name')
storage_path = hypertune_params.get('storage_path')
sampler_params = hypertune_params['sampler_params']
pruner_params = hypertune_params['pruner_params']
error_metric = hypertune_params['error_metric']
sampler = create_optuna_sampler(sampler_params)
pruner = create_optuna_pruner(pruner_params)
logger = self.logger
# Print hypertune params
vprint("### Hypertune params ###")
for key, value in hypertune_params.items():
if key == 'tune_params': # Check if 'tune_params' exists
vprint("Active tune_params:")
for param, param_config in value.items():
if param_config.get('state', False): # Print only if 'state' is True
vprint(f" {param.ljust(12)}: {param_config}")
else:
vprint(f"{key.ljust(16)}: {value}")
vprint(" ")
# Check error metric validity
valid_metrics = {"contrast", "loss"}
if error_metric not in valid_metrics:
raise ValueError(f"Invalid error metric: '{error_metric}'. Expected one of {valid_metrics}.")
copy_params = self.params['recon_params']['copy_params']
output_dir = self.params['recon_params']['output_dir'] # This will be later modified
prefix_time = self.params['recon_params']['prefix_time']
prefix = self.params['recon_params']['prefix']
postfix = self.params['recon_params']['postfix']
# Retrieve Optuna's logger
optuna_logger = logging.getLogger("optuna")
optuna_logger.setLevel(logging.INFO)
# Remove any existing console handlers from Optuna's logger to avoid duplicate logs
for handler in optuna_logger.handlers:
if isinstance(handler, logging.StreamHandler): # StreamHandler is the console handler
optuna_logger.removeHandler(handler)
# Redirect Optuna's logger to custom logger
optuna_logger.addHandler(logger.buffer_handler)
optuna_logger.addHandler(logger.console_handler)
# Create a study object and optimize the objective function
study = optuna.create_study(
direction='minimize',
sampler=sampler,
pruner=pruner, # In Optuna default, setting pruner=None will change to a MedianPruner which is a bit odd. In PtyRAD optuna_objective we will skip the pruning if pruner=None.
storage=storage_path, # Specify the storage URL here.
study_name=study_name,
load_if_exists=True)
# Modify the 'output_dir' and reset the params dict specifically for hypertune mode
# Note this will change the params saved with model.pt, but has no effect to the 'copy_params'
prefix = prefix + '_' if prefix != '' else ''
postfix = '_'+ postfix if postfix != '' else ''
# Attach time string if prefix_time is true or non-empty str
if prefix_time is True or (isinstance(prefix_time, str) and prefix_time):
time_str = get_time(prefix_time) # e.g. '20250606'
prefix = f"{time_str}_{prefix}"
sampler_str = sampler_params['name']
pruner_str = '_' + pruner_params['name'] if pruner_params is not None else ''
output_dir += f"/{prefix}hypertune_{sampler_str}{pruner_str}_{error_metric}{postfix}"
self.params['recon_params']['output_dir'] = output_dir
self.params['recon_params']['prefix_time'] = ''
self.params['recon_params']['prefix'] = ''
self.params['recon_params']['postfix'] = ''
if copy_params:
copy_params_to_dir(params_path, output_dir, self.params)
# Set output_dir to None if the user doesn't want to create the output_dir at all
if not copy_params and self.params['recon_params']['SAVE_ITERS'] is None and not hypertune_params['collate_results']:
output_dir = None
if logger is not None and logger.flush_file:
logger.flush_to_file(log_dir=output_dir) # Note that there's an internal flag of self.flush_file controls the actual file creation
optuna_logger.addHandler(logger.file_handler)
study.optimize(lambda trial: optuna_objective(trial, self.params, self.init, self.loss_fn, self.constraint_fn, self.device, self.verbose),
n_trials=n_trials,
timeout=timeout)
vprint(f"Hypertune study is finished due to either (1) n_trials = {n_trials} or (2) study timeout = {timeout} sec has reached")
vprint("Best hypertune params:")
for key, value in study.best_params.items():
vprint(f"\t{key}: {value}")
# Wrapper function to run either "reconstruction" or "hypertune" modes
def run(self):
"""A wrapper method to run the solver in either reconstruction or hyperparameter
tuning mode based on the if_hypertune flag"""
start_t = time_sync()
solver_mode = 'hypertune' if self.if_hypertune else 'reconstruct'
vprint(f"### Starting the PtyRADSolver in {solver_mode} mode ###")
vprint(" ")
if self.if_hypertune:
self.hypertune()
else:
self.reconstruct()
end_t = time_sync()
solver_t = end_t - start_t
time_str = "" if solver_t < 60 else f", or {parse_sec_to_time_str(solver_t)}"
vprint(f"### The PtyRADSolver is finished in {solver_t:.3f} sec{time_str} ###")
vprint(" ")
if self.logger is not None and self.logger.flush_file:
self.logger.close()
# End the process properly when in DDP mode
if dist.is_initialized():
dist.destroy_process_group()
[docs]
class IndicesDataset(Dataset):
"""
The Dataset class used specifically for the multiGPU mode for DDP
"""
[docs]
def __init__(self, indices):
self.indices = indices
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
return self.indices[idx]
###### Reconstruction workflow related functions ######
# These are called within PtyRADSolver, and the detailed walkthrough notebook
[docs]
def create_optimizer(optimizer_params, optimizable_params, verbose=True):
def _fix_optimizer_state_dict_format(optim_state_dict: dict) -> dict:
"""
Fix HDF5-loaded optimizer state dict by:
- Recovering integer keys (HDF5 forces strings as keys).
- Converting param_groups from dicts back to list format, if needed.
- Converting any remaining param indices to lists.
Args:
op_state_dict (dict): Loaded optimizer state dict (e.g. from HDF5).
Returns:
dict: Fixed optimizer state dict.
"""
fixed = {}
for key, val in optim_state_dict.items():
# If the value is a dict (like 'state'), fix its integer keys
if isinstance(val, dict):
fixed_val = {}
for nested_key, nested_val in val.items():
try:
fixed_nested_key = int(nested_key) # Convert '0', '1' etc. to 0, 1
except (ValueError, TypeError):
fixed_nested_key = nested_key # Keep string keys as-is
fixed_val[fixed_nested_key] = nested_val
fixed[key] = fixed_val
else:
fixed[key] = val
# Fix param_groups format if it was accidentally stored as a dict
if isinstance(fixed.get("param_groups"), dict):
param_groups_dict = fixed["param_groups"]
# Convert {0: {...}, 1: {...}} -> [{...}, {...}]
fixed["param_groups"] = [
param_groups_dict[k] for k in sorted(param_groups_dict, key=lambda x: int(x))
]
# Ensure 'params' field is a list of ints, not tensors or ndarrays
for group in fixed.get("param_groups", []):
if isinstance(group.get("params"), torch.Tensor):
group["params"] = group["params"].tolist()
elif isinstance(group.get("params"), np.ndarray):
group["params"] = group["params"].tolist()
return fixed
# Extract the optimizer name and configs
optimizer_name = optimizer_params['name']
optimizer_configs = optimizer_params.get('configs') or {} # if "None" is provided or missing, it'll default an empty dict {}
ptyrad_path = optimizer_params.get('load_state')
vprint(f"### Creating PyTorch '{optimizer_name}' optimizer with configs = {optimizer_configs} ###", verbose=verbose)
# Get the optimizer class from torch.optim
optimizer_class = getattr(torch.optim, optimizer_name, None)
if optimizer_class is None:
raise ValueError(f"Optimizer '{optimizer_name}' is not supported.")
if optimizer_name == 'LBFGS':
vprint("Note: LBFGS optimizer is a quasi-Newton 2nd order optimizer that will run multiple forward passes (default: 20) for 1 update step")
vprint("Note: LBFGS usually converges faster for convex problem with full-batch non-noisy gradients, but each update step is computationally slower")
non_zero_lr = [p['lr'] for p in optimizable_params if p['lr'] != 0]
optimizer_configs['lr'] = min(non_zero_lr)
vprint(f"Note: LBFGS optimizer does not support per parameter learning rate so it'll be set to the minimal non-zero learning rate = {min(non_zero_lr)}")
optimizable_params = [p['params'][0] for p in optimizable_params if p['params'][0].requires_grad] # LBFGS only takes 1 params group as an iterable
optimizer = optimizer_class(optimizable_params, **optimizer_configs)
device = optimizable_params[0]['params'][0].device
if ptyrad_path is not None and isinstance(ptyrad_path, str):
try:
from ptyrad.load import load_ptyrad
optim_state_dict = load_ptyrad(ptyrad_path)['optim_state_dict']
optim_state_dict = _fix_optimizer_state_dict_format(optim_state_dict)
# Convert 'state' to tensors on the right device, while 'param_groups' are kept as generic scalars/arrays/boolean/None/list of int
optim_state_dict['state'] = ndarrays_to_tensors(optim_state_dict['state'], device=device)
optimizer.load_state_dict(optim_state_dict)
vprint(f"Loaded optimizer state from '{ptyrad_path}'", verbose=verbose)
except Exception as e:
vprint(f"Failed to load optimizer state from '{ptyrad_path}': {e}. Using fresh optimizer.", verbose=verbose)
vprint(" ", verbose=verbose)
return optimizer
[docs]
def prepare_recon(model, init, params):
"""
Prepares the indices, batches, and output path for ptychographic reconstruction.
This function parses the necessary parameters and generates the indices for scanning,
creates batches based on the probe positions, and sets up the output directory for
saving results. It also plots and saves a figure illustrating the grouping of probe
positions.
Args:
model (PtychoAD): The ptychographic model containing the object, probe,
probe positions, and other relevant parameters.
init (Initializer): The initializer object containing the initialized variables
needed for reconstruction.
params (dict): A dictionary containing various parameters needed for the
reconstruction process, including experimental parameters, loss parameters,
constraint parameters, and reconstruction settings.
Returns:
tuple: A tuple containing the following:
- indices (numpy.ndarray): Array of indices for scanning positions.
- batches (list of numpy.ndarray): List of batches where each batch contains
indices grouped according to the selected grouping mode.
- output_path (str): The path to the directory where reconstruction results
and figures will be saved.
"""
verbose = not params['recon_params']['if_quiet']
vprint("### Generating indices, batches, and output_path ###", verbose=verbose)
# Parse the variables
init_variables = init.init_variables
init_params = init.init_params # These could be modified by Optuna, hence can be different from params['init_params]
params_path = params.get('params_path')
loss_params = params.get('loss_params')
constraint_params = params.get('constraint_params')
recon_params = params.get('recon_params')
INDICES_MODE = recon_params['INDICES_MODE'].get("mode")
subscan_slow = recon_params['INDICES_MODE'].get("subscan_slow")
subscan_fast = recon_params['INDICES_MODE'].get("subscan_fast")
GROUP_MODE = recon_params['GROUP_MODE']
SAVE_ITERS = recon_params['SAVE_ITERS']
batch_size = recon_params['BATCH_SIZE'].get("size")
grad_accumulation = recon_params['BATCH_SIZE'].get("grad_accumulation")
output_dir = recon_params['output_dir']
recon_dir_affixes = recon_params['recon_dir_affixes']
copy_params = recon_params['copy_params']
if_hypertune = params.get('hypertune_params', {}).get('if_hypertune', False)
# Generate the indices, batches, and fig_grouping
pos = (model.crop_pos + model.opt_probe_pos_shifts).detach().cpu().numpy()
probe_int = model.get_complex_probe_view().abs().pow(2).sum(0).detach().cpu().numpy()
dx = init_variables['dx']
d_out = get_blob_size(dx, probe_int, output='d90', verbose=verbose) # d_out unit is in Ang
indices = select_scan_indices(init_variables['N_scan_slow'], init_variables['N_scan_fast'], subscan_slow=subscan_slow, subscan_fast=subscan_fast, mode=INDICES_MODE, verbose=verbose)
batches = make_batches(indices, pos, batch_size, mode=GROUP_MODE, seed=init_variables['random_seed'], verbose=verbose)
fig_grouping = plot_pos_grouping(pos, batches, circle_diameter=d_out/dx, diameter_type='90%', dot_scale=1, show_fig=False, pass_fig=True)
vprint(f"The effective batch size (i.e., how many probe positions are simultaneously used for 1 update of ptychographic parameters) is batch_size * grad_accumulation = {batch_size} * {grad_accumulation} = {batch_size*grad_accumulation}", verbose=verbose)
# Create the output path, save fig_grouping, and copy params file
if SAVE_ITERS is not None:
output_path = make_output_folder(output_dir, indices, init_params, recon_params, model, constraint_params, loss_params, recon_dir_affixes, verbose=verbose)
fig_grouping.savefig(safe_filename(output_path + "/summary_pos_grouping.png"))
if copy_params and not if_hypertune:
# Save params.yml to separate reconstruction folder for normal mode. Hypertune mode params copying is handled at hypertune()
copy_params_to_dir(params_path, output_path, params, verbose=verbose)
else:
output_path = None
plt.close(fig_grouping)
vprint(" ", verbose=verbose)
return indices, batches, output_path
[docs]
def select_scan_indices(N_scan_slow, N_scan_fast, subscan_slow=None, subscan_fast=None, mode='full', verbose=True):
N_scans = N_scan_slow * N_scan_fast
vprint(f"Selecting indices with the '{mode}' mode ", verbose=verbose)
# Generate flattened indices for the entire FOV
if mode == 'full':
indices = np.arange(N_scans)
return indices
# Set default values for subscan params
if subscan_slow is None and subscan_fast is None:
vprint("Subscan params are not provided, setting subscans to default as half of the total scan for both directions", verbose=verbose)
subscan_slow = N_scan_slow//2
subscan_fast = N_scan_fast//2
# Generate flattened indices for the center rectangular region
if mode == 'center':
vprint(f"Choosing subscan with {(subscan_slow, subscan_fast)}", verbose=verbose)
start_row = (N_scan_slow - subscan_slow) // 2
end_row = start_row + subscan_slow
start_col = (N_scan_fast - subscan_fast) // 2
end_col = start_col + subscan_fast
indices = np.array([row * N_scan_fast + col for row in range(start_row, end_row) for col in range(start_col, end_col)])
# Generate flattened indices for the entire FOV with sub-sampled indices
elif mode == 'sub':
vprint(f"Choosing subscan with {(subscan_slow, subscan_fast)}", verbose=verbose)
full_indices = np.arange(N_scans).reshape(N_scan_slow, N_scan_fast)
subscan_slow_id = np.linspace(0, N_scan_slow-1, num=subscan_slow, dtype=int)
subscan_fast_id = np.linspace(0, N_scan_fast-1, num=subscan_fast, dtype=int)
slow_grid, fast_grid = np.meshgrid(subscan_slow_id, subscan_fast_id, indexing='ij')
indices = full_indices[slow_grid, fast_grid].reshape(-1)
else:
raise ValueError(f"Indices selection mode {mode} not implemented, please use either 'full', 'center', or 'sub'")
return indices
[docs]
def make_batches(indices, pos, batch_size, mode='random', seed=None, verbose=True):
''' Make batches from input indices '''
# Input:
# indices: int, (Ns,) array. indices could be a subset of all indices.
# pos: int/float (N,2) array. Always pass in the full positions.
# batch_size: int. The number of indices of each mini-batch
# mode: str. Choose between 'random', 'compact', or 'sparse' grouping.
# Output:
# batches: A list of `num_batch` arrays, or [batch0, batch1, ...]
# Note:
# The actual batch size would only be "close" if it's not divisible by len(indices) for 'random' grouping
# For 'compact' or 'sparse', it's generally fluctuating around the specified batch size
# 'sparse' can be quite slow for large scan positions (like 256x256 takes more than 10min, and 128x128 takes more than 1min on a CPU)
# PtychoShelves automatically switches to 'random' for len(pos) > 1e3 and relying on the random statistics
# To check the correctness of each grouping, you may visualize the pos
# Also we want to make sure we're not missing any indices, so we can do:
#
# flatten_indices = np.concatenate(batches)
# flatten_indices.sort()
# indices.sort()
# all(flatten_indices == indices)
from time import time
try:
from sklearn.cluster import MiniBatchKMeans
except ImportError as e:
missing_package = str(e).split()[-1]
vprint(f"### {missing_package} is not available, group mode set to 'random'. 'scikit-learn' is needed for 'sparse' and 'compact' ###")
mode = 'random'
if len(indices) > len(pos):
raise ValueError(f"len(indices) = '{len(indices)}' is larger than total number of probe positions ({len(pos)}), check your indices generation params")
if indices.max() > len(pos):
raise ValueError(f"Maximum index '{indices.max()}' is larger than total number of probe positions ({len(pos)}), check your indices generation params")
num_batch = len(indices) // batch_size
t_start = time()
if mode == 'random':
rng = np.random.default_rng(seed=seed)
shuffled_indices = rng.permutation(indices) # This will make a shuffled copy
random_batches = np.array_split(shuffled_indices, num_batch)
vprint(f"Generated {num_batch} '{mode}' groups of ~{batch_size} scan positions in {time() - t_start:.3f} sec", verbose=verbose)
return random_batches
else: # Either 'compact' or 'sparse'
# Choose the selected pos from indices
pos_s = pos[indices]
# Kmeans for clustering
kmeans = MiniBatchKMeans(init="k-means++", n_init=10, n_clusters=num_batch, max_iter=10, batch_size=3072, random_state=seed)
kmeans.fit(pos_s)
labels = kmeans.labels_
# Separate data points into groups
compact_batches = []
for batch_idx in range(num_batch):
batch_indices_s = np.where(labels == batch_idx)[0]
compact_batches.append(indices[batch_indices_s])
if mode == 'compact':
vprint(f"Generated {num_batch} '{mode}' groups of ~{batch_size} scan positions in {time() - t_start:.3f} sec", verbose=verbose)
return compact_batches
else: # 'sparse' mode
from scipy.spatial.distance import cdist
sparse_indices = indices.copy() # Make a deep copy of indices so that we may pop elements from sparse_indices later
# Initialize the list to store groups
sparse_batches = []
# Calculate the centroid for each compact group as initial start for sparse groups
# The idea is the centroids of each compact group are naturally sparse
centroids = np.array([np.mean(pos[cbatch], axis=0) for cbatch in compact_batches])
pairwise_distances = cdist(pos, pos) # Calculate the dist for ALL pos can keep the absolute index and skip the conversion between indexing
used_indices = [] # This list stores the indices used for initialization of the sparse groups
# Find the indices closest to the centroids of compact groups, these indices are the initial point for each sparse group
for batch_idx in range(num_batch):
distances = np.linalg.norm(pos_s - centroids[batch_idx], axis=1) # Note that this distances is only for selected pos (pos_s = pos[indices])
closest_idx_s = np.argmin(distances) # closest_idx_s is the position of min distances
closest_idx = indices[closest_idx_s] # closest_idx is the actual index that is closest to the centroid
sparse_batches.append([closest_idx])
used_indices.append(closest_idx_s)
sparse_indices = np.delete(sparse_indices, used_indices) # Delete the used_indices after the entire loop, this helps keep indexing correct and consistent
# Deleting elements in a loop would make indexing very challenging
# Iterate through remaining points
for idx in sparse_indices:
min_distances = []
# Iterate through groups
for batch_idx in range(num_batch):
distances = pairwise_distances[sparse_batches[batch_idx], idx]
min_distances.append(np.min(distances))
max_group_index = np.argmax(min_distances)
# Add the point to the group with the farthest minimal distance
sparse_batches[max_group_index].append(idx)
# Final check because this procedure is fairly complicated
flatten_indices = np.concatenate(sparse_batches)
flatten_indices.sort()
indices.sort()
assert all(flatten_indices == indices), "Sorry, something went wrong with the sparse grouping, please try 'random' for now"
vprint(f"Generated {num_batch} '{mode}' groups of ~{batch_size} scan positions in {time() - t_start:.3f} sec", verbose=verbose)
# Final process to make batches a list of arrays
sparse_batches = [np.array(batch) for batch in sparse_batches]
return sparse_batches
[docs]
def parse_torch_compile_configs(configs):
"""
Convert user-facing CompilerConfigs to dict suitable for torch.compile
Note:
The params.yaml defines as 'enable': bool = False,
while torch.compile takes only 'disable': bool, so a conversion is needed.
"""
if 'enable' in configs:
configs['disable'] = not configs.pop('enable')
return configs
[docs]
def recon_loop(model, init, params, optimizer, loss_fn, constraint_fn, indices, batches, output_path, acc=None):
"""
Executes the iterative optimization loop for ptychographic reconstruction.
This function performs the iterative reconstruction process by optimizing the model
parameters over a specified number of iterations. During each iteration, it applies
the loss and constraint functions, updates the model, and logs the loss values.
Intermediate results are saved at specified intervals, and a summary is plotted.
Args:
model (PtychoAD): The ptychographic model containing the parameters and variables
to be optimized.
init (Initializer): The initializer object containing the initialized variables
needed for reconstruction.
params (dict): A dictionary containing various parameters for the reconstruction
process, including experimental parameters, source parameters, loss parameters,
constraint parameters, and reconstruction settings.
optimizer (torch.optim.Optimizer): The optimizer used to update the model parameters.
loss_fn (CombinedLoss): The loss function object used to compute the loss during
each iteration.
constraint_fn (CombinedConstraint): The constraint function object applied during
each iteration to enforce specific constraints on the model.
indices (numpy.ndarray): Array of indices for scanning positions.
batches (list of numpy.ndarray): List of batches where each batch contains indices
grouped according to the selected grouping mode.
output_path (str): The path to the directory where reconstruction results and
figures will be saved.
Returns:
list: A list of tuples, where each tuple contains the iteration number, the loss
value for that iteration, and the time taken for that iteration.
"""
# Parse the variables
init_variables = init.init_variables
recon_params = params.get('recon_params')
NITER = recon_params['NITER']
SAVE_ITERS = recon_params['SAVE_ITERS']
grad_accumulation = recon_params['BATCH_SIZE'].get("grad_accumulation", 1)
selected_figs = recon_params['selected_figs']
compiler_configs = parse_torch_compile_configs(recon_params['compiler_configs'])
verbose = not recon_params['if_quiet']
# Use the method on the wrapped model (DDP) if it exists
model_instance = model.module if hasattr(model, "module") else model
vprint("### Start the PtyRAD iterative ptycho reconstruction ###", verbose=verbose)
# Optimization loop
for niter in range(1,NITER+1):
# Toggle the grad calculation to enable or disable AD update on tensors at certain iterations
toggle_grad_requires(model_instance, niter, verbose)
# Apply torch.compile to `recon_step``
if niter in model_instance.compilation_iters: # compilation_iters always contain niter=1
vprint(f"Setting up PyTorch compiler with {compiler_configs}", verbose=verbose)
torch._dynamo.reset()
recon_step_compiled = torch.compile(recon_step, **compiler_configs)
batch_losses = recon_step_compiled(batches, grad_accumulation, model, optimizer, loss_fn, constraint_fn, niter, verbose=verbose, acc=acc)
# Only log the main process
if acc is None or acc.is_main_process:
## Saving intermediate results
if SAVE_ITERS is not None and niter % SAVE_ITERS == 0:
with torch.no_grad():
# Note that `params` stores the original params from the configuration file,
# while `model` contains the actual params that could be updated by meas_crop, meas_pad, or meas_resample
save_results(output_path, model_instance, params, optimizer, niter, indices, batch_losses)
## Saving summary
plot_summary(output_path, model_instance, niter, indices, init_variables, selected_figs=selected_figs, show_fig=False, save_fig=True, verbose=verbose)
vprint(f"### Finished {NITER} iterations, averaged iter_t = {np.mean(model_instance.iter_times):.5g} with std = {np.std(model_instance.iter_times):.3f} ###", verbose=verbose)
vprint(" ", verbose=verbose)
[docs]
def recon_step(batches, grad_accumulation, model, optimizer, loss_fn, constraint_fn, niter, verbose=True, acc=None):
"""
Performs one iteration (or step) of the ptychographic reconstruction in the optimization loop.
This function executes a single iteration of the reconstruction process, including:
- Computing the forward model to generate diffraction patterns.
- Calculating the loss by comparing the modeled and measured diffraction patterns.
- Performing a backward pass to compute gradients and update the model parameters using the optimizer.
- Applying iteration-wise constraints after all batches are processed.
Args:
batches (list of numpy.ndarray): List of batches where each batch contains indices
grouped according to the selected grouping mode.
model (PtychoAD): The ptychographic model containing the parameters and variables
to be optimized.
optimizer (torch.optim.Optimizer): The optimizer used to update the model parameters.
loss_fn (CombinedLoss): The loss function object used to compute the loss for each batch.
constraint_fn (CombinedConstraint): The constraint function object applied after each iteration
to enforce specific constraints on the model.
niter (int): The current iteration number in the optimization loop.
verbose (bool, optional): If True, prints progress information during the batch processing.
Defaults to True.
Returns:
tuple: A tuple containing:
- batch_losses (dict): A dictionary where each key corresponds to a loss component name,
and the value is a list of loss values computed for each batch in the iteration.
- iter_t (float): The total time taken to complete the iteration.
"""
batch_losses = {name: [] for name in loss_fn.loss_params.keys()}
start_iter_t = time_sync()
# Use the method on the wrapped model (DDP) if it exists
model_instance = model.module if hasattr(model, "module") else model
# Run the iteration with closure for LBFGS optimizer
if isinstance(optimizer, torch.optim.LBFGS):
# Make nested list of batches for the closure with internal grad accumulation over mini-batches
num_batch = len(batches)
batch_indices = np.arange(num_batch)
if model.random_seed is not None:
set_random_seed(seed=model.random_seed + niter) # This ensures batch_indices is different for each iter in a reproducible way
np.random.shuffle(batch_indices)
accu_batch_indices = np.array_split(batch_indices,num_batch//grad_accumulation)
def closure():
optimizer.zero_grad()
total_loss = 0
# Run grad accumulation inside the closure for LBFGS, note that each closure is ideally 1 full iter with grad_accu
for batch_idx in accu_batch_idx:
batch = batches[batch_idx]
model_DP = model(batch) # Forward pass is handled automatically by DDP, but methods/attributes should use the unwrapped model
measured_DP = model_instance.get_measurements(batch)
object_patches = model_instance._current_object_patches
loss_batch, losses = loss_fn(model_DP, measured_DP, object_patches, model_instance.omode_occu)
total_loss += loss_batch # LBFGS uses the returned loss to perform the line-search so it's better to return the loss that's associated to all the batches
total_loss = total_loss / len(accu_batch_idx)
acc.backward(total_loss) if acc is not None else total_loss.backward()
return total_loss, losses
# Iterate through all accumulated batches. accu_batches = [[batch1],[batch2],[batch3]...], batches = [[accu_batches1],[accu_batches2],[accu_batches3]...]
for accu_batch_idx in accu_batch_indices:
optimizer.step(lambda: closure()[0])
# This extra evaluation on accumulated batches is just to get the `losses` for logging purpose
_, losses = closure()
optimizer.zero_grad()
# Clear the model cache after the mini-batch
model_instance.clear_cache()
# Append losses and log batch progress
if acc is not None:
acc.wait_for_everyone()
for loss_name, loss_value in zip(loss_fn.loss_params.keys(), losses):
batch_losses[loss_name].append(loss_value.detach().cpu().numpy())
# Start mini-batch optimization for all other optimizers doesn't require a closure
else:
optimizer.zero_grad() # Since PyTorch 2.0 the default behavior is set_to_none=True for performance https://github.com/pytorch/pytorch/issues/92656
for batch_idx, batch in enumerate(batches):
start_batch_t = time_sync()
# Compute forward pass and loss (wrapped in autocast if accelerate is enabled)
loss_batch, losses = compute_loss(batch, model, model_instance, loss_fn, acc)
# Normalize the `loss_batch`` before populating the gradients
# We only want to scale the `loss_batch` so the grad/update is scaled accordingly
# while keeping `losses` to be batch-size-independent for logging purpose
loss_batch = loss_batch / grad_accumulation
# Perform backward pass
acc.backward(loss_batch) if acc is not None else loss_batch.backward()
# Perform the optimizer step when batch_idx + 1 is divisible by grad_accumulation or it's the last batch
if (batch_idx + 1) % grad_accumulation == 0 or (batch_idx + 1) == len(batches):
if acc is not None:
acc.wait_for_everyone()
optimizer.step()
optimizer.zero_grad()
batch_t = time_sync() - start_batch_t
# Clear the model cache after the mini-batch
model_instance.clear_cache()
# Append losses and log batch progress
if acc is not None:
acc.wait_for_everyone()
for loss_name, loss_value in zip(loss_fn.loss_params.keys(), losses):
batch_losses[loss_name].append(loss_value.detach().cpu().numpy())
if batch_idx in np.linspace(0, len(batches)-1, num=6, dtype=int):
vprint(f"Done batch {batch_idx+1} with {len(batch)} indices ({batch[:5].tolist()}...) in {batch_t:.3f} sec", verbose=verbose)
constraint_fn(model_instance, niter)
iter_t = time_sync() - start_iter_t
model_instance.loss_iters.append((niter, loss_logger(batch_losses, niter, iter_t, verbose=verbose)))
model_instance.iter_times.append(iter_t)
model_instance.dz_iters.append((niter, model_instance.opt_slice_thickness.detach().cpu().numpy()))
model_instance.avg_tilt_iters.append((niter, model_instance.opt_obj_tilts.detach().mean(0).cpu().numpy()))
return batch_losses
[docs]
def toggle_grad_requires(model, niter, verbose=True):
"""Toggle requires_grad based on start and end iteration for each optimizable tensor."""
vprint(" ", verbose=verbose) # Empty line for the start of each iteration
optimizable_tensors = model.optimizable_tensors
for param_name in model.optimizable_tensors.keys():
start_iter = model.start_iter.get(param_name)
end_iter = model.end_iter.get(param_name)
# Determine if gradients should be enabled
grad_started = start_iter is not None and niter >= start_iter
grad_ended = end_iter is not None and niter + 1 > end_iter # end_iter is exclusive
requires_grad = grad_started and not grad_ended
optimizable_tensors[param_name].requires_grad = requires_grad
vprint(f"Iter: {niter}, {param_name}.requires_grad = {requires_grad}", verbose=verbose)
[docs]
def compute_loss(batch, model, model_instance, loss_fn, acc=None):
"""Compute the model output and loss, with optional support for accelerate's autocast."""
if acc is not None:
with acc.autocast():
model_DP = model(batch)
measured_DP = model_instance.get_measurements(batch)
object_patches = model_instance._current_object_patches
loss_batch, losses = loss_fn(model_DP, measured_DP, object_patches, model_instance.omode_occu)
else:
model_DP = model(batch)
measured_DP = model_instance.get_measurements(batch)
object_patches = model_instance._current_object_patches
loss_batch, losses = loss_fn(model_DP, measured_DP, object_patches, model_instance.omode_occu)
return loss_batch, losses
[docs]
@torch.compiler.disable
def loss_logger(batch_losses, niter, iter_t, verbose=True):
"""
Logs and summarizes the loss values for an iteration during the ptychographic reconstruction.
This function computes the average loss for each loss component across all batches in the
current iteration. It then logs the total loss, the individual loss components, and the
time taken for the iteration. The function also returns the total loss for the iteration.
Args:
batch_losses (dict): A dictionary where each key corresponds to a loss component name,
and the value is a list of loss values computed for each batch in the iteration.
niter (int): The current iteration number in the optimization loop.
iter_t (float): The total time taken to complete the iteration, in seconds.
verbose (bool, optional): If True, prints the loss summary to the console. Defaults to True.
Returns:
float: The total loss for the current iteration, computed as the sum of the average
loss values for each component.
"""
avg_losses = {name: np.mean(values) for name, values in batch_losses.items()}
loss_str = ', '.join([f"{name}: {value:.4f}" for name, value in avg_losses.items()])
vprint(f"Iter: {niter}, Total Loss: {sum(avg_losses.values()):.4f}, {loss_str}, in {parse_sec_to_time_str(iter_t)}", verbose=verbose)
loss_iter = sum(avg_losses.values())
return loss_iter
###### Hypertune / Optuna related functions ######
# These are called inside PtyRADSolver.hypertune
[docs]
def create_optuna_sampler(sampler_params, verbose=True):
# Note that this function supports all Optuna samplers except "PartialFixedSampler" because it requires a sequential sampler setup
# Different samplers have different available configurations so please refer to https://optuna.readthedocs.io/en/stable/reference/samplers/index.html for more details
# For example, GridSampler would need to pass in the 'search_space' so you need to explicitly specify every target variable range in 'sampler_params' : {'name': GridSampler, 'configs': {'search_space': {'optimizer': ['Adam', 'AdamW', 'RMSprop'], 'batch_size': [16,24,32,64,128,256,512], 'oalr': [1.0e-4, 1.0e-3, 1.0e-2], 'oplr': [1.0e-4, 1.0e-3, 1.0e-2]}}}
# Also the GridSampler would only use the defined search_space and will ignore the range/step setup in 'tune_params'.
# A handy usage of GridSampler is to exhaust some combination of reconstruction parameters
# The recommmendation setup for PtyRAD is `sampler_params = {'name': 'TPESampler', 'configs': {'multivariate':True, 'group':True, 'constant_liar':True}}`
import optuna
# Extract the sampler name and configs
sampler_name = sampler_params['name']
sampler_configs = sampler_params.get('configs') or {} # if "None" is provided or missing, it'll default an empty dict {}
vprint(f"### Creating Optuna '{sampler_name}' sampler with configs = {sampler_configs} ###", verbose=verbose)
# Get the optimizer class from optuna.samplers
sampler_class = getattr(optuna.samplers, sampler_name, None)
if sampler_class is None or sampler_name == 'ParitalFixedSampler':
raise ValueError(f"Optuna sampler '{sampler_name}' is not supported.")
sampler = sampler_class(**sampler_configs)
vprint(" ", verbose=verbose)
return sampler
[docs]
def create_optuna_pruner(pruner_params, verbose=True):
# Note that this function supports all Optuna pruners except "WilcoxonPruner" because it requires a nested evaluation setup
# Different pruners have different available configurations so please refer to https://optuna.readthedocs.io/en/stable/reference/pruners.html for more details
# PatientPruner and PercentilePruner have required fields that need to be passed in with 'configs'
# For PatientPruner that wraps around a base pruner, you need to specify the base pruner name and configs in a nested way
# pruner_params = {'name': 'PatientPruner',
# 'configs': {'patience': 1,
# 'wrapped_pruner_configs':{'name': 'MedianPruner',
# 'configs': {}}}}
# If you're testing pruner with some other objective function, note that the objective function must contain iterative steps for you to prune (early termination)
# The recommendation setup for PtyRAD is `pruner_params = {'name': 'HyperbandPruner', 'configs': {'min_resource': 5, 'reduction_factor': 2}}`
import optuna
if pruner_params is None:
return None
else:
# Extract the pruner name and configs
pruner_name = pruner_params['name']
pruner_configs = pruner_params.get('configs') or {} # if "None" is provided or missing, it'll default an empty dict {}
vprint(f"### Creating Optuna '{pruner_name}' pruner with configs = {pruner_configs} ###", verbose=verbose)
# Get the pruner class from optuna.pruners
pruner_class = getattr(optuna.pruners, pruner_name, None)
if pruner_class is None or pruner_name == 'WilcoxonPruner':
raise ValueError(f"Optuna pruner '{pruner_name}' is not supported.")
elif pruner_name == 'NopPruner':
raise ValueError("Optuna NopPruner is an empty pruner, please set pruner_params = None if you don't want to prune.")
elif pruner_name == 'PatientPruner':
wrapped_pruner = create_optuna_pruner(pruner_configs['wrapped_pruner_configs'], verbose=verbose)
pruner_configs.pop('wrapped_pruner_configs', None) # Delete the wrapped_pruner_configs
pruner = pruner_class(wrapped_pruner, **pruner_configs)
else:
pruner = pruner_class(**pruner_configs)
vprint(" ", verbose=verbose)
return pruner
# Major Optuna routine
[docs]
def optuna_objective(trial, params, init, loss_fn, constraint_fn, device='cuda', verbose=False):
"""
Objective function for Optuna hyperparameter tuning in ptychographic reconstruction.
This function is used by Optuna to optimize the hyperparameters of the ptychographic reconstruction
process. The function updates the reconstruction parameters based on the trial's suggestions and
runs the reconstruction loop to evaluate the performance. The function also implements Optuna's
pruning mechanism to stop unpromising trials early.
Args:
trial (optuna.trial.Trial): A trial object that suggests hyperparameter values and handles
pruning.
params (dict): A dictionary containing all the parameters for the reconstruction, including
experimental parameters, model parameters, and hyperparameter tuning configurations.
init (Initializer): An instance of the Initializer class that holds initialized variables
and methods for updating them based on the trial's suggestions.
loss_fn (CombinedLoss): The loss function object that calculates the reconstruction loss.
constraint_fn (CombinedConstraint): The constraint function object that applies constraints
during optimization.
device (str, optional): The device to run the reconstruction on, e.g., 'cuda'. Defaults to 'cuda'.
verbose (bool, optional): If True, enables verbose output. Defaults to False.
Returns:
float: The total loss for the final iteration of the reconstruction process, used by Optuna
to evaluate the trial's performance.
Raises:
optuna.exceptions.TrialPruned: Raised when the trial should be pruned based on the
intermediate results.
"""
import optuna
init.verbose = verbose # This would affect the initialization printing for each hypertune trials
params = deepcopy(params)
# Parse the recon_params
recon_params = params.get('recon_params')
NITER = recon_params['NITER']
SAVE_ITERS = recon_params['SAVE_ITERS']
grad_accumulation = recon_params['BATCH_SIZE'].get("grad_accumulation", 1)
output_dir = recon_params['output_dir']
selected_figs = recon_params['selected_figs']
compiler_configs = parse_torch_compile_configs(recon_params['compiler_configs'])
# Parse the hypertune_params
hypertune_params = params['hypertune_params']
collate_results = hypertune_params['collate_results']
append_params = hypertune_params['append_params']
error_metric = hypertune_params['error_metric']
tune_params = hypertune_params['tune_params']
trial_id = 't' + str(trial.number).zfill(4)
params['recon_params']['prefix'] += trial_id
## Currently only re-initialize the required parts for performance, but once there're too many correlated params need to be re-initialized,
## we might put the entire initialization inside optuna_objective for readability, although init_measurements for every trial would be a large overhead.
## TODO After the refactoring of `init_calibration` and better dx setting logic, it's possible to include more optimizable params without exploding the logic here
# Batch size
if tune_params['batch_size']['state']:
vname = 'batch_size'
vparams = tune_params[vname]
params['recon_params']['BATCH_SIZE']['size'] = get_optuna_suggest(trial, vparams['suggest'], vname, vparams['kwargs'])
# Optimizer
if tune_params['optimizer']['state']:
vname = 'optimizer'
vparams = tune_params[vname]
optim_name = get_optuna_suggest(trial, vparams['suggest'], vname, vparams['kwargs'])
params['model_params']['optimizer_params']['name'] = optim_name
params['model_params']['optimizer_params']['configs'] = vparams['kwargs']['optim_configs'].get(optim_name, {}) # Update optimizer_configs if the user has specified them for each optimizer
# learning rates
lr_to_tensor = {'plr': 'probe', 'oalr': 'obja', 'oplr': 'objp', 'slr': 'probe_pos_shifts', 'tlr': 'obj_tilts', 'dzlr': 'slice_thickness'}
for vname in ['plr', 'oalr', 'oplr', 'slr', 'tlr', 'dzlr']:
if tune_params[vname]['state']:
vparams = tune_params[vname]
params['model_params']['update_params'][lr_to_tensor[vname]]['lr'] = get_optuna_suggest(trial, vparams['suggest'], vname, vparams['kwargs'])
# dx (calibration)
if tune_params['dx']['state']:
vname = 'dx'
vparams = tune_params[vname]
init.init_params['meas_calib'] = {'mode': vname, 'value': get_optuna_suggest(trial, vparams['suggest'], vname, vparams['kwargs'])}
init.init_calibration()
init.set_variables_dict()
init.init_probe()
init.init_pos()
init.init_obj()
init.init_H()
# probe_params (pmode_max, conv_angle, defocus, z_shift, c3, c5)
remake_probe = False
for vname in ['pmode_max', 'conv_angle', 'defocus', 'z_shift', 'c3', 'c5']:
if tune_params[vname]['state']:
vparams = tune_params[vname]
init.init_params['probe_' + vname] = get_optuna_suggest(trial, vparams['suggest'], vname, vparams['kwargs'])
remake_probe = True
if remake_probe:
init.init_probe()
# Nlayer
if tune_params['Nlayer']['state']:
vname = 'Nlayer'
vparams = tune_params[vname]
init.init_params['obj_Nlayer'] = get_optuna_suggest(trial, vparams['suggest'], vname, vparams['kwargs'])
init.init_obj()
# slice_thickness
if tune_params['dz']['state']:
vname = 'dz'
vparams = tune_params[vname]
init.init_params['obj_slice_thickness'] = get_optuna_suggest(trial, vparams['suggest'], vname, vparams['kwargs'])
init.set_variables_dict()
init.init_obj() # Currently the slice_thickness only modifies the printed obj_extent value, but eventually we'll add obj resampling so let's keep it for now
init.init_H()
# scan_affine
scan_affine = []
scan_affine_init = params['init_params']['pos_scan_affine']
if scan_affine_init is not None:
default_affine = {'scale':scan_affine_init[0], 'asymmetry':scan_affine_init[1], 'rotation':scan_affine_init[2], 'shear':scan_affine_init[3]}
else:
default_affine = {'scale':1, 'asymmetry':0, 'rotation':0, 'shear':0}
for vname in ['scale', 'asymmetry', 'rotation', 'shear']:
if tune_params[vname]['state']:
vparams = tune_params[vname]
scan_affine.append(get_optuna_suggest(trial, vparams['suggest'], vname, vparams['kwargs']))
else:
scan_affine.append(default_affine[vname])
if scan_affine != [1,0,0,0]:
init.init_params['pos_scan_affine'] = scan_affine
init.init_pos()
init.init_obj() # Update obj initialization because the scan range has changed
# tilt (This will override the current tilts and force it to be a global tilt (2,1))
obj_tilts = []
for vname in ['tilt_y', 'tilt_x']:
if tune_params[vname]['state']:
vparams = tune_params[vname]
obj_tilts.append(get_optuna_suggest(trial, vparams['suggest'], vname, vparams['kwargs']))
else:
obj_tilts.append(0)
obj_tilts = [obj_tilts] # Make it into [[tilt_y, tilt_x]]
if obj_tilts != [[0,0]]:
init.init_variables['obj_tilts'] = obj_tilts # No need to update init_params['tilt_params'] because the pass-in value is only used when `tilt_params = 'custom'`
# Create the model and optimizer, prepare indices, batches, and output_path
model = PtychoAD(init.init_variables, params['model_params'], device=device, verbose=verbose)
optimizer = create_optimizer(model.optimizer_params, model.optimizable_params, verbose=verbose)
indices, batches, output_path = prepare_recon(model, init, params)
# Optimization loop
for niter in range(1, NITER+1):
# Toggle the grad calculation to enable or disable AD update on tensors at certain iterations
toggle_grad_requires(model, niter, verbose)
# Apply torch.compile to `recon_step``
if niter in model.compilation_iters: # compilation_iters always contain niter=1
vprint(f"Setting up PyTorch compiler with {compiler_configs}", verbose=verbose)
torch._dynamo.reset()
recon_step_compiled = torch.compile(recon_step, **compiler_configs)
if model.random_seed is not None:
set_random_seed(seed=model.random_seed + niter) # This ensures the batches order are different for each iter in a reproducible way
shuffle(batches)
batch_losses = recon_step_compiled(batches, grad_accumulation, model, optimizer, loss_fn, constraint_fn, niter, verbose=verbose)
## Saving intermediate results
if SAVE_ITERS is not None and niter % SAVE_ITERS == 0:
save_results(output_path, model, params, optimizer, niter, indices, batch_losses, collate_str='')
plot_summary(output_path, model, niter, indices, init.init_variables, selected_figs=selected_figs, collate_str='', show_fig=False, save_fig=True, verbose=verbose)
## Pruning logic for optuna
if hypertune_params['pruner_params'] is not None:
optuna_error = compute_optuna_error(model, indices, error_metric)
trial.report(optuna_error, niter)
# Handle pruning based on the intermediate value.
if trial.should_prune():
# Save the current results of the pruned trials
params_str = parse_hypertune_params_to_str(trial.params) if append_params else ''
collate_str = f"_error_{optuna_error:.5f}_{trial_id}{params_str}"
if collate_results:
save_results(output_dir, model, params, optimizer, niter, indices, batch_losses, collate_str=collate_str)
plot_summary(output_dir, model, niter, indices, init.init_variables, selected_figs=selected_figs, collate_str=collate_str, show_fig=False, save_fig=True, verbose=verbose)
raise optuna.exceptions.TrialPruned()
## Final optuna_error evaluation (only needed if pruner never ran)
if hypertune_params['pruner_params'] is None:
optuna_error = compute_optuna_error(model, indices, error_metric)
## Saving collate results and figs of the finished trials
params_str = parse_hypertune_params_to_str(trial.params) if append_params else ''
collate_str = f"_error_{optuna_error:.5f}_{trial_id}{params_str}"
if collate_results:
save_results(output_dir, model, params, optimizer, niter, indices, batch_losses, collate_str=collate_str)
plot_summary(output_dir, model, niter, indices, init.init_variables, selected_figs=selected_figs, collate_str=collate_str, show_fig=False, save_fig=True, verbose=verbose)
vprint(f"### Finished {NITER} iterations, averaged iter_t = {np.mean(model.iter_times):.3g} sec ###", verbose=verbose)
vprint(" ", verbose=verbose)
return optuna_error
[docs]
def get_optuna_suggest(trial, suggest, name, kwargs):
if suggest == 'cat':
return trial.suggest_categorical(name, **kwargs)
elif suggest == 'int':
return trial.suggest_int(name, **kwargs)
elif suggest == 'float':
return trial.suggest_float(name, **kwargs)
else:
raise (f"Optuna trail.suggest method '{suggest}' is not supported.")
[docs]
def compute_optuna_error(model, indices, metric):
"""
Helper function to compute the current error for Optuna
"""
if metric == 'contrast':
return -1*get_objp_contrast(model, indices) # Negative for minimization
elif metric == 'loss':
return model.loss_iters[-1][-1]
else:
raise ValueError(f"Unsupported hypertune error metric: '{metric}'. Expected 'contrast' or 'loss'.")