Source code for ptyrad.utils.image_proc

"""
Image processing tools for fitting, cropping, normalization, etc.

"""

from typing import Optional, Tuple

import numpy as np
import logging

logger = logging.getLogger(__name__)


# Some quick estimation analysis tools
[docs] def get_blob_size(dx, blob, output='d90', plot_profile=False): import matplotlib.pyplot as plt """ Get the probe / blob size Args: dx (float): px size in Ang blob (array): the probe/blob image, note that we assume the input is already directly measurable and no squaring is needed, centered, and background free plot_profile (bool): Flag for plotting the profile or not Returns: D50*dx: D50 in Ang D90*dx: D90 in Ang radius_rms*dx: RMS radius in Ang radial_profile: radially averaged profile radial_sum: radial profile without normalizing by the ring area fig: Line profile figure """ def get_radial_profile(data, center): # The radial intensity is calculated up to the corners # So len(radialprofile) will be len(data)/sqrt(2) # The bin width is set to be the same with original data spacing (dr = dx) y, x = np.indices((data.shape)) r = np.sqrt((x - center[0])**2 + (y - center[1])**2) r = r.astype(int) tbin = np.bincount(r.ravel(), data.ravel()) nr = np.bincount(r.ravel()) radial_profile = tbin / nr radial_sum = tbin return radial_profile, radial_sum radial_profile, radial_sum = get_radial_profile(blob, (len(blob)//2, len(blob)//2)) #print("sum(radial_sum) = %.5f " %(np.sum(radial_sum))) # Calculate the rms radius, in px x = np.arange(len(radial_profile)) radius_rms = np.sqrt(np.sum(x**2*radial_profile*x)/np.sum(radial_profile*x)) # Calculate FWHM HWHM = np.max(np.where((radial_profile / radial_profile.max()) >=0.5)) # Calculate D50, D90 cum_sum = np.cumsum(radial_sum) # R50, 90 without normalization R50 = np.min(np.where(cum_sum>=0.50*np.sum(radial_sum))[0]) R90 = np.min(np.where(cum_sum>=0.90*np.sum(radial_sum))[0]) R99 = np.min(np.where(cum_sum>=0.99*np.sum(radial_sum))[0]) R995 = np.min(np.where(cum_sum>=0.995*np.sum(radial_sum))[0]) R999 = np.min(np.where(cum_sum>=0.999*np.sum(radial_sum))[0]) D50 = (2*R50+1) D90 = (2*R90+1) D99 = (2*R99+1) D995 = (2*R995+1) D999 = (2*R999+1) FWHM = (2*HWHM+1) if plot_profile: num_ticks = 11 x = dx*np.arange(len(radial_profile)) fig = plt.figure() ax = fig.add_subplot(111) plt.title("Radially averaged profile") plt.margins(x=0, y=0) ax.plot(x, radial_profile/np.max(radial_profile), label='Radially averaged profile') #plt.plot(x, cum_sum, 'k--', label='Integrated current') plt.vlines(x=R50*dx, ymin=0, ymax=1, color="tab:orange", linestyle=":", label='R50') #Draw vertical lines at the data coordinate, in this case would be Ang. plt.vlines(x=R90*dx, ymin=0, ymax=1, color="tab:red", linestyle=":", label='R90') plt.vlines(x=HWHM*dx, ymin=0, ymax=1, color="tab:blue", linestyle=":", label='FWHM') plt.vlines(x=radius_rms*dx, ymin=0, ymax=1, color="tab:green", linestyle=":", label='Radius_RMS') plt.xticks(np.arange(num_ticks)*np.round(len(radial_profile)*dx/num_ticks, decimals = 1-int(np.floor(np.log10(len(radial_profile)*dx))))) ax.set_xlabel(r"Distance from blob center ($\AA$)") ax.set_ylabel("Normalized intensity") plt.legend() plt.show() if output == 'd50': out = D50*dx elif output =='d90': out = D90*dx elif output =='d99': out = D99*dx elif output =='d995': out = D995*dx elif output =='d999': out = D999*dx elif output =='radius_rms': out = radius_rms*dx elif output =='FWHM': out = FWHM*dx elif output =='radial_profile': out = radial_profile elif output =='radial_sum': out = radial_sum elif output =='fig': out = fig else: raise ValueError(f"output ={output} not implemented!") if output not in ['radial_profile', 'radial_sum', 'fig']: logger.info(f'{output} = {out/dx:.3f} px or {out:.3f} Ang') return out
[docs] def guess_radius_of_bright_field_disk(image: np.ndarray, thresh: float=0.5): """ Utility function that returns an estimate of the radius of rbf from CBED """ # meas: 2D array of (ky,kx) # thresh: 0.5 for FWHM, 0.1 for Full-width at 10th maximum max_val = np.max(image) binary_img = image > (max_val * thresh) area = np.sum(binary_img) rbf = np.sqrt(area / np.pi) # Assume the region is circular return rbf
# Use in initial estimation of CBED geometry (center, radius, and edge blur)
[docs] def fit_cbed_pattern(image: np.ndarray, initial_guess=None): """ Estimate the center, radius, and std of a CBED pattern by minimizing the difference between the observed image and a synthetic model. Args: image (np.ndarray): The input image to fit. initial_guess (dict, optional): Dictionary with initial guess parameters. Returns: dict: Dictionary containing the fitted parameters as dict['center', 'radius', 'std']. """ from scipy.optimize import minimize Npix = image.shape[0] image = image / image.max() # Make sure it's normalized to max at 1 like our mask assert image.shape[0] == image.shape[1], "Only square images supported for now." def loss(params): y0, x0, r, std = params # Note: y0, x0 order to match center=(y,x) in make_gaussian_mask model = make_gaussian_mask(Npix, radius=r, std=std, center=(y0, x0)) return np.mean((image - model) ** 2) # Mean Squared Error # Set initial guess if initial_guess is None: # Try to estimate initial parameters from the image # Find approximate center by calculating the center of mass y_indices, x_indices = np.indices(image.shape) total_mass = np.sum(image) if total_mass > 0: y0_guess = np.sum(y_indices * image) / total_mass x0_guess = np.sum(x_indices * image) / total_mass else: y0_guess, x0_guess = Npix / 2, Npix / 2 r_guess = guess_radius_of_bright_field_disk(image) std_guess = 0.5 # Start with a reasonable Gaussian blur else: # Use provided initial guess center = initial_guess.get("center", (Npix / 2, Npix / 2)) y0_guess, x0_guess = center r_guess = initial_guess.get("radius", Npix / 4) std_guess = initial_guess.get("std", 0.5) p0 = [y0_guess, x0_guess, r_guess, std_guess] logger.info(f"Initial guess: center=({y0_guess:.2f}, {x0_guess:.2f}), radius={r_guess:.2f}, Gaussian blur std={std_guess:.2f}") # Use tighter bounds for optimization bounds = [(0, Npix-1), (0, Npix-1), (1, Npix/2), (0, 5)] # Run optimization with more iterations and a higher tolerance options = {'maxiter': 1000, 'disp': False} result = minimize(loss, p0, bounds=bounds, method='L-BFGS-B', options=options) counts = 1 # Try multiple starting points if the first optimization doesn't succeed if not result.success or result.fun > 0.01: logger.info("First optimization attempt didn't converge well, trying different starting points") # Try a few different starting points best_result = result shift_range = np.linspace(-Npix/10, Npix/10, 10) for shift_y in shift_range: for shift_x in shift_range: counts += 1 new_p0 = [y0_guess + shift_y, x0_guess + shift_x, r_guess, std_guess] new_result = minimize(loss, new_p0, bounds=bounds, method='L-BFGS-B', options=options) if new_result.fun < best_result.fun: best_result = new_result logger.info(f"Found better solution with starting point at ({new_p0[0]:.2f}, {new_p0[1]:.2f})") logger.info(f"Total fitting trials with different initial guesses = {counts}") result = best_result y0, x0, r, std = result.x logger.info(f"Final fit: center=({y0:.2f}, {x0:.2f}), radius={r:.2f}, Gaussian blur std={std:.2f}") return { "center": (y0, x0), "radius": r, "std": std, "success": result.success, "fun": result.fun }
[docs] def make_gaussian_mask(Npix: int, radius: float, std: float, center: Optional[Tuple[float, float]] = None): """ Create a 2D Gaussian-blurred circular mask. Args: Npix (int): Size of the square mask (Npix x Npix). radius (float): Radius of the circular mask. std (float): Standard deviation of the Gaussian blur. center (tuple): (y, x) coordinates of the center of the circle. Returns: np.ndarray: A 2D Gaussian-blurred circular mask. """ from scipy.ndimage import gaussian_filter # Set default center if not provided if center is None: center = (Npix / 2, Npix / 2) # Create a grid of coordinates y = np.linspace(0, Npix - 1, Npix) x = np.linspace(0, Npix - 1, Npix) grid_y, grid_x = np.meshgrid(y, x, indexing='ij') # Compute the distance from the center dist_from_center = np.sqrt((grid_y - center[0])**2 + (grid_x - center[1])**2) # Create a binary circular mask circular_mask = (dist_from_center <= radius).astype(float) # Apply Gaussian blur to the circular mask gaussian_mask = gaussian_filter(circular_mask, sigma=std) return gaussian_mask
# This is used across the paper figure notebook but not really in the package
[docs] def center_crop(image, crop_height, crop_width, offset = (0,0)): """ Center crops a 2D or 3D array (e.g., an image). Args: image (numpy.ndarray): The input array to crop. Can be 2D (H, W) or 3D (H, W, C). crop_height (int): The desired height of the crop. crop_width (int): The desired width of the crop. Returns: numpy.ndarray: The cropped image. """ if len(image.shape) not in [2, 3]: raise ValueError("Input image must be a 2D or 3D array.") height, width = image.shape[-2:] if crop_height > height or crop_width > width: raise ValueError("Crop size must be smaller than the input image size.") start_y = (height - crop_height) // 2 + offset[0] start_x = (width - crop_width) // 2 + offset[0] return image[..., start_y:start_y + crop_height, start_x:start_x + crop_width]
# This is used across the paper figure notebook but not really in the package
[docs] def mfft2(im): # Periodic Artifact Reduction in Fourier Transforms of Full Field Atomic Resolution Images # https://doi.org/10.1017/S1431927614014639 rows, cols = im.shape # Compute boundary conditions s = np.zeros_like(im) s[0, :] = im[0, :] - im[rows-1, :] s[rows-1, :] = -s[0, :] s[:, 0] += im[:, 0] - im[:, cols-1] s[:, cols-1] -= im[:, 0] - im[:, cols-1] # Create grid for computing Poisson solution cx, cy = np.meshgrid(2 * np.pi * np.arange(cols) / cols, 2 * np.pi * np.arange(rows) / rows) # Generate smooth component from Poisson Eq with boundary condition D = 2 * (2 - np.cos(cx) - np.cos(cy)) D[0, 0] = np.inf # Enforce zero mean & handle division by zero S = np.fft.fft2(s) / D P = np.fft.fft2(im) - S # FFT of periodic component return P, S
# These are called during save.save_results()
[docs] def normalize_from_zero_to_one(arr): norm_arr = (arr - arr.min())/(arr.max()-arr.min()) return norm_arr
[docs] def normalize_by_bit_depth(arr, bit_depth): if bit_depth == '8': norm_arr_in_bit_depth = np.uint8(255*normalize_from_zero_to_one(arr)) elif bit_depth == '16': norm_arr_in_bit_depth = np.uint16(65535*normalize_from_zero_to_one(arr)) elif bit_depth == '32': norm_arr_in_bit_depth = np.float32(normalize_from_zero_to_one(arr)) elif bit_depth == 'raw': norm_arr_in_bit_depth = np.float32(arr) else: logger.warning(f'Unsuported bit_depth :{bit_depth} was passed into `result_modes`, `raw` is used instead') norm_arr_in_bit_depth = np.float32(arr) return norm_arr_in_bit_depth
# These are used for meas_pad
[docs] def create_one_hot_mask(image, percentile): threshold = np.percentile(image, percentile) mask = image <= threshold logger.info(f"Using percentile = {percentile:.2f}% to create an one-hot mask for measurements amplitude background fitting") radius_px = np.sqrt(np.abs(1-mask).sum() / np.pi) radius_r = radius_px / (len(mask)//2) logger.info(f"The mask has roughly {radius_px:.2f} px in radius, or {radius_r:.2f} of the distance from center to edge of the image") return mask.astype(int)
[docs] def fit_background(image, mask, fit_type='exp'): from scipy.optimize import curve_fit y, x = np.indices(image.shape) center = np.array(image.shape) // 2 r = np.sqrt((x - center[1])**2 + (y - center[0])**2) + 1e-10 masked_r = r[mask == 1] masked_image = image[mask == 1] if fit_type == 'exp': initial_guess = [np.max(masked_image), 0.1] # [a_guess, b_guess] bounds = ([0, 0], [np.inf, np.inf]) # a > 0, b > 0 popt, _ = curve_fit(exponential_decay, masked_r, masked_image, p0=initial_guess, bounds=bounds,maxfev=10000) logger.info(f"Fitted a = {popt[0]:.4f}, b = {popt[1]:.4f} for exponential decay: y = a*exp(-b*r)") elif fit_type == 'power': initial_guess = [np.max(masked_image), 1] # [a_guess, b_guess] bounds = ([0, 0], [np.inf, np.inf]) # a > 0, b > 0 popt, _ = curve_fit(power_law, masked_r, masked_image, p0=initial_guess, bounds=bounds, maxfev=10000) logger.info(f"Fitted a = {popt[0]:.4f}, b = {popt[1]:.4f} for power law decay: y = a*r^-b") else: raise ValueError("fit_type must be 'exp' or 'power'") return popt
[docs] def exponential_decay(r, a, b): return a * np.exp(-b * r)
[docs] def power_law(r, a, b): return a * r**-b