Source code for ptyrad.forward

"""
Physical forwad model that generates diffraction patterns from mixed-state probe/object in a fully vectorized way

"""

import torch
from torch.fft import fft2, ifft2

from ptyrad.utils import fftshift2

# The forward model takes a batch of object patches and probes with their mixed states
# By introducing and aligning the singleton dimensions carefully, 
# we can vectorize all the operations except the serial z-dimension propagation
# For 3D object with n_slices, the for loop would go through n-1 loops and multiply the last slice without further Fresnel propagaiton
# This way we can skip the if statement and make it slightly faster 
# For 2D object (n_slices = 1), the entire for loop is skipped
# Note that element-wise multiplication of tensor (*) is defaulted as out-of-place operation
# So new tensor is being created and referenced to the old graph to keep the gradient flowing

[docs] def multislice_forward_model_vec_all(object_patches, probe, H, omode_occu=None, eps=1e-10): """ Computes the multislice electron diffraction pattern with multiple incoherent probe and object modes using a vectorized forward model. Args: object_patches (torch.Tensor): Tensor of shape (N, omode, Nz, Ny, Nx, 2), representing pseudo-complex object patches with float32 amplitude and phase components. N is the number of samples in a batch, omode is the number of object modes, Nz, Ny, Nx are the dimensions of the object patches. omode_occu (torch.Tensor): Tensor of shape (omode,) with float32 values, representing the occupancy/expectation for each object mode. The sum of all elements should be 1. probe (torch.Tensor): Tensor of shape (N, pmode, Ny, Nx) with complex64 values, representing the probe(s). N is the number of samples in the batch, pmode is the number of probe modes. By default, N is 1, assuming the same probe for all samples. H (torch.Tensor): Tensor of shape (N, Ky, Kx) with complex64 values, representing the Fresnel propagator that propagates the wave by a slice thickness. eps (float, optional): A small value added for numerical stability. Defaults to 1e-10. Returns: torch.Tensor: Tensor of shape (N, Ky, Kx) with float32 positive values, representing the forward diffraction pattern for each sample in the batch. """ # These .contiguous() are needed for torch.compile in Linux object_patches = object_patches.contiguous() probe = probe.contiguous() H = H.contiguous() # Initialize omode_occu if it's not specified if omode_occu is None: objp= object_patches[...,1] device = objp.device dtype = objp.dtype omode = objp.size(1) omode_occu = torch.ones(omode, dtype=dtype, device=device) / omode # Cast the object back to actual complex tensor object_cplx = torch.polar(object_patches[...,0], object_patches[...,1]).contiguous() # (N, omode, Nz, Ny, Nx) n_slices = object_cplx.shape[2] # Expand psi to include omode dimension psi = probe[:, :, None, :, :].contiguous() # (N, pmode, Ny, Nx) -> (N, pmode, omode, Ny, Nx) # Propagating each object layer using broadcasting for n in range(n_slices-1): object_slice = object_cplx[:, :, n, :, :] # object_slice -> (N, omode, Ny, Nx) psi = psi * object_slice[:, None, :, :, :] # psi -> (N, pmode, omode, Ny, Nx). Note that psi is always centered in real space psi = ifft2(H[:,None,None] * fft2(psi)) # Note that fft2 and ifft2 are applying to the last 2 axes. Although preshift psi before fft2 would seem more natural, it's nearly 50% slower to do it as fftshift2(ifft2(fft2(ifftshift2(psi)))) # Interacting with the last layer, and no propagation is needed afterward object_slice = object_cplx[:, :, n_slices-1, :, :] psi = psi * object_slice[:, None, :, :, :] # Propagate the object-modified exit wave psi(r) to detector plane into psi(k) # The contribution from probe / object modes are incoherently summed together # Breaking down the steps for clarity, while combine all of these for lower peak memory consumption # psi_k = fftshift(fft2(psi)) # |psi_k|^2 = psi_k.abs().square() # weighted_psi_k = |psi_k|^2 * omode_occu # dp_fwd = sum(weighted_psi_k) # Note that norm = 'ortho' is needed to ensure that for each sample, sum(|psi|^2) and sum(dp) has the same scale (should be 1) dp_fwd = torch.sum((fftshift2(fft2(psi, norm='ortho'))).abs().square() * omode_occu[:,None,None], dim=(1, 2)) + eps # Add eps for numerical stability return dp_fwd