Source code for ptyrad.utils.dev_tools

"""
Developer tools for logging, testing, checking sizes and types, etc.

"""

import ast
import os
from collections import defaultdict

import numpy as np





[docs] def has_nan_or_inf(tensor): """ Check if a torch.Tensor contains any NaN or Inf values. Parameters: tensor (torch.Tensor): Input tensor to check. Returns: bool: True if the tensor contains any NaN or Inf values, False otherwise. """ import torch # Check for NaN values has_nan = torch.isnan(tensor).any() # Check for Inf values has_inf = torch.isinf(tensor).any() return has_nan or has_inf
[docs] def get_size_bytes(x): import torch print(f"Input tensor has shape {x.shape}, dtype {x.dtype}, and live on {x.device}") size_bytes = torch.numel(x) * x.element_size() size_mib = size_bytes / (1024 * 1024) size_gib = size_bytes / (1024 * 1024 * 1024) if size_bytes < 128 * 1024 * 1024: print(f"The size of the tensor is {size_mib:.2f} MiB") else: print(f"The size of the tensor is {size_gib:.2f} GiB") return size_bytes
[docs] def check_modes_ortho(tensor, rtol=1e-3): ''' Check if the modes in tensor (Nmodes, []) is orthogonal to each other''' # The easiest way to check orthogonality is to calculate the dot product of their 1D vector views # Orthogonal vectors would have dot product equals to 0 (Note that `orthonormal` also requires they have unit length) # Note that due to the floating point precision, we should set a reasonable tolerance w.r.t 0. # Also note that Matlab's dot(p2,p1) for complex input would implictly apply with the complex conjugate, # so Matlab's dot() != torch.dot because torch.dot doesn't automatically apply the complex conjugate. # This is pointed out by @dong-zehao in issue #11. # Therefore, instead of torch.dot(a,a), which would output un-intended result when a is complex, # use torch.dot(a, a.conj()) for the correct inner product. import torch import numpy as np # Automatically convert numpy array to torch tensor if isinstance(tensor, np.ndarray): print("Casting input tensor from 'np.ndarray' to 'torch.tensor'") tensor = torch.tensor(tensor) tensor = tensor.to(dtype=torch.complex128) print(f"Input tensor has shape {tensor.shape} and dtype {tensor.dtype}") # Initialize the master flag is_orthogonal = True for i in range(tensor.shape[0]): ai = tensor[i].view(-1) norm_i = torch.linalg.norm(ai) for j in range(i + 1, tensor.shape[0]): aj = tensor[j].view(-1) norm_j = torch.linalg.norm(aj) dot = torch.dot(ai, aj.conj()) # Note that torch.dot only takes 1D tensor # Relative overlap (dimensionless!) rel_overlap = dot.abs() / (norm_i * norm_j + 1e-16) if rel_overlap < rtol: print(f"Modes {i}, {j} orthogonal: rel_overlap = {rel_overlap.item():.3e}") else: print(f"Modes {i}, {j} NOT orthogonal: rel_overlap = {rel_overlap.item():.3e}") # Flip the flag if any pair fails the tolerance check is_orthogonal = False return is_orthogonal
# Testing functions
[docs] def test_loss_fn(model, indices, loss_fn): """ Print loss values for each term for convenient weight tuning """ # model: PtychoModel model # indices: array-like indices indicating which probe position to evaluate # measurements: 4D-STEM data that's already passed to DEVICE # loss_fn: loss function object created from CombinedLoss import torch with torch.no_grad(): model_CBEDs, objp_patches = model(indices) measured_CBEDs = model.get_measurements(indices) _, losses = loss_fn(model_CBEDs, measured_CBEDs, objp_patches, model.omode_occu) # Print loss_name and loss_value with padding for loss_name, loss_value in zip(loss_fn.loss_params.keys(), losses): print(f"{loss_name.ljust(11)}: {loss_value.detach().cpu().numpy():.8f}") return
[docs] def test_constraint_fn(test_model, constraint_fn, plot_forward_pass): """ Test run of the constraint_fn """ # Note that this would directly modify the model so we need to make a test one indices = np.random.randint(0,len(test_model.measurements),2) constraint_fn(test_model, niter=1) if plot_forward_pass is not None: plot_forward_pass(test_model, indices, 0.5) del test_model return