Source code for ptyrad.utils.dev_tools

"""
Developer tools for logging, testing, checking sizes and types, etc.

"""

import ast
import os
from collections import defaultdict

import numpy as np
import torch





[docs] def has_nan_or_inf(tensor): """ Check if a torch.Tensor contains any NaN or Inf values. Parameters: tensor (torch.Tensor): Input tensor to check. Returns: bool: True if the tensor contains any NaN or Inf values, False otherwise. """ # Check for NaN values has_nan = torch.isnan(tensor).any() # Check for Inf values has_inf = torch.isinf(tensor).any() return has_nan or has_inf
[docs] def get_size_bytes(x): print(f"Input tensor has shape {x.shape}, dtype {x.dtype}, and live on {x.device}") size_bytes = torch.numel(x) * x.element_size() size_mib = size_bytes / (1024 * 1024) size_gib = size_bytes / (1024 * 1024 * 1024) if size_bytes < 128 * 1024 * 1024: print(f"The size of the tensor is {size_mib:.2f} MiB") else: print(f"The size of the tensor is {size_gib:.2f} GiB") return size_bytes
[docs] def check_modes_ortho(tensor, atol = 2e-4): ''' Check if the modes in tensor (Nmodes, []) is orthogonal to each other''' # The easiest way to check orthogonality is to calculate the dot product of their 1D vector views # Orthogonal vectors would have dot product equals to 0 (Note that `orthonormal` also requires they have unit length) # Note that due to the floating point precision, we should set a reasonable tolerance w.r.t 0. # Also note that Matlab's dot(p2,p1) for complex input would implictly apply with the complex conjugate, # so Matlab's dot() != torch.dot because torch.dot doesn't automatically apply the complex conjugate. # This is pointed out by @dong-zehao in issue #11. # Therefore, instead of torch.dot(a,a), which would output un-intended result when a is complex, # use torch.dot(a, a.conj()) for the correct inner product. # Automatically convert numpy array to torch tensor if isinstance(tensor, np.ndarray): print("Casting input tensor from 'np.ndarray' to 'torch.tensor'") tensor = torch.tensor(tensor) tensor = tensor.to(dtype=torch.complex128) print(f"Input tensor has shape {tensor.shape} and dtype {tensor.dtype}") for i in range(tensor.shape[0]): for j in range(i + 1, tensor.shape[0]): dot_product = torch.dot(tensor[i].view(-1), tensor[j].view(-1).conj()) # Note that torch.dot only takes 1D tensor if torch.allclose(dot_product, torch.tensor(0., dtype=dot_product.dtype, device=dot_product.device), atol=atol): print(f"Modes {i} and {j} are orthogonal with abs(dot) = {dot_product.abs().detach().cpu().numpy()}") else: print(f"Modes {i} and {j} are not orthogonal with abs(dot) = {dot_product.abs().detach().cpu().numpy()}")
[docs] def yaml2json(input_filepath, output_filepath): import json import yaml with open(input_filepath, 'r') as file: try: # Load as YAML data = yaml.safe_load(file) # Save to JSON with open(output_filepath, 'w') as json_file: json.dump(data, json_file, indent=4) print(f"YAML {input_filepath} has been successfully converted and saved to JSON {output_filepath}") except yaml.YAMLError as e: print("Error parsing YAML file:", e)
# Testing functions
[docs] def test_loss_fn(model, indices, loss_fn): """ Print loss values for each term for convenient weight tuning """ # model: PtychoAD model # indices: array-like indices indicating which probe position to evaluate # measurements: 4D-STEM data that's already passed to DEVICE # loss_fn: loss function object created from CombinedLoss with torch.no_grad(): model_CBEDs, objp_patches = model(indices) measured_CBEDs = model.get_measurements(indices) _, losses = loss_fn(model_CBEDs, measured_CBEDs, objp_patches, model.omode_occu) # Print loss_name and loss_value with padding for loss_name, loss_value in zip(loss_fn.loss_params.keys(), losses): print(f"{loss_name.ljust(11)}: {loss_value.detach().cpu().numpy():.8f}") return
[docs] def test_constraint_fn(test_model, constraint_fn, plot_forward_pass): """ Test run of the constraint_fn """ # Note that this would directly modify the model so we need to make a test one indices = np.random.randint(0,len(test_model.measurements),2) constraint_fn(test_model, niter=1) if plot_forward_pass is not None: plot_forward_pass(test_model, indices, 0.5) del test_model return