Source code for ptyrad.core.constraints

"""
Physical constraints that directly modify optimizable tensors with specified intervals of iterations

"""

import logging

import torch
from torch.fft import fft2, fftfreq, fftn, ifft2, ifftn
from torch.nn.functional import interpolate
from torchvision.transforms.functional import gaussian_blur

from ptyrad.core.functional import (
    approx_torch_quantile,
    dct_2d,
    fftshift2,
    find_probe_focus_dz,
    gaussian_blur_1d,
    idct_2d,
    ifftshift2,
    make_sigmoid_mask,
    make_super_gaussian_mask,
    near_field_evolution_torch,
)

logger = logging.getLogger(__name__)

[docs] @torch.compiler.disable # Nearly no benefit to compile the iter-wise constraint as it's negligible comparing to forward/backward pass class CombinedConstraint(torch.nn.Module): """Applies iteration-wise in-place constraints on optimizable tensors. This class is designed to apply various constraints to a model's parameters during the optimization process. The constraints are applied at specific iteration frequencies, as determined by the `constraint_params` dictionary. These constraints include orthogonality, probe amplitude constraints in Fourier space, intensity constraints, Gaussian blurring, Fourier filtering, and more. Args: constraint_params (dict): A dictionary containing the configuration for each constraint. Each constraint should have a frequency and other parameters necessary for its application. device (str, optional): The device on which the tensors are located (e.g., 'cuda' or 'cpu'). Defaults to 'cuda'. """ def __init__(self, constraint_params, device='cuda'): super(CombinedConstraint, self).__init__() self.device = device self.constraint_params = constraint_params def _should_apply_at_iter(self, constraint_name, niter): """Check if the constraint should be applied at the current iteration.""" start = self.constraint_params[constraint_name]['start_iter'] step = self.constraint_params[constraint_name]['step'] end = self.constraint_params[constraint_name]['end_iter'] if start is None: return False if niter < start: return False if end is not None and niter >= end: return False return (niter - start) % step == 0
[docs] def apply_probe_mask_k(self, model, niter): ''' Apply probe amplitude constraint in Fourier space ''' # Note that this will change the total probe intensity, please use this with `fix_probe_int` # Although the mask wouldn't change during the iteration, making a mask takes only ~0.5us on CPU so really no need to pre-calculate it # The sandwitch fftshift(fft(ifftshift(probe))) is needed to properly handle the complex probe without serrated phase # fft2 is for real->fourier, while fftshift2 is for corner->center if self._should_apply_at_iter('probe_mask_k', niter): relative_radius = self.constraint_params['probe_mask_k']['radius'] relative_width = self.constraint_params['probe_mask_k']['width'] power_thresh = self.constraint_params['probe_mask_k']['power_thresh'] probe = model.get_complex_probe_view() Npix = probe.size(-1) powers = probe.abs().pow(2).sum((-2,-1)) / probe.abs().pow(2).sum() powers_cumsum = powers.cumsum(0) pmode_index = (powers_cumsum > power_thresh).nonzero()[0].item() # This gives the pmode index that the cumulative power along mode dimension is greater than the power_thresh and should have mask extend to this index mask = torch.ones_like(probe, dtype=torch.float32, device=model.device) mask_value = make_sigmoid_mask(Npix, relative_radius, relative_width).to(model.device) mask[:pmode_index+1] = mask_value probe_k = fftshift2 (fft2(ifftshift2(probe), norm='ortho')) # probe_k at center for later masking probe_r = fftshift2(ifft2(ifftshift2(mask * probe_k), norm='ortho')) # probe_r at center. Note that the norm='ortho' is explicitly specified but not needed for a round-trip # Re-sort the probe modes, note that the masked strong modes might be swapping order with unmasked weak modes model.opt_probe.copy_(torch.view_as_real(sort_by_mode_int(probe_r))) probe_int = model.get_complex_probe_view().abs().pow(2) logger.debug(f"Apply Fourier-space probe amplitude constraint at iter {niter}, pmode_index = {pmode_index} when power_thresh = {power_thresh}. Current probe int sum = {probe_int.sum():.4f}")
[docs] def apply_probe_mask_r(self, model, niter): ''' Apply probe amplitude constraint in Real space ''' # Note that this will change the total probe intensity, please use this with `fix_probe_int` # Although the mask wouldn't change during the iteration, making a mask takes only ~0.5us on CPU so really no need to pre-calculate it # The sandwitch fftshift(fft(ifftshift(probe))) is needed to properly handle the complex probe without serrated phase # fft2 is for real->fourier, while fftshift2 is for corner->center if self._should_apply_at_iter('probe_mask_r', niter): params = self.constraint_params['probe_mask_r'] relative_radius = params.get('radius', 0.95) order = params.get('order', 6) # Using order for Super-Gaussian power_thresh = params.get('power_thresh', 0.95) z_range = params.get('z_range', [-500, 500]) z_steps = params.get('z_steps', 101) probe = model.get_complex_probe_view() probe = sort_by_mode_int(probe) Npix = probe.size(-1) # 1. Determine which modes to mask (based on power threshold) powers = probe.abs().pow(2).sum((-2,-1)) / probe.abs().pow(2).sum() powers_cumsum = powers.cumsum(0) pmode_index = (powers_cumsum > power_thresh).nonzero()[0].item() # This gives the pmode index that the cumulative power along mode dimension is greater than the power_thresh and should have mask extend to this index # 2. Make mask mask = torch.ones_like(probe, dtype=torch.float32, device=model.device) mask_value = make_super_gaussian_mask(Npix, relative_radius, order=order, device=model.device) mask[:pmode_index+1] = mask_value # 3. Find the focal plane and get the forward and backward propagator, note that forward means "to focal plane", and backward is "reset" best_dz = find_probe_focus_dz(probe[0], model.dx, model.lambd, z_range, z_steps) # Feed the strongest pmode H_fwd = near_field_evolution_torch(probe.shape[-2:], model.dx, best_dz, model.lambd, dtype=probe.dtype, device=model.device) H_bwd = torch.conj(H_fwd) # 4. Roll the probe to focal plane, mask, and then roll back probe_focused = ifft2(fft2(probe) * H_fwd) probe_masked = mask * probe_focused probe_final = ifft2(fft2(probe_masked) * H_bwd) # 5. Update the probe and report debug info # Re-sort the probe modes, note that the masked strong modes might be swapping order with unmasked weak modes model.opt_probe.copy_(torch.view_as_real(sort_by_mode_int(probe_final))) probe_int = model.get_complex_probe_view().abs().pow(2) logger.debug(f"Apply Real-space probe amplitude constraint at iter {niter}, " \ f"pmode_index = {pmode_index} when power_thresh = {power_thresh}. " \ f"Focal plane found at {best_dz} Ang with z_range = {z_range}. " \ f"Mask with raidus = {relative_radius} and order = {order}. " \ f"Current probe int sum = {probe_int.sum():.4f}")
[docs] def apply_ortho_pmode(self, model, niter): ''' Apply orthogonality constraint to probe modes ''' if self._should_apply_at_iter('ortho_pmode', niter): model.opt_probe.copy_(torch.view_as_real(orthogonalize_modes_vec(model.get_complex_probe_view(), sort=True))) # Note that model stores the complex probe as a (pmode, Ny, Nx, 2) float tensor (real view) so we need to do some real-complex view conversion. probe_int = model.get_complex_probe_view().abs().pow(2) probe_pow = (probe_int.sum((1,2))/probe_int.sum()).detach().cpu().numpy().round(3) logger.debug(f"Apply ortho pmode constraint at iter {niter}, relative pmode power = {probe_pow}, probe int sum = {probe_int.sum():.4f}")
[docs] def apply_fix_probe_int(self, model, niter): ''' Apply probe intensity constraint ''' # Note that the probe intensity fluctuation (std/mean) is typically only 0.5%, there's very little point to do a position-dependent probe intensity constraint # Therefore, a mean probe intensity is used here as the target intensity if self._should_apply_at_iter('fix_probe_int', niter): probe = model.get_complex_probe_view() current_amp = probe.abs().pow(2).sum().pow(0.5) target_amp = model.probe_int_sum**0.5 model.opt_probe.copy_(torch.view_as_real(probe * target_amp/current_amp)) probe_int = model.get_complex_probe_view().abs().pow(2) logger.debug(f"Apply fix probe int constraint at iter {niter}, probe int sum = {probe_int.sum():.4f}")
[docs] def apply_obj_rblur(self, model, niter): ''' Apply Gaussian blur to object, this only applies to the last 2 dimension (...,H,W) ''' # Note that it's not clear whether applying blurring after every iteration would ever reach a steady state # However, this is at least similar to PtychoShelves' eng. reg_mu if self._should_apply_at_iter('obj_rblur', niter) and self.constraint_params['obj_rblur']['std'] !=0: obj_type = self.constraint_params['obj_rblur']['obj_type'] obj_rblur_ks = self.constraint_params['obj_rblur']['kernel_size'] obj_rblur_std = self.constraint_params['obj_rblur']['std'] if obj_type in ['amplitude', 'both']: model.opt_obja.copy_(gaussian_blur(model.opt_obja, kernel_size=obj_rblur_ks, sigma=obj_rblur_std)) logger.debug(f"Apply lateral (y,x) Gaussian blur with std = {obj_rblur_std} px on obja at iter {niter}") if obj_type in ['phase', 'both']: model.opt_objp.copy_(gaussian_blur(model.opt_objp, kernel_size=obj_rblur_ks, sigma=obj_rblur_std)) logger.debug(f"Apply lateral (y,x) Gaussian blur with std = {obj_rblur_std} px on objp at iter {niter}")
[docs] def apply_obj_zblur(self, model, niter): ''' Apply Gaussian blur to object along the z-axis (slice dimension) ''' if self._should_apply_at_iter('obj_zblur', niter) and self.constraint_params['obj_zblur']['std'] !=0: obj_type = self.constraint_params['obj_zblur']['obj_type'] obj_zblur_ks = self.constraint_params['obj_zblur']['kernel_size'] obj_zblur_std = self.constraint_params['obj_zblur']['std'] if obj_type in ['amplitude', 'both']: model.opt_obja.copy_(gaussian_blur_1d(model.opt_obja, kernel_size=obj_zblur_ks, sigma=obj_zblur_std)) logger.debug(f"Apply z-direction Gaussian blur with std = {obj_zblur_std} px on obja at iter {niter}") if obj_type in ['phase', 'both']: model.opt_objp.copy_(gaussian_blur_1d(model.opt_objp, kernel_size=obj_zblur_ks, sigma=obj_zblur_std)) logger.debug(f"Apply z-direction Gaussian blur with std = {obj_zblur_std} px on objp at iter {niter}")
[docs] def apply_kr_filter(self, model, niter): ''' Apply kr Fourier filter constraint on object ''' # Note that the `kr_filter` is applied on stacked 2D FFT of object, so it's applying on (omode,z,ky,kx) # The kr filter is similar to a top-hat, so it's more like a cut-off, instead of the weak lateral Gaussian blurring (alpha) included in the `kz_filter` if self._should_apply_at_iter('kr_filter', niter): obj_type = self.constraint_params['kr_filter']['obj_type'] relative_radius = self.constraint_params['kr_filter']['radius'] relative_width = self.constraint_params['kr_filter']['width'] if obj_type in ['amplitude', 'both']: model.opt_obja.copy_(kr_filter(model.opt_obja, relative_radius, relative_width)) logger.debug(f"Apply kr_filter constraint with kr_radius = {relative_radius} on obja at iter {niter}") if obj_type in ['phase', 'both']: model.opt_objp.copy_(kr_filter(model.opt_objp, relative_radius, relative_width)) logger.debug(f"Apply kr_filter constraint with kr_radius = {relative_radius} on objp at iter {niter}")
[docs] def apply_kz_filter(self, model, niter): ''' Apply kz Fourier filter constraint on object ''' # Note that the `kz_filter`` behaves differently for 'amplitude' and 'phase', see `kz_filter` implementaion for details if self._should_apply_at_iter('kz_filter', niter): obj_type = self.constraint_params['kz_filter']['obj_type'] beta_regularize_layers = self.constraint_params['kz_filter']['beta'] alpha_gaussian = self.constraint_params['kz_filter']['alpha'] if obj_type in ['amplitude', 'both']: model.opt_obja.copy_(kz_filter(model.opt_obja, beta_regularize_layers, alpha_gaussian, obj_type='amplitude')) logger.debug(f"Apply kz_filter constraint with beta = {beta_regularize_layers} on obja at iter {niter}") if obj_type in ['phase', 'both']: model.opt_objp.copy_(kz_filter(model.opt_objp, beta_regularize_layers, alpha_gaussian, obj_type='phase')) logger.debug(f"Apply kz_filter constraint with beta = {beta_regularize_layers} on objp at iter {niter}")
[docs] def apply_kr_thresh(self, model, niter): ''' Apply kr threshold constraint on object ''' if self._should_apply_at_iter('kr_thresh', niter): obj_type = self.constraint_params['kr_thresh']['obj_type'] thresh = self.constraint_params['kr_thresh']['thresh'] if obj_type in ['amplitude', 'both']: model.opt_obja.copy_(dct_threshold_filter(model.opt_obja, threshold_ratio=thresh)) logger.debug(f"Apply kr_filter constraint with threshold = {thresh} (ratio of spatial frequencies to keep) on obja at iter {niter}") if obj_type in ['phase', 'both']: model.opt_objp.copy_(dct_threshold_filter(model.opt_objp, threshold_ratio=thresh)) logger.debug(f"Apply kr_filter constraint with threshold = {thresh} (ratio of spatial frequencies to keep) on objp at iter {niter}")
[docs] def apply_complex_ratio(self, model, niter): ''' Apply complex constraint on object ''' # Original paper seems to apply this constraint at each position. I'll try an iteration-wise constraint first if self._should_apply_at_iter('complex_ratio', niter): obj_type = self.constraint_params['complex_ratio']['obj_type'] alpha1 = self.constraint_params['complex_ratio']['alpha1'] alpha2 = self.constraint_params['complex_ratio']['alpha2'] objac, objpc, Cbar = complex_ratio_constraint(model, alpha1, alpha2) if obj_type in ['amplitude', 'both']: model.opt_obja.copy_(objac) amin, amax = model.opt_obja.min().item(), model.opt_obja.max().item() logger.debug(f"Apply complex ratio constraint with alpha1: {alpha1}, alpha2: {alpha2}, and Cbar: {Cbar.item():.3f} on obja at iter {niter}. obja range becomes ({amin:.3f}, {amax:.3f})") if obj_type in ['phase', 'both']: model.opt_objp.copy_(objpc) pmin, pmax = model.opt_objp.min().item(), model.opt_objp.max().item() logger.debug(f"Apply complex ratio constraint with alpha1: {alpha1}, alpha2: {alpha2}, and Cbar: {Cbar.item():.3f} on objp at iter {niter}. objp range becomes ({pmin:.3f}, {pmax:.3f})")
[docs] def apply_mirrored_amp(self, model, niter): '''Apply mirrored amplitude constraint on obja at voxel level''' # The idea is to replace the amplitude with Amp' = exp(-scale*phase^2), because the absorptive potential should scale with V^2 if self._should_apply_at_iter('mirrored_amp', niter): relax = self.constraint_params['mirrored_amp']['relax'] scale = self.constraint_params['mirrored_amp']['scale'] power = self.constraint_params['mirrored_amp']['power'] v_power = model.opt_objp.clamp(min=0).pow(power) # amp_new = torch.exp(-scale*v_power) amp_new = 1-scale*v_power model.opt_obja.copy_(relax * model.opt_obja + (1-relax) * amp_new) amin, amax = model.opt_obja.min().item(), model.opt_obja.max().item() relax_str = f'relaxed ({relax}*obj + ({1-relax}*obj_new))' if relax != 0 else 'hard' logger.debug(f"Apply {relax_str} mirrored amplitude constraint with scale = {scale} and power = {power} on obja at iter {niter}. obja range becomes ({amin:.3f}, {amax:.3f})")
[docs] def apply_obj_z_recenter(self, model, niter): '''Apply object z-recentering along depth dimension ''' # The idea is to recenter the object within the object tensor using CoM along depth # Because fundamentally there's an ambiguity between probe defocus and object z positioning, # so we'll have to shift the object and adjust the probe accordingly. # I thought I saw this idea in other packages but I couldn't seem to find it anymore...... if self._should_apply_at_iter('obj_z_recenter', niter): threshold = self.constraint_params['obj_z_recenter'].get('thresh', 90) scale = self.constraint_params['obj_z_recenter'].get('scale', 0.5) max_shift = self.constraint_params['obj_z_recenter'].get('max_shift', 5) unit_str = model.length_unit dz = model.opt_slice_thickness.detach().item() probe = model.get_complex_probe_view() dx = model.dx lambd = model.lambd obja = model.opt_obja objp = model.opt_objp objc = torch.polar(obja, objp) # Shift the obj along z z_shift = get_obj_z_shift(objp, threshold, scale, max_shift) # unit: px objc = shift_obj_along_z(objc, z_shift) # Update model obj model.opt_obja.copy_(torch.abs(objc)) model.opt_objp.copy_(torch.angle(objc)) # Update model probe H = near_field_evolution_torch(probe.shape[-2:], dx, -z_shift*dz, lambd, device=model.device) # If the object is shifted along +z, then the probe should be back-propagated along -z model.opt_probe.copy_(torch.view_as_real(ifft2(H[None,] * fft2(probe)))) logger.debug(f"Apply obj z-recenter constraint. Complex object and probe defocus are shifted by {z_shift:.3f} slice ({(z_shift*dz):.3f} {unit_str}) along depth dimension. threshold = {threshold}, scale = {scale}, and max_shift = {max_shift}.")
[docs] def apply_obja_thresh(self, model, niter): ''' Apply thresholding on obja at voxel level ''' # Although there's a lot of code repitition with `apply_postiv`, phase positivity itself is important enough as an individual operation if self._should_apply_at_iter('obja_thresh', niter): relax = self.constraint_params['obja_thresh']['relax'] thresh = self.constraint_params['obja_thresh']['thresh'] model.opt_obja.copy_(relax * model.opt_obja + (1-relax) * model.opt_obja.clamp(min=thresh[0], max=thresh[1])) relax_str = f'relaxed ({relax}*obj + ({1-relax}*obj_clamp))' if relax != 0 else 'hard' logger.debug(f"Apply {relax_str} threshold constraint with thresh = {thresh} on obja at iter {niter}")
[docs] def apply_objp_postiv(self, model, niter): ''' Apply positivity constraint on objp at voxel level ''' # Note that this `relax` is defined oppositly to PtychoShelves's `positivity_constraint_object` in `ptycho_solver`. # Here, relax=1 means fully relaxed and essentially no constraint. if self._should_apply_at_iter('objp_postiv', niter): relax = self.constraint_params['objp_postiv']['relax'] mode = self.constraint_params['objp_postiv'].get('mode', 'clip_neg') original_min = model.opt_objp.min() if mode == 'subtract_min': modified_objp = model.opt_objp - original_min else: # 'clip_neg' modified_objp = model.opt_objp.clamp(min=0) model.opt_objp.copy_(relax * model.opt_objp + (1-relax) * modified_objp) omin, omax = model.opt_objp.min().item(), model.opt_objp.max().item() relax_str = f'relaxed ({relax}*obj + ({1-relax}*obj_postiv))' if relax != 0 else 'hard' logger.debug(f"Apply {relax_str} positivity constraint on objp with '{mode}' mode at iter {niter}. Original min = {original_min.item():.3f}. objp range becomes ({omin:.3f}, {omax:.3f})")
[docs] def apply_pos_recenter(self, model, niter): ''' Apply position recentering constraint on probe positions ''' # Here, relax=1 means fully relaxed and essentially no constraint. # Usually we start from reasonable probe and positions, so the object would get reconstructed in place instantly. # This makes object, probe, and probe positions remain relatively aligned so this constraint isn't critically needed. # For certain use cases (i.e., large position learning rates or misplaced object during initialization), the probe positions might accumulate large global offsets. # Then we can use this constraint to remove the global offset, which will soon recenter the object as well given that probe CoM is relatively stable. # TODO: Technically we can shift the object accordingly to accelerate the convergence but it doesn't seem to be critically needed at the moment if self._should_apply_at_iter('pos_recenter', niter): relax = self.constraint_params['pos_recenter']['relax'] pos_shifts = model.opt_probe_pos_shifts # float32 orig_mean = pos_shifts.mean(0) model.opt_probe_pos_shifts.copy_(pos_shifts - (1 - relax) * orig_mean) relax_str = f'relaxed (pos_shifts - ({1-relax:.3f}*original_mean))' if relax != 0 else 'hard' logger.debug(f"Apply {relax_str} position recentering constraint at iter {niter}. Original mean = {orig_mean.detach().cpu().numpy().round(3)}. probe_pos_shifts.mean(0) becomes {model.opt_probe_pos_shifts.mean(0).detach().cpu().numpy().round(3)}")
[docs] def apply_tilt_smooth(self, model, niter): ''' Apply Gaussian blur to object tilts ''' # Note that the smoothing is applied along the last 2 axes, which are scan dimensions, so the unit of std is "scan positions" # Besides, the relative position of the obj_tilts are neglected for simplicity if self._should_apply_at_iter('tilt_smooth', niter) and self.constraint_params['tilt_smooth']['std'] !=0: tilt_smooth_std = self.constraint_params['tilt_smooth']['std'] N_scan_slow = model.N_scan_slow N_scan_fast = model.N_scan_fast if model.opt_obj_tilts.shape[0] == 1: # obj_tilts.shape = (1,2) for tilt_type: 'all', and (N,2) for 'each' logger.debug("`tilt_smooth` constraint requires `tilt_type':'each'`, skip this constraint") return obj_tilts = (model.opt_obj_tilts.reshape(N_scan_slow, N_scan_fast, 2)).permute(2,0,1) model.opt_obj_tilts.copy_(gaussian_blur(obj_tilts, kernel_size=5, sigma=tilt_smooth_std).permute(1,2,0).reshape(-1,2)) logger.debug(f"Apply Gaussian blur with std = {tilt_smooth_std} scan positions on obj_tilts at iter {niter}")
[docs] def forward(self, model, niter): """ Applies constraints to the optimizable `model` parameters if `niter` satisfies the pre-determined conditions (start_iter, step, end_iter) """ # Note that the if check blocks are included in each apply methods so that it's cleaner, and I can print the info with niter with torch.no_grad(): # Probe constraints self.apply_probe_mask_k (model, niter) self.apply_probe_mask_r (model, niter) self.apply_ortho_pmode (model, niter) self.apply_fix_probe_int (model, niter) # Object constraints self.apply_obj_rblur (model, niter) self.apply_obj_zblur (model, niter) self.apply_kr_filter (model, niter) self.apply_kz_filter (model, niter) self.apply_kr_thresh (model, niter) self.apply_complex_ratio (model, niter) self.apply_mirrored_amp (model, niter) self.apply_obj_z_recenter(model, niter) self.apply_obja_thresh (model, niter) self.apply_objp_postiv (model, niter) # Position constraints self.apply_pos_recenter (model, niter) # Local tilt constraint self.apply_tilt_smooth (model, niter)
###### Filter and helper functions for constraints ######
[docs] def sort_by_mode_int(modes): modes_int = modes.abs().pow(2).sum(tuple(range(1,modes.ndim))) # Sum every but 1st dimension _, indices = torch.sort(modes_int, descending=True) modes = modes[indices] return modes
[docs] def orthogonalize_modes_vec(modes, sort=False): ''' Orthogonalize probe modes via Gram matrix eigendecomposition. Uses eigh (Hermitian eigensolver) for cross-platform numerical stability. eigh dispatches to cheev/dsyev on all LAPACK backends (Accelerate on macOS, MKL on Windows/Linux), unlike eig which used the numerically weaker cgeev. A is upcasted to complex128 for the small (Nmode x Nmode) decomposition to further guard against float32 precision loss on any backend. Falls back silently to the original modes if the result is invalid. Note: - MPS does not support complex128 or eigh; the entire double-precision computation is moved to CPU, which routes through the stable Hermitian LAPACK path on macOS. Results are cast back to the original device/dtype only after validation. - The use of eigh and A = 0.5 * (A + A.conj().T) are suggested in PR #34 by @SoverHHH, @EdwardPooh, and @dong-zehao - This is a highly vectorized PyTorch implementation of ``ptycho\\+core\\probe_modes_ortho.m`` from PtychoShelves. The expected shape of `modes` input is (pmode, Ny, Nx) to be consistent with ptyrad. - Matlab's dot(p2,p1) for complex input would implictly apply with the complex conjugate, so Matlab's dot() != torch.dot because torch.dot doesn't automatically apply the complex conjugate. This is pointed out by @dong-zehao in issue #11. ''' orig_modes_dtype = modes.dtype orig_device = modes.device if not modes.is_complex(): modes = torch.complex(modes, torch.zeros_like(modes)) input_shape = modes.shape modes_reshaped = modes.reshape(input_shape[0], -1) # Determine compute backend: MPS does not support complex128 or eigh, # so all double-precision math is routed through CPU on Apple Silicon. compute_device = torch.device('cpu') if orig_device.type == 'mps' else orig_device compute_dtype = torch.complex128 # 1. Move to compute backend, then upcast to double precision modes_double = modes_reshaped.to(device=compute_device).to(dtype=compute_dtype) # 2. Compute Gram matrix A = M @ M^H (Nmode x Nmode) A = torch.matmul(modes_double, modes_double.H) # 3. Enforce exact Hermitian symmetry A = 0.5 * (A + A.H) # 4. Calculates eigen vectors with eigh _, evecs_double = torch.linalg.eigh(A) # 5. Project to get orthogonal modes (still at double precision on compute_device) ortho_modes_double = torch.matmul(evecs_double.H, modes_double) # 6. Validate at double precision BEFORE casting back if not _validate_ortho_update(modes_double, ortho_modes_double): return modes # Return the original modes unmodified # 7. Cast back to original dtype, device, and then reshape ortho_modes = ortho_modes_double.to(dtype=orig_modes_dtype).to(device=orig_device).reshape(input_shape) # 8. Optional sort if sort: ortho_modes = sort_by_mode_int(ortho_modes) return ortho_modes
def _validate_ortho_update(orig_modes_flat, new_modes_flat, ortho_tol=1e-3, norm_rtol=1e-3): """ Defensive check to ensure orthogonalization didn't destroy the modes due to numerical instability (e.g., precision loss on highly correlated inputs). """ # Test A: Orthogonality Leakage O_gram = torch.matmul(new_modes_flat, new_modes_flat.H) N = O_gram.shape[0] if N <= 1: return True # single mode is trivially orthogonal off_diag_mask = ~torch.eye(N, dtype=torch.bool, device=O_gram.device) max_off_diag = torch.abs(O_gram[off_diag_mask]).max() max_diag = torch.abs(torch.diag(O_gram)).max() # If relative error exceeds the tolerance, reject the update if max_diag > 0 and (max_off_diag / max_diag) > ortho_tol: rel_err = (max_off_diag / max_diag).item() logger.warning(f"WARNING: Orthogonality warning, high mode leakage detected (rel error: {rel_err:.2e}). Skipping orthogonalization for this iteration.") return False # Test B: Norm Preservation orig_intensity = torch.sum(orig_modes_flat.abs().square()) new_intensity = torch.sum(new_modes_flat.abs().square()) if not torch.allclose(orig_intensity, new_intensity, rtol=norm_rtol): logger.warning(f"WARNING: Norm-preserving warning, relative total intensity changed more than {norm_rtol} from orthogonalization. Skipping orthogonalization for this iteration.") return False return True
[docs] def dct_threshold_filter( x: torch.Tensor, threshold_ratio: float = 0.05, ) -> torch.Tensor: """Applies hard-threshold filtering in the DCT domain. Keeps only the largest-magnitude DCT coefficients according to ``threshold_ratio`` and zeros out the rest. Works for any input shape (..., H, W). Args: x (torch.Tensor): Input real-valued tensor of shape (..., H, W). threshold_ratio (float): Fraction of coefficients to keep. Must be in [0, 1]. For example, 0.05 retains the top 5% coefficients. Returns: torch.Tensor: Filtered tensor of same shape as ``x``. Raises: ValueError: If ``threshold_ratio`` is not in [0, 1]. """ if not (0.0 <= threshold_ratio <= 1.0): raise ValueError("threshold_ratio must be between 0 and 1.") # Compute DCT dct = dct_2d(x) # Flatten all but keep batch shape implicit dct_abs = dct.abs() flat = dct_abs.reshape(-1) # Determine number of coefficients to keep k = max(1, int(threshold_ratio * flat.numel())) # kth largest = (N-k+1)-th smallest kth = torch.kthvalue(flat, flat.numel() - k + 1).values # Hard thresholding mask = dct_abs >= kth dct_filtered = dct * mask # Inverse DCT return idct_2d(dct_filtered)
[docs] def kr_filter(obj, radius, width): ''' Apply kr_filter using the 2D sigmoid filter ''' # Create the filter function W, note that the W has to be corner-centered Ny, Nx = obj.shape[-2:] mask = make_sigmoid_mask(min(Ny,Nx), radius, width).to(obj.device) W = ifftshift2(interpolate(mask[None,None,], size=(Ny,Nx))).squeeze() # interpolate needs 2 additional dimension (N,C,...) for the input than the output dimension # Filter the obj with filter function Wa, take the real part because Fourier-filtered obj could contain negative values fobj = torch.real(ifft2(fft2(obj) * W[None,None,])) # Apply fft2/ifft2 for only the r(y,x) dimension so the omode and z would be broadcasted return fobj
[docs] def kz_filter(obj, beta_regularize_layers=1, alpha_gaussian=1, obj_type='phase'): ''' Apply kz_filter using the arctan filter ''' # Note: Calculate force of regularization based on the idea that DoF = resolution^2/lambda device = obj.device # Generate 1D grids along each dimension Npix = obj.shape[-3:] kz = fftfreq(Npix[0]).to(device) ky = fftfreq(Npix[1]).to(device) kx = fftfreq(Npix[2]).to(device) # Generate 3D coordinate grid using meshgrid grid_kz, grid_ky, grid_kx = torch.meshgrid(kz, ky, kx, indexing='ij') # Create the filter function Wa. W and Wa is exactly the same as PtychoShelves for now W = 1 - torch.atan((beta_regularize_layers * torch.abs(grid_kz) / torch.sqrt(grid_kx**2 + grid_ky**2 + 1e-3))**2) / (torch.pi/2) Wa = W * torch.exp(-alpha_gaussian * (grid_kx**2 + grid_ky**2)) # Filter the obj with filter function Wa, take the real part because Fourier-filtered obj could contain negative values fobj = torch.real(ifftn(fftn(obj, dim=(-3,-2,-1)) * Wa[None,], dim=(-3,-2,-1))) # Apply fftn/ifftn for only spatial dimension so the omode would be broadcasted if obj_type == 'amplitude': fobj = 1+0.9*(fobj-1) # This is essentially a soft obja threshold constraint built into the kz_filter routine for obja return fobj
[docs] def get_obj_z_shift(obj_phase, threshold=95, scale=1, max_shift=10): """ Compute z-shift from the center-of-mass (CoM) of the object phase. Args: obj_phase: tensor (omode, z, y, x), phase values in radians threshold: threshold factor used to remove weak intensities in the image scale: scaling factor applied to the measured shift max_shift: maximum allowed shift in pixels Returns: float, signed shift (positive = shift down in z) """ nz = obj_phase.shape[1] if max_shift is None: max_shift = (nz-1)/2 else: max_shift = min(max_shift, (nz-1)/2) # Ensure no negative phase value obj_phase = torch.clamp(obj_phase, min=0) # Threshold to focus on actual signal if threshold is not None: cutoff = approx_torch_quantile(obj_phase, q=threshold/100) # quantile is between [0,1] obj_phase[obj_phase < cutoff] = 0 # Value smaller than cutoff is set to 0 # Collapse omode,y,x → get mean phase per z-slice phase_z = obj_phase.mean(dim=(0, 2, 3)) # shape (z,) # Calculate CoM and shift z_coords = torch.arange(nz, device=obj_phase.device, dtype=obj_phase.dtype) com_z = (z_coords * phase_z).sum() / (phase_z.sum() + 1e-8) center_z = (nz - 1) / 2.0 shift = center_z - com_z # Apply scaling and clip shift *= scale shift = torch.clamp(shift, min=-max_shift, max=max_shift) return shift.item()
[docs] def shift_obj_along_z(objc, z_shift): """ Apply a subpixel shift along z using Fourier shift theorem. Args: objc: tensor (omode, z, y, x), complex z_shift: float, shift in pixels along +z direction Returns: shifted tensor, same shape """ if abs(z_shift) < 1e-3: return objc nz = objc.shape[1] freq_z = torch.fft.fftfreq(nz, d=1.0, device=objc.device) # cycles/pixel phase_ramp = torch.exp(-2j * torch.pi * freq_z * z_shift) # (nz,) # FFT along z only obj_f = torch.fft.fft(objc, dim=1) obj_f = obj_f * phase_ramp[None, :, None, None] obj_shifted = torch.fft.ifft(obj_f, dim=1) return obj_shifted
[docs] def complex_ratio_constraint(model, alpha1, alpha2): # https://doi.org/10.1016/j.ultramic.2024.114068 # https://doi.org/10.1364/OE.18.001981 # Suggested values for alpha1, alpha2 are 1 and 0 # For alpha1 = 1, alpha2 = 0, it's suggesting a phase object and phase would not be updated. # Namely, objac = exp(-alpha1*Cbar*objp); objpc = objp # NOTE that my implementaiton is slightly different from the papers # Because for electron ptychography we usually defines a positive phase shift in the transmission function, obj(r) = T(r) = exp(i*sigma*V(r)) # Hence the object phase = angle(obj) = i*sigma*V(r) # So when electron being scattered by the nuclei, it accumulates positive phase shift and slightly less than 1 amplitude # Hence the constraint foumula is slightly modified accordingly, so we have positive phase associated with slightly less than 1 amplitude obja = model.opt_obja objp = model.opt_objp log_obja = torch.log(obja) # Compute Cbar for the entire object across (omode, z, y, x) # Although we can consider repeat this across z slices or omode Cbar = (log_obja.abs().sum()) / (objp.abs().sum() + 1e-8) # Avoid division by zero # Compute updated amplitude, note the negative sign for the second term! objac = torch.exp((1 - alpha1) * log_obja - alpha1 * Cbar * objp) # Compute updated phase, note the negative sign for the second term! objpc = (1 - alpha2) * objp - alpha2 / (Cbar + 1e-8) * log_obja # Avoid division by zero return objac, objpc, Cbar