ptyrad.core.functional#
PyTorch implementation of core functionals like image shifts, blurring, DCT, etc.
Functions
|
Approximated quantile to prevent the 2^24 element (roughly 16.7M) limitation of torch.quantile as of now. |
|
Resample a complex 3D object along the depth (z) axis while conserving amplitude product, phase sum, and total thickness. |
|
Computes a 2D DCT-II (orthonormalized except for constant factors) using FFT. |
|
A wrapper over torch.fft.fftshift for the last 2 dims |
|
Performs a line search along z to find the focal plane using the L4 norm (sharpness). |
|
Applies a 1D Gaussian blur to a PyTorch tensor along its second dimension (dim 1). |
|
Finds and returns the center of mass of an real-valued 2/3D tensor |
|
Generates a 1D Gaussian kernel. |
|
Computes a 2D inverse DCT-II (IDCT) using FFT. |
|
A wrapper over torch.fft.ifftshift for the last 2 dims |
|
Generates a batch of shifted images from a single input image (..., Ny,Nx) with arbitray leading dimensions. |
|
Return normalized (Fy, Fx) frequency grids, corner-centered (DC at [0,0]). |
|
Return normalized (Ry, Rx) real space grids, [0,N). |
|
Create a 2D circular mask with a sigmoid transition. |
|
Creates a 2D Super-Gaussian flat-top mask. |
|
Fresnel propagator |
|
Creates a complex tensor with unit magnitude using the phase. |
- ptyrad.core.functional.ifftshift2(x)[source]#
A wrapper over torch.fft.ifftshift for the last 2 dims
- ptyrad.core.functional.complex_object_z_resample_torch(obj, dz_now, resample_mode, resample_value, output_type='complex', return_np=True)[source]#
Resample a complex 3D object along the depth (z) axis while conserving amplitude product, phase sum, and total thickness.
This function performs interpolation along the z-axis of a complex-valued object using PyTorch. The object is decomposed into amplitude and phase, resampled with conservation laws applied, and recombined into the desired output representation.
- Parameters:
obj (ndarray or torch.Tensor) – Input complex object with shape (…, Nz, Ny, Nx). Can be a NumPy array or a torch.Tensor.
dz_now (float) – Current slice thickness along the z-axis.
resample_mode (str) – Resampling mode for the depth axis. Must be one of: - “scale_Nlayer”: Scale the number of layers by a float factor. - “scale_slice_thickness”: Scale slice thickness by a float factor. - “target_Nlayer”: Resample to a target integer number of layers. - “target_slice_thickness”: Resample to a target slice thickness.
resample_value (int or float) – Parameter value for the resampling mode. - Positive float for “scale_Nlayer” or “scale_slice_thickness”. - Positive integer (>=1) for “target_Nlayer”. - Positive float for “target_slice_thickness”.
output_type (str, optional) – Output representation. Must be one of: - “complex”: Return recombined complex object (default). - “amplitude”: Return amplitude only. - “phase”: Return phase only. - “amp_phase”: Return tuple (amplitude, phase).
return_np (bool, optional) – If True (default), convert outputs to NumPy arrays. If False, return PyTorch tensors.
- Returns:
The resampled object in the requested representation: - Complex ndarray/tensor if output_type == “complex”. - Real ndarray/tensor if output_type == “amplitude” or “phase”. - Tuple of (amplitude, phase) if output_type == “amp_phase”.
Type depends on return_np.
- Return type:
ndarray or torch.Tensor or tuple
- Raises:
ValueError – If resample_mode is invalid.
ValueError – If the target number of layers is less than 1.
ValueError – If the input object has unsupported dimensionality.
ValueError – If output_type is not one of the allowed options.
Examples
Resample by doubling the number of z-layers:
>>> out = complex_object_z_resample_torch( ... obj, dz_now=0.5, resample_mode="scale_Nlayer", ... resample_value=2.0, output_type="complex" ... ) >>> out.shape
Resample to a target of 64 layers, keeping total thickness fixed:
>>> out_amp, out_phase = complex_object_z_resample_torch( ... obj, dz_now=0.5, resample_mode="target_Nlayer", ... resample_value=64, output_type="amp_phase" ... )
- ptyrad.core.functional.approx_torch_quantile(t, q, sample_size=16000000)[source]#
Approximated quantile to prevent the 2^24 element (roughly 16.7M) limitation of torch.quantile as of now. See pytorch/pytorch#64947 RuntimeError: quantile() input tensor is too large Note that this approximated quantile would have some randomness.
- Parameters:
t (torch.Tensor) – Input torch tensor
q (float) – Targeted quantile number [0,1]
sample_size (int, optional) – Number of randomly selected elements used to approximate the true quantile. Defaults to 16_000_000.
- Returns:
The approximated quantile value for the input tensor
- Return type:
float
- ptyrad.core.functional.get_gaussian1d(size, std, norm=False)[source]#
Generates a 1D Gaussian kernel.
- Parameters:
size (int) – The number of points in the output window.
std (float) – The standard deviation (sigma) of the Gaussian distribution.
norm (bool, optional) – If True, normalizes the kernel so that its elements sum to 1. Defaults to False.
- Returns:
The 1D Gaussian kernel.
- Return type:
numpy.ndarray
- ptyrad.core.functional.gaussian_blur_1d(tensor, kernel_size=5, sigma=0.5)[source]#
Applies a 1D Gaussian blur to a PyTorch tensor along its second dimension (dim 1).
Designed for 4D object tensors of shape [omode, z, H, W]. The blur is applied along the z-axis (dim 1), treating each spatial position (H, W) and object mode independently. Replicate padding is used along z to properly handle boundaries for both object amplitude and phase, avoiding the edge artifacts caused by standard zero-padding.
Uses F.conv2d with a (kernel_size, 1) kernel on a reshaped [omode, 1, z, H*W] view so that z stays in its natural position without any permutation. conv2d is used instead of conv3d because conv3d silently produces incorrect results on the MPS backend.
- Parameters:
tensor (torch.Tensor) – Input tensor of shape [omode, z, H, W].
kernel_size (int, optional) – Length of the 1D Gaussian kernel. Defaults to 5.
sigma (float, optional) – Standard deviation of the Gaussian kernel in pixels. Defaults to 0.5.
- Returns:
Blurred tensor with the same shape, dtype, and device as input.
- Return type:
torch.Tensor
- ptyrad.core.functional.make_sigmoid_mask(Npix, relative_radius=0.6666666666666666, relative_width=0.2, center=None)[source]#
Create a 2D circular mask with a sigmoid transition.
- Parameters:
Npix (int) – Size of the square mask (Npix x Npix).
relative_radius (float) – Relative radius of the circular mask where the sigmoid equals 0.5, as a fraction of the image size.
relative_width (float) – Relative width of the sigmoid transition, as a fraction of the image size.
center (Optional[Tuple[float, float]]) – (y, x) coordinates of the center of the circle. Defaults to the center of the image.
- Returns:
A 2D circular mask with a sigmoid transition.
- Return type:
torch.Tensor
Notes
The default relative_radius=2/3 is inspired by its use in abTEM to reduce edge artifacts in diffraction patterns. It sets an antialias cutoff frequency at 2/3 of the simulated kMax. https://abtem.readthedocs.io/en/latest/user_guide/appendix/antialiasing.html
The relative_width controls the steepness of the sigmoid transition. Smaller values result in sharper transitions, while larger values produce smoother transitions.
- ptyrad.core.functional.make_super_gaussian_mask(Npix, relative_radius=0.95, order=6, device='cuda')[source]#
Creates a 2D Super-Gaussian flat-top mask. order=1 is a standard Gaussian. order=4 to 6 gives a nice flat top with smooth edges.
- ptyrad.core.functional.find_probe_focus_dz(probe, dx, lambd, z_range_ang=(-500, 500), z_steps=101)[source]#
Performs a line search along z to find the focal plane using the L4 norm (sharpness). Returns the optimal dz in Angstroms, positive value corresponds to forward propagation.
- ptyrad.core.functional.dct_2d(x)[source]#
Computes a 2D DCT-II (orthonormalized except for constant factors) using FFT.
Supports arbitrary batch dimensions. The DCT is applied over the last two dimensions (H, W).
- Parameters:
x (torch.Tensor) – Real-valued input tensor of shape (…, H, W).
- Returns:
DCT coefficients of shape (…, H, W).
- Return type:
torch.Tensor
- ptyrad.core.functional.idct_2d(x)[source]#
Computes a 2D inverse DCT-II (IDCT) using FFT.
The inverse restores a real-valued signal and supports arbitrary batch dimensions.
- Parameters:
x (torch.Tensor) – DCT coefficients of shape (…, H, W).
- Returns:
Reconstructed signal of shape (…, H, W).
- Return type:
torch.Tensor
- ptyrad.core.functional.make_freq_grid_2d(shape, indexing='ij', dtype=None, device=None)[source]#
Return normalized (Fy, Fx) frequency grids, corner-centered (DC at [0,0]).
Values are dimensionless fftfreq in [-0.5, 0.5). Multiply by 2π/dx to get rad/Å.
- Parameters:
shape – (Ny, Nx) tuple of ints
indexing – ‘ij’ or ‘xy’, default is ‘ij’
dtype – optional torch dtype for the output
device – torch device
- Returns:
each of shape (Ny, Nx), dimensionless, range [-0.5, 0.5)
- Return type:
(Fy, Fx)
- ptyrad.core.functional.make_real_grid_2d(shape, indexing='ij', dtype=None, device=None)[source]#
Return normalized (Ry, Rx) real space grids, [0,N).
Values are px indices in [0, N).
- Parameters:
shape – (Ny, Nx) tuple of ints
indexing – ‘ij’ or ‘xy’, default is ‘ij’
dtype – optional torch dtype for the output
device – torch device
- Returns:
each of shape (Ny, Nx), unit in px, range [0, N)
- Return type:
(Ry, Rx)
- ptyrad.core.functional.near_field_evolution_torch(Npix_shape, dx, dz, lambd, dtype=torch.complex64, device='cuda')[source]#
Fresnel propagator
- ptyrad.core.functional.torch_phasor(phase)[source]#
Creates a complex tensor with unit magnitude using the phase.
- Parameters:
phase (torch.Tensor) – phase angle for the exp(i*theta)
Note
This util function is created so torch.compile can properly handle complex tensors, because torch.exp(1j*phase) involves the 1j which is actually a Python built-in that can’t be traced.
- ptyrad.core.functional.imshift_batch(img, shifts, grid)[source]#
Generates a batch of shifted images from a single input image (…, Ny,Nx) with arbitray leading dimensions.
This function shifts a complex/real-valued input image by applying phase shifts in the Fourier domain, achieving subpixel shifts in both x and y directions.
- Parameters:
img (torch.Tensor) – The input image to be shifted. img could be either a mixed-state complex probe (pmode, Ny, Nx) complex64 tensor, or a mixed-state pseudo-complex object stack (2,omode,Nz,Ny,Nx) float32 tensor.
shifts (torch.Tensor) – The shifts to be applied to the image. It should be a (Nb,2) tensor and each slice as (shift_y, shift_x).
grid (torch.Tensor) – The k-space grid used for computing the shifts in the Fourier domain. It should be a tensor with shape=(2, Ny, Nx), where Ny and Nx are the height and width of the images, respectively. Note that the grid is normalized so the value spans from [-0.5,0.5)
- Returns:
- The batch of shifted images. It has an extra dimension than the input image, i.e., shape=(Nb, …, Ny, Nx),
where Nb is the number of samples in the input batch.
- Return type:
shifted_img (torch.Tensor)
Note
The shifts are in unit of pixel. For example, a shift of (0.5, 0.5) will shift the image by half a pixel in both y and x directions, positive is down/right-ward.
The function utilizes the fast Fourier transform (FFT) to perform the shifting operation efficiently.
Make sure to convert the input image and shifts tensor to the desired device before passing them to this function.
The fft2 and fftshifts are all applied on the last 2 dimensions, therefore it’s only shifting along y and x directions
tensor[None, …] would add an extra dimension at 0, so *[None]*ndim means unwrapping a list of ndim None as [None, None, …]
The img is automatically broadcast to (Nb, *img.shape), so if a batch of images are passed in, each image would be shifted independently