Source code for ptyrad.optics.probe

"""
Numpy-based electron / x-ray probe generation functions

"""

import logging
from typing import Dict, Literal, Tuple, Union

import numpy as np
from numpy.fft import fft2, fftfreq, fftshift, ifft2, ifftshift

from ptyrad.optics.aberrations import Aberrations
from ptyrad.optics.constants import get_wavelength_ang

logger = logging.getLogger(__name__)

# Initialize probes
[docs] def make_aberration_surface_krivanek_polar( aberrations: Dict[Tuple[int, int], Dict[str, float]], kX: np.ndarray, kY: np.ndarray, wavelength: float ) -> np.ndarray: """Calculates the aberration phase surface chi(k) using Krivanek Polar form. Implements the standard polar expansion as defined in Kirkland Eqn. 2.22. Args: aberrations: A dictionary mapping order (n, m) to polar coefficients. Format: {(n, m): {'mag': float, 'phi': float}} 'mag': Coefficient magnitude (e.g., C_s) in Angstroms. 'phi': Azimuthal angle in degrees. kX: Spatial frequency coordinate X (1/Angstrom). kY: Spatial frequency coordinate Y (1/Angstrom). wavelength: Electron wavelength in Angstroms. Returns: np.ndarray: The aberration phase surface in radians. """ alphaR = np.sqrt(kX**2 + kY**2) * wavelength alphaPhi = np.arctan2(kY, kX) chi = np.zeros_like(alphaR) for (n,m), coeffs in aberrations.items(): if m == 0: C_nm = coeffs chi += (C_nm * alphaR**(n+1)) / (n+1) else: C_nm = coeffs['mag'] phi_nm = np.radians(coeffs['phi']) # Kirkland Eq 2.22 chi += (C_nm * alphaR**(n+1) * np.cos(m*(alphaPhi - phi_nm))) / (n+1) chi *= 2 * np.pi / wavelength return chi
[docs] def make_aberration_surface_krivanek_complex( aberrations: Dict[Tuple[int, int], complex], kX: np.ndarray, kY: np.ndarray, wavelength: float ) -> np.ndarray: """Calculates the aberration phase surface chi(k) using Krivanek Complex form. Implements the complex power series expansion (Kirkland Eqn. 2.19/2.20). This form utilizes the complex coordinate omega = alpha_x + i*alpha_y. Note that we swapped the exponents of omega and conj(omega) so the angle convention is consistent with Cartesian and Polar form. Args: aberrations: A dictionary mapping order (n, m) to complex coefficients. Format: {(n, m): complex_value} kX: Spatial frequency coordinate X (1/Angstrom). kY: Spatial frequency coordinate Y (1/Angstrom). wavelength: Electron wavelength in Angstroms. Returns: np.ndarray: The aberration phase surface in radians (real-valued). """ alphaX = kX * wavelength # alphaX in radians alphaY = kY * wavelength chi = np.zeros_like(alphaX, dtype=complex) omega = alphaX + 1.0j*alphaY for (n,m), coeffs in aberrations.items(): s = (n+m+1)//2 C_nm = coeffs # This is complex valued chi += (C_nm * np.conj(omega)**s * omega**(n+1-s)) / (n+1) # Note that we swapped the exponenets of omega and conj(omega) so the angle conventions are consistent chi = 2 * np.pi / wavelength * chi.real return chi
[docs] def make_aberration_surface_krivanek_cartesian( aberrations: Dict[Tuple[int, int], Dict[str, float]], kX: np.ndarray, kY: np.ndarray, wavelength: float ) -> np.ndarray: r"""Calculates the aberration phase surface using recursive Cartesian polynomials. This method is significantly faster than polar or complex forms for large arrays as it avoids expensive trigonometric operations by using recursive multiplication. Mathematical Derivation: ------------------------ 1. Start with the standard Polar form (Kirkland Eq 2.22): $\chi(\alpha, \phi) = \frac{2\pi}{\lambda} \frac{1}{n+1} C_{nm} \alpha^{n+1} \cos[m(\phi - \phi_{nm})]$ 2. Expand the cosine term $\cos(m\phi - m\phi_{nm})$ and ignore the prefactor and summation for now: $\chi \propto \alpha^{n+1} [ \cos(m\phi)\cos(m\phi_{nm}) + \sin(m\phi)\sin(m\phi_{nm}) ]$ 3. Define Cartesian coefficients $C_{nma}, C_{nmb}$ to substitute $C_{nm}$ and $\phi_{nm}$: $C_{nma} = C_{nm} \cos(m\phi_{nm})$ $C_{nmb} = C_{nm} \sin(m\phi_{nm})$ 4. Split the radial term $\alpha^{n+1}$ into $\alpha^{n+1-m} \cdot \alpha^m$ to isolate angular parts: $\chi \propto \alpha^{n+1-m} [ C_{nma} (\alpha^m \cos m\phi) + C_{nmb} (\alpha^m \sin m\phi) ]$ 5. Define Cartesian Angular Polynomials $X_m, Y_m$ using complex variable $Z = \alpha_x + i\alpha_y$: $Z^m = (\alpha e^{i\phi})^m = \alpha^m (\cos m\phi + i \sin m\phi)$ Therefore: $X_m = \text{Re}(Z^m) = \alpha^m \cos(m\phi)$ $Y_m = \text{Im}(Z^m) = \alpha^m \sin(m\phi)$ 6. Final Calculation: $X_m, Y_m$ are pre-calculated using the recurrence $Z_{m+1} = Z_m \cdot Z$. $\chi = \frac{2\pi}{\lambda} \sum \frac{1}{n+1} (\alpha^2)^{\frac{n+1-m}{2}} [ C_{nma}X_m + C_{nmb}Y_m ]$ Args: aberrations: A dictionary mapping order (n, m) to Cartesian coefficients. Format: {(n, m): {'a': float, 'b': float}} 'a': Cnma, cosine-like coefficient (Real part). 'b': Cnmb, sine-like coefficient (Imaginary part). kX: Spatial frequency coordinate X (1/Angstrom). kY: Spatial frequency coordinate Y (1/Angstrom). wavelength: Electron wavelength in Angstroms. Returns: np.ndarray: The aberration phase surface in radians. """ alphaX = kX * wavelength alphaY = kY * wavelength alpha_sq = alphaX**2 + alphaY**2 # We scan the input dict to find the highest 'm' we need to generate max_m = 0 if aberrations: max_m = max(m for (n, m) in aberrations.keys()) # 3. Generate Angular Polynomials (X_m, Y_m) via Recursion # X_m = Real part of (ax + i*ay)^m # Y_m = Imag part of (ax + i*ay)^m # Storage for the basis functions X = {} Y = {} # Base Case: m=0 (1, 0) X[0] = np.ones_like(alphaX) Y[0] = np.zeros_like(alphaX) # Recurrence Loop for m in range(max_m): # The Recurrence Relation: Z_{m+1} = Z_m * (x + iy) # Real: X_{m+1} = X_m * x - Y_m * y # Imag: Y_{m+1} = X_m * y + Y_m * x X[m+1] = X[m] * alphaX - Y[m] * alphaY Y[m+1] = X[m] * alphaY + Y[m] * alphaX chi = np.zeros_like(alphaX) for (n, m), val in aberrations.items(): if m == 0: C_a = val C_b = 0.0 else: C_a = val.get('a', 0.0) C_b = val.get('b', 0.0) if C_a == 0 and C_b == 0: continue # A. Radial Term: alpha^(n+1-m) # alpha^(n+1-m) = (alpha^2) ^ ((n+1-m)/2) power_rad = (n + 1 - m) / 2.0 term_radial = alpha_sq ** power_rad # B. Angular Term: (C_a * X_m + C_b * Y_m) term_angular = C_a * X[m] + C_b * Y[m] chi += term_radial * term_angular / (n + 1) chi *= (2 * np.pi / wavelength) return chi
[docs] def make_stem_probe( kv: float, conv_angle: float, Npix: int, dx: float, aberrations: Union[dict, Aberrations], method: Literal['polar', 'cartesian', 'complex'] = 'cartesian', ) -> np.ndarray: """Simulates a STEM probe in real space using the specified methods for chi(k) calculations. The three methods (polar, cartesian, complex) give identical result within numerical precision, while 'cartesian' is chosen as the default as it's the fastest (though they're all just few ms). Constructs the probe by defining the aperture and aberrations in Fourier space, applying the phase shift, and performing an inverse FFT to obtain the real-space complex wave function. Args: kv: Acceleration voltage in kilovolts (kV). conv_angle: Convergence semi-angle in milliradians (mrad). Npix: Number of pixels for the square simulation grid. dx: Real-space pixel size in Angstroms. aberrations: An Aberrations instance, or dictionary of aberration coefficients. The dictionary can be in Haider (e.g., {'C1': 10}), or Krivanek (e.g., {'C12': 10, 'phi12': 30}) notation in polar / cartesian / complex form. Mix-match and aliases like 'defocus', 'Cs' are supported. method: The computation approach for chi(k) calculation. Options: - 'polar': Standard Krivanek polar form (C_nm * alpha^(n+1) * cos[m(phi-phi_nm)]). - 'cartesian': Recursive Cartesian polynomials (C_nma * X[m] + C_nmb * Y[m]). - 'complex': Analytic complex power series (C_nm * w*^(n+1-s) * w^s). Returns: np.ndarray: A 2D complex array representing the probe wave function in real space, normalized such that the total intensity sums to 1. """ # Instantiate the Aberrations object if users are passing a dict if isinstance(aberrations, dict): ab = Aberrations(aberrations) else: ab = aberrations # Calculate some variables wavelength = get_wavelength_ang(kv) # wavelength in Ang k_aperture = conv_angle/1e3/wavelength dk = 1/(dx*Npix) # Make k space sampling and probe forming aperture k = fftshift(fftfreq(Npix, dx)) # k is now in unit of Ang-1 kX,kY = np.meshgrid(k,k, indexing='xy') kR = np.sqrt(kX**2+kY**2) mask = (kR<=k_aperture) # Info printing logger.info("Start simulating STEM probe") logger.info(f' kv = {kv} kV') logger.info(f' wavelength = {wavelength:.4f} Ang') logger.info(f' conv_angle = {conv_angle} mrad') logger.info(f' Npix = {Npix} px') logger.info(f' dk = {dk:.4f} Ang^-1') logger.info(f' kMax = {(Npix*dk/2):.4f} Ang^-1') logger.info(f' alpha_max = {(Npix*dk/2*wavelength*1000):.4f} mrad') logger.info(f' dx = {dx:.4f} Ang, Nyquist-limited dmin = 2*dx = {2*dx:.4f} Ang') logger.info(f' Rayleigh-limited resolution = {(0.61*wavelength/conv_angle*1e3):.4f} Ang (0.61*lambda/alpha for focused probe )') logger.info(f' Real space probe extent = {dx*Npix:.4f} Ang') for line in ab.pretty_print().splitlines(): logger.info(line) # Choosing the computation method used for chi calculation if method == 'polar': aberrations_by_order = ab.export(notation='krivanek', style= 'polar', layout='nested') make_aberration_surface = make_aberration_surface_krivanek_polar elif method == 'cartesian': aberrations_by_order = ab.export(notation='krivanek', style= 'cartesian', layout='nested') make_aberration_surface = make_aberration_surface_krivanek_cartesian elif method == 'complex': aberrations_by_order = ab.export(notation='krivanek', style= 'complex', layout='nested') make_aberration_surface = make_aberration_surface_krivanek_complex else: raise ValueError(f"Unknown calculation method = {method}, please choose between 'polar', 'cartesian', or 'complex'") # Calculate chi(k) in unit of radians chi = make_aberration_surface(aberrations=aberrations_by_order, kX=kX, kY=kY, wavelength=wavelength) # Make probe and normalize psi = np.exp(-1j*chi) probe = mask*psi # It's now the masked wave function at the aperture plane probe = fftshift(ifft2(ifftshift(probe))) # Propagate the wave function from aperture to the sample plane. probe = probe/np.sqrt(np.sum((np.abs(probe))**2)) # Normalize the probe so sum(abs(probe)^2) = 1 return probe
[docs] def make_fzp_probe( beam_kev: float, Npix: int, dx: float, Ls: float, Rn: float, dRn: float, D_FZP: float, D_H: float, ) -> np.ndarray: """ Generates a Fresnel zone plate probe with internal Fresnel propagation for x-ray ptychography simulations. Parameters: beam_kev (float): Energy of the x-ray photon. Npix (int): Number of pixels. dx (float): Pixel size (in meters) in the sample plane. Ls (float): Distance (in meters) from the focal plane to the sample. Rn (float): Radius of outermost zone (in meters). dRn (float): Width of outermost zone (in meters). D_FZP (float): Diameter of pinhole. D_H (float): Diameter of the central beamstop (in meters). Returns: ndarray: Calculated probe field in the sample plane. """ lambda_ = 1.23984193e-9 / beam_kev # lambda_: m; energy: keV fl = 2 * Rn * dRn / lambda_ # focal length corresponding to central wavelength logger.info("Start simulating FZP probe") dx_fzp = lambda_ * fl / Npix / dx # pixel size in the FZP plane # Coordinate in the FZP plane lx_fzp = np.linspace(-dx_fzp * Npix / 2, dx_fzp * Npix / 2, Npix) x_fzp, y_fzp = np.meshgrid(lx_fzp, lx_fzp) T = np.exp(-1j * 2 * np.pi / lambda_ * (x_fzp**2 + y_fzp**2) / (2 * fl)) C = (np.sqrt(x_fzp**2 + y_fzp**2) <= (D_FZP / 2)).astype(np.float64) # circular function of FZP H = (np.sqrt(x_fzp**2 + y_fzp**2) >= (D_H / 2)).astype(np.float64) # central block IN = C * T * H M, N = IN.shape k = 2 * np.pi / lambda_ # Coordinate grid for input plane lx = np.linspace(-dx_fzp * M / 2, dx_fzp * M / 2, M) x, y = np.meshgrid(lx, lx) # Coordinate grid for output plane fc = 1 / dx_fzp fu = lambda_ * (fl + Ls) * fc lu = ifftshift(np.linspace(-fu / 2, fu / 2, M)) u, v = np.meshgrid(lu, lu) z = fl + Ls if z > 0: # Propagation in the positive z direction pf = np.exp(1j * k * z) * np.exp(1j * k * (u**2 + v**2) / (2 * z)) kern = IN * np.exp(1j * k * (x**2 + y**2) / (2 * z)) kerntemp = fftshift(kern) cgh = fft2(kerntemp) probe = fftshift(cgh * pf) else: # Propagation in the negative z direction (or backward propagation) z = abs(z) pf = np.exp(1j * k * z) * np.exp(1j * k * (x**2 + y**2) / (2 * z)) cgh = ifft2(ifftshift(IN) / np.exp(1j * k * (u**2 + v**2) / (2 * z))) probe = fftshift(cgh) / pf return probe
[docs] def make_mixed_probe(probe, pmodes, pmode_init_pows): ''' Make a mixed state probe from a single state probe ''' # Input: # probe: (Ny,Nx) complex array # pmodes: number of incoherent probe modes, scaler int # pmode_init_pows: Integrated intensity of modes. List of a value (e.g. [0.02]) or a couple values for the first few modes. sum(pmode_init_pows) must < 1. # Output: # mixed_probe: A mixed state probe with (pmode,Ny,Nx) # Prepare a mixed-state probe `mixed_probe` logger.info(f"Start making mixed-state STEM probe with {pmodes} incoherent probe modes") M = np.ceil(pmodes**0.5)-1 N = np.ceil(pmodes/(M+1))-1 mixed_probe = hermite_like(probe, M,N)[:pmodes] # Normalize each pmode pmode_pows = np.zeros(pmodes) for ii in range(1,pmodes): if ii<np.size(pmode_init_pows): pmode_pows[ii] = pmode_init_pows[ii-1] else: pmode_pows[ii] = pmode_init_pows[-1] if sum(pmode_pows)>1: raise ValueError('Modes total power exceeds 1, check pmode_init_pows') else: pmode_pows[0] = 1-sum(pmode_pows) mixed_probe = mixed_probe * np.sqrt(pmode_pows)[:,None,None] logger.info(f"Relative power of probe modes = {pmode_pows}") return mixed_probe
[docs] def hermite_like(fundam, M, N): """Generates orthogonal Hermite-like probe modes from a fundamental mode. This function takes a base probe (the fundamental mode) and multiplies it by Hermitian functions up to a maximum $x$-order $M$ and $y$-order $N$ to compute higher-order modes. The resulting modes are then iteratively orthonormalized against all previously generated modes. Note: This is a Python implementation ported from `ptycho/+core/hermite_like.m` in PtychoShelves, with the following modifications: * Array indexing is converted from MATLAB (1-based) to Python (0-based). * The X and Y spatial meshgrids are generated internally rather than passed as arguments. * The output tensor `H` has the shape `(pmode, Ny, Nx)` to be consistent with PtyRAD conventions. * The function always outputs $(M+1)(N+1)$ modes, which may be slightly more than a user's target `pmode` count (requiring subsequent truncation). Args: fundam (numpy.ndarray): The base fundamental probe function, typically a 2D complex array of shape `(Ny, Nx)`. M (int or float): The maximum $x$-order of the Hermite basis. N (int or float): The maximum $y$-order of the Hermite basis. Returns: numpy.ndarray: A 3D complex array of the generated orthonormalized modes with shape `((M+1)*(N+1), Ny, Nx)`. """ # Initialize i/o M = M.astype('int') N = N.astype('int') m = np.arange(M+1) n = np.arange(N+1) H = np.zeros(((M+1)*(N+1), fundam.shape[-2], fundam.shape[-1]), dtype=fundam.dtype) # Create meshgrid rows, cols = fundam.shape[-2:] x = np.arange(cols) - cols / 2 y = np.arange(rows) - rows / 2 X, Y = np.meshgrid(x, y) cenx = np.sum(X * np.abs(fundam)**2) / np.sum(np.abs(fundam)**2) ceny = np.sum(Y * np.abs(fundam)**2) / np.sum(np.abs(fundam)**2) varx = np.sum((X - cenx)**2 * np.abs(fundam)**2) / np.sum(np.abs(fundam)**2) vary = np.sum((Y - ceny)**2 * np.abs(fundam)**2) / np.sum(np.abs(fundam)**2) counter = 0 # Create basis for nii in n: for mii in m: auxfunc = ((X - cenx)**mii) * ((Y - ceny)**nii) * fundam if counter == 0: auxfunc = auxfunc / np.sqrt(np.sum(np.abs(auxfunc)**2)) else: auxfunc = auxfunc * np.exp(-((X - cenx)**2 / (2*varx)) - ((Y - ceny)**2 / (2*vary))) auxfunc = auxfunc / np.sqrt(np.sum(np.abs(auxfunc)**2)) # Now make it orthogonal to the previous ones for ii in range(counter): # The other ones auxfunc = auxfunc - np.dot(H[ii].reshape(-1), np.conj(auxfunc).reshape(-1)) * H[ii] # Normalize each mode so that their intensities sum to 1 auxfunc = auxfunc / np.sqrt(np.sum(np.abs(auxfunc)**2)) H[counter] = auxfunc counter += 1 return H
[docs] def sort_by_mode_int_np(modes): """Sorts a set of modes in descending order based on their total intensity. The intensity of each mode is calculated as the sum of its squared amplitude across all spatial dimensions. This is commonly used to ensure the dominant probe or object modes are positioned at the lowest indices. Args: modes (numpy.ndarray): An N-dimensional array of modes, where the first dimension represents the mode index (e.g., `(pmode, Ny, Nx)` for 2D modes or `(omode, Nz, Ny, Nx)` for 3D modes). Returns: numpy.ndarray: The input array sorted in descending order of total intensity along the first dimension. """ spatial_axes = tuple(range(1, modes.ndim)) modes_int = np.sum(np.abs(modes)**2, axis=spatial_axes) indices = np.argsort(modes_int)[::-1] # sort descending modes = modes[indices] return modes
[docs] def orthogonalize_modes_vec_np(modes, sort=False): """Orthogonalize probe modes via Gram matrix eigendecomposition (NumPy version). Mirrors the logic of ``orthogonalize_modes_vec`` (PyTorch) but operates on NumPy arrays, making it suitable for probe initialization where PyTorch is not yet involved. Uses ``np.linalg.eigh`` (Hermitian eigensolver) for numerical stability. ``eigh`` dispatches to ``dsyev``/``cheev`` on all LAPACK backends (Accelerate on macOS, MKL on Windows/Linux), unlike ``eig`` which used the numerically weaker ``cgeev``. The Gram matrix is computed and decomposed in complex128 to guard against float32/64 precision loss. Falls back silently to the original modes if the result is invalid (poor orthogonality or norm change), matching the behaviour of the Torch version. Parameters ---------- modes : np.ndarray Input modes of shape (Nmode, Ny, Nx), complex. sort : bool, optional Whether to sort modes by their intensity (descending), by default False. Returns ------- np.ndarray Orthogonalized modes of the same shape and dtype as input, or the original modes unchanged if the validity check fails. """ orig_dtype = modes.dtype input_shape = modes.shape n_modes = input_shape[0] # 1. Reshape into (Nmode, Ny*Nx) and upcast to complex128 for numerical stability modes_reshaped = modes.reshape(n_modes, -1) modes_double = modes_reshaped.astype(np.complex128) # 2. Compute Gram matrix A = M @ M^H (Nmode x Nmode) A = modes_double @ modes_double.conj().T # 3. Enforce exact Hermitian symmetry A = 0.5 * (A + A.conj().T) # 4. Calculates eigen vectors with eigh _, eigvecs = np.linalg.eigh(A) # 5. Project to get orthogonal modes (still at double precision) ortho_modes_double = eigvecs.conj().T @ modes_double # 6. Validity check: orthogonality and norm preservation if not _validate_ortho_update_np(modes_double, ortho_modes_double): return modes # return the original modes unmodified # 7. Cast back to original dtype and reshape ortho_modes = ortho_modes_double.astype(orig_dtype).reshape(input_shape) # 8. Optional sort if sort: ortho_modes = sort_by_mode_int_np(ortho_modes) return ortho_modes
def _validate_ortho_update_np(orig_modes_flat, new_modes_flat, ortho_tol=1e-3, norm_rtol=1e-3): """Defensive check that orthogonalization preserved orthogonality and total norm. Mirrors ``_validate_ortho_update`` (PyTorch) for NumPy arrays. Both inputs are expected to be 2D (Nmode, Y*X) complex128 arrays. Returns True if the update is valid, False if it should be rejected. """ # Test A: Orthogonality leakage O_gram = new_modes_flat @ new_modes_flat.conj().T N = O_gram.shape[0] if N <= 1: return True # single mode is trivially orthogonal off_diag_mask = ~np.eye(N, dtype=bool) max_off_diag = np.abs(O_gram[off_diag_mask]).max() max_diag = np.abs(np.diag(O_gram)).max() # If relative error exceeds the tolerance, reject the update if max_diag > 0 and (max_off_diag / max_diag) > ortho_tol: rel_err = max_off_diag / max_diag logger.warning(f"WARNING: Orthogonality warning, high mode leakage detected (rel error: {rel_err:.2e}). Skipping orthogonalization for this iteration.") return False # Test B: Norm preservation orig_intensity = np.sum(np.abs(orig_modes_flat) ** 2) new_intensity = np.sum(np.abs(new_modes_flat) ** 2) if orig_intensity > 0 and abs(new_intensity - orig_intensity) / orig_intensity > norm_rtol: logger.warning(f"WARNING: Norm-preserving warning, relative total intensity changed more than {norm_rtol} from orthogonalization. Skipping orthogonalization for this iteration.") return False return True