Source code for ptyrad.utils.image_proc

"""
Image processing tools for CoM, fitting, normalization, shifting, etc.

"""

import numpy as np
import torch
from scipy.optimize import minimize
from torch.fft import fft2, fftfreq, ifft2

from .common import vprint
from .math_ops import fftshift2, ifftshift2, make_gaussian_mask, torch_phasor


# Some quick estimation analysis tools
[docs] def get_center_of_mass(image, corner_centered=False): """ Finds and returns the center of mass of an real-valued 2/3D tensor """ # The expected input shape can be either (Ny, Nx) or (N, Ny, Nx) # The output center_y and center_x will be either (N,) or a scaler tensor # Note that for even-number sized arr (like [128,128]), even it's uniformly ones, the "center" would be between pixels like [63.5,63.5] # Note that the `corner_centered` flag idea is adapted from py4DSTEM, which is quite handy when we have corner-centered probe or CBED # https://github.com/py4dstem/py4DSTEM/blob/dev/py4DSTEM/process/utils/utils.py ndim = image.ndim assert ndim in [2, 3], f"image.ndim must be either 2 or 3, we've got {ndim}" # Create grid of coordinates device = image.device (ny, nx) = image.shape[-2:] if corner_centered: grid_y, grid_x = torch.meshgrid(fftfreq(ny, 1 / ny, device=device), fftfreq(nx, 1 / nx, device=device), indexing='ij') else: grid_y, grid_x = torch.meshgrid(torch.arange(ny, device=device), torch.arange(nx, device=device), indexing='ij') # Compute total intensity total_intensity = torch.sum(image, dim = (-2,-1)).mean() # Compute weighted sum of x and y coordinates center_y = torch.sum(grid_y * image, dim = (-2,-1)) / total_intensity center_x = torch.sum(grid_x * image, dim = (-2,-1)) / total_intensity return center_y, center_x
[docs] def get_blob_size(dx, blob, output='d90', plot_profile=False, verbose=True): import matplotlib.pyplot as plt """ Get the probe / blob size Args: dx (float): px size in Ang blob (array): the probe/blob image, note that we assume the input is already directly measurable and no squaring is needed, centered, and background free plot_profile (bool): Flag for plotting the profile or not Returns: D50*dx: D50 in Ang D90*dx: D90 in Ang radius_rms*dx: RMS radius in Ang radial_profile: radially averaged profile radial_sum: radial profile without normalizing by the ring area fig: Line profile figure """ def get_radial_profile(data, center): # The radial intensity is calculated up to the corners # So len(radialprofile) will be len(data)/sqrt(2) # The bin width is set to be the same with original data spacing (dr = dx) y, x = np.indices((data.shape)) r = np.sqrt((x - center[0])**2 + (y - center[1])**2) r = r.astype(int) tbin = np.bincount(r.ravel(), data.ravel()) nr = np.bincount(r.ravel()) radial_profile = tbin / nr radial_sum = tbin return radial_profile, radial_sum radial_profile, radial_sum = get_radial_profile(blob, (len(blob)//2, len(blob)//2)) #print("sum(radial_sum) = %.5f " %(np.sum(radial_sum))) # Calculate the rms radius, in px x = np.arange(len(radial_profile)) radius_rms = np.sqrt(np.sum(x**2*radial_profile*x)/np.sum(radial_profile*x)) # Calculate FWHM HWHM = np.max(np.where((radial_profile / radial_profile.max()) >=0.5)) # Calculate D50, D90 cum_sum = np.cumsum(radial_sum) # R50, 90 without normalization R50 = np.min(np.where(cum_sum>=0.50*np.sum(radial_sum))[0]) R90 = np.min(np.where(cum_sum>=0.90*np.sum(radial_sum))[0]) R99 = np.min(np.where(cum_sum>=0.99*np.sum(radial_sum))[0]) R995 = np.min(np.where(cum_sum>=0.995*np.sum(radial_sum))[0]) R999 = np.min(np.where(cum_sum>=0.999*np.sum(radial_sum))[0]) D50 = (2*R50+1) D90 = (2*R90+1) D99 = (2*R99+1) D995 = (2*R995+1) D999 = (2*R999+1) FWHM = (2*HWHM+1) if plot_profile: num_ticks = 11 x = dx*np.arange(len(radial_profile)) fig = plt.figure() ax = fig.add_subplot(111) plt.title("Radially averaged profile") plt.margins(x=0, y=0) ax.plot(x, radial_profile/np.max(radial_profile), label='Radially averaged profile') #plt.plot(x, cum_sum, 'k--', label='Integrated current') plt.vlines(x=R50*dx, ymin=0, ymax=1, color="tab:orange", linestyle=":", label='R50') #Draw vertical lines at the data coordinate, in this case would be Ang. plt.vlines(x=R90*dx, ymin=0, ymax=1, color="tab:red", linestyle=":", label='R90') plt.vlines(x=HWHM*dx, ymin=0, ymax=1, color="tab:blue", linestyle=":", label='FWHM') plt.vlines(x=radius_rms*dx, ymin=0, ymax=1, color="tab:green", linestyle=":", label='Radius_RMS') plt.xticks(np.arange(num_ticks)*np.round(len(radial_profile)*dx/num_ticks, decimals = 1-int(np.floor(np.log10(len(radial_profile)*dx))))) ax.set_xlabel(r"Distance from blob center ($\AA$)") ax.set_ylabel("Normalized intensity") plt.legend() plt.show() if output == 'd50': out = D50*dx elif output =='d90': out = D90*dx elif output =='d99': out = D99*dx elif output =='d995': out = D995*dx elif output =='d999': out = D999*dx elif output =='radius_rms': out = radius_rms*dx elif output =='FWHM': out = FWHM*dx elif output =='radial_profile': out = radial_profile elif output =='radial_sum': out = radial_sum elif output =='fig': out = fig else: raise ValueError(f"output ={output} not implemented!") if output not in ['radial_profile', 'radial_sum', 'fig'] and verbose: vprint(f'{output} = {out/dx:.3f} px or {out:.3f} Ang') return out
[docs] def guess_radius_of_bright_field_disk(image: np.ndarray, thresh: float=0.5): """ Utility function that returns an estimate of the radius of rbf from CBED """ # meas: 2D array of (ky,kx) # thresh: 0.5 for FWHM, 0.1 for Full-width at 10th maximum max_val = np.max(image) binary_img = image > (max_val * thresh) area = np.sum(binary_img) rbf = np.sqrt(area / np.pi) # Assume the region is circular return rbf
# Use in initial estimation of CBED geometry (center, radius, and edge blur)
[docs] def fit_cbed_pattern(image: np.ndarray, initial_guess=None, verbose=False): """ Estimate the center, radius, and std of a CBED pattern by minimizing the difference between the observed image and a synthetic model. Args: image (np.ndarray): The input image to fit. initial_guess (dict, optional): Dictionary with initial guess parameters. verbose (bool): Whether to print detailed information during fitting. Returns: dict: Dictionary containing the fitted parameters as dict['center', 'radius', 'std']. """ Npix = image.shape[0] image = image / image.max() # Make sure it's normalized to max at 1 like our mask assert image.shape[0] == image.shape[1], "Only square images supported for now." def loss(params): y0, x0, r, std = params # Note: y0, x0 order to match center=(y,x) in make_gaussian_mask model = make_gaussian_mask(Npix, radius=r, std=std, center=(y0, x0)) return np.mean((image - model) ** 2) # Mean Squared Error # Set initial guess if initial_guess is None: # Try to estimate initial parameters from the image # Find approximate center by calculating the center of mass y_indices, x_indices = np.indices(image.shape) total_mass = np.sum(image) if total_mass > 0: y0_guess = np.sum(y_indices * image) / total_mass x0_guess = np.sum(x_indices * image) / total_mass else: y0_guess, x0_guess = Npix / 2, Npix / 2 r_guess = guess_radius_of_bright_field_disk(image) std_guess = 0.5 # Start with a reasonable Gaussian blur else: # Use provided initial guess center = initial_guess.get("center", (Npix / 2, Npix / 2)) y0_guess, x0_guess = center r_guess = initial_guess.get("radius", Npix / 4) std_guess = initial_guess.get("std", 0.5) p0 = [y0_guess, x0_guess, r_guess, std_guess] vprint(f"Initial guess: center=({y0_guess:.2f}, {x0_guess:.2f}), radius={r_guess:.2f}, Gaussian blur std={std_guess:.2f}", verbose=verbose) # Use tighter bounds for optimization bounds = [(0, Npix-1), (0, Npix-1), (1, Npix/2), (0, 5)] # Run optimization with more iterations and a higher tolerance options = {'maxiter': 1000, 'disp': verbose} result = minimize(loss, p0, bounds=bounds, method='L-BFGS-B', options=options) counts = 1 # Try multiple starting points if the first optimization doesn't succeed if not result.success or result.fun > 0.01: vprint("First optimization attempt didn't converge well, trying different starting points", verbose=verbose) # Try a few different starting points best_result = result shift_range = np.linspace(-Npix/10, Npix/10, 10) for shift_y in shift_range: for shift_x in shift_range: counts += 1 new_p0 = [y0_guess + shift_y, x0_guess + shift_x, r_guess, std_guess] new_result = minimize(loss, new_p0, bounds=bounds, method='L-BFGS-B', options=options) if new_result.fun < best_result.fun: best_result = new_result if verbose: vprint(f"Found better solution with starting point at ({new_p0[0]:.2f}, {new_p0[1]:.2f})") vprint(f"Total fitting trials with different initial guesses = {counts}", verbose=verbose) result = best_result y0, x0, r, std = result.x vprint(f"Final fit: center=({y0:.2f}, {x0:.2f}), radius={r:.2f}, Gaussian blur std={std:.2f}", verbose=verbose) return { "center": (y0, x0), "radius": r, "std": std, "success": result.success, "fun": result.fun }
[docs] def get_local_obj_tilts(pos, objp, dx, slice_thickness, slice_indices, blob_params, window_size=9): """ Estimate the local obj tilts from relative atomic column shifts """ # objp (Nz, Ny, Nx) # pos: probe position at integer px sites, (N,2) import matplotlib.pyplot as plt from scipy.interpolate import griddata from scipy.ndimage import center_of_mass from scipy.optimize import curve_fit from skimage.feature import blob_log # Choose the 2 slices from objp and detect blobs from the top slice slice_t, slice_b = slice_indices height = (slice_b - slice_t)*slice_thickness print(f"The height difference between slices {(slice_t, slice_b)} is {height:.2f} Ang") target_stack = objp[[slice_t,slice_b]] blobs = blob_log(target_stack[0], **blob_params) print(f"Found {len(blobs)} blobs with mean radius of {1.414*blobs.mean(0)[-1]:.2f} px or {dx*1.414*blobs.mean(0)[-1]:.2f} Ang") # Plot the detected blobs fig, ax = plt.subplots(figsize=(18,16)) ax.imshow(target_stack[0]) for blob in blobs: y, x, r = blob c = plt.Circle((x, y), r, linewidth=2, fill=False) ax.add_patch(c) plt.show() # Get the CoM of each atomic column for both top and bottom slices row_start = np.uint32(blobs[:,0]-window_size//2) row_end = np.uint32(blobs[:,0]+window_size//2+1) col_start = np.uint32(blobs[:,1]-window_size//2) col_end = np.uint32(blobs[:,1]+window_size//2+1) coord_t = np.zeros((len(blobs),2)) coord_b = np.zeros((len(blobs),2)) for i in range(len(blobs)): crop_img_t = target_stack[0][row_start[i]:row_end[i], col_start[i]:col_end[i]] crop_img_b = target_stack[1][row_start[i]:row_end[i], col_start[i]:col_end[i]] coord_t[i] = center_of_mass(crop_img_t) + blobs[i,:-1] - window_size//2 coord_b[i] = center_of_mass(crop_img_b) + blobs[i,:-1] - window_size//2 shift_vecs = coord_b - coord_t # This is the needed tilt to correct the obj tilt so it's pointing from top to bottom # Plot the detected CoM fig, axs = plt.subplots(1,2, figsize=(8,4)) im0 = axs[0].imshow(crop_img_t) im1 = axs[1].imshow(crop_img_b) axs[0].set_title(f"crop_img_t \n {coord_t[-1].round(2)}") axs[1].set_title(f"crop_img_b \n {coord_b[-1].round(2)}") fig.colorbar(im0, shrink=0.7) fig.colorbar(im1, shrink=0.7) plt.show() # Plot the tilt vectors X = coord_t[:,1] Y = coord_t[:,0] U = shift_vecs[:,1] V = shift_vecs[:,0] M = np.arctan(np.hypot(U,V)*dx/height)*1e3 fig, ax = plt.subplots(figsize=(16,12)) plt.title("Needed local object tilts", fontsize=16) ax.imshow(target_stack[0], cmap='gray') q = ax.quiver(X, Y, U, V, M, pivot='mid', angles='xy', scale_units='xy') cbar = fig.colorbar(q, shrink=0.75) cbar.ax.set_ylabel('mrad') plt.show() # Interpolate tilt_y, tilt_x map tilt_y = np.arctan(V*dx/height)*1e3 tilt_x = np.arctan(U*dx/height)*1e3 xnew, ynew= np.mgrid[0:target_stack.shape[-2]:1, 0:target_stack.shape[-1]:1] tilt_y_interp = griddata(np.stack([Y,X], -1), tilt_y ,(xnew, ynew), method='cubic') tilt_x_interp = griddata(np.stack([Y,X], -1), tilt_x ,(xnew, ynew), method='cubic') fig, axs = plt.subplots(1,2, figsize=(12,6)) im0=axs[0].imshow(tilt_y_interp) im1=axs[1].imshow(tilt_x_interp) axs[0].set_title("tilt_y_interp") axs[1].set_title("tilt_x_interp") cbar0 = fig.colorbar(im0, shrink=0.7) cbar0.ax.set_ylabel('mrad') cbar1 = fig.colorbar(im1, shrink=0.7) cbar1.ax.set_ylabel('mrad') plt.show() # Use curve_fit to extrapolate to the entire FOV def surface_fn(t, a1, b1, c1, d): y,x = t return a1*x + b1*y + c1*x*y + d xdata = np.vstack((Y,X)) ydata_tilt_y = tilt_y ydata_tilt_x = tilt_x popt_tilt_y, _ = curve_fit(surface_fn, xdata, ydata_tilt_y) popt_tilt_x, _ = curve_fit(surface_fn, xdata, ydata_tilt_x) # Implanting griddata interpolated values into the fitted background surface_tilt_y = surface_fn(np.stack((ynew,xnew)), *popt_tilt_y) surface_tilt_x = surface_fn(np.stack((ynew,xnew)), *popt_tilt_x) mask_tilt_y = ~np.isnan(tilt_y_interp) surface_tilt_y[mask_tilt_y] = tilt_y_interp[mask_tilt_y] mask_tilt_x = ~np.isnan(tilt_x_interp) surface_tilt_x[mask_tilt_x] = tilt_x_interp[mask_tilt_x] fig, axs = plt.subplots(1,2, figsize=(12,6)) im0=axs[0].imshow(surface_tilt_y) im1=axs[1].imshow(surface_tilt_x) axs[0].set_title("surface_tilt_y") axs[1].set_title("surface_tilt_x") cbar0 = fig.colorbar(im0, shrink=0.7) cbar0.ax.set_ylabel('mrad') cbar1 = fig.colorbar(im1, shrink=0.7) cbar1.ax.set_ylabel('mrad') plt.show() # Sample the surface with our probe position tilt_ys = surface_tilt_y[pos[:,0], pos[:,1]] tilt_xs = surface_tilt_x[pos[:,0], pos[:,1]] obj_tilts = np.stack([tilt_ys, tilt_xs], axis=-1) fig, axs = plt.subplots(1,2, figsize=(12,4)) im0=axs[0].scatter(x=pos[:,1], y=pos[:,0], c=obj_tilts[:,0]) im1=axs[1].scatter(x=pos[:,1], y=pos[:,0], c=obj_tilts[:,1]) axs[0].invert_yaxis() axs[1].invert_yaxis() axs[0].set_title("tilt_ys") axs[1].set_title("tilt_xs") cbar0 = fig.colorbar(im0, shrink=0.7) cbar0.ax.set_ylabel('mrad') cbar1 = fig.colorbar(im1, shrink=0.7) cbar1.ax.set_ylabel('mrad') plt.show() return obj_tilts
# This is used across the paper figure notebook but not really in the package
[docs] def center_crop(image, crop_height, crop_width, offset = (0,0)): """ Center crops a 2D or 3D array (e.g., an image). Args: image (numpy.ndarray): The input array to crop. Can be 2D (H, W) or 3D (H, W, C). crop_height (int): The desired height of the crop. crop_width (int): The desired width of the crop. Returns: numpy.ndarray: The cropped image. """ if len(image.shape) not in [2, 3]: raise ValueError("Input image must be a 2D or 3D array.") height, width = image.shape[-2:] if crop_height > height or crop_width > width: raise ValueError("Crop size must be smaller than the input image size.") start_y = (height - crop_height) // 2 + offset[0] start_x = (width - crop_width) // 2 + offset[0] return image[..., start_y:start_y + crop_height, start_x:start_x + crop_width]
# These are called during save.save_results()
[docs] def normalize_from_zero_to_one(arr): norm_arr = (arr - arr.min())/(arr.max()-arr.min()) return norm_arr
[docs] def normalize_by_bit_depth(arr, bit_depth): if bit_depth == '8': norm_arr_in_bit_depth = np.uint8(255*normalize_from_zero_to_one(arr)) elif bit_depth == '16': norm_arr_in_bit_depth = np.uint16(65535*normalize_from_zero_to_one(arr)) elif bit_depth == '32': norm_arr_in_bit_depth = np.float32(normalize_from_zero_to_one(arr)) elif bit_depth == 'raw': norm_arr_in_bit_depth = np.float32(arr) else: print(f'Unsuported bit_depth :{bit_depth} was passed into `result_modes`, `raw` is used instead') norm_arr_in_bit_depth = np.float32(arr) return norm_arr_in_bit_depth
# These are called inside constraints.py / CombinedConstraint > apply_obj_zblur
[docs] def get_gaussian1d(size, std, norm=False): from scipy.signal.windows import gaussian as gaussian1d k = gaussian1d(size, std) if norm: k /= k.sum() return k
[docs] def gaussian_blur_1d(tensor, kernel_size=5, sigma=0.5): # Note that the F.con1d does not have `padding_mode`, so it's default to be 0 padding, which is not ideal for obja # tensor_blur = F.conv1d(input=tensor.reshape(-1, 1, tensor.size(-1)), weight=k1d, padding='same').view(*tensor.shape) dtype = tensor.dtype device = tensor.device k = torch.from_numpy(get_gaussian1d(kernel_size, sigma, norm=True)).type(dtype).to(device) k1d = k.view(1, 1, -1) gaussian1d = torch.nn.Conv1d(1,1,kernel_size,padding='same', bias=False, padding_mode='replicate') gaussian1d.weight = torch.nn.Parameter(k1d) tensor_blur = gaussian1d(tensor.reshape(-1, 1, tensor.size(-1))).view(*tensor.shape) return tensor_blur
# These are used for meas_pad
[docs] def create_one_hot_mask(image, percentile): threshold = np.percentile(image, percentile) mask = image <= threshold vprint(f"Using percentile = {percentile:.2f}% to create an one-hot mask for measurements amplitude background fitting") radius_px = np.sqrt(np.abs(1-mask).sum() / np.pi) radius_r = radius_px / (len(mask)//2) vprint(f"The mask has roughly {radius_px:.2f} px in radius, or {radius_r:.2f} of the distance from center to edge of the image") return mask.astype(int)
[docs] def fit_background(image, mask, fit_type='exp'): from scipy.optimize import curve_fit from ptyrad.utils.math_ops import exponential_decay, power_law y, x = np.indices(image.shape) center = np.array(image.shape) // 2 r = np.sqrt((x - center[1])**2 + (y - center[0])**2) + 1e-10 masked_r = r[mask == 1] masked_image = image[mask == 1] if fit_type == 'exp': initial_guess = [np.max(masked_image), 0.1] # [a_guess, b_guess] bounds = ([0, 0], [np.inf, np.inf]) # a > 0, b > 0 popt, _ = curve_fit(exponential_decay, masked_r, masked_image, p0=initial_guess, bounds=bounds,maxfev=10000) vprint(f"Fitted a = {popt[0]:.4f}, b = {popt[1]:.4f} for exponential decay: y = a*exp(-b*r)") elif fit_type == 'power': initial_guess = [np.max(masked_image), 1] # [a_guess, b_guess] bounds = ([0, 0], [np.inf, np.inf]) # a > 0, b > 0 popt, _ = curve_fit(power_law, masked_r, masked_image, p0=initial_guess, bounds=bounds, maxfev=10000) vprint(f"Fitted a = {popt[0]:.4f}, b = {popt[1]:.4f} for power law decay: y = a*r^-b") else: raise ValueError("fit_type must be 'exp' or 'power'") return popt
# This is only called inside `models.py / PtychoAD`
[docs] def imshift_batch(img, shifts, grid): """ Generates a batch of shifted images from a single input image (..., Ny,Nx) with arbitray leading dimensions. This function shifts a complex/real-valued input image by applying phase shifts in the Fourier domain, achieving subpixel shifts in both x and y directions. Args: img (torch.Tensor): The input image to be shifted. img could be either a mixed-state complex probe (pmode, Ny, Nx) complex64 tensor, or a mixed-state pseudo-complex object stack (2,omode,Nz,Ny,Nx) float32 tensor. shifts (torch.Tensor): The shifts to be applied to the image. It should be a (Nb,2) tensor and each slice as (shift_y, shift_x). grid (torch.Tensor): The k-space grid used for computing the shifts in the Fourier domain. It should be a tensor with shape=(2, Ny, Nx), where Ny and Nx are the height and width of the images, respectively. Note that the grid is normalized so the value spans from [-0.5,0.5) Returns: shifted_img (torch.Tensor): The batch of shifted images. It has an extra dimension than the input image, i.e., shape=(Nb, ..., Ny, Nx), where Nb is the number of samples in the input batch. Note: - The shifts are in unit of pixel. For example, a shift of (0.5, 0.5) will shift the image by half a pixel in both y and x directions, positive is down/right-ward. - The function utilizes the fast Fourier transform (FFT) to perform the shifting operation efficiently. - Make sure to convert the input image and shifts tensor to the desired device before passing them to this function. - The fft2 and fftshifts are all applied on the last 2 dimensions, therefore it's only shifting along y and x directions - tensor[None, ...] would add an extra dimension at 0, so `*[None]*ndim` means unwrapping a list of ndim None as [None, None, ...] - The img is automatically broadcast to `(Nb, *img.shape)`, so if a batch of images are passed in, each image would be shifted independently """ assert img.shape[-2:] == grid.shape[-2:], f"Found incompatible dimensions. img.shape[-2:] = {img.shape[-2:]} while grid.shape[-2:] = {grid.shape[-2:]}" ndim = img.ndim # Get the total img ndim so that the shift is dimension-indepent shifts = shifts[(...,) + (None,) * ndim] # Expand shifts to (Nb,2,1,1,...) so shifts.ndim = ndim+2. It was written as `shifts = shifts[..., *[None]*ndim]` for Python 3.11 or above with better readability grid = grid[(slice(None),) + (None,) * (ndim - 1) + (...,)] # Expand grid to (2,1,1,...,Ny,Nx) so grid.ndim = ndim+2. It was written as `grid = grid[:,*[None]*(ndim-1), ...]` for Python 3.11 or above with better readability shift_y, shift_x = shifts[:, 0], shifts[:, 1] # shift_y, shift_x are (Nb,1,1,...) with ndim singletons, so the shift_y.ndim = ndim+1 ky, kx = grid[0], grid[1] # ky, kx are (1,1,...,Ny,Nx) with ndim-2 singletons, so the ky.ndim = ndim+1 phase = -2*torch.pi * (shift_x * kx + shift_y * ky) w = torch_phasor(phase) # w = (Nb, 1,1,...,Ny,Nx) so w.ndim = ndim+1. The zero frequency term of w is at the corner. shifted_img = ifft2(fft2(img) * w) # For real-valued input, take shifted_img.real return shifted_img