"""
Optimizable model of the ptychographic reconstruction using automatic differentiation (AD)
This is the PyTorch model that holds optimizable tensors and interacts with loss and constraints.
"""
from collections import defaultdict
import logging
import torch
import torch.nn as nn
from torch.fft import fft2, ifft2
from torch.nn.functional import interpolate
from torchvision.transforms.functional import gaussian_blur
from ptyrad.core.forward import multislice_forward
from ptyrad.core.functional import imshift_batch, make_freq_grid_2d, make_real_grid_2d, torch_phasor
from ptyrad.io.dataloader import MeasDataLoader
from ptyrad.utils.image_proc import center_crop
# The obj_ROI_grid is modified from precalculation to on-the-fly generation for memory consumption
# It has very little performance impact but saves lots of memory for large 4D-STEM data
# Added create_grids and print_model_summary for readability and decoupling
# set_optimizer function is called at the end of the initializaiton, while this can also be called if you want to update the optimizer params without initializing the object
# obj optimization is now split into objp and obja
# mixed object modes are normalized by the init_omode_occu. By design this is a fixed value because optimizing omode_occu with obj simultaneously could be unstable
# obj_ROI cropping is done with vectorization and the obj_ROI_grid is only generated once
# probe is always sub-px shifted by opt_probe_pos_shifts so the forward model uses exact positions, not the rounded integer crop_pos
# All the sub-px shifted probes in a batch are processed together with vectorizaiton
# Likewise, the multislice forward model is also fully vectorized across samples (in batch), pmode, and omode
# Note that it's possible to reduce the peak-memory consumption by reducing the level of vectorizaiton and roll back to for loops
# Lastly, the forward pass of this model would output the dp_fwd (N, Ky, Kx) and objp_patches (N, omode, Nz, Ny, Nx) in float32 for later loss calculation
logger = logging.getLogger(__name__)
[docs]
class PtychoModel(torch.nn.Module):
"""
Main optimization class for ptychographic reconstruction using automatic differentiation (AD).
This class is responsible for initializing the model parameters, setting up the optimizer, and
performing forward passes to compute diffraction patterns based on the given input indices.
Attributes:
device (str): Device to run computations on ('cuda:0' by default).
detector_blur_std (float): Standard deviation for detector blur, or None if no blur.
lr_params (dict): Learning rate parameters for optimizable tensors.
opt_obja (torch.Tensor): Amplitude of the object.
opt_objp (torch.Tensor): Phase of the object.
opt_obj_tilts (torch.Tensor): Tilts of the object.
opt_probe (torch.Tensor): Probe function.
opt_probe_pos_shifts (torch.Tensor): Shifts for the probe positions.
omode_occu (torch.Tensor): Occupation mode.
H (torch.Tensor): Propagator matrix.
measurements (torch.Tensor): Measurements for the ptychographic reconstruction.
N_scan_slow (torch.Tensor): Number of scans in the slow direction.
N_scan_fast (torch.Tensor): Number of scans in the fast direction.
crop_pos (torch.Tensor): Cropping positions.
slice_thickness (torch.Tensor): slice thickness (dz) parameter.
dx (torch.Tensor): Pixel size in the x direction.
dk (torch.Tensor): K-space sampling interval.
scan_affine (affine.Affine): Affine transformation for scan.
tilt_obj (bool): Whether object tilts are being optimized.
probe_int_sum (torch.Tensor): Sum of squared probe intensities.
optimizable_tensors (dict): Dictionary of tensors that can be optimized.
Args:
init_variables (dict): Dictionary of initial variables required for the model.
model_params (dict): Dictionary of model parameters including learning rates and blur stds.
device (str): Device to run computations on. Default is 'cuda:0'.
"""
def __init__(self, init_variables, model_params, device='cuda'):
super(PtychoModel, self).__init__()
with torch.no_grad():
logger.info('### Initializing PtychoModel model ###')
# Setup model behaviors
self.device = device
self.detector_blur_std = model_params['detector_blur_std']
self.preload_data = model_params.get('preload_data', True)
self.meas_loader = MeasDataLoader(
init_variables['measurements'],
preload_data=model_params.get('preload_data', True),
device=self.device,
meas_padded=init_variables.get('on_the_fly_meas_padded', None),
meas_padded_idx=init_variables.get('on_the_fly_meas_padded_idx', None),
meas_scale_factors=init_variables.get('on_the_fly_meas_scale_factors', None),
)
# Parse the learning rate and start iter for optimizable tensors
start_iter_dict = {}
end_iter_dict = {}
lr_dict = {}
for key, params in model_params['update_params'].items():
start_iter_dict[key] = params.get('start_iter')
end_iter_dict[key] = params.get('end_iter')
lr_dict[key] = params['lr']
self.optimizer_params = model_params['optimizer_params']
self.scheduler_params = model_params.get('scheduler_params')
self.start_iter = start_iter_dict
self.end_iter = end_iter_dict
self.lr_params = lr_dict
# Optimizable parameters
self.opt_obja = nn.Parameter(torch.abs(torch.tensor(init_variables['obj'], device=device)).to(torch.float32))
self.opt_objp = nn.Parameter(torch.angle(torch.tensor(init_variables['obj'], device=device)).to(torch.float32))
self.opt_obj_tilts = nn.Parameter(torch.tensor(init_variables['obj_tilts'], dtype=torch.float32, device=device))
self.opt_slice_thickness = nn.Parameter(torch.tensor(init_variables['slice_thickness'], dtype=torch.float32, device=device))
self.opt_probe = nn.Parameter(torch.view_as_real(torch.tensor(init_variables['probe'], dtype=torch.complex64, device=device))) # The `torch.view_as_real` allows correct handling of DDP via NCCL even in PyTorch 2.4
self.opt_probe_pos_shifts = nn.Parameter(torch.tensor(init_variables['probe_pos_shifts'], dtype=torch.float32, device=device))
# Buffers are used during forward pass
self.register_buffer ('omode_occu', torch.tensor(init_variables['omode_occu'], dtype=torch.float32, device=device))
self.register_buffer ('H', torch.tensor(init_variables['H'], dtype=torch.complex64, device=device))
self.register_buffer ('N_scan_slow', torch.tensor(init_variables['N_scan_slow'], dtype=torch.int32, device=device))# Saving this for reference, the cropping is based on self.obj_ROI_grid.
self.register_buffer ('N_scan_fast', torch.tensor(init_variables['N_scan_fast'], dtype=torch.int32, device=device))# Saving this for reference, the cropping is based on self.obj_ROI_grid.
self.register_buffer ('crop_pos', torch.tensor(init_variables['crop_pos'], dtype=torch.int32, device=device))# Saving this for reference, the cropping is based on self.obj_ROI_grid.
self.register_buffer ('slice_thickness', torch.tensor(init_variables['slice_thickness'], dtype=torch.float32, device=device))# Saving this for reference
self.register_buffer ('dx', torch.tensor(init_variables['dx'], dtype=torch.float32, device=device))# Saving this for reference
self.register_buffer ('dk', torch.tensor(init_variables['dk'], dtype=torch.float32, device=device))# Saving this for reference
self.register_buffer ('lambd', torch.tensor(init_variables['lambd'], dtype=torch.float32, device=device))
self.random_seed = init_variables['random_seed'] # Saving this for reference
self.length_unit = init_variables['length_unit'] # Saving this for reference
self.scan_affine = init_variables['scan_affine'] # Saving this for reference
self.tilt_obj = bool(self.lr_params['obj_tilts'] != 0 or torch.any(self.opt_obj_tilts)) # Set tilt_obj to True if lr_params['obj_tilts'] is not 0 or we have any none-zero tilt values
self.change_thickness = bool(self.lr_params['slice_thickness'] != 0)
self.meas_Npix = init_variables['meas_Npix']
self.simu_Npix = init_variables['simu_Npix']
self.simu_match_mode = init_variables['simu_match_mode']
self.probe_int_sum = self.get_complex_probe_view().abs().pow(2).sum() # This is only used for the `fix_probe_int`
self.loss_iters = []
self.iter_times = []
self.dz_iters = []
self.avg_tilt_iters = defaultdict(list)
self.lr_iters = defaultdict(list)
self.convergence_iters = defaultdict(list)
self.recon_provenance = init_variables['recon_provenance']
# Create grids for shifting
self.create_grids()
# Create a dictionary to store the optimizable tensors
self.optimizable_tensors = {
'obja' : self.opt_obja,
'objp' : self.opt_objp,
'obj_tilts' : self.opt_obj_tilts,
'slice_thickness' : self.opt_slice_thickness,
'probe' : self.opt_probe,
'probe_pos_shifts': self.opt_probe_pos_shifts}
self.create_optimizable_params_dict(self.lr_params)
# Initialize propagator-related variables
self.init_propagator_vars()
# Initialize iteration numbers that require torch.compile
self.init_compilation_iters()
logger.info('### Done initializing PtychoModel model ###')
logger.info(' ')
[docs]
def get_complex_probe_view(self):
""" Retrieve complex view of the probe """
# This is a post-processing to ensure minimal code changes in PtyRAD for the DDP (multiGPU) via NCCL due to limited support for Complex value
return torch.view_as_complex(self.opt_probe)
[docs]
def create_grids(self):
""" Create the grids for shifting probes, selecting obj ROI, and Fresnel propagator in a vectorized approach """
# Note that the shift_object_grid is pre-generated for potential future usage of sub-px object shifts
device = self.device
Npy, Npx = self.get_complex_probe_view().shape[-2:] # Number of probe pixels in y and x directions
Noy, Nox = self.opt_objp.shape[-2:] # Number of object pixels in y and x directions
# Real space grids: obj_ROI selection, real space grid spans across probe extent [0, Npix)
self.obj_ROI_grid = torch.stack(make_real_grid_2d((Npy, Npx), indexing='ij', dtype=torch.int32, device=device), dim=0) # (2,Npy,Npx)
# Fourier space normalized frequency grids: probe shifting and Fresnel propagation.
# Dimensionless fftfreq in [-0.5, 0.5), DC at corner index 0
self.shift_probes_grid = torch.stack(make_freq_grid_2d((Npy, Npx), indexing='ij', dtype=torch.float32, device=device), dim=0) # (2,Npy,Npx), for sub-px probe shifting via imshift_batch
self.shift_object_grid = torch.stack(make_freq_grid_2d((Noy, Nox), indexing='ij', dtype=torch.float32, device=device), dim=0) # (2,Noy,Nox), for sub-px object shifting (implemented for completeness, not used in PtyRAD)
self.propagator_grid = self.shift_probes_grid * (2 * torch.pi / self.dx) # (2,Npy,Npx), for Fresnel propagation, scale normalized fftfreq [-0.5, 0.5) → physical k-space [rad/Å]
[docs]
def create_optimizable_params_dict(self, lr_params):
""" Sets the optimizer with lr_params """
# # Use this to edit learning rate if needed some refinement
# model.set_optimizer(lr_params={'obja' : 5e-4,
# 'objp' : 5e-4,
# 'obj_tilts' : 1e-4,
# 'probe' : 1e-4,
# 'probe_pos_shifts': 1e-4})
# optimizer=torch.optim.Adam(model.optimizer_params)
self.lr_params = lr_params
self.optimizable_params = []
for param_name, lr in lr_params.items():
if param_name not in self.optimizable_tensors:
raise ValueError(f"WARNING: '{param_name}' is not a valid parameter name, check your `update_params` and choose from 'obja', 'objp', 'obj_tilts', 'slice_thickness', 'probe', and 'probe_pos_shifts'")
else:
self.optimizable_tensors[param_name].requires_grad = (lr != 0) and (self.start_iter[param_name] ==1) # Set requires_grad based on learning rate and start_iter
if lr != 0:
self.optimizable_params.append({'params': [self.optimizable_tensors[param_name]], 'lr': lr})
self.print_model_summary()
[docs]
def init_propagator_vars(self):
""" Initialize propagator related variables """
# Initialize propagator for fixed non-zero tilts and fixed thickness that could be position dependent
# It's better to calculate the full one during initialization and slice it later given indices so we can use torch.compile later
dz = self.opt_slice_thickness.detach()
Ky, Kx = self.propagator_grid
tilts_y_full = self.opt_obj_tilts[:,0,None,None] / 1e3 # mrad, tilts_y = (N,Y,X)
tilts_x_full = self.opt_obj_tilts[:,1,None,None] / 1e3
self.H_fixed_tilts_full = self.H * torch_phasor(dz * (Ky * torch.tan(tilts_y_full) + Kx * torch.tan(tilts_x_full))) # (1,Y,X) or (N,Y,X)
# Initialize other relevant variables
self.k = 2 * torch.pi / self.lambd
self.Kz = torch.sqrt(torch.clamp(self.k ** 2 - Kx ** 2 - Ky ** 2, min=0.0)) # (Ny,Nx), real-valued kz in rad/Å.
# Clamping is just defensive programming because later torch_phasor requires a real-valued phase.
# Evanescent modes (kx²+ky²>k²) are practically impossible since lambda << dx for all practical usage.
[docs]
def init_compilation_iters(self):
""" Initialize iteration numbers that require torch.compile """
compilation_iters = {1} # Always compile at first iteration
for param_name in self.optimizable_tensors.keys():
start_iter = self.start_iter.get(param_name)
end_iter = self.end_iter.get(param_name)
# Add start_iter compilation points
if start_iter is not None and start_iter >= 1:
compilation_iters.add(start_iter)
# Add end_iter compilation points
if end_iter is not None and end_iter >= 1:
# Compile at end_iter to handle the transition. end_iter is exclusive for grad calculation.
compilation_iters.add(end_iter)
# Store as sorted list
self.compilation_iters = sorted(compilation_iters)
[docs]
def print_model_summary(self):
""" Prints a summary of the model's optimizable variables and statistics. """
logger.info('### PtychoModel optimizable variables ###')
for name, tensor in self.optimizable_tensors.items():
logger.info(f"{name.ljust(16)}: {str(tensor.shape).ljust(32)}, {str(tensor.dtype).ljust(16)}, device:{tensor.device}, grad:{str(tensor.requires_grad).ljust(5)}, lr:{self.lr_params[name]:.0e}")
total_var = sum(tensor.numel() for _, tensor in self.optimizable_tensors.items() if tensor.requires_grad)
# When you create a new model, make sure to pass the optimizer_params to optimizer using "optimizer = torch.optim.Adam(model.optimizer_params)"
logger.info(" ")
logger.info('### Optimizable variables statitsics ###')
logger.info(f"Total measurement values : {self.meas_loader.meas_arr.size:,d}")
logger.info(f"Total optimizing variables: {total_var:,d}")
logger.info(f"Overdetermined ratio : {self.meas_loader.meas_arr.size/total_var:.2f}")
logger.info(" ")
logger.info('### Model behavior ###')
logger.info(f"Tilt propagator : {self.tilt_obj}")
logger.info(f"Change slice thickness : {self.change_thickness}")
logger.info(f"Detector blur : {True if self.detector_blur_std is not None else False}")
logger.info(f"Preload data : {self.preload_data}")
logger.info(f"On-the-fly meas padding : {True if self.meas_loader.meas_padded is not None else False}")
logger.info(f"On-the-fly meas resample : {True if self.meas_loader.meas_scale_factors is not None else False}")
logger.info(f"On-the-fly simu match mode: {self.simu_match_mode}")
logger.info(" ")
[docs]
def get_obj_patches(self, indices):
"""
Get object patches from specified indices
"""
# obja_patches = (N,B,D,H,W), N is the additional sample index within the input batch, B is now used for omode.
# rpy_grid is the obj_ROI_grid[0], i.e., y-grid (Ny,Nx), by adding the y coordinates from init_crop_pos (N,1) in a broadcast way, it becomes (N,Ny,Nx)
# obj_ROI_grid_y = (N,Ny,Nx)
obj_ROI_grid_y = self.obj_ROI_grid[0, None, :, :] + self.crop_pos[indices, None, None, 0]
obj_ROI_grid_x = self.obj_ROI_grid[1, None, :, :] + self.crop_pos[indices, None, None, 1]
obja_patches = self.opt_obja[:,:,obj_ROI_grid_y,obj_ROI_grid_x].permute(2,0,1,3,4).contiguous()
objp_patches = self.opt_objp[:,:,obj_ROI_grid_y,obj_ROI_grid_x].permute(2,0,1,3,4).contiguous()
return obja_patches, objp_patches
[docs]
def get_probes(self, indices):
""" Get probes for each position """
# Always apply sub-px shifts from opt_probe_pos_shifts so the forward model uses the exact
# position (crop_pos is integer; opt_probe_pos_shifts carries the fractional residual).
probe = self.get_complex_probe_view()
probes = imshift_batch(probe, shifts=self.opt_probe_pos_shifts[indices], grid=self.shift_probes_grid)
return probes.contiguous()
[docs]
def get_propagators(self, indices):
""" Get propagators for each position """
# self.tilt_obj is True as long as we're optimizing the opt_obj_tilts or we have non-zero initial tilt values
# This function will return a single propagator (H) if self.opt_obj_tilts has shape = (1,2) (single tilt_y, tilt_x)
# If self.opt_obj_tilts has shape = (N,2), it'll return multiple propagtors stacked at axis 0 (N,Y,X)
# Note that 0 tilts is numerically equivalent to the H and can be verified by "torch.allclose(model.H, model.get_propagators([0]))"
# The exp(2pi * i * sqrt(k^2 - kx^2 - ky^2)) approach is equivalent to the common exp(-i * pi * lambda * dz * k^2) for small angles,
# see J. Goodman, Introduction to Fourier Optics (McGraw-Hill, 1968) (PDF page 88, eqn 4-20, 4-21 as attached).
# https://www.hlevkin.com/hlevkin/90MathPhysBioBooks/Physics/Physics/Mix/Introduction%20to%20Fourier%20Optics.pdf
# Note that torch.exp(1j*phase_shift) is not compatible with torch.compile because the 1j is a Python built-in and not a tensor,
# so I've replaced them with torch.polar(torch.ones_like(phase), phase), which is wrapped as a util function `torch_phasor(phase)`
# Setup boolean flags
tilt_obj = self.tilt_obj # Whether we need to apply tilt to the Fresnel propagator
global_tilt = (self.opt_obj_tilts.shape[0] == 1) # 'tilt_type' = 'all' or 'each'
change_tilt = (self.lr_params['obj_tilts'] != 0) # Whether tilts are optimized or not
change_thickness = self.change_thickness # Whether thickness is optimized or not
# Setup tilts and other variables
dz = self.opt_slice_thickness
Kz = self.Kz # kz = torch.sqrt(k ** 2 - Kx ** 2 - Ky ** 2), k = 2pi/lambda.
Ky, Kx = self.propagator_grid
# tilts can be either (1,2) or (N,2) depends on global_tilt flag
if global_tilt:
tilts = self.opt_obj_tilts
else:
tilts = self.opt_obj_tilts[indices]
tilts_y = tilts[:,0,None,None] / 1e3 # mrad, tilts_y = (N,Y,X)
tilts_x = tilts[:,1,None,None] / 1e3
if tilt_obj and change_thickness:
# Case 1: Tilts are either non-zero or optimizing, while thickness is optimizing
H_opt_dz = torch_phasor(dz * Kz) # H has zero frequency at the corner in k-space
propagators = H_opt_dz * torch_phasor(dz * (Ky * torch.tan(tilts_y) + Kx * torch.tan(tilts_x)))
elif tilt_obj and not change_thickness:
if change_tilt:
# Case 2A: Tilts are optimizing, while thickness is fixed
propagators = self.H * torch_phasor(dz * (Ky * torch.tan(tilts_y) + Kx * torch.tan(tilts_x)))
else:
# Case 2B: Tilts are fixed non-zero values (1,2) or (N,2), while thickness is fixed
# Propagator is pre-calculated during init_propagator_vars
propagators = self.H_fixed_tilts_full if global_tilt else self.H_fixed_tilts_full[indices]
elif not tilt_obj and change_thickness:
# Case 3: Tilt is fixed at 0 and thickness is optimizing
H_opt_dz = torch_phasor(dz * Kz)
propagators = H_opt_dz[None,]
else:
# Case 4: Tilt is fixed at 0 and thickness is fixed
propagators = self.H[None,]
return propagators.contiguous()
[docs]
def get_propagated_probe(self, index):
probe = self.get_probes(index)[0].detach() # (pmode, Ny, Nx), just grab the probe at 1st index
H = self.get_propagators(index)[[0]].detach() # (1, Ny, Nx) or (N, Ny, Nx) depends on tilt_type ('all' or 'each'), so we need to grab the 1st index without reducing dimension
n_slices = self.opt_objp.shape[1]
probe_prop = torch.zeros((n_slices, *probe.shape), dtype=probe.dtype, device=probe.device)
psi = probe # (z, pmode, Ny, Nx)
for n in range(n_slices):
probe_prop[n] = psi
psi = ifft2(H[None,] * fft2(psi))
return probe_prop
[docs]
def get_forward_pattern(self, obja_patches, objp_patches, probes, propagators):
# Returns the diffraction pattern at simulation precision before being corrupted by the detector
# In the future we may add a switch for different forward methods
dp_fwd = multislice_forward(obja_patches, objp_patches, probes, propagators, omode_occu=self.omode_occu)
return dp_fwd
[docs]
def get_detector_pattern(self, dp):
# Returns the 'measured' diffraction pattern after being corrupted by the projection / detection related transformations,
# like shifting, elliptical distortion, cropping / downsampling, and PSF
# Match the geometry (sampling and extent) of simulated pattern with measured pattern
if self.simu_match_mode == 'crop':
dp = center_crop(dp, crop_height=self.meas_Npix, crop_width=self.meas_Npix)
if self.simu_match_mode == 'resample':
scale_factor = (self.meas_Npix / self.simu_Npix ) # Used to rescale to keep meas int the same
dp = interpolate(dp.unsqueeze(1), size=[self.meas_Npix, self.meas_Npix], mode='bilinear').squeeze(1) / scale_factor**2 # interpolate takes (N,C,H,W) as input and asks for [H', W'] for sizes
# Note that detector blur is always applied at the same sampling with actual measured data
if self.detector_blur_std is not None and self.detector_blur_std != 0:
kernel_size = max(5, 2*round(3*self.detector_blur_std)+1) # Kernel size would have minimum 5, and scale with 6sigma+1
dp = gaussian_blur(dp, kernel_size=kernel_size, sigma=self.detector_blur_std)
return dp
[docs]
@torch.compiler.disable
def get_measurements(self, indices=None):
""" Get measurements for each position through the data loader"""
# Return the selected measurements based on input indices
# If no indices are passed, return the original numpy arr of measurements
if indices is not None:
return self.meas_loader[indices]
else:
return self.meas_loader.meas_arr
[docs]
def clear_cache(self):
"""Clear temporary attributes like cached object patches."""
self._current_object_patches = None
[docs]
def forward(self, indices, return_raw=False):
""" Doing the forward pass and get an output diffraction pattern for each input index """
# The indices are passed in as an array and representing the whole batch
obja_patches, objp_patches = self.get_obj_patches(indices)
probes = self.get_probes(indices)
propagators = self.get_propagators(indices)
dp_fwd = self.get_forward_pattern(obja_patches, objp_patches, probes, propagators)
dp_det = self.get_detector_pattern(dp_fwd)
# Keep the object_patches for later object-specific loss
self._current_object_patches = (obja_patches, objp_patches)
if return_raw:
return dp_fwd
else:
return dp_det