Source code for ptyrad.utils.math_ops

"""
Math operations of FFTs, masks, affine transformation matrices, etc.

"""

from typing import Optional, Tuple

import numpy as np
import torch


[docs] def torch_phasor(phase): """ Creates a complex tensor with unit magnitude using the phase. Args: phase (torch.Tensor): phase angle for the exp(i*theta) Note: This util function is created so torch.compile can properly handle complex tensors, because torch.exp(1j*phase) involves the 1j which is actually a Python built-in that can't be traced. """ return torch.polar(torch.ones_like(phase), phase)
[docs] def exponential_decay(r, a, b): return a * np.exp(-b * r)
[docs] def power_law(r, a, b): return a * r**-b
[docs] def fftshift2(x): """ A wrapper over torch.fft.fftshift for the last 2 dims """ # Note that fftshift and ifftshift are only equivalent when N = even return torch.fft.fftshift(x, dim=(-2,-1))
[docs] def ifftshift2(x): """ A wrapper over torch.fft.ifftshift for the last 2 dims""" # Note that fftshift and ifftshift are only equivalent when N = even return torch.fft.ifftshift(x, dim=(-2,-1))
[docs] def mfft2(im): # Periodic Artifact Reduction in Fourier Transforms of Full Field Atomic Resolution Images # https://doi.org/10.1017/S1431927614014639 rows, cols = im.shape # Compute boundary conditions s = np.zeros_like(im) s[0, :] = im[0, :] - im[rows-1, :] s[rows-1, :] = -s[0, :] s[:, 0] += im[:, 0] - im[:, cols-1] s[:, cols-1] -= im[:, 0] - im[:, cols-1] # Create grid for computing Poisson solution cx, cy = np.meshgrid(2 * np.pi * np.arange(cols) / cols, 2 * np.pi * np.arange(rows) / rows) # Generate smooth component from Poisson Eq with boundary condition D = 2 * (2 - np.cos(cx) - np.cos(cy)) D[0, 0] = np.inf # Enforce zero mean & handle division by zero S = np.fft.fft2(s) / D P = np.fft.fft2(im) - S # FFT of periodic component return P, S
[docs] def make_sigmoid_mask(Npix: int, relative_radius: float = 2/3, relative_width: float = 0.2, center: Optional[Tuple[float, float]] = None): """ Create a 2D circular mask with a sigmoid transition. Args: Npix (int): Size of the square mask (Npix x Npix). relative_radius (float): Relative radius of the circular mask where the sigmoid equals 0.5, as a fraction of the image size. relative_width (float): Relative width of the sigmoid transition, as a fraction of the image size. center (Optional[Tuple[float, float]]): (y, x) coordinates of the center of the circle. Defaults to the center of the image. Returns: torch.Tensor: A 2D circular mask with a sigmoid transition. Notes: - The default `relative_radius=2/3` is inspired by its use in abTEM to reduce edge artifacts in diffraction patterns. It sets an antialias cutoff frequency at 2/3 of the simulated kMax. https://abtem.readthedocs.io/en/latest/user_guide/appendix/antialiasing.html - The `relative_width` controls the steepness of the sigmoid transition. Smaller values result in sharper transitions, while larger values produce smoother transitions. """ def scaled_sigmoid(x, offset=0, scale=1): # If scale = 1, y drops from 1 to 0 between (-0.5,0.5), or effectively 1 px # If scale = 10, it takes roughly 10 px for y to drop from 1 to 0 return 1 / (1 + torch.exp((x - offset) / scale * 10)) # Set default center if not provided if center is None: center = (Npix // 2, Npix // 2) # Use integer division for consistency # Create a grid of coordinates ky = torch.arange(Npix, dtype=torch.float32) kx = torch.arange(Npix, dtype=torch.float32) grid_ky, grid_kx = torch.meshgrid(ky, kx, indexing='ij') # Compute the distance from the specified center kR = torch.sqrt((grid_ky - center[0])**2 + (grid_kx - center[1])**2) # Apply the scaled sigmoid function sigmoid_mask = scaled_sigmoid(kR, offset=Npix * relative_radius / 2, scale=relative_width * Npix) return sigmoid_mask
[docs] def make_gaussian_mask(Npix: int, radius: float, std: float, center: Optional[Tuple[float, float]] = None): """ Create a 2D Gaussian-blurred circular mask. Args: Npix (int): Size of the square mask (Npix x Npix). radius (float): Radius of the circular mask. std (float): Standard deviation of the Gaussian blur. center (tuple): (y, x) coordinates of the center of the circle. Returns: np.ndarray: A 2D Gaussian-blurred circular mask. """ from scipy.ndimage import gaussian_filter # Set default center if not provided if center is None: center = (Npix / 2, Npix / 2) # Create a grid of coordinates y = np.linspace(0, Npix - 1, Npix) x = np.linspace(0, Npix - 1, Npix) grid_y, grid_x = np.meshgrid(y, x, indexing='ij') # Compute the distance from the center dist_from_center = np.sqrt((grid_y - center[0])**2 + (grid_x - center[1])**2) # Create a binary circular mask circular_mask = (dist_from_center <= radius).astype(float) # Apply Gaussian blur to the circular mask gaussian_mask = gaussian_filter(circular_mask, sigma=std) return gaussian_mask
[docs] def approx_torch_quantile(t, q, sample_size=16_000_000): """ Approximated quantile to prevent the 2^24 element (roughly 16.7M) limitation of torch.quantile as of now. See https://github.com/pytorch/pytorch/issues/64947 `RuntimeError: quantile() input tensor is too large` Note that this approximated quantile would have some randomness. Args: t (torch.Tensor): Input torch tensor q (float): Targeted quantile number [0,1] sample_size (int, optional): Number of randomly selected elements used to approximate the true quantile. Defaults to 16_000_000. Returns: float: The approximated quantile value for the input tensor """ # flatten flat = t.view(-1) # random subsample if necessary if flat.numel() > sample_size: idx = torch.randint(0, flat.numel(), (sample_size,), device=flat.device) flat = flat[idx] return torch.quantile(flat, q)
# Affine
[docs] def compose_affine_matrix(scale, asymmetry, rotation, shear): # Adapted from PtychoShelves +math/compose_affine_matrix.m # The input rotation and shear is in unit of degree rotation_rad = np.radians(rotation) shear_rad = np.radians(shear) A1 = np.array([[scale, 0], [0, scale]]) A2 = np.array([[1 + asymmetry/2, 0], [0, 1 - asymmetry/2]]) A3 = np.array([[np.cos(rotation_rad), np.sin(rotation_rad)], [-np.sin(rotation_rad), np.cos(rotation_rad)]]) A4 = np.array([[1, 0], [np.tan(shear_rad), 1]]) affine_mat = A1 @ A2 @ A3 @ A4 return affine_mat
[docs] def decompose_affine_matrix(input_affine_mat): from scipy.optimize import least_squares def err_fun(x): scale, asymmetry, rotation, shear = x fit_affine_mat = compose_affine_matrix(scale, asymmetry, rotation, shear) return (input_affine_mat - fit_affine_mat).ravel() # Initial guess initial_guess = np.array([1, 0, 0, 0]) result = least_squares(err_fun, initial_guess) scale, asymmetry, rotation, shear = result.x return scale, asymmetry, rotation, shear
[docs] def get_decomposed_affine_matrix_from_bases(input, output): """ Fit the affine matrix components from input and output matrices A and B """ # This util function is used to quickly estimate the needed affine transformation for scan positions # If we know the lattice constant and angle between lattice vectors, then we can easily correct the scale, asymmetry, and shear # The global rotation of the object is NOT defined by lattice constant/angle so we still need to compare with the actual CBED # Typical usage of this function is to first construct A by measuring the lattice vectors of a reconstructed object suffers from affine transformation # Then estimate ideal lattice vectors with prior knowledge (lattice constant and angle) # Lastly we use this function to estimate the needed F such that B = F @ A from scipy.optimize import minimize def objective(params, A, B): scale, asymmetry, rotation, shear = params F = compose_affine_matrix(scale, asymmetry, rotation, shear) return np.linalg.norm(B - F @ A) initial_guess = [1, 0, 0, 0] # Initial guess for scale, asymmetry, rotation, shear result = minimize(objective, initial_guess, args=(input, output), method='L-BFGS-B') if result.success: (scale, asymmetry, rotation, shear) = result.x return (scale, asymmetry, rotation, shear) else: raise ValueError("Optimization failed")