ptyrad.core.functional#

PyTorch implementation of core functionals like image shifts, blurring, DCT, etc.

Functions

approx_torch_quantile(t, q[, sample_size])

Approximated quantile to prevent the 2^24 element (roughly 16.7M) limitation of torch.quantile as of now.

complex_object_z_resample_torch(obj, dz_now, ...)

Resample a complex 3D object along the depth (z) axis while conserving amplitude product, phase sum, and total thickness.

dct_2d(x)

Computes a 2D DCT-II (orthonormalized except for constant factors) using FFT.

fftshift2(x)

A wrapper over torch.fft.fftshift for the last 2 dims

find_probe_focus_dz(probe, dx, lambd[, ...])

Performs a line search along z to find the focal plane using the L4 norm (sharpness).

gaussian_blur_1d(tensor[, kernel_size, sigma])

Applies a 1D Gaussian blur to a PyTorch tensor along its second dimension (dim 1).

get_center_of_mass(image[, corner_centered])

Finds and returns the center of mass of an real-valued 2/3D tensor

get_gaussian1d(size, std[, norm])

Generates a 1D Gaussian kernel.

idct_2d(x)

Computes a 2D inverse DCT-II (IDCT) using FFT.

ifftshift2(x)

A wrapper over torch.fft.ifftshift for the last 2 dims

imshift_batch(img, shifts, grid)

Generates a batch of shifted images from a single input image (..., Ny,Nx) with arbitray leading dimensions.

make_freq_grid_2d(shape[, indexing, dtype, ...])

Return normalized (Fy, Fx) frequency grids, corner-centered (DC at [0,0]).

make_real_grid_2d(shape[, indexing, dtype, ...])

Return normalized (Ry, Rx) real space grids, [0,N).

make_sigmoid_mask(Npix[, relative_radius, ...])

Create a 2D circular mask with a sigmoid transition.

make_super_gaussian_mask(Npix[, ...])

Creates a 2D Super-Gaussian flat-top mask.

near_field_evolution_torch(Npix_shape, dx, ...)

Fresnel propagator

torch_phasor(phase)

Creates a complex tensor with unit magnitude using the phase.

ptyrad.core.functional.fftshift2(x)[source]#

A wrapper over torch.fft.fftshift for the last 2 dims

ptyrad.core.functional.ifftshift2(x)[source]#

A wrapper over torch.fft.ifftshift for the last 2 dims

ptyrad.core.functional.complex_object_z_resample_torch(obj, dz_now, resample_mode, resample_value, output_type='complex', return_np=True)[source]#

Resample a complex 3D object along the depth (z) axis while conserving amplitude product, phase sum, and total thickness.

This function performs interpolation along the z-axis of a complex-valued object using PyTorch. The object is decomposed into amplitude and phase, resampled with conservation laws applied, and recombined into the desired output representation.

Parameters:
  • obj (ndarray or torch.Tensor) – Input complex object with shape (…, Nz, Ny, Nx). Can be a NumPy array or a torch.Tensor.

  • dz_now (float) – Current slice thickness along the z-axis.

  • resample_mode (str) – Resampling mode for the depth axis. Must be one of: - “scale_Nlayer”: Scale the number of layers by a float factor. - “scale_slice_thickness”: Scale slice thickness by a float factor. - “target_Nlayer”: Resample to a target integer number of layers. - “target_slice_thickness”: Resample to a target slice thickness.

  • resample_value (int or float) – Parameter value for the resampling mode. - Positive float for “scale_Nlayer” or “scale_slice_thickness”. - Positive integer (>=1) for “target_Nlayer”. - Positive float for “target_slice_thickness”.

  • output_type (str, optional) – Output representation. Must be one of: - “complex”: Return recombined complex object (default). - “amplitude”: Return amplitude only. - “phase”: Return phase only. - “amp_phase”: Return tuple (amplitude, phase).

  • return_np (bool, optional) – If True (default), convert outputs to NumPy arrays. If False, return PyTorch tensors.

Returns:

The resampled object in the requested representation: - Complex ndarray/tensor if output_type == “complex”. - Real ndarray/tensor if output_type == “amplitude” or “phase”. - Tuple of (amplitude, phase) if output_type == “amp_phase”.

Type depends on return_np.

Return type:

ndarray or torch.Tensor or tuple

Raises:
  • ValueError – If resample_mode is invalid.

  • ValueError – If the target number of layers is less than 1.

  • ValueError – If the input object has unsupported dimensionality.

  • ValueError – If output_type is not one of the allowed options.

Examples

Resample by doubling the number of z-layers:

>>> out = complex_object_z_resample_torch(
...     obj, dz_now=0.5, resample_mode="scale_Nlayer",
...     resample_value=2.0, output_type="complex"
... )
>>> out.shape

Resample to a target of 64 layers, keeping total thickness fixed:

>>> out_amp, out_phase = complex_object_z_resample_torch(
...     obj, dz_now=0.5, resample_mode="target_Nlayer",
...     resample_value=64, output_type="amp_phase"
... )
ptyrad.core.functional.approx_torch_quantile(t, q, sample_size=16000000)[source]#

Approximated quantile to prevent the 2^24 element (roughly 16.7M) limitation of torch.quantile as of now. See pytorch/pytorch#64947 RuntimeError: quantile() input tensor is too large Note that this approximated quantile would have some randomness.

Parameters:
  • t (torch.Tensor) – Input torch tensor

  • q (float) – Targeted quantile number [0,1]

  • sample_size (int, optional) – Number of randomly selected elements used to approximate the true quantile. Defaults to 16_000_000.

Returns:

The approximated quantile value for the input tensor

Return type:

float

ptyrad.core.functional.get_gaussian1d(size, std, norm=False)[source]#

Generates a 1D Gaussian kernel.

Parameters:
  • size (int) – The number of points in the output window.

  • std (float) – The standard deviation (sigma) of the Gaussian distribution.

  • norm (bool, optional) – If True, normalizes the kernel so that its elements sum to 1. Defaults to False.

Returns:

The 1D Gaussian kernel.

Return type:

numpy.ndarray

ptyrad.core.functional.gaussian_blur_1d(tensor, kernel_size=5, sigma=0.5)[source]#

Applies a 1D Gaussian blur to a PyTorch tensor along its second dimension (dim 1).

Designed for 4D object tensors of shape [omode, z, H, W]. The blur is applied along the z-axis (dim 1), treating each spatial position (H, W) and object mode independently. Replicate padding is used along z to properly handle boundaries for both object amplitude and phase, avoiding the edge artifacts caused by standard zero-padding.

Uses F.conv2d with a (kernel_size, 1) kernel on a reshaped [omode, 1, z, H*W] view so that z stays in its natural position without any permutation. conv2d is used instead of conv3d because conv3d silently produces incorrect results on the MPS backend.

Parameters:
  • tensor (torch.Tensor) – Input tensor of shape [omode, z, H, W].

  • kernel_size (int, optional) – Length of the 1D Gaussian kernel. Defaults to 5.

  • sigma (float, optional) – Standard deviation of the Gaussian kernel in pixels. Defaults to 0.5.

Returns:

Blurred tensor with the same shape, dtype, and device as input.

Return type:

torch.Tensor

ptyrad.core.functional.make_sigmoid_mask(Npix, relative_radius=0.6666666666666666, relative_width=0.2, center=None)[source]#

Create a 2D circular mask with a sigmoid transition.

Parameters:
  • Npix (int) – Size of the square mask (Npix x Npix).

  • relative_radius (float) – Relative radius of the circular mask where the sigmoid equals 0.5, as a fraction of the image size.

  • relative_width (float) – Relative width of the sigmoid transition, as a fraction of the image size.

  • center (Optional[Tuple[float, float]]) – (y, x) coordinates of the center of the circle. Defaults to the center of the image.

Returns:

A 2D circular mask with a sigmoid transition.

Return type:

torch.Tensor

Notes

  • The default relative_radius=2/3 is inspired by its use in abTEM to reduce edge artifacts in diffraction patterns. It sets an antialias cutoff frequency at 2/3 of the simulated kMax. https://abtem.readthedocs.io/en/latest/user_guide/appendix/antialiasing.html

  • The relative_width controls the steepness of the sigmoid transition. Smaller values result in sharper transitions, while larger values produce smoother transitions.

ptyrad.core.functional.make_super_gaussian_mask(Npix, relative_radius=0.95, order=6, device='cuda')[source]#

Creates a 2D Super-Gaussian flat-top mask. order=1 is a standard Gaussian. order=4 to 6 gives a nice flat top with smooth edges.

ptyrad.core.functional.find_probe_focus_dz(probe, dx, lambd, z_range_ang=(-500, 500), z_steps=101)[source]#

Performs a line search along z to find the focal plane using the L4 norm (sharpness). Returns the optimal dz in Angstroms, positive value corresponds to forward propagation.

ptyrad.core.functional.dct_2d(x)[source]#

Computes a 2D DCT-II (orthonormalized except for constant factors) using FFT.

Supports arbitrary batch dimensions. The DCT is applied over the last two dimensions (H, W).

Parameters:

x (torch.Tensor) – Real-valued input tensor of shape (…, H, W).

Returns:

DCT coefficients of shape (…, H, W).

Return type:

torch.Tensor

ptyrad.core.functional.idct_2d(x)[source]#

Computes a 2D inverse DCT-II (IDCT) using FFT.

The inverse restores a real-valued signal and supports arbitrary batch dimensions.

Parameters:

x (torch.Tensor) – DCT coefficients of shape (…, H, W).

Returns:

Reconstructed signal of shape (…, H, W).

Return type:

torch.Tensor

ptyrad.core.functional.make_freq_grid_2d(shape, indexing='ij', dtype=None, device=None)[source]#

Return normalized (Fy, Fx) frequency grids, corner-centered (DC at [0,0]).

Values are dimensionless fftfreq in [-0.5, 0.5). Multiply by 2π/dx to get rad/Å.

Parameters:
  • shape – (Ny, Nx) tuple of ints

  • indexing – ‘ij’ or ‘xy’, default is ‘ij’

  • dtype – optional torch dtype for the output

  • device – torch device

Returns:

each of shape (Ny, Nx), dimensionless, range [-0.5, 0.5)

Return type:

(Fy, Fx)

ptyrad.core.functional.make_real_grid_2d(shape, indexing='ij', dtype=None, device=None)[source]#

Return normalized (Ry, Rx) real space grids, [0,N).

Values are px indices in [0, N).

Parameters:
  • shape – (Ny, Nx) tuple of ints

  • indexing – ‘ij’ or ‘xy’, default is ‘ij’

  • dtype – optional torch dtype for the output

  • device – torch device

Returns:

each of shape (Ny, Nx), unit in px, range [0, N)

Return type:

(Ry, Rx)

ptyrad.core.functional.near_field_evolution_torch(Npix_shape, dx, dz, lambd, dtype=torch.complex64, device='cuda')[source]#

Fresnel propagator

ptyrad.core.functional.torch_phasor(phase)[source]#

Creates a complex tensor with unit magnitude using the phase.

Parameters:

phase (torch.Tensor) – phase angle for the exp(i*theta)

Note

This util function is created so torch.compile can properly handle complex tensors, because torch.exp(1j*phase) involves the 1j which is actually a Python built-in that can’t be traced.

ptyrad.core.functional.imshift_batch(img, shifts, grid)[source]#

Generates a batch of shifted images from a single input image (…, Ny,Nx) with arbitray leading dimensions.

This function shifts a complex/real-valued input image by applying phase shifts in the Fourier domain, achieving subpixel shifts in both x and y directions.

Parameters:
  • img (torch.Tensor) – The input image to be shifted. img could be either a mixed-state complex probe (pmode, Ny, Nx) complex64 tensor, or a mixed-state pseudo-complex object stack (2,omode,Nz,Ny,Nx) float32 tensor.

  • shifts (torch.Tensor) – The shifts to be applied to the image. It should be a (Nb,2) tensor and each slice as (shift_y, shift_x).

  • grid (torch.Tensor) – The k-space grid used for computing the shifts in the Fourier domain. It should be a tensor with shape=(2, Ny, Nx), where Ny and Nx are the height and width of the images, respectively. Note that the grid is normalized so the value spans from [-0.5,0.5)

Returns:

The batch of shifted images. It has an extra dimension than the input image, i.e., shape=(Nb, …, Ny, Nx),

where Nb is the number of samples in the input batch.

Return type:

shifted_img (torch.Tensor)

Note

  • The shifts are in unit of pixel. For example, a shift of (0.5, 0.5) will shift the image by half a pixel in both y and x directions, positive is down/right-ward.

  • The function utilizes the fast Fourier transform (FFT) to perform the shifting operation efficiently.

  • Make sure to convert the input image and shifts tensor to the desired device before passing them to this function.

  • The fft2 and fftshifts are all applied on the last 2 dimensions, therefore it’s only shifting along y and x directions

  • tensor[None, …] would add an extra dimension at 0, so *[None]*ndim means unwrapping a list of ndim None as [None, None, …]

  • The img is automatically broadcast to (Nb, *img.shape), so if a batch of images are passed in, each image would be shifted independently

ptyrad.core.functional.get_center_of_mass(image, corner_centered=False)[source]#

Finds and returns the center of mass of an real-valued 2/3D tensor