Source code for ptyrad.core.functional

"""
PyTorch implementation of core functionals like image shifts, blurring, DCT, etc.

"""

from typing import Literal, Optional, Tuple, Union

import numpy as np
import torch

# This is currently used in core/constraints and core/forward as FFT aliases
[docs] def fftshift2(x): """ A wrapper over torch.fft.fftshift for the last 2 dims """ # Note that fftshift and ifftshift are only equivalent when N = even return torch.fft.fftshift(x, dim=(-2,-1))
[docs] def ifftshift2(x): """ A wrapper over torch.fft.ifftshift for the last 2 dims""" # Note that fftshift and ifftshift are only equivalent when N = even return torch.fft.ifftshift(x, dim=(-2,-1))
# This is currently only used in init/initializer
[docs] def complex_object_z_resample_torch(obj: Union[torch.Tensor, np.ndarray], dz_now: float, resample_mode: Literal['scale_Nlayer', 'scale_slice_thickness', 'target_Nlayer', 'target_slice_thickness'], resample_value: Union[float, int], output_type: Optional[Literal['complex', 'amplitude', 'phase', 'amp_phase']] = 'complex', return_np: bool = True): """Resample a complex 3D object along the depth (z) axis while conserving amplitude product, phase sum, and total thickness. This function performs interpolation along the z-axis of a complex-valued object using PyTorch. The object is decomposed into amplitude and phase, resampled with conservation laws applied, and recombined into the desired output representation. Args: obj (ndarray or torch.Tensor): Input complex object with shape (..., Nz, Ny, Nx). Can be a NumPy array or a torch.Tensor. dz_now (float): Current slice thickness along the z-axis. resample_mode (str): Resampling mode for the depth axis. Must be one of: - "scale_Nlayer": Scale the number of layers by a float factor. - "scale_slice_thickness": Scale slice thickness by a float factor. - "target_Nlayer": Resample to a target integer number of layers. - "target_slice_thickness": Resample to a target slice thickness. resample_value (int or float): Parameter value for the resampling mode. - Positive float for "scale_Nlayer" or "scale_slice_thickness". - Positive integer (>=1) for "target_Nlayer". - Positive float for "target_slice_thickness". output_type (str, optional): Output representation. Must be one of: - "complex": Return recombined complex object (default). - "amplitude": Return amplitude only. - "phase": Return phase only. - "amp_phase": Return tuple (amplitude, phase). return_np (bool, optional): If True (default), convert outputs to NumPy arrays. If False, return PyTorch tensors. Returns: ndarray or torch.Tensor or tuple: The resampled object in the requested representation: - Complex ndarray/tensor if output_type == "complex". - Real ndarray/tensor if output_type == "amplitude" or "phase". - Tuple of (amplitude, phase) if output_type == "amp_phase". Type depends on `return_np`. Raises: ValueError: If `resample_mode` is invalid. ValueError: If the target number of layers is less than 1. ValueError: If the input object has unsupported dimensionality. ValueError: If `output_type` is not one of the allowed options. Examples: Resample by doubling the number of z-layers: >>> out = complex_object_z_resample_torch( ... obj, dz_now=0.5, resample_mode="scale_Nlayer", ... resample_value=2.0, output_type="complex" ... ) >>> out.shape Resample to a target of 64 layers, keeping total thickness fixed: >>> out_amp, out_phase = complex_object_z_resample_torch( ... obj, dz_now=0.5, resample_mode="target_Nlayer", ... resample_value=64, output_type="amp_phase" ... ) """ import torch from torch.nn.functional import interpolate # Assign variables Nz_now, Ny_now, Nx_now = obj.shape[-3:] # Setup resampling modes and scaling constants if resample_mode == 'scale_Nlayer': scale_factors = [resample_value, 1, 1] sizes = None Nz_scale = resample_value elif resample_mode == 'scale_slice_thickness': scale_factors = [1/resample_value, 1, 1] sizes = None Nz_scale = 1/resample_value elif resample_mode == 'target_Nlayer': scale_factors = None sizes = [int(resample_value), Ny_now, Nx_now] Nz_scale = resample_value/Nz_now elif resample_mode == 'target_slice_thickness': scale_factors = [dz_now/resample_value, 1, 1] sizes = None Nz_scale = dz_now/resample_value else: raise ValueError(f"Supported obj_z_resample modes are 'scale_Nlayer', 'scale_slice_thickness', 'target_Nlayer', and 'target_slice_thickness', got {resample_mode}.") # Check scale factor validity if Nz_now * Nz_scale < 1: raise ValueError(f"Detected target Nlayer = {Nz_now * Nz_scale:.3f} < 1 (single slice), please check your 'obj_z_resampling' settings.") # Preprocess obj into torch tensor device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not isinstance(obj, torch.Tensor): obj_tensor = torch.tensor(obj, dtype=torch.complex64, device=device) else: obj_tensor = obj.to(dtype=torch.complex64, device=device) # Make it into 5D (1,omode,Nz,Ny,Nx) for 3D interpolation if obj_tensor.ndim == 3: orig_ndim = 3 obj_tensor = obj_tensor.unsqueeze(0).unsqueeze(0) elif obj_tensor.ndim == 4: orig_ndim = 4 obj_tensor = obj_tensor.unsqueeze(0) elif obj_tensor.ndim == 5: orig_ndim = 5 else: raise ValueError(f"Complex object 3D interpolation only supports 3, 4, 5D tensor, got {obj_tensor.ndim}.") # Split into amplitude and phase parts obja = torch.abs(obj_tensor) objp = torch.angle(obj_tensor) # Apply resampling with proper value scaling to conserve prod(amp, axis='depth'), sum(phase, axis='depth'), and total thickness obja_resample = torch.exp(interpolate(torch.log(obja), size=sizes, scale_factor=scale_factors, mode='area') / Nz_scale) objp_resample = interpolate(objp, size=sizes, scale_factor=scale_factors, mode='area') / Nz_scale # Handle outputs if output_type == 'complex': out = torch.polar(obja_resample, objp_resample) elif output_type == 'amplitude': out = obja_resample elif output_type == 'phase': out = objp_resample elif output_type == 'amp_phase': out = (obja_resample, objp_resample) else: raise ValueError( f"output_type must be one of 'complex', 'amplitude', 'phase', 'amp_phase', " f"got {output_type}" ) # Reduce back to original ndim if orig_ndim == 3: if isinstance(out, tuple): out = tuple(o.squeeze(0).squeeze(0) for o in out) else: out = out.squeeze(0).squeeze(0) elif orig_ndim == 4: if isinstance(out, tuple): out = tuple(o.squeeze(0) for o in out) else: out = out.squeeze(0) # Convert to numpy if requested if return_np: if isinstance(out, tuple): out = tuple(o.detach().cpu().numpy() for o in out) else: out = out.detach().cpu().numpy() return out
# This is currently used in core/constraint.py > get_obj_z_shift
[docs] def approx_torch_quantile(t, q, sample_size=16_000_000): """ Approximated quantile to prevent the 2^24 element (roughly 16.7M) limitation of torch.quantile as of now. See https://github.com/pytorch/pytorch/issues/64947 `RuntimeError: quantile() input tensor is too large` Note that this approximated quantile would have some randomness. Args: t (torch.Tensor): Input torch tensor q (float): Targeted quantile number [0,1] sample_size (int, optional): Number of randomly selected elements used to approximate the true quantile. Defaults to 16_000_000. Returns: float: The approximated quantile value for the input tensor """ # flatten flat = t.view(-1) # random subsample if necessary if flat.numel() > sample_size: idx = torch.randint(0, flat.numel(), (sample_size,), device=flat.device) flat = flat[idx] return torch.quantile(flat, q)
# This is currently used in core/constraints.py > apply_obj_zblur
[docs] def get_gaussian1d(size, std, norm=False): """Generates a 1D Gaussian kernel. Args: size (int): The number of points in the output window. std (float): The standard deviation (sigma) of the Gaussian distribution. norm (bool, optional): If True, normalizes the kernel so that its elements sum to 1. Defaults to False. Returns: numpy.ndarray: The 1D Gaussian kernel. """ from scipy.signal.windows import gaussian as gaussian1d k = gaussian1d(size, std) if norm: k /= k.sum() return k
[docs] def gaussian_blur_1d(tensor, kernel_size=5, sigma=0.5): """Applies a 1D Gaussian blur to a PyTorch tensor along its second dimension (dim 1). Designed for 4D object tensors of shape [omode, z, H, W]. The blur is applied along the z-axis (dim 1), treating each spatial position (H, W) and object mode independently. Replicate padding is used along z to properly handle boundaries for both object amplitude and phase, avoiding the edge artifacts caused by standard zero-padding. Uses F.conv2d with a (kernel_size, 1) kernel on a reshaped [omode, 1, z, H*W] view so that z stays in its natural position without any permutation. conv2d is used instead of conv3d because conv3d silently produces incorrect results on the MPS backend. Args: tensor (torch.Tensor): Input tensor of shape [omode, z, H, W]. kernel_size (int, optional): Length of the 1D Gaussian kernel. Defaults to 5. sigma (float, optional): Standard deviation of the Gaussian kernel in pixels. Defaults to 0.5. Returns: torch.Tensor: Blurred tensor with the same shape, dtype, and device as input. """ import torch.nn.functional as F dtype = tensor.dtype device = tensor.device k = torch.from_numpy(get_gaussian1d(kernel_size, sigma, norm=True)).to(dtype=dtype, device=device) # Conv2d weight: [out_channels=1, in_channels=1, kZ, 1] weight = k.view(1, 1, kernel_size, 1) # Reshape [omode, z, H, W] → [omode, 1, z, H*W] (flatten spatial dims, add channel) omode, z, H, W = tensor.shape t = tensor.reshape(omode, z, H * W).unsqueeze(1) # [omode, 1, z, H*W] # Replicate-pad along z (dim 2 of the 4D view) to match padding='same' semantics. # Asymmetric for even kernel_size: extra pad goes to the back. pad_total = kernel_size - 1 pad_front = pad_total // 2 pad_back = pad_total - pad_front # F.pad pads from the last dim inward: (W_l, W_r, H_t, H_b) t_padded = F.pad(t, (0, 0, pad_front, pad_back), mode='replicate') return F.conv2d(t_padded, weight).squeeze(1).reshape(omode, z, H, W)
# This is currently used in core/constraints.py > kr_filter, probe_mask_k
[docs] def make_sigmoid_mask(Npix: int, relative_radius: float = 2/3, relative_width: float = 0.2, center: Optional[Tuple[float, float]] = None): """ Create a 2D circular mask with a sigmoid transition. Args: Npix (int): Size of the square mask (Npix x Npix). relative_radius (float): Relative radius of the circular mask where the sigmoid equals 0.5, as a fraction of the image size. relative_width (float): Relative width of the sigmoid transition, as a fraction of the image size. center (Optional[Tuple[float, float]]): (y, x) coordinates of the center of the circle. Defaults to the center of the image. Returns: torch.Tensor: A 2D circular mask with a sigmoid transition. Notes: - The default `relative_radius=2/3` is inspired by its use in abTEM to reduce edge artifacts in diffraction patterns. It sets an antialias cutoff frequency at 2/3 of the simulated kMax. https://abtem.readthedocs.io/en/latest/user_guide/appendix/antialiasing.html - The `relative_width` controls the steepness of the sigmoid transition. Smaller values result in sharper transitions, while larger values produce smoother transitions. """ def scaled_sigmoid(x, offset=0, scale=1): # If scale = 1, y drops from 1 to 0 between (-0.5,0.5), or effectively 1 px # If scale = 10, it takes roughly 10 px for y to drop from 1 to 0 return 1 / (1 + torch.exp((x - offset) / scale * 10)) # Set default center if not provided if center is None: center = (Npix // 2, Npix // 2) # Use integer division for consistency # Create a grid of coordinates ky = torch.arange(Npix, dtype=torch.float32) kx = torch.arange(Npix, dtype=torch.float32) grid_ky, grid_kx = torch.meshgrid(ky, kx, indexing='ij') # Compute the distance from the specified center kR = torch.sqrt((grid_ky - center[0])**2 + (grid_kx - center[1])**2) # Apply the scaled sigmoid function sigmoid_mask = scaled_sigmoid(kR, offset=Npix * relative_radius / 2, scale=relative_width * Npix) return sigmoid_mask
# This is currently used in core/constraints.py > kr_thrsh
[docs] def dct_2d(x: torch.Tensor) -> torch.Tensor: """Computes a 2D DCT-II (orthonormalized except for constant factors) using FFT. Supports arbitrary batch dimensions. The DCT is applied over the last two dimensions (H, W). Args: x (torch.Tensor): Real-valued input tensor of shape (..., H, W). Returns: torch.Tensor: DCT coefficients of shape (..., H, W). """ H, W = x.shape[-2:] # --- DCT along height (dim = -2) --- x_ext_h = torch.cat([x, x.flip(dims=[-2])], dim=-2) # (..., 2H, W) X_h = torch.fft.fft(x_ext_h, dim=-2) n_h = torch.arange(H, device=x.device) scale_h = torch.exp(-1j * torch.pi * n_h / (2 * H)) # (H,) dct_h = (X_h[..., :H, :] * scale_h[:, None]).real * 2 # --- DCT along width (dim = -1) --- x_ext_w = torch.cat([dct_h, dct_h.flip(dims=[-1])], dim=-1) # (..., H, 2W) X_w = torch.fft.fft(x_ext_w, dim=-1) n_w = torch.arange(W, device=x.device) scale_w = torch.exp(-1j * torch.pi * n_w / (2 * W)) # (W,) dct_2d = (X_w[..., :W] * scale_w).real * 2 return dct_2d
[docs] def idct_2d(x: torch.Tensor) -> torch.Tensor: """Computes a 2D inverse DCT-II (IDCT) using FFT. The inverse restores a real-valued signal and supports arbitrary batch dimensions. Args: x (torch.Tensor): DCT coefficients of shape (..., H, W). Returns: torch.Tensor: Reconstructed signal of shape (..., H, W). """ H, W = x.shape[-2:] X = x.to(torch.complex64) # --- Undo width scaling --- n_w = torch.arange(W, device=x.device) scale_w = torch.exp(1j * torch.pi * n_w / (2 * W)) Xw = X * scale_w / 2 # Symmetric extension for width # Conjugate mirror excluding the DC term Xw_ext = torch.cat( [Xw, Xw[..., 1:].flip(dims=[-1]).conj()], dim=-1 ) # (..., H, 2W) # IFFT along width x_w = torch.fft.ifft(Xw_ext, dim=-1)[..., :W].real # --- Undo height scaling --- n_h = torch.arange(H, device=x.device) scale_h = torch.exp(1j * torch.pi * n_h / (2 * H)) Xh = x_w.to(torch.complex64) * scale_h[:, None] / 2 # Symmetric extension for height Xh_ext = torch.cat( [Xh, Xh[..., 1:, :].flip(dims=[-2]).conj()], dim=-2 ) # (..., 2H, W) # IFFT along height out = torch.fft.ifft(Xh_ext, dim=-2)[..., :H, :].real return out
# This is currently used in 'obj_z_recenter' constraint to shift the probe defocus.
[docs] def near_field_evolution_torch(Npix_shape, dx, dz, lambd, dtype=torch.complex64, device='cuda'): """ Fresnel propagator """ # Translated and simplified from Yi's fold_slice Matlab implementation into PyTorch by Chia-Hao Lee # The forward pass uses the propagator direcly constructed in `PtychoModel.get_propagators`` for efficiency. ygrid = (torch.arange(-Npix_shape[0] // 2, Npix_shape[0] // 2, device=device) + 0.5) / Npix_shape[0] xgrid = (torch.arange(-Npix_shape[1] // 2, Npix_shape[1] // 2, device=device) + 0.5) / Npix_shape[1] # Standard ASM k = 2 * torch.pi / lambd ky = 2 * torch.pi * ygrid / dx kx = 2 * torch.pi * xgrid / dx Ky, Kx = torch.meshgrid(ky, kx, indexing="ij") H = ifftshift2(torch.exp(1j * dz * torch.sqrt(k ** 2 - Kx ** 2 - Ky ** 2)), ) # H has zero frequency at the corner in k-space return H.to(dtype)
# This is currently used in core/models/ptycho > PtychoModel for AD-optimizable propagators
[docs] def torch_phasor(phase): """ Creates a complex tensor with unit magnitude using the phase. Args: phase (torch.Tensor): phase angle for the exp(i*theta) Note: This util function is created so torch.compile can properly handle complex tensors, because torch.exp(1j*phase) involves the 1j which is actually a Python built-in that can't be traced. """ return torch.polar(torch.ones_like(phase), phase)
# This is currently used in core/models/ptycho_model > get_probes
[docs] def imshift_batch(img, shifts, grid): """ Generates a batch of shifted images from a single input image (..., Ny,Nx) with arbitray leading dimensions. This function shifts a complex/real-valued input image by applying phase shifts in the Fourier domain, achieving subpixel shifts in both x and y directions. Args: img (torch.Tensor): The input image to be shifted. img could be either a mixed-state complex probe (pmode, Ny, Nx) complex64 tensor, or a mixed-state pseudo-complex object stack (2,omode,Nz,Ny,Nx) float32 tensor. shifts (torch.Tensor): The shifts to be applied to the image. It should be a (Nb,2) tensor and each slice as (shift_y, shift_x). grid (torch.Tensor): The k-space grid used for computing the shifts in the Fourier domain. It should be a tensor with shape=(2, Ny, Nx), where Ny and Nx are the height and width of the images, respectively. Note that the grid is normalized so the value spans from [-0.5,0.5) Returns: shifted_img (torch.Tensor): The batch of shifted images. It has an extra dimension than the input image, i.e., shape=(Nb, ..., Ny, Nx), where Nb is the number of samples in the input batch. Note: - The shifts are in unit of pixel. For example, a shift of (0.5, 0.5) will shift the image by half a pixel in both y and x directions, positive is down/right-ward. - The function utilizes the fast Fourier transform (FFT) to perform the shifting operation efficiently. - Make sure to convert the input image and shifts tensor to the desired device before passing them to this function. - The fft2 and fftshifts are all applied on the last 2 dimensions, therefore it's only shifting along y and x directions - tensor[None, ...] would add an extra dimension at 0, so `*[None]*ndim` means unwrapping a list of ndim None as [None, None, ...] - The img is automatically broadcast to `(Nb, *img.shape)`, so if a batch of images are passed in, each image would be shifted independently """ assert img.shape[-2:] == grid.shape[-2:], f"Found incompatible dimensions. img.shape[-2:] = {img.shape[-2:]} while grid.shape[-2:] = {grid.shape[-2:]}" ndim = img.ndim # Get the total img ndim so that the shift is dimension-independent shifts = shifts[(...,) + (None,) * ndim] # Expand shifts to (Nb,2,1,1,...) so shifts.ndim = ndim+2. It was written as `shifts = shifts[..., *[None]*ndim]` for Python 3.11 or above with better readability grid = grid[(slice(None),) + (None,) * (ndim - 1) + (...,)] # Expand grid to (2,1,1,...,Ny,Nx) so grid.ndim = ndim+2. It was written as `grid = grid[:,*[None]*(ndim-1), ...]` for Python 3.11 or above with better readability shift_y, shift_x = shifts[:, 0], shifts[:, 1] # shift_y, shift_x are (Nb,1,1,...) with ndim singletons, so the shift_y.ndim = ndim+1 ky, kx = grid[0], grid[1] # ky, kx are (1,1,...,Ny,Nx) with ndim-2 singletons, so the ky.ndim = ndim+1 phase = -2*torch.pi * (shift_x * kx + shift_y * ky) w = torch_phasor(phase) # w = (Nb, 1,1,...,Ny,Nx) so w.ndim = ndim+1. The zero frequency term of w is at the corner. shifted_img = torch.fft.ifft2(torch.fft.fft2(img) * w) # For real-valued input, take shifted_img.real return shifted_img
# This is not used in PtyRAD yet, but could be useful for some analysis notebooks
[docs] def get_center_of_mass(image, corner_centered=False): """ Finds and returns the center of mass of an real-valued 2/3D tensor """ # The expected input shape can be either (Ny, Nx) or (N, Ny, Nx) # The output center_y and center_x will be either (N,) or a scaler tensor # Note that for even-number sized arr (like [128,128]), even it's uniformly ones, the "center" would be between pixels like [63.5,63.5] # Note that the `corner_centered` flag idea is adapted from py4DSTEM, which is quite handy when we have corner-centered probe or CBED # https://github.com/py4dstem/py4DSTEM/blob/dev/py4DSTEM/process/utils/utils.py ndim = image.ndim assert ndim in [2, 3], f"image.ndim must be either 2 or 3, we've got {ndim}" # Create grid of coordinates device = image.device (ny, nx) = image.shape[-2:] if corner_centered: grid_y, grid_x = torch.meshgrid(torch.fft.fftfreq(ny, 1 / ny, device=device), torch.fft.fftfreq(nx, 1 / nx, device=device), indexing='ij') else: grid_y, grid_x = torch.meshgrid(torch.arange(ny, device=device), torch.arange(nx, device=device), indexing='ij') # Compute total intensity total_intensity = torch.sum(image, dim = (-2,-1)).mean() # Compute weighted sum of x and y coordinates center_y = torch.sum(grid_y * image, dim = (-2,-1)) / total_intensity center_x = torch.sum(grid_x * image, dim = (-2,-1)) / total_intensity return center_y, center_x