"""
PtyRAD-specific saving functions
"""
import logging
import os
from typing import Any, Dict
import h5py
import numpy as np
import torch
from tifffile import imwrite
from ptyrad.io.provenance import generate_provenance_json, save_provenance_to_hdf5
from ptyrad.utils.image_proc import normalize_by_bit_depth
from ptyrad.utils.time import get_time
logger = logging.getLogger(__name__)
###### These are results saving functions ######
[docs]
def expand_presets(input_list, presets):
"""Expands a list of tags by replacing preset keys with their corresponding lists.
If a tag in the `input_list` exists as a key in the `presets` dictionary,
it is replaced by the list of tags associated with that preset. Duplicates
are removed while preserving the original order of the tags.
Args:
input_list (list of str): The initial list of tags or preset names.
presets (dict): A dictionary mapping a preset string to a list of strings.
Returns:
list of str: The expanded list of tags with duplicates removed.
"""
expanded = []
for tag in input_list:
if tag in presets:
expanded.extend(presets[tag])
else:
expanded.append(tag)
return list(dict.fromkeys(expanded)) # Removes duplicates, keeps order
[docs]
def safe_filename(filepath):
"""Ensures a filepath is safe and valid across different operating systems.
This function prevents path-related crashes by:
1. Converting relative paths to absolute paths.
2. Truncating individual directory or file components to 255 characters.
3. Handling the Windows 260-character total path limit by applying the "\\\\?\\" extended-length path prefix if necessary.
Args:
filepath (str): The original filepath to sanitize.
Returns:
str: A modified, cross-platform-safe absolute filepath.
"""
import platform
# Store original path for reporting
original_path = filepath
# Handle relative paths by converting to absolute
filepath = os.path.abspath(filepath)
# Platform detection
is_windows = platform.system() == 'Windows'
# Check if path already has long path prefix on Windows
has_long_prefix = is_windows and filepath.startswith("\\\\?\\")
# Early return if path is already valid
if not has_long_prefix:
# Check individual component limit (255 chars)
components_valid = True
sep = '\\' if is_windows else '/'
parts = filepath.split(sep)
for part in parts:
if len(part) > 255:
components_valid = False
break
# Check total path length limit
length_valid = (len(filepath) <= 260) if is_windows else True
# If everything is valid, return the absolute path
if components_valid and length_valid:
return filepath
# Path requires correction - continue with fixing logic
# Path separator based on platform
sep = '\\' if is_windows else '/'
# Split path into directory and filename
directory, filename = os.path.split(filepath)
# Track if any changes were made
changes_made = False
# Limit filename component to 255 chars (preserve extension)
if len(filename) > 255:
changes_made = True
name, ext = os.path.splitext(filename)
max_name_length = 255 - len(ext)
filename = name[:max_name_length] + ext
# Handle directory components (limit each to 255 chars)
if directory:
parts = directory.split(sep)
for i, part in enumerate(parts):
if len(part) > 255:
changes_made = True
parts[i] = part[:255]
directory = sep.join(parts)
# Recombine path
result_path = os.path.join(directory, filename)
# Handle Windows total path length
if is_windows and len(result_path) > 260:
changes_made = True
# If still too long, apply the \\?\ prefix for long path support
if not result_path.startswith("\\\\?\\"):
# Ensure we're working with an absolute path for the \\?\ prefix
result_path = "\\\\?\\" + os.path.abspath(result_path)
# Provide feedback if corrections were made
if changes_made:
logger.info("Path corrected for compatibility:")
logger.info(f" Original: {original_path}")
logger.info(f" Corrected: {result_path}")
return result_path
[docs]
def make_save_dict(output_path, model, params, optimizer, scheduler, niter, indices, batch_losses):
"""Compiles the model state, parameters, and optimizer data into a dictionary.
This explicitly extracts and formats the current runtime attributes from the
model (e.g., actual grid sizes, learning rates) rather than relying solely
on the initial parameter file, as values may have changed during initialization
(e.g., cropping, resampling) or interactive walkthroughs.
Args:
output_path (str): The directory path where results will be saved.
model (PtychoModel): The current reconstruction model object.
params (dict): The initial configuration dictionary.
optimizer (torch.optim.Optimizer): The PyTorch optimizer.
niter (int): The current iteration number.
indices (list or numpy.ndarray): The batch indices used in the current iteration.
batch_losses (dict): A dictionary mapping loss names to lists of batch loss values.
Returns:
dict: A comprehensive dictionary ready for HDF5 serialization.
"""
avg_losses = {name: np.mean(values) for name, values in batch_losses.items()}
avg_iter_t = np.mean(model.iter_times)
# While it might seem redundant to save bothe `params` and lots of `model_attributes`,
# one should note that `params` only stores the initial value from params files,
# the actual values used for reconstuction such as N_scan_slow, N_scan_fast, dx, dk, Npix, N_scans could be different from initial value due to the meas_crop, meas_resample
# the model behavior and learning rates could also be different from the initial params dict if the user
# run the reconstuction with manually modified `model_params` in the detailed walkthrough notebook
# Postprocess the opt_probe back to complex view
optimizable_tensors = {}
for name, tensor in model.optimizable_tensors.items():
optimizable_tensors[name] = tensor.detach().clone()
if name == 'probe':
optimizable_tensors['probe'] = model.get_complex_probe_view().detach().clone()
# Postprocess the scheduler_state_dict
if (scheduler is not None and 'scheduler_state' in params['recon_params']['save_result']):
scheduler_state_dict = scheduler.state_dict().copy()
del scheduler_state_dict['step'] # Remove the monkey-patched closure — it's not serializable and not part of resumable state
else:
scheduler_state_dict = None
from ptyrad import __version__ as ptyrad_version
save_dict = {
'ptyrad_version' : ptyrad_version,
'output_path' : output_path,
'optimizable_tensors' : optimizable_tensors,
'optim_state_dict' : optimizer.state_dict() if 'optim_state' in params['recon_params']['save_result'] else None,
'scheduler_state_dict' : scheduler_state_dict,
'params' : params,
'model_attributes': # Have to do this explicit saving because I want specific fields but don't want the enitre model with grids and other redundant info
{'detector_blur_std': model.detector_blur_std,
'start_iter' : model.start_iter,
'lr_params' : model.lr_params,
'omode_occu' : model.omode_occu,
'H' : model.H,
'N_scan_slow' : model.N_scan_slow,
'N_scan_fast' : model.N_scan_fast,
'crop_pos' : model.crop_pos,
'slice_thickness' : model.slice_thickness,
'dx' : model.dx,
'dk' : model.dk,
'lambd' : model.lambd,
'meas_Npix' : model.meas_Npix,
'simu_Npix' : model.simu_Npix,
'simu_match_mode' : model.simu_match_mode,
'scan_affine' : model.scan_affine,
'tilt_obj' : model.tilt_obj,
'probe_int_sum' : model.probe_int_sum
},
'random_seed' : model.random_seed,
'loss_iters' : model.loss_iters,
'iter_times' : model.iter_times,
'dz_iters' : model.dz_iters,
'lr_iters' : model.lr_iters,
'avg_tilt_iters' : model.avg_tilt_iters,
'convergence_iters' : dict(model.convergence_iters),
'avg_iter_t' : avg_iter_t,
'niter' : niter,
'indices' : indices,
'batch_losses' : batch_losses,
'avg_losses' : avg_losses
}
return save_dict
[docs]
def save_dict_to_hdf5(
d: Dict[str, Any], output_path: str, none_sentinel: str = "__NONE__", **kwargs
) -> None:
"""Saves a nested Python dictionary to an HDF5 file.
Recursively parses the dictionary and converts common Python, NumPy, and
PyTorch types into HDF5-compatible formats. Non-compatible types (e.g.,
lists of tuples, `None`) are converted to safe string representations or sentinels.
Integer keys (common in optimizer state dicts) are coerced to strings.
Args:
d (dict): The nested dictionary to serialize.
output_path (str): The target file path for the `.hdf5` file.
none_sentinel (str, optional): The string used to represent `None`
values in the HDF5 file. Defaults to "__NONE__".
**kwargs: Additional keyword arguments passed to `h5py.File.create_dataset()`
(e.g., `compression="gzip"`).
"""
def _recursively_save_dict_to_hdf5(d: Dict[str, Any], h5group: h5py.Group, path="") -> None:
for key, value in d.items():
full_key = f"{path}/{key}" if path else str(key)
key = str(key) # convert to string for HDF5, especially important for optimizer state dict with integer as key
try:
# Delete existing group/dataset if it exists
if key in h5group:
del h5group[key]
if value is None:
h5group.create_dataset(key, data=none_sentinel, **kwargs)
elif isinstance(value, dict):
subgroup = h5group.create_group(key)
_recursively_save_dict_to_hdf5(value, subgroup)
elif isinstance(value, list):
if all(isinstance(i, (int, float, np.number)) for i in value):
h5group.create_dataset(key, data=np.array(value), **kwargs)
elif all(isinstance(i, str) for i in value):
dt = h5py.special_dtype(vlen=str)
h5group.create_dataset(key, data=np.array(value, dtype=dt), **kwargs)
elif all(isinstance(i, tuple) for i in value):
try:
arr = np.array([list(t) for t in value])
h5group.create_dataset(key, data=arr, **kwargs)
except Exception:
h5group.create_dataset(key, data=str(value), **kwargs)
elif all(isinstance(i, dict) for i in value):
subgroup = h5group.create_group(key)
for idx, item in enumerate(value):
item_group = subgroup.create_group(str(idx))
_recursively_save_dict_to_hdf5(item, item_group)
elif all(isinstance(i, (np.ndarray, torch.Tensor)) for i in value):
try:
arr = np.stack([i.detach().cpu().numpy() if isinstance(i, torch.Tensor) else i for i in value])
h5group.create_dataset(key, data=arr, **kwargs)
except Exception:
h5group.create_dataset(key, data=str(value), **kwargs)
else:
# fallback to storing list as strings (warn if needed)
h5group.create_dataset(key, data=str(value), **kwargs)
elif isinstance(value, tuple):
h5group.create_dataset(key, data=np.array(value), **kwargs)
elif isinstance(value, (int, float, str, np.number)):
h5group.create_dataset(key, data=value, **kwargs)
elif isinstance(value, torch.Tensor):
h5group.create_dataset(key, data=value.detach().cpu().numpy(), **kwargs)
elif isinstance(value, np.ndarray):
h5group.create_dataset(key, data=value, **kwargs)
# Fallback option
else:
h5group.create_dataset(key, data=str(value), **kwargs)
except Exception as e:
raise RuntimeError(f"Failed to save key '{key}' (full path: '{full_key}') of type {type(value)}") from e
with h5py.File(output_path, "w") as hf:
_recursively_save_dict_to_hdf5(d, hf)
[docs]
def make_output_folder(
output_dir,
indices,
init_params,
recon_params,
model,
constraint_params,
loss_params,
recon_dir_affixes=["default"],
):
"""Generates a highly descriptive output folder name based on runtime configurations.
Constructs a detailed directory name by concatenating abbreviations of the
active reconstruction parameters, model attributes, constraints, and losses.
This allows users to identify the exact settings of a run just by looking
at the folder name.
Args:
output_dir (str): The base directory where the output folder will be created.
indices (list): The list of indices used in the reconstruction.
init_params (dict): Initialization parameters.
recon_params (dict): Reconstruction parameters containing naming prefixes/postfixes.
model (PtychoModel): The model containing current spatial and optimization attributes.
constraint_params (dict): Dictionary of active constraints and filters.
loss_params (dict): Dictionary of active loss functions and weights.
recon_dir_affixes (list of str, optional): A list of tags or preset keys
dictating which parameters to include in the folder name.
Defaults to ["default"].
Returns:
str: The sanitized absolute path to the newly generated output folder.
"""
prefix_time = recon_params.get("prefix_time", False)
prefix = recon_params.get("prefix", "")
postfix = recon_params.get("postfix", "")
parts = []
recon_dir_presets = {
"minimal": ['indices', 'meas', 'batch', 'pmode', 'omode', 'nlayer'],
"default": ['indices', 'meas', 'batch', 'pmode', 'omode', 'nlayer',
'lr', 'model', 'constraint',
'loss', 'affine', 'tilt', 'aberrations'],
"all": ['indices', 'meas', 'batch', 'pmode', 'omode', 'nlayer',
'optimizer', 'scheduler', 'start_iter', 'lr', 'model', 'constraint',
'loss', 'conv_angle', 'aberrations', 'Ls', 'z_shift', 'dx', 'affine', 'tilt']
}
# Process recon_dir_affixes to expand presets
if any(tag in recon_dir_presets for tag in recon_dir_affixes):
logger.info(f"Original recon_dir_affixes = {recon_dir_affixes}")
recon_dir_affixes = expand_presets(recon_dir_affixes, recon_dir_presets)
logger.info(f"Expanded recon_dir_affixes = {recon_dir_affixes}")
# Attach time string if prefix_time is true or non-empty str
if prefix_time is True or (isinstance(prefix_time, str) and prefix_time):
time_str = get_time(prefix_time) # e.g. '20250606'
parts.append(time_str)
# Attach prefix (only if prefix is non-empty str)
if isinstance(prefix, str) and prefix:
parts.append(prefix)
# Attach indices mode (optional)
if "indices" in recon_dir_affixes:
indices_mode = recon_params["INDICES_MODE"].get("mode")
parts.append(f"{indices_mode}_N{len(indices)}")
# Attach DP size and meas flip (optional)
if "meas" in recon_dir_affixes:
dp_size = model.get_complex_probe_view().size(-1)
parts.append(f"dp{dp_size}")
meas_flipT = init_params["meas_flipT"]
# Attach meas flipping
if meas_flipT is not None: # Note that [0,0,0] will be attached is specified for clarity
flipT_str = "flipT" + "".join(str(x) for x in meas_flipT)
parts.append(flipT_str)
# Attach group mode and batch size (optional)
if "batch" in recon_dir_affixes:
group_mode = recon_params["GROUP_MODE"]
batch_size = recon_params["BATCH_SIZE"].get("size")
grad_accum = recon_params["BATCH_SIZE"].get("grad_accumulation", 1)
batch_size *= grad_accum # Affix the effective batch size
parts.append(f"{group_mode}{batch_size}")
# Attach pmode (optional)
if "pmode" in recon_dir_affixes:
pmode = model.get_complex_probe_view().size(0)
parts.append(f"p{pmode}")
# Attach omode (optional)
if "omode" in recon_dir_affixes:
omode = model.opt_objp.size(0)
parts.append(f"{omode}obj")
# Attach obj Nlayer and dz (optional)
if "nlayer" in recon_dir_affixes:
nlayer = model.opt_objp.size(1)
parts.append(f"{nlayer}slice")
if nlayer != 1:
slice_thickness = (
model.slice_thickness.detach().cpu().numpy()
) # This is the initialized slice thickness
parts.append(f"dz{slice_thickness:.3g}")
# Attach optimizer name (optional)
if "optimizer" in recon_dir_affixes:
optimizer_str = model.optimizer_params["name"]
parts.append(f"{optimizer_str}")
# Attach scheduler name (optional)
if "scheduler" in recon_dir_affixes and model.scheduler_params is not None:
scheduler_str = model.scheduler_params["name"]
parts.append(f"{scheduler_str}")
# Attach start_iter (optional)
if "start_iter" in recon_dir_affixes:
start_iter_map = {
"probe": "ps",
"obja": "oas",
"objp": "ops",
"probe_pos_shifts": "ss",
"obj_tilts": "ts",
"slice_thickness": "dzs",
}
for key, tag in start_iter_map.items():
start_val = model.start_iter.get(key)
if start_val is not None and start_val > 1:
parts.append(f"{tag}{start_val}")
# Attach learning rate (optional)
if "lr" in recon_dir_affixes:
lr_map = {
"probe": "plr",
"obja": "oalr",
"objp": "oplr",
"probe_pos_shifts": "slr",
"obj_tilts": "tlr",
"slice_thickness": "dzlr",
}
for key, tag in lr_map.items():
lr_val = model.lr_params[key]
if lr_val != 0:
lr_str = format(lr_val, ".0e").replace("e-0", "e-")
parts.append(f"{tag}{lr_str}")
# Attach model params (optional)
if "model" in recon_dir_affixes:
if model.detector_blur_std is not None and model.detector_blur_std != 0:
parts.append(f"dpblur{model.detector_blur_std}")
# Attach constraint params (optional)
if "constraint" in recon_dir_affixes:
if constraint_params["kr_filter"]["start_iter"] is not None:
obj_type = constraint_params["kr_filter"]["obj_type"]
kr_str = {"both": "kr", "amplitude": "kra", "phase": "krp"}.get(obj_type)
radius = constraint_params["kr_filter"]["radius"]
parts.append(f"{kr_str}f{radius}")
if constraint_params["kz_filter"]["start_iter"] is not None:
obj_type = constraint_params["kz_filter"]["obj_type"]
kz_str = {"both": "kz", "amplitude": "kza", "phase": "kzp"}.get(obj_type)
beta = constraint_params["kz_filter"]["beta"]
parts.append(f"{kz_str}f{beta}")
if constraint_params["kr_thresh"]["start_iter"] is not None:
obj_type = constraint_params["kr_thresh"]["obj_type"]
krt_str = {"both": "krt", "amplitude": "krta", "phase": "krtp"}.get(obj_type)
thresh = constraint_params["kr_thresh"]["thresh"]
parts.append(f"{krt_str}{thresh}")
if (
constraint_params["obj_rblur"]["start_iter"] is not None
and constraint_params["obj_rblur"]["std"] != 0
):
obj_type = constraint_params["obj_rblur"]["obj_type"]
obj_str = {"both": "o", "amplitude": "oa", "phase": "op"}.get(obj_type)
parts.append(f"{obj_str}rblur{constraint_params['obj_rblur']['std']}")
if (
constraint_params["obj_zblur"]["start_iter"] is not None
and constraint_params["obj_zblur"]["std"] != 0
):
obj_type = constraint_params["obj_zblur"]["obj_type"]
obj_str = {"both": "o", "amplitude": "oa", "phase": "op"}.get(obj_type)
parts.append(f"{obj_str}zblur{constraint_params['obj_zblur']['std']}")
if constraint_params["complex_ratio"]["start_iter"] is not None:
obj_type = constraint_params["complex_ratio"]["obj_type"]
obj_str = {"both": "o", "amplitude": "oa", "phase": "op"}.get(obj_type)
alpha1 = round(constraint_params["complex_ratio"]["alpha1"], 2)
alpha2 = round(constraint_params["complex_ratio"]["alpha2"], 2)
parts.append(f"{obj_str}cplx{alpha1}_{alpha2}")
if constraint_params["mirrored_amp"]["start_iter"] is not None:
scale = round(constraint_params["mirrored_amp"]["scale"], 2)
power = round(constraint_params["mirrored_amp"]["power"], 2)
parts.append(f"mamp{scale}_{power}")
if constraint_params["obj_z_recenter"]["start_iter"] is not None:
parts.append("ozrec")
if constraint_params["obja_thresh"]["start_iter"] is not None:
parts.append(f"oathr{round(constraint_params['obja_thresh']['thresh'][0], 2)}")
if constraint_params["objp_postiv"]["start_iter"] is not None:
mode = constraint_params["objp_postiv"].get("mode", "clip_neg")
mode_str = "s" if mode == "subtract_min" else "c"
relax = constraint_params["objp_postiv"]["relax"]
relax_str = "" if relax == 0 else f"{round(relax, 2)}"
parts.append(f"opos{mode_str}{relax_str}")
if constraint_params["tilt_smooth"]["start_iter"] is not None:
parts.append(f"tsm{round(constraint_params['tilt_smooth']['std'], 2)}")
if constraint_params["probe_mask_k"]["start_iter"] is not None:
parts.append(f"pmk{round(constraint_params['probe_mask_k']['radius'], 2)}")
if constraint_params["probe_mask_r"]["start_iter"] is not None:
parts.append(f"pmr{round(constraint_params['probe_mask_r']['radius'], 2)}")
# Attach loss params (optional)
if "loss" in recon_dir_affixes:
loss_map = {
"loss_single": ("sng", 2),
"loss_poissn": ("psn", 2),
"loss_pacbed": ("pcb", 2),
"loss_sparse": ("spr", 2),
"loss_simlar": ("sml", 2),
}
for key, (tag, digits) in loss_map.items():
loss = loss_params.get(key, {})
if loss.get("state"):
parts.append(f"{tag}{round(loss.get('weight', 0), digits)}")
# Attach conv_angle (optional)
if "conv_angle" in recon_dir_affixes and "probe_conv_angle" in init_params:
parts.append(f"ca{init_params['probe_conv_angle']:.3g}")
# Attach aberrations (optional)
if "aberrations" in recon_dir_affixes and "probe_aberrations" in init_params:
init_aberrations = init_params.get("probe_aberrations")
if init_aberrations:
for k, v in init_aberrations.items():
parts.append(f"{k}_{v:.3g}") # Note that the user values are canonicalized to Krivanek polar during load_params
# Attach probe_Ls (optional, for xray only)
if "Ls" in recon_dir_affixes and "probe_Ls" in init_params:
init_Ls = init_params.get("probe_Ls")
if init_Ls is not None:
parts.append(f"Ls{init_Ls * 1e9:.0f}")
# Attach probe_z_shift (optional)
if "z_shift" in recon_dir_affixes and "probe_z_shift" in init_params:
init_z_shift = init_params.get("probe_z_shift")
if init_z_shift is not None and init_z_shift != 0:
parts.append(f"z_shift{init_z_shift:.3g}")
# Attach dx (optional)
if "dx" in recon_dir_affixes:
dx = model.dx.detach().cpu().numpy()
parts.append(f"dx{dx:.4g}")
# Attach scan_affine (optional)
if "affine" in recon_dir_affixes:
scan_affine = model.scan_affine # Note that scan_affine could be None
if scan_affine is not None and not np.allclose(scan_affine, [1, 0, 0, 0]):
formats = [".2f", ".2f", ".1f", ".1f"] # customize per index
formatted = [format(x, fmt) for x, fmt in zip(scan_affine, formats)]
affine_str = "aff" + "_".join(formatted) # (4,)
parts.append(f"{affine_str}")
# Attach init tilts (optional)
if "tilt" in recon_dir_affixes:
init_tilts = (
model.opt_obj_tilts.mean(0).detach().cpu().numpy()
) # (2,) regardless tilt_type = 'all' or 'each'
if np.any(init_tilts):
parts.append(f"tilt{init_tilts[0]:.2g}_{init_tilts[1]:.2g}")
# Attach postfix (only if postfix is non-empty str)
if isinstance(postfix, str) and postfix:
parts.append(postfix)
# Make output folder
output_path = os.path.join(output_dir, "_".join(parts)) if parts else output_dir
output_path = safe_filename(output_path)
os.makedirs(output_path, exist_ok=True)
logger.info(f"output_path = '{output_path}' is generated!")
return output_path
[docs]
def save_results(output_path, model, params, optimizer, scheduler, niter, indices, batch_losses, collate_str=''):
"""Exports the reconstruction model state and renders image outputs.
This function acts as the main saving hub. Depending on the `save_result`
and `result_modes` specifications in `params`, it saves the full state
to an HDF5 file (including provenance) and renders the complex object and
probe arrays out to `.tif` images with appropriate bit-depth scaling and
dimensional slicing (e.g., z-stacks, sums, or cropped fields of view).
Args:
output_path (str): The directory where files will be written.
model (PtychoModel): The reconstruction model.
params (dict): The configuration dictionary dictating save preferences.
optimizer (torch.optim.Optimizer): The active optimizer.
scheduler (torch.optim.lr_scheduler.LRScheduler or None): The active LR scheduler, or None if not used.
niter (int): The current iteration number.
indices (list or numpy.ndarray): The batch indices.
batch_losses (dict): The recorded loss values.
collate_str (str, optional): An optional string injected into the
filenames (useful for distinguishing multiple concurrent outputs).
Defaults to ''.
"""
save_result_list = params['recon_params'].get('save_result', ['model', 'obj', 'probe'])
result_modes = params['recon_params'].get('result_modes')
iter_str = '_iter' + str(niter).zfill(4)
if 'model' in save_result_list:
hdf5_file_path = safe_filename(os.path.join(output_path, f"model{collate_str}{iter_str}.hdf5"))
save_dict = make_save_dict(output_path, model, params, optimizer, scheduler, niter, indices, batch_losses)
save_dict_to_hdf5(save_dict, hdf5_file_path)
provenance_json_str = generate_provenance_json(current_provenance=model.recon_provenance, params=params, output_filename=hdf5_file_path)
save_provenance_to_hdf5(hdf5_file_path, provenance_json_str)
probe = model.get_complex_probe_view()
probe_amp = probe.permute(1,0,2).flatten(1).abs().detach().cpu().numpy() # (pmode, Y, X) -> Permute (Y, pmode, X) -> Flatten (Y, pmode*X)
probe_prop = model.get_propagated_probe(np.array([0])) # Use np.array([0]) instead of [0] for indices is more consistent with types and safer with torch.compile
probe_prop_amp = probe_prop.permute(0,2,1,3).flatten(2).abs().detach().cpu().numpy() # (Z, pmode, Y, X) -> Permute (Z, Y, pmode, X) -> Flatten (Z, Y, pmode*X).
objp = model.opt_objp.detach().cpu().numpy()
obja = model.opt_obja.detach().cpu().numpy()
# omode_occu = model.omode_occu # Currently not used but we'll need it when omode_occu != 'uniform'
omode = model.opt_objp.size(0)
zslice = model.opt_objp.size(1)
crop_pos = model.crop_pos[indices].detach().cpu().numpy() + np.array(probe.shape[-2:])//2
y_min, y_max = crop_pos[:,0].min(), crop_pos[:,0].max()
x_min, x_max = crop_pos[:,1].min(), crop_pos[:,1].max()
for bit in result_modes['bit']:
if bit == '8':
bit_str = '_08bit'
elif bit == '16':
bit_str = '_16bit'
elif bit == '32':
bit_str = '_32bit'
elif bit == 'raw':
bit_str = ''
else:
bit_str = ''
if 'probe' in save_result_list:
imwrite(safe_filename(os.path.join(output_path, f"probe_amp{bit_str}{collate_str}{iter_str}.tif")), normalize_by_bit_depth(probe_amp, bit))
if 'probe_prop' in save_result_list:
imwrite(safe_filename(os.path.join(output_path, f"probe_prop_amp{bit_str}{collate_str}{iter_str}.tif")), normalize_by_bit_depth(probe_prop_amp, bit))
for fov in result_modes['FOV']:
if fov == 'crop':
fov_str = '_crop'
objp_crop = objp[:, :, y_min:y_max+1, x_min:x_max+1]
obja_crop = obja[:, :, y_min:y_max+1, x_min:x_max+1]
elif fov == 'full':
fov_str = ''
objp_crop = objp
obja_crop = obja
else:
fov_str = ''
objp_crop = objp
obja_crop = obja
postfix_str = fov_str + bit_str + collate_str + iter_str
if any(keyword in save_result_list for keyword in ['obj', 'objp', 'object']):
# TODO: For omode_occu != 'uniform', we should do a weighted sum across omode instead
for dim in result_modes['obj_dim']:
if omode == 1 and zslice == 1:
if dim == 2:
imwrite(safe_filename(os.path.join(output_path, f"objp{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[0,0], bit))
elif omode == 1 and zslice > 1:
if dim == 3:
imwrite(safe_filename(os.path.join(output_path, f"objp_zstack{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[0,:], bit))
if dim == 2:
imwrite(safe_filename(os.path.join(output_path, f"objp_zsum{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[0,:].sum(0), bit))
elif omode > 1 and zslice == 1:
if dim == 3:
imwrite(safe_filename(os.path.join(output_path, f"objp_ostack{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,0], bit))
if dim == 2:
imwrite(safe_filename(os.path.join(output_path, f"objp_omean{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,0].mean(0), bit))
imwrite(safe_filename(os.path.join(output_path, f"objp_ostd{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,0].std(0), bit))
else:
if dim == 4:
imwrite(safe_filename(os.path.join(output_path, f"objp_4D{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,:], bit))
if dim == 3:
imwrite(safe_filename(os.path.join(output_path, f"objp_ostack_zsum{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,:].sum(1), bit))
imwrite(safe_filename(os.path.join(output_path, f"objp_omean_zstack{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,:].mean(0), bit))
if dim == 2:
imwrite(safe_filename(os.path.join(output_path, f"objp_omean_zsum{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,:].mean(0).sum(0), bit))
if any(keyword in save_result_list for keyword in ['obja']):
# TODO: For omode_occu != 'uniform', we should do a weighted sum across omode instead
for dim in result_modes['obj_dim']:
if omode == 1 and zslice == 1:
if dim == 2:
imwrite(safe_filename(os.path.join(output_path, f"obja{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[0,0], bit))
elif omode == 1 and zslice > 1:
if dim == 3:
imwrite(safe_filename(os.path.join(output_path, f"obja_zstack{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[0,:], bit))
if dim == 2:
imwrite(safe_filename(os.path.join(output_path, f"obja_zmean{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[0,:].mean(0), bit))
imwrite(safe_filename(os.path.join(output_path, f"obja_zprod{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[0,:].prod(0), bit))
elif omode > 1 and zslice == 1:
if dim == 3:
imwrite(safe_filename(os.path.join(output_path, f"obja_ostack{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,0], bit))
if dim == 2:
imwrite(safe_filename(os.path.join(output_path, f"obja_omean{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,0].mean(0), bit))
imwrite(safe_filename(os.path.join(output_path, f"obja_ostd{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,0].std(0), bit))
else:
if dim == 4:
imwrite(safe_filename(os.path.join(output_path, f"obja_4D{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:], bit))
if dim == 3:
imwrite(safe_filename(os.path.join(output_path, f"obja_ostack_zmean{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:].mean(1), bit))
imwrite(safe_filename(os.path.join(output_path, f"obja_ostack_zprod{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:].prod(1), bit))
imwrite(safe_filename(os.path.join(output_path, f"obja_omean_zstack{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:].mean(0), bit))
if dim == 2:
imwrite(safe_filename(os.path.join(output_path, f"obja_omean_zmean{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:].mean(0).mean(0), bit))
imwrite(safe_filename(os.path.join(output_path, f"obja_omean_zprod{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:].mean(0).prod(0), bit))