ptyrad.reconstruction

ptyrad.reconstruction#

Reconstruction and hypertune workflows for ptychographic reconstructions

Classes

IndicesDataset(indices)

The Dataset class used specifically for the multiGPU mode for DDP

PtyRADSolver(params[, device, seed, acc, logger])

A wrapper class to perform ptychographic reconstruction or hyperparameter tuning.

Functions

compute_loss(batch, model, model_instance, ...)

Compute the model output and loss, with optional support for accelerate's autocast.

compute_optuna_error(model, indices, metric)

Helper function to compute the current error for Optuna

create_optimizer(optimizer_params, ...[, ...])

create_optuna_pruner(pruner_params[, verbose])

create_optuna_sampler(sampler_params[, verbose])

get_optuna_suggest(trial, suggest, name, kwargs)

loss_logger(batch_losses, niter, iter_t[, ...])

Logs and summarizes the loss values for an iteration during the ptychographic reconstruction.

make_batches(indices, pos, batch_size[, ...])

Make batches from input indices

optuna_objective(trial, params, init, ...[, ...])

Objective function for Optuna hyperparameter tuning in ptychographic reconstruction.

parse_torch_compile_configs(configs)

Convert user-facing CompilerConfigs to dict suitable for torch.compile

prepare_recon(model, init, params)

Prepares the indices, batches, and output path for ptychographic reconstruction.

recon_loop(model, init, params, optimizer, ...)

Executes the iterative optimization loop for ptychographic reconstruction.

recon_step(batches, grad_accumulation, ...)

Performs one iteration (or step) of the ptychographic reconstruction in the optimization loop.

select_scan_indices(N_scan_slow, N_scan_fast)

toggle_grad_requires(model, niter[, verbose])

Toggle requires_grad based on start and end iteration for each optimizable tensor.