Source code for ptyrad.plotting.model

"""
Plotting functions related to PyTorch models
"""
# This isolates the heavy torch import (3-6 sec) and is not promoted via __init__

import logging

import matplotlib.pyplot as plt
import numpy as np
import torch

from ptyrad.io.save import safe_filename

from .basic import (
    plot_convergence_dashboard,
    plot_loss_curves,
    plot_obj_tilts,
    plot_obj_tilts_avg,
    plot_probe_modes,
    plot_scan_positions,
    plot_slice_thickness,
    plot_learning_rates_schedule,
)

logger = logging.getLogger(__name__)

[docs] def plot_summary(output_path, model, niter, indices, init_variables, selected_figs=['loss', 'forward', 'probe_r_amp', 'probe_k_amp', 'probe_k_phase', 'pos'], collate_str='', show_fig=True, save_fig=False): """ Wrapper function for most visualization function """ # selected_figs can take 'loss', 'forward', 'probe_r_amp', 'probe_k_amp', 'probe_k_phase', 'pos', 'tilt', or 'all' # Note: Set show_fig=False and save_fig=True if you just want to save the figure without showing # Sets figure saving to be True if you accidiently disable both show_fig and save_fig if show_fig is False and save_fig is False: save_fig = True if save_fig: logger.info(f"Saving summary figures for iter {niter}") iter_str = '_iter' + str(niter).zfill(4) # loss curves if 'loss' in selected_figs or 'all' in selected_figs: fig_loss = plot_loss_curves(model.loss_iters, last_n_iters=10, show_fig=show_fig, pass_fig=True) if show_fig: fig_loss.show() if save_fig: fig_loss.savefig(safe_filename(output_path + f"/summary_loss{collate_str}{iter_str}.png")) # Learning rate schedule if 'learning_rates' in selected_figs or 'all' in selected_figs: fig_lr = plot_learning_rates_schedule(model.lr_iters, log=True, show_fig=show_fig, pass_fig=True) if show_fig: fig_lr.show() if save_fig: fig_lr.savefig(safe_filename(output_path + f"/summary_learning_rates{collate_str}{iter_str}.png")) # Forward pass if 'forward' in selected_figs or 'all' in selected_figs: n = int(len(indices)**0.5) n2 = int(len(indices)) plot_indices = indices[np.int32([n2/2+n/4, n2/2+3*n/4])] # The idea is to get 2 regions of (N/2)x(N/2) that are +-N/4 from the center of the FOV. fig_forward = plot_forward_pass(model, plot_indices, 0.5, plot_raw=False, show_fig=False, pass_fig=True) fig_forward.suptitle(f"Forward pass at iter {niter}", fontsize=24) if show_fig: fig_forward.show() if save_fig: fig_forward.savefig(safe_filename(output_path + f"/summary_forward_pass{collate_str}{iter_str}.png")) # Probe modes in real and reciprocal space init_probe = init_variables['probe'] opt_probe = model.get_complex_probe_view().detach().cpu().numpy() if 'probe_r_amp' in selected_figs or 'all' in selected_figs: fig_probe_modes_real_amp = plot_probe_modes(init_probe, opt_probe, real_or_fourier='real', amp_or_phase='amplitude', show_fig=False, pass_fig=True) fig_probe_modes_real_amp.suptitle(f"Probe modes amplitude in real space at iter {niter}", fontsize=18) if show_fig: fig_probe_modes_real_amp.show() if save_fig: fig_probe_modes_real_amp.savefig(safe_filename(output_path + f"/summary_probe_modes_real_amp{collate_str}{iter_str}.png"),bbox_inches='tight') if 'probe_k_amp' in selected_figs or 'all' in selected_figs: fig_probe_modes_fourier_amp = plot_probe_modes(init_probe, opt_probe, real_or_fourier='fourier', amp_or_phase='amplitude', show_fig=False, pass_fig=True) fig_probe_modes_fourier_amp.suptitle(f"Probe modes amplitude in fourier space at iter {niter}", fontsize=18) if show_fig: fig_probe_modes_fourier_amp.show() if save_fig: fig_probe_modes_fourier_amp.savefig(safe_filename(output_path + f"/summary_probe_modes_fourier_amp{collate_str}{iter_str}.png"),bbox_inches='tight') if 'probe_k_phase' in selected_figs or 'all' in selected_figs: fig_probe_modes_fourier_phase = plot_probe_modes(init_probe, opt_probe, real_or_fourier='fourier', amp_or_phase='phase', show_fig=False, pass_fig=True) fig_probe_modes_fourier_phase.suptitle(f"Probe modes phase in fourier space at iter {niter}", fontsize=18) if show_fig: fig_probe_modes_fourier_phase.show() if save_fig: fig_probe_modes_fourier_phase.savefig(safe_filename(output_path + f"/summary_probe_modes_fourier_phase{collate_str}{iter_str}.png"),bbox_inches='tight') # Scan positions and tilts init_pos = init_variables['crop_pos'] + init_variables['probe_pos_shifts'] pos = (model.crop_pos + model.opt_probe_pos_shifts).detach().cpu().numpy() tilts = model.opt_obj_tilts.detach().cpu().numpy() tilts = np.broadcast_to(tilts, (len(pos), 2)) # tilts has to be (N_scan, 2) if 'pos' in selected_figs or 'all' in selected_figs: fig_scan_pos, ax = plot_scan_positions(pos=pos[indices], init_pos=init_pos[indices], dot_scale=1, show_fig=False, pass_fig=True) ax.set_title(f"Scan positions at iter {niter}", fontsize=16) if show_fig: fig_scan_pos.show() if save_fig: fig_scan_pos.savefig(safe_filename(output_path + f"/summary_scan_pos{collate_str}{iter_str}.png")) if 'tilt' in selected_figs or 'all' in selected_figs: fig_obj_tilts, ax = plot_obj_tilts(pos=pos[indices], tilts=tilts[indices], show_fig=False, pass_fig=True) ax.set_title(f"Object tilts at iter {niter}", fontsize=16) if show_fig: fig_obj_tilts.show() if save_fig: fig_obj_tilts.savefig(safe_filename(output_path + f"/summary_obj_tilts{collate_str}{iter_str}.png")) if 'tilt_avg' in selected_figs or 'all' in selected_figs: fig_avg_obj_tilts = plot_obj_tilts_avg(model.avg_tilt_iters, last_n_iters=10, show_fig=show_fig, pass_fig=True) if show_fig: fig_avg_obj_tilts.show() if save_fig: fig_avg_obj_tilts.savefig(safe_filename(output_path + f"/summary_obj_tilts_avg{collate_str}{iter_str}.png")) # Slice thickness if 'slice_thickness' in selected_figs or 'all' in selected_figs: fig_slice_thickness = plot_slice_thickness(model.dz_iters, last_n_iters=10, show_fig=show_fig, pass_fig=True) if show_fig: fig_slice_thickness.show() if save_fig: fig_slice_thickness.savefig(safe_filename(output_path + f"/summary_slice_thickness{collate_str}{iter_str}.png")) # Convergence dashboard — unified time-series figure (loss, LR, dz, tilts, tensor metrics). # Old standalone keys 'loss', 'slice_thickness', 'tilt_avg' still work as aliases above. # 'convergence' / 'convergence_full': full history from iter 0 (iter_offset=0). # 'convergence_dynamic': per-panel Kneedle zoom to the most informative x-range. _conv_kwargs = dict( loss_iters=model.loss_iters, lr_iters=model.lr_iters, dz_iters=model.dz_iters, avg_tilt_iters=model.avg_tilt_iters, convergence_iters=model.convergence_iters, show_fig=show_fig, pass_fig=True, ) if 'convergence' in selected_figs or 'convergence_full' in selected_figs or 'all' in selected_figs: fig_conv = plot_convergence_dashboard(**_conv_kwargs, iter_offset=0) if fig_conv is not None: if show_fig: fig_conv.show() if save_fig: fig_conv.savefig(safe_filename(output_path + f"/summary_convergence{collate_str}{iter_str}.png")) if 'convergence_dynamic' in selected_figs or 'all' in selected_figs: fig_conv_dyn = plot_convergence_dashboard(**_conv_kwargs, iter_offset=None) if fig_conv_dyn is not None: if show_fig: fig_conv_dyn.show() if save_fig: fig_conv_dyn.savefig(safe_filename(output_path + f"/summary_convergence_dynamic{collate_str}{iter_str}.png")) # Close figures after saving plt.close('all')
[docs] def plot_forward_pass(model, indices, dp_power, plot_raw=False, show_fig=True, pass_fig=False): """ Plot the forward pass for the input torch model """ # The input is expected to be torch object and the attributes are all torch tensors and will be converted to numpy # probes_int = (N_i, Ny, Nx), float32 np array # obj_ROI = (N_i, omode, Nz, Ny, Nx) -> (N_i, Nz, Ny, Nx), float32 np array # For probe, only plot the intensity of incoherently summed mixed-state probe # For object, only plot the phase of the weighted sum object mode and sums over z-slices # The dp_power here is for visualization purpose, the actual loss function has its own param field with torch.no_grad(): probes = model.get_probes(indices) probes_int = probes.abs().pow(2).sum(1) model_DP = model(indices, return_raw=plot_raw) # We can return the raw simulated DP for debugging and development obj_patches = model.get_obj_patches(indices) # The cache would be cleared right after the mini-batch update so we have to re-calculate it here omode_occu = model.omode_occu measured_DP = model.get_measurements(indices) probes_int = probes_int.detach().cpu().numpy() obja_ROI = (obj_patches[0] * omode_occu[:,None,None,None]).sum(1).detach().cpu().numpy() # obj_ROI = (N_i, Nz,Ny,Nx) objp_ROI = (obj_patches[1] * omode_occu[:,None,None,None]).sum(1).detach().cpu().numpy() # obj_ROI = (N_i, Nz,Ny,Nx) model_DP = model_DP.detach().cpu().numpy() measured_DP = measured_DP.detach().cpu().numpy() plt.ioff() # Temporaily disable the interactive plotting mode fig, axs = plt.subplots(len(indices), 5, figsize=(24, 5*len(indices))) plt.suptitle("Forward pass", fontsize=24) for i, idx in enumerate(indices): # Looping over the N_i dimension im00 = axs[i,0].imshow(probes_int[i]) axs[i,0].set_title(f"Probe intensity idx{idx}", fontsize=16) fig.colorbar(im00, shrink=0.6) im01 = axs[i,1].imshow(obja_ROI[i].prod(0)) axs[i,1].set_title(f"Object amp. (osum, zprod) idx{idx}", fontsize=16) fig.colorbar(im01, shrink=0.6) im02 = axs[i,2].imshow(objp_ROI[i].sum(0)) axs[i,2].set_title(f"Object phase (osum, zsum) idx{idx}", fontsize=16) fig.colorbar(im02, shrink=0.6) im03 = axs[i,3].imshow((model_DP[i]**dp_power)) axs[i,3].set_title(f"Model DP^{dp_power} idx{idx}", fontsize=16) fig.colorbar(im03, shrink=0.6) im04 = axs[i,4].imshow((measured_DP[i]**dp_power)) axs[i,4].set_title(f"Data DP^{dp_power} idx{idx}", fontsize=16) fig.colorbar(im04, shrink=0.6) plt.tight_layout() if show_fig: plt.show() if pass_fig: return fig