"""
Plotting functions for output figures including forward passes, loss curves, positions, etc.
"""
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import torch
from numpy.fft import fft2, fftshift, ifftshift
from ptyrad.utils import make_sigmoid_mask, safe_filename, vprint
[docs]
def plot_sigmoid_mask(Npix, relative_radius, relative_width, img=None, show_circles=False):
""" Plot a sigmoid mask overlay on img with a line profile """
# Note that relative_radius ranges from 0 - 1 for center -> edge. radius = 1 corresponds to a inscribed circle
# While relative_width also ranges from 0 - 1 for Npix * relative_width. width = 0.05 corresponds to a width of 5% of the image width would have sigmoid value change from 0 - 1
mask = make_sigmoid_mask(Npix, relative_radius, relative_width).detach().cpu().numpy()
img = np.ones((Npix,Npix)) if img is None else img/img.max()
masked_img = mask * img
fig, axs = plt.subplots(1,2, figsize=(13,6))
fig.suptitle(f"Sigmoid mask with radius = {relative_radius}, width = {relative_width}", fontsize=18)
im = axs[0].imshow(masked_img)
axs[0].axhline(y=Npix//2, xmin=0.5, c='r', linestyle='--')
axs[1].plot(mask[Npix//2, Npix//2:], c='r', label='mask')
if img is not None:
axs[1].plot(img[Npix//2, Npix//2:], label='image')
axs[1].plot(masked_img[Npix//2, Npix//2:], label='masked_img')
# Draw circles on the imshow
if show_circles:
circle1 = plt.Circle((Npix // 2, Npix // 2), (relative_radius-relative_width) * Npix/2, color='k', fill=False, linestyle='--')
circle2 = plt.Circle((Npix // 2, Npix // 2), (relative_radius+relative_width) * Npix/2, color='k', fill=False, linestyle='--')
axs[0].add_artist(circle1)
axs[0].add_artist(circle2)
axs[1].axvline(x=(relative_radius-relative_width) * Npix/2, color='k', linestyle='--')
axs[1].axvline(x=(relative_radius+relative_width) * Npix/2, color='k', linestyle='--')
fig.colorbar(im, shrink=0.7)
plt.legend()
plt.show()
[docs]
def plot_forward_pass(model, indices, dp_power, show_fig=True, pass_fig=False):
""" Plot the forward pass for the input torch model """
# The input is expected to be torch object and the attributes are all torch tensors and will be converted to numpy
# probes_int = (N_i, Ny, Nx), float32 np array
# obj_ROI = (N_i, omode, Nz, Ny, Nx) -> (N_i, Nz, Ny, Nx), float32 np array
# For probe, only plot the intensity of incoherently summed mixed-state probe
# For object, only plot the phase of the weighted sum object mode and sums over z-slices
# The dp_power here is for visualization purpose, the actual loss function has its own param field
with torch.no_grad():
probes = model.get_probes(indices)
probes_int = probes.abs().pow(2).sum(1)
model_DP = model(indices)
obj_patches = model.get_obj_patches(indices) # The cache would be cleared right after the mini-batch update so we have to re-calculate it here
omode_occu = model.omode_occu
measured_DP = model.get_measurements(indices)
probes_int = probes_int.detach().cpu().numpy()
obja_ROI = (obj_patches[...,0] * omode_occu[:,None,None,None]).sum(1).detach().cpu().numpy() # obj_ROI = (N_i, Nz,Ny,Nx)
objp_ROI = (obj_patches[...,1] * omode_occu[:,None,None,None]).sum(1).detach().cpu().numpy() # obj_ROI = (N_i, Nz,Ny,Nx)
model_DP = model_DP.detach().cpu().numpy()
measured_DP = measured_DP.detach().cpu().numpy()
plt.ioff() # Temporaily disable the interactive plotting mode
fig, axs = plt.subplots(len(indices), 5, figsize=(24, 5*len(indices)))
plt.suptitle("Forward pass", fontsize=24)
for i, idx in enumerate(indices):
# Looping over the N_i dimension
im00 = axs[i,0].imshow(probes_int[i])
axs[i,0].set_title(f"Probe intensity idx{idx}", fontsize=16)
fig.colorbar(im00, shrink=0.6)
im01 = axs[i,1].imshow(obja_ROI[i].prod(0))
axs[i,1].set_title(f"Object amp. (osum, zprod) idx{idx}", fontsize=16)
fig.colorbar(im01, shrink=0.6)
im02 = axs[i,2].imshow(objp_ROI[i].sum(0))
axs[i,2].set_title(f"Object phase (osum, zsum) idx{idx}", fontsize=16)
fig.colorbar(im02, shrink=0.6)
im03 = axs[i,3].imshow((model_DP[i]**dp_power))
axs[i,3].set_title(f"Model DP^{dp_power} idx{idx}", fontsize=16)
fig.colorbar(im03, shrink=0.6)
im04 = axs[i,4].imshow((measured_DP[i]**dp_power))
axs[i,4].set_title(f"Data DP^{dp_power} idx{idx}", fontsize=16)
fig.colorbar(im04, shrink=0.6)
plt.tight_layout()
if show_fig:
plt.show()
if pass_fig:
return fig
[docs]
def plot_obj_tilts_avg(avg_tilt_iters, last_n_iters=2, show_fig=True, pass_fig=False):
last_n_iters = int(last_n_iters)
# Unpack iteration numbers and tilt values
iters, tilts = zip(*avg_tilt_iters) # Separates into two tuples
tilts = np.vstack(tilts) # Converts list of (1,2) arrays to (N,2) array
iters = np.array(iters) # Convert iteration numbers to a NumPy array
plt.ioff() # Temporarily disable interactive mode
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(8, 10), sharex=True)
# Plot first component (tilt_y)
axes[0].plot(iters, tilts[:, 0], marker='o', color='C0')
axes[0].set_ylabel('Avg Obj tilt_y (mrad)', fontsize=16)
axes[0].set_title(f'Avg Obj tilt_y (mrad): {tilts[-1,0]:.3f} at iter {iters[-1]}', fontsize=16)
axes[0].grid(True)
# Plot second component (tilt_x)
axes[1].plot(iters, tilts[:, 1], marker='o', color='C1')
axes[1].set_xlabel('Iterations', fontsize=16)
axes[1].set_ylabel('Avg Obj tilt_x (mrad)', fontsize=16)
axes[1].set_title(f'Avg Obj tilt_x (mrad): {tilts[-1,1]:.3f} at iter {iters[-1]}', fontsize=16)
axes[1].grid(True)
for i, ax in enumerate(axes):
# Plot the last n iters as an inset
if len(iters) > 20 and last_n_iters is not None:
axins = ax.inset_axes([0.45, 0.3, 0.4, 0.5])
# Correctly match inset plots to main plots
axins.plot(iters[-last_n_iters:], tilts[-last_n_iters:, i], marker='o', color = f'{"C0" if i == 0 else "C1"}')
axins.set_xlabel('Iterations', fontsize=12)
axins.set_ylabel(f'Avg Obj tilt_{"y" if i == 0 else "x"} (mrad)', fontsize=12)
axins.yaxis.set_major_formatter(ticker.StrMethodFormatter('{x:.3f}'))
ax.indicate_inset_zoom(axins, edgecolor="gray")
axins.set_title(f'Last {last_n_iters} iterations', fontsize=12, pad=10)
axins.grid(True)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.tight_layout()
if show_fig:
plt.show()
if pass_fig:
return fig
[docs]
def plot_obj_tilts(pos, tilts, figsize=(16,16), show_fig=True, pass_fig=False):
""" Plot the obj tilts given the probe position and pos-dependent tilts """
plt.ioff() # Temporaily disable the interactive plotting mode
fig = plt.figure(figsize = figsize)
ax = plt.gca() # There's only 1 ax for plt.figure(), and plt.title is an Axes-level attribute so I need to pass the Axes out because I like plt.title layout better
plt.title("Object tilts", fontsize=16)
tilts = np.broadcast_to(tilts, shape=(len(pos),2))
if np.allclose(tilts[:,0], 0, atol=1e-3):
# All tilts are effectively zero; skip quiver plot and annotate
ax.text(
0.5, 0.5, "All tilts are effectively zero (<1e-3) mrad, no quiver plot",
ha="center", va="center", fontsize=18, color="gray", transform=ax.transAxes
)
else:
M = np.hypot(tilts[:,0], tilts[:,1])
q = ax.quiver(pos[:,1], pos[:,0], tilts[:,1], tilts[:,0], M, pivot='mid', angles='xy', scale_units='xy', label='Obj tilts')
cbar = fig.colorbar(q, shrink=0.75)
cbar.ax.set_ylabel('mrad')
cbar.ax.get_yaxis().labelpad = 15
plt.gca().set_aspect('equal', adjustable='box')
plt.gca().invert_yaxis() # Flipped y-axis if there's only scatter plot
plt.xlabel('X (obj coord, px)')
plt.ylabel('Y (obj coord, px)')
plt.tight_layout()
if show_fig:
plt.show()
if pass_fig:
return fig, ax
[docs]
def plot_scan_positions(pos, init_pos=None, img=None, offset=None, figsize=(16,16), dot_scale=0.001, show_arrow=True, show_fig=True, pass_fig=False):
""" Plot the scan positions given an array of (N,2) """
# The array is expected to have shape (N,2)
# Each row is rendered as (y, x), or equivalently (height, width)
# The dots are plotted with asending size and color changes to represent the relative order
plt.ioff() # Temporaily disable the interactive plotting mode
fig = plt.figure(figsize = figsize)
ax = plt.gca() # There's only 1 ax for plt.figure(), and plt.title is an Axes-level attribute so I need to pass the Axes out because I like plt.title layout better
plt.title("Scan positions", fontsize=16)
if img is not None:
plt.imshow(img)
pos = np.array(pos) + np.array(offset)
plt.gca().invert_yaxis() # Pre-flip y-axis so the y-axis is image-like no matter what
if init_pos is None:
plt.scatter(x=pos[:,1], y=pos[:,0], c=np.arange(len(pos)), s=dot_scale*np.arange(len(pos)), label='Scan positions')
else:
plt.scatter(x=init_pos[:,1], y=init_pos[:,0], c='C0', s=dot_scale, label='Init scan positions')
plt.scatter(x=pos[:,1], y=pos[:,0], c='C1', s=dot_scale, label='Opt scan positions')
plt.ylim(init_pos[:,0].min()-10, init_pos[:,0].max()+10)
plt.xlim(init_pos[:,1].min()-10, init_pos[:,1].max()+10)
plt.gca().set_aspect('equal', adjustable='box')
plt.gca().invert_yaxis() # Flipped y-axis if there's only scatter plot
plt.xlabel('X (obj coord, px)')
plt.ylabel('Y (obj coord, px)')
# Draw arrow from 1st position to 10th position
if show_arrow:
plt.arrow(pos[0, 1], pos[0, 0], pos[9, 1] - pos[0, 1], pos[9, 0] - pos[0, 0],
color='red', head_width=2.5, head_length=5)
plt.legend()
plt.tight_layout()
if show_fig:
plt.show()
if pass_fig:
return fig, ax
[docs]
def plot_pos_grouping(pos, batches, circle_diameter=False, diameter_type='90%', figsize=(16,8), dot_scale=1, show_fig=True, pass_fig=False):
plt.ioff() # Temporaily disable the interactive plotting mode
fig, axs = plt.subplots(1,2, figsize = figsize)
for i, ax in enumerate(axs):
if i == 0:
axs[0].set_title(f"Scan positions for all {len(batches)} groups", fontsize=18)
for batch in batches:
ax.scatter(x=pos[batch, 1], y=pos[batch, 0], s=dot_scale)
else:
axs[1].set_title("Scan positions from group 0", fontsize=18)
ax.scatter(x=pos[batches[0], 1], y=pos[batches[0], 0], s=dot_scale)
# Draw a circle at the first point with the given diameter
if circle_diameter:
first_point = pos[batches[0][0]]
circle = plt.Circle((first_point[1], first_point[0]), circle_diameter / 2, fill=False, color='r', linestyle='--')
ax.scatter(x=first_point[1], y=first_point[0], s=dot_scale, color='r')
ax.add_artist(circle)
# Add annotation for "90% probe intensity"
annotation_text = f"{diameter_type} probe intensity"
annotation_x = first_point[1]
annotation_y = first_point[0] #+ circle_diameter / 2 + 10 # Adjust the vertical offset as needed
ax.annotate(annotation_text, xy=(annotation_x-circle_diameter/2, annotation_y-circle_diameter/2-3))
ax.set_xlabel('X (obj coord, px)')
ax.set_ylabel('Y (obj coord, px)')
ax.set_xlim(pos[:,1].min()-10, pos[:,1].max()+10) # Show the full range to better visualize if a sub-group (like 'center') is selected
ax.set_ylim(pos[:,0].min()-10, pos[:,0].max()+10)
ax.invert_yaxis()
ax.set_aspect('equal', adjustable='box')
plt.tight_layout()
if show_fig:
plt.show(block=False)
if pass_fig:
return fig
[docs]
def plot_loss_curves(loss_iters, last_n_iters=10, show_fig=True, pass_fig=False):
last_n_iters = int(last_n_iters)
data = np.array(loss_iters)
plt.ioff() # Temporaily disable the interactive plotting mode
fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
# Plot all loss values
axs.plot(data[:,0], data[:,1], marker='o')
# Plot the last n iters as an inset
if len(data) > 20 and last_n_iters is not None:
# Create inset subplot for zoomed-in plot
axins = axs.inset_axes([0.45, 0.3, 0.4, 0.5])
axins.plot(data[-last_n_iters:,0], data[-last_n_iters:,1], marker='o')
axins.set_xlabel('Iterations', fontsize=12)
axins.set_ylabel('Loss value', fontsize=12)
axins.yaxis.set_major_formatter(ticker.StrMethodFormatter('{x:.5f}'))
axs.indicate_inset_zoom(axins, edgecolor="gray")
axins.set_title(f'Last {last_n_iters} iterations', fontsize=12, pad=10)
# Set labels and title for the main plot
axs.set_xlabel('Iterations', fontsize=16)
axs.set_ylabel('Loss value', fontsize=16)
axs.set_title(f'Loss value: {data[-1,1]:.5f} at iter {int(data[-1,0])}', fontsize=16)
axs.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
plt.yticks(fontsize=14)
plt.xticks(fontsize=14)
plt.tight_layout()
if show_fig:
plt.show()
if pass_fig:
return fig
[docs]
def plot_slice_thickness(dz_iters, last_n_iters=10, show_fig=True, pass_fig=False):
last_n_iters = int(last_n_iters)
data = np.array(dz_iters)
plt.ioff() # Temporaily disable the interactive plotting mode
fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
# Plot all loss values
axs.plot(data[:,0], data[:,1], marker='o')
axs.grid(True)
# Plot the last n iters as an inset
if len(data) > 20 and last_n_iters is not None:
# Create inset subplot for zoomed-in plot
axins = axs.inset_axes([0.45, 0.3, 0.4, 0.5])
axins.plot(data[-last_n_iters:,0], data[-last_n_iters:,1], marker='o')
axins.set_xlabel('Iterations', fontsize=12)
axins.set_ylabel('Slice thickness (Ang)', fontsize=12)
axins.yaxis.set_major_formatter(ticker.StrMethodFormatter('{x:.5f}'))
axs.indicate_inset_zoom(axins, edgecolor="gray")
axins.set_title(f'Last {last_n_iters} iterations', fontsize=12, pad=10)
# Set labels and title for the main plot
axs.set_xlabel('Iterations', fontsize=16)
axs.set_ylabel('Slice thickness (Ang)', fontsize=16)
axs.set_title(f'Slice thickness (Ang): {data[-1,1]:.5f} at iter {int(data[-1,0])}', fontsize=16)
axs.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
plt.yticks(fontsize=14)
plt.xticks(fontsize=14)
plt.tight_layout()
if show_fig:
plt.show()
if pass_fig:
return fig
[docs]
def plot_probe_modes(init_probe=None, opt_probe=None, amp_or_phase='amplitude', real_or_fourier='real', phase_cmap=None, amplitude_cmap=None, show_fig=True, pass_fig=False):
# The input probes are expected to be numpy array
# This is for visualization so each mode has its own colorbar.
# See the actual probe amplitude output for absolute scale visualizaiton
# Initial checks
if init_probe is None and opt_probe is None:
raise ValueError("At least one of init_probe or opt_probe must be provided.")
if all(p is not None for p in (init_probe, opt_probe)) and init_probe.shape[0] != opt_probe.shape[0]:
raise ValueError(f"All provided probes must have the same number of probe modes (axis 0), got {init_probe.shape} and {opt_probe.shape}.")
# Initialize
probes = [init_probe, opt_probe]
labels = ["Init pmode", "Opt pmode"] # row titles
processed_probes = []
probes_pow = []
# Loop through possible input probes
for probe in probes:
if probe is None:
processed_probes.append(None)
probes_pow.append(None)
continue
# Power distribution
probe_int = np.abs(probe)**2
probe_pow = np.sum(probe_int, axis=(-2,-1))/np.sum(probe_int)
probes_pow.append(probe_pow)
# Fourier or real
# While it might seem redundant, the sandwitch fftshift(fft(ifftshift(probe)))) is needed for the following reason:
# Although probe_fourier = fft2(ifftshift(probe)) and probe_fourier = fft2(probe) gives the same abs(probe_fourier),
# pre-fftshifting the probe back to corner gives more accurate phase angle while plotting the angle(probe_fourier)
# On the other hand, fft2(probe) would generate additional phase shifts that looks like checkerboard artifact in angle(probe_fourier)
if real_or_fourier == 'fourier':
probe = fftshift(fft2(ifftshift(probe, axes=(-2,-1)), norm='ortho'), axes=(-2,-1))
elif real_or_fourier =='real':
pass
else:
raise ValueError("Please use 'real' or 'fourier' for probe mode visualization!")
# Amplitude or phase
# Negative sign for consistency with chi(k), because psi = exp(-i*chi(k)).
# Overfocus (negative df = positive C1) should give positive phase shift near the edge of aperture
# Scale the plotted phase by the amplitude so we can focus more on the relevant phases
# Although note that noisy amplitude will also make the phase appears noisy
if amp_or_phase == 'phase':
probe = -np.angle(probe)*np.abs(probe)
cmap = phase_cmap if phase_cmap else 'twilight'
elif amp_or_phase in ('amplitude', 'amp'):
probe = np.abs(probe)
cmap = amplitude_cmap if amplitude_cmap else 'viridis'
else:
raise ValueError("Please use 'amplitude' or 'phase' for probe mode visualization!")
processed_probes.append(probe)
# Parse variables
non_none = [(label, probe, probe_pow) for label, probe, probe_pow in zip(labels, processed_probes, probes_pow) if probe is not None]
n_modes = non_none[0][1].shape[0] # non_none[0][1] would be probe, probe = (pmode, Ny, Nx)
rows = len(non_none)
# Actual plotting
plt.ioff() # Temporaily disable the interactive plotting mode
fig, axs = plt.subplots(rows, n_modes, figsize=(n_modes*2.5, rows*3))
# Normalize axs shapes
axs = np.asarray(axs)
if axs.ndim == 0:
axs = axs.reshape(1, 1)
elif axs.ndim == 1:
if rows == 1:
axs = axs.reshape(1, n_modes)
else:
axs = axs.reshape(rows, 1)
for row_idx, (label, probe, probe_pow) in enumerate(non_none):
for i in range(n_modes):
ax = axs[row_idx, i]
ax.set_title(f"{label} {i}: {probe_pow[i]:.1%}")
im = ax.imshow(probe[i], cmap=cmap)
ax.axis('off')
fig.colorbar(im, ax=ax, shrink=0.6)
plt.suptitle(f"Probe modes {amp_or_phase} in {real_or_fourier} space", fontsize=18)
plt.tight_layout()
if show_fig:
plt.show()
if pass_fig:
return fig
[docs]
@torch.compiler.disable
def plot_summary(output_path, model, niter, indices, init_variables, selected_figs=['loss', 'forward', 'probe_r_amp', 'probe_k_amp', 'probe_k_phase', 'pos'], collate_str='', show_fig=True, save_fig=False, verbose=True):
""" Wrapper function for most visualization function """
# selected_figs can take 'loss', 'forward', 'probe_r_amp', 'probe_k_amp', 'probe_k_phase', 'pos', 'tilt', or 'all'
# Note: Set show_fig=False and save_fig=True if you just want to save the figure without showing
# Sets figure saving to be True if you accidiently disable both show_fig and save_fig
if show_fig is False and save_fig is False:
save_fig = True
if save_fig and verbose:
vprint(f"Saving summary figures for iter {niter}")
iter_str = '_iter' + str(niter).zfill(4)
# loss curves
if 'loss' in selected_figs or 'all' in selected_figs:
fig_loss = plot_loss_curves(model.loss_iters, last_n_iters=10, show_fig=show_fig, pass_fig=True)
if show_fig:
fig_loss.show()
if save_fig:
fig_loss.savefig(safe_filename(output_path + f"/summary_loss{collate_str}{iter_str}.png"))
# Forward pass
if 'forward' in selected_figs or 'all' in selected_figs:
n = int(len(indices)**0.5)
n2 = int(len(indices))
plot_indices = indices[np.int32([n2/2+n/4, n2/2+3*n/4])] # The idea is to get 2 regions of (N/2)x(N/2) that are +-N/4 from the center of the FOV.
fig_forward = plot_forward_pass(model, plot_indices, 0.5, show_fig=False, pass_fig=True)
fig_forward.suptitle(f"Forward pass at iter {niter}", fontsize=24)
if show_fig:
fig_forward.show()
if save_fig:
fig_forward.savefig(safe_filename(output_path + f"/summary_forward_pass{collate_str}{iter_str}.png"))
# Probe modes in real and reciprocal space
init_probe = init_variables['probe']
opt_probe = model.get_complex_probe_view().detach().cpu().numpy()
if 'probe_r_amp' in selected_figs or 'all' in selected_figs:
fig_probe_modes_real_amp = plot_probe_modes(init_probe, opt_probe, real_or_fourier='real', amp_or_phase='amplitude', show_fig=False, pass_fig=True)
fig_probe_modes_real_amp.suptitle(f"Probe modes amplitude in real space at iter {niter}", fontsize=18)
if show_fig:
fig_probe_modes_real_amp.show()
if save_fig:
fig_probe_modes_real_amp.savefig(safe_filename(output_path + f"/summary_probe_modes_real_amp{collate_str}{iter_str}.png"),bbox_inches='tight')
if 'probe_k_amp' in selected_figs or 'all' in selected_figs:
fig_probe_modes_fourier_amp = plot_probe_modes(init_probe, opt_probe, real_or_fourier='fourier', amp_or_phase='amplitude', show_fig=False, pass_fig=True)
fig_probe_modes_fourier_amp.suptitle(f"Probe modes amplitude in fourier space at iter {niter}", fontsize=18)
if show_fig:
fig_probe_modes_fourier_amp.show()
if save_fig:
fig_probe_modes_fourier_amp.savefig(safe_filename(output_path + f"/summary_probe_modes_fourier_amp{collate_str}{iter_str}.png"),bbox_inches='tight')
if 'probe_k_phase' in selected_figs or 'all' in selected_figs:
fig_probe_modes_fourier_phase = plot_probe_modes(init_probe, opt_probe, real_or_fourier='fourier', amp_or_phase='phase', show_fig=False, pass_fig=True)
fig_probe_modes_fourier_phase.suptitle(f"Probe modes phase in fourier space at iter {niter}", fontsize=18)
if show_fig:
fig_probe_modes_fourier_phase.show()
if save_fig:
fig_probe_modes_fourier_phase.savefig(safe_filename(output_path + f"/summary_probe_modes_fourier_phase{collate_str}{iter_str}.png"),bbox_inches='tight')
# Scan positions and tilts
init_pos = init_variables['crop_pos'] + init_variables['probe_pos_shifts']
pos = (model.crop_pos + model.opt_probe_pos_shifts).detach().cpu().numpy()
tilts = model.opt_obj_tilts.detach().cpu().numpy()
tilts = np.broadcast_to(tilts, (len(pos), 2)) # tilts has to be (N_scan, 2)
if 'pos' in selected_figs or 'all' in selected_figs:
fig_scan_pos, ax = plot_scan_positions(pos=pos[indices], init_pos=init_pos[indices], dot_scale=1, show_fig=False, pass_fig=True)
ax.set_title(f"Scan positions at iter {niter}", fontsize=16)
if show_fig:
fig_scan_pos.show()
if save_fig:
fig_scan_pos.savefig(safe_filename(output_path + f"/summary_scan_pos{collate_str}{iter_str}.png"))
if 'tilt' in selected_figs or 'all' in selected_figs:
fig_obj_tilts, ax = plot_obj_tilts(pos=pos[indices], tilts=tilts[indices], show_fig=False, pass_fig=True)
ax.set_title(f"Object tilts at iter {niter}", fontsize=16)
if show_fig:
fig_obj_tilts.show()
if save_fig:
fig_obj_tilts.savefig(safe_filename(output_path + f"/summary_obj_tilts{collate_str}{iter_str}.png"))
if 'tilt_avg' in selected_figs or 'all' in selected_figs:
fig_avg_obj_tilts = plot_obj_tilts_avg(model.avg_tilt_iters, last_n_iters=10, show_fig=show_fig, pass_fig=True)
if show_fig:
fig_avg_obj_tilts.show()
if save_fig:
fig_avg_obj_tilts.savefig(safe_filename(output_path + f"/summary_obj_tilts_avg{collate_str}{iter_str}.png"))
# Slice thickness
if 'slice_thickness' in selected_figs or 'all' in selected_figs:
fig_slice_thickness = plot_slice_thickness(model.dz_iters, last_n_iters=10, show_fig=show_fig, pass_fig=True)
if show_fig:
fig_slice_thickness.show()
if save_fig:
fig_slice_thickness.savefig(safe_filename(output_path + f"/summary_slice_thickness{collate_str}{iter_str}.png"))
# Close figures after saving
plt.close('all')