ptyrad.core.models.ptycho#

Optimizable model of the ptychographic reconstruction using automatic differentiation (AD)

This is the PyTorch model that holds optimizable tensors and interacts with loss and constraints.

Classes

PtychoModel(init_variables, model_params[, ...])

Main optimization class for ptychographic reconstruction using automatic differentiation (AD).

class ptyrad.core.models.ptycho.PtychoModel(init_variables, model_params, device='cuda')[source]#

Bases: Module

Main optimization class for ptychographic reconstruction using automatic differentiation (AD).

This class is responsible for initializing the model parameters, setting up the optimizer, and performing forward passes to compute diffraction patterns based on the given input indices.

device#

Device to run computations on (‘cuda:0’ by default).

Type:

str

detector_blur_std#

Standard deviation for detector blur, or None if no blur.

Type:

float

lr_params#

Learning rate parameters for optimizable tensors.

Type:

dict

opt_obja#

Amplitude of the object.

Type:

torch.Tensor

opt_objp#

Phase of the object.

Type:

torch.Tensor

opt_obj_tilts#

Tilts of the object.

Type:

torch.Tensor

opt_probe#

Probe function.

Type:

torch.Tensor

opt_probe_pos_shifts#

Shifts for the probe positions.

Type:

torch.Tensor

omode_occu#

Occupation mode.

Type:

torch.Tensor

H#

Propagator matrix.

Type:

torch.Tensor

measurements#

Measurements for the ptychographic reconstruction.

Type:

torch.Tensor

N_scan_slow#

Number of scans in the slow direction.

Type:

torch.Tensor

N_scan_fast#

Number of scans in the fast direction.

Type:

torch.Tensor

crop_pos#

Cropping positions.

Type:

torch.Tensor

slice_thickness#

slice thickness (dz) parameter.

Type:

torch.Tensor

dx#

Pixel size in the x direction.

Type:

torch.Tensor

dk#

K-space sampling interval.

Type:

torch.Tensor

scan_affine#

Affine transformation for scan.

Type:

affine.Affine

tilt_obj#

Whether object tilts are being optimized.

Type:

bool

probe_int_sum#

Sum of squared probe intensities.

Type:

torch.Tensor

optimizable_tensors#

Dictionary of tensors that can be optimized.

Type:

dict

Parameters:
  • init_variables (dict) – Dictionary of initial variables required for the model.

  • model_params (dict) – Dictionary of model parameters including learning rates and blur stds.

  • device (str) – Device to run computations on. Default is ‘cuda:0’.

get_complex_probe_view()[source]#

Retrieve complex view of the probe

create_grids()[source]#

Create the grids for shifting probes, selecting obj ROI, and Fresnel propagator in a vectorized approach

create_optimizable_params_dict(lr_params)[source]#

Sets the optimizer with lr_params

init_propagator_vars()[source]#

Initialize propagator related variables

init_compilation_iters()[source]#

Initialize iteration numbers that require torch.compile

print_model_summary()[source]#

Prints a summary of the model’s optimizable variables and statistics.

get_obj_patches(indices)[source]#

Get object patches from specified indices

get_probes(indices)[source]#

Get probes for each position

get_propagators(indices)[source]#

Get propagators for each position

get_propagated_probe(index)[source]#
get_forward_pattern(obja_patches, objp_patches, probes, propagators)[source]#
get_detector_pattern(dp)[source]#
get_measurements(indices=None)[source]#

Get measurements for each position through the data loader

clear_cache()[source]#

Clear temporary attributes like cached object patches.

forward(indices, return_raw=False)[source]#

Doing the forward pass and get an output diffraction pattern for each input index