Source code for ptyrad.io.save

"""
PtyRAD-specific saving functions

"""

import logging
import os
from typing import Any, Dict

import h5py
import numpy as np
import torch
from tifffile import imwrite

from ptyrad.io.provenance import generate_provenance_json, save_provenance_to_hdf5
from ptyrad.utils.image_proc import normalize_by_bit_depth
from ptyrad.utils.time import get_time

logger = logging.getLogger(__name__)

###### These are results saving functions ######

[docs] def expand_presets(input_list, presets): """Expands a list of tags by replacing preset keys with their corresponding lists. If a tag in the `input_list` exists as a key in the `presets` dictionary, it is replaced by the list of tags associated with that preset. Duplicates are removed while preserving the original order of the tags. Args: input_list (list of str): The initial list of tags or preset names. presets (dict): A dictionary mapping a preset string to a list of strings. Returns: list of str: The expanded list of tags with duplicates removed. """ expanded = [] for tag in input_list: if tag in presets: expanded.extend(presets[tag]) else: expanded.append(tag) return list(dict.fromkeys(expanded)) # Removes duplicates, keeps order
[docs] def safe_filename(filepath): """Ensures a filepath is safe and valid across different operating systems. This function prevents path-related crashes by: 1. Converting relative paths to absolute paths. 2. Truncating individual directory or file components to 255 characters. 3. Handling the Windows 260-character total path limit by applying the "\\\\?\\" extended-length path prefix if necessary. Args: filepath (str): The original filepath to sanitize. Returns: str: A modified, cross-platform-safe absolute filepath. """ import platform # Store original path for reporting original_path = filepath # Handle relative paths by converting to absolute filepath = os.path.abspath(filepath) # Platform detection is_windows = platform.system() == 'Windows' # Check if path already has long path prefix on Windows has_long_prefix = is_windows and filepath.startswith("\\\\?\\") # Early return if path is already valid if not has_long_prefix: # Check individual component limit (255 chars) components_valid = True sep = '\\' if is_windows else '/' parts = filepath.split(sep) for part in parts: if len(part) > 255: components_valid = False break # Check total path length limit length_valid = (len(filepath) <= 260) if is_windows else True # If everything is valid, return the absolute path if components_valid and length_valid: return filepath # Path requires correction - continue with fixing logic # Path separator based on platform sep = '\\' if is_windows else '/' # Split path into directory and filename directory, filename = os.path.split(filepath) # Track if any changes were made changes_made = False # Limit filename component to 255 chars (preserve extension) if len(filename) > 255: changes_made = True name, ext = os.path.splitext(filename) max_name_length = 255 - len(ext) filename = name[:max_name_length] + ext # Handle directory components (limit each to 255 chars) if directory: parts = directory.split(sep) for i, part in enumerate(parts): if len(part) > 255: changes_made = True parts[i] = part[:255] directory = sep.join(parts) # Recombine path result_path = os.path.join(directory, filename) # Handle Windows total path length if is_windows and len(result_path) > 260: changes_made = True # If still too long, apply the \\?\ prefix for long path support if not result_path.startswith("\\\\?\\"): # Ensure we're working with an absolute path for the \\?\ prefix result_path = "\\\\?\\" + os.path.abspath(result_path) # Provide feedback if corrections were made if changes_made: logger.info("Path corrected for compatibility:") logger.info(f" Original: {original_path}") logger.info(f" Corrected: {result_path}") return result_path
[docs] def make_save_dict(output_path, model, params, optimizer, scheduler, niter, indices, batch_losses): """Compiles the model state, parameters, and optimizer data into a dictionary. This explicitly extracts and formats the current runtime attributes from the model (e.g., actual grid sizes, learning rates) rather than relying solely on the initial parameter file, as values may have changed during initialization (e.g., cropping, resampling) or interactive walkthroughs. Args: output_path (str): The directory path where results will be saved. model (PtychoModel): The current reconstruction model object. params (dict): The initial configuration dictionary. optimizer (torch.optim.Optimizer): The PyTorch optimizer. niter (int): The current iteration number. indices (list or numpy.ndarray): The batch indices used in the current iteration. batch_losses (dict): A dictionary mapping loss names to lists of batch loss values. Returns: dict: A comprehensive dictionary ready for HDF5 serialization. """ avg_losses = {name: np.mean(values) for name, values in batch_losses.items()} avg_iter_t = np.mean(model.iter_times) # While it might seem redundant to save bothe `params` and lots of `model_attributes`, # one should note that `params` only stores the initial value from params files, # the actual values used for reconstuction such as N_scan_slow, N_scan_fast, dx, dk, Npix, N_scans could be different from initial value due to the meas_crop, meas_resample # the model behavior and learning rates could also be different from the initial params dict if the user # run the reconstuction with manually modified `model_params` in the detailed walkthrough notebook # Postprocess the opt_probe back to complex view optimizable_tensors = {} for name, tensor in model.optimizable_tensors.items(): optimizable_tensors[name] = tensor.detach().clone() if name == 'probe': optimizable_tensors['probe'] = model.get_complex_probe_view().detach().clone() # Postprocess the scheduler_state_dict if (scheduler is not None and 'scheduler_state' in params['recon_params']['save_result']): scheduler_state_dict = scheduler.state_dict().copy() del scheduler_state_dict['step'] # Remove the monkey-patched closure — it's not serializable and not part of resumable state else: scheduler_state_dict = None from ptyrad import __version__ as ptyrad_version save_dict = { 'ptyrad_version' : ptyrad_version, 'output_path' : output_path, 'optimizable_tensors' : optimizable_tensors, 'optim_state_dict' : optimizer.state_dict() if 'optim_state' in params['recon_params']['save_result'] else None, 'scheduler_state_dict' : scheduler_state_dict, 'params' : params, 'model_attributes': # Have to do this explicit saving because I want specific fields but don't want the enitre model with grids and other redundant info {'detector_blur_std': model.detector_blur_std, 'start_iter' : model.start_iter, 'lr_params' : model.lr_params, 'omode_occu' : model.omode_occu, 'H' : model.H, 'N_scan_slow' : model.N_scan_slow, 'N_scan_fast' : model.N_scan_fast, 'crop_pos' : model.crop_pos, 'slice_thickness' : model.slice_thickness, 'dx' : model.dx, 'dk' : model.dk, 'lambd' : model.lambd, 'meas_Npix' : model.meas_Npix, 'simu_Npix' : model.simu_Npix, 'simu_match_mode' : model.simu_match_mode, 'scan_affine' : model.scan_affine, 'tilt_obj' : model.tilt_obj, 'probe_int_sum' : model.probe_int_sum }, 'random_seed' : model.random_seed, 'loss_iters' : model.loss_iters, 'iter_times' : model.iter_times, 'dz_iters' : model.dz_iters, 'lr_iters' : model.lr_iters, 'avg_tilt_iters' : model.avg_tilt_iters, 'convergence_iters' : dict(model.convergence_iters), 'avg_iter_t' : avg_iter_t, 'niter' : niter, 'indices' : indices, 'batch_losses' : batch_losses, 'avg_losses' : avg_losses } return save_dict
[docs] def save_dict_to_hdf5( d: Dict[str, Any], output_path: str, none_sentinel: str = "__NONE__", **kwargs ) -> None: """Saves a nested Python dictionary to an HDF5 file. Recursively parses the dictionary and converts common Python, NumPy, and PyTorch types into HDF5-compatible formats. Non-compatible types (e.g., lists of tuples, `None`) are converted to safe string representations or sentinels. Integer keys (common in optimizer state dicts) are coerced to strings. Args: d (dict): The nested dictionary to serialize. output_path (str): The target file path for the `.hdf5` file. none_sentinel (str, optional): The string used to represent `None` values in the HDF5 file. Defaults to "__NONE__". **kwargs: Additional keyword arguments passed to `h5py.File.create_dataset()` (e.g., `compression="gzip"`). """ def _recursively_save_dict_to_hdf5(d: Dict[str, Any], h5group: h5py.Group, path="") -> None: for key, value in d.items(): full_key = f"{path}/{key}" if path else str(key) key = str(key) # convert to string for HDF5, especially important for optimizer state dict with integer as key try: # Delete existing group/dataset if it exists if key in h5group: del h5group[key] if value is None: h5group.create_dataset(key, data=none_sentinel, **kwargs) elif isinstance(value, dict): subgroup = h5group.create_group(key) _recursively_save_dict_to_hdf5(value, subgroup) elif isinstance(value, list): if all(isinstance(i, (int, float, np.number)) for i in value): h5group.create_dataset(key, data=np.array(value), **kwargs) elif all(isinstance(i, str) for i in value): dt = h5py.special_dtype(vlen=str) h5group.create_dataset(key, data=np.array(value, dtype=dt), **kwargs) elif all(isinstance(i, tuple) for i in value): try: arr = np.array([list(t) for t in value]) h5group.create_dataset(key, data=arr, **kwargs) except Exception: h5group.create_dataset(key, data=str(value), **kwargs) elif all(isinstance(i, dict) for i in value): subgroup = h5group.create_group(key) for idx, item in enumerate(value): item_group = subgroup.create_group(str(idx)) _recursively_save_dict_to_hdf5(item, item_group) elif all(isinstance(i, (np.ndarray, torch.Tensor)) for i in value): try: arr = np.stack([i.detach().cpu().numpy() if isinstance(i, torch.Tensor) else i for i in value]) h5group.create_dataset(key, data=arr, **kwargs) except Exception: h5group.create_dataset(key, data=str(value), **kwargs) else: # fallback to storing list as strings (warn if needed) h5group.create_dataset(key, data=str(value), **kwargs) elif isinstance(value, tuple): h5group.create_dataset(key, data=np.array(value), **kwargs) elif isinstance(value, (int, float, str, np.number)): h5group.create_dataset(key, data=value, **kwargs) elif isinstance(value, torch.Tensor): h5group.create_dataset(key, data=value.detach().cpu().numpy(), **kwargs) elif isinstance(value, np.ndarray): h5group.create_dataset(key, data=value, **kwargs) # Fallback option else: h5group.create_dataset(key, data=str(value), **kwargs) except Exception as e: raise RuntimeError(f"Failed to save key '{key}' (full path: '{full_key}') of type {type(value)}") from e with h5py.File(output_path, "w") as hf: _recursively_save_dict_to_hdf5(d, hf)
[docs] def make_output_folder( output_dir, indices, init_params, recon_params, model, constraint_params, loss_params, recon_dir_affixes=["default"], ): """Generates a highly descriptive output folder name based on runtime configurations. Constructs a detailed directory name by concatenating abbreviations of the active reconstruction parameters, model attributes, constraints, and losses. This allows users to identify the exact settings of a run just by looking at the folder name. Args: output_dir (str): The base directory where the output folder will be created. indices (list): The list of indices used in the reconstruction. init_params (dict): Initialization parameters. recon_params (dict): Reconstruction parameters containing naming prefixes/postfixes. model (PtychoModel): The model containing current spatial and optimization attributes. constraint_params (dict): Dictionary of active constraints and filters. loss_params (dict): Dictionary of active loss functions and weights. recon_dir_affixes (list of str, optional): A list of tags or preset keys dictating which parameters to include in the folder name. Defaults to ["default"]. Returns: str: The sanitized absolute path to the newly generated output folder. """ prefix_time = recon_params.get("prefix_time", False) prefix = recon_params.get("prefix", "") postfix = recon_params.get("postfix", "") parts = [] recon_dir_presets = { "minimal": ['indices', 'meas', 'batch', 'pmode', 'omode', 'nlayer'], "default": ['indices', 'meas', 'batch', 'pmode', 'omode', 'nlayer', 'lr', 'model', 'constraint', 'loss', 'affine', 'tilt', 'aberrations'], "all": ['indices', 'meas', 'batch', 'pmode', 'omode', 'nlayer', 'optimizer', 'scheduler', 'start_iter', 'lr', 'model', 'constraint', 'loss', 'conv_angle', 'aberrations', 'Ls', 'z_shift', 'dx', 'affine', 'tilt'] } # Process recon_dir_affixes to expand presets if any(tag in recon_dir_presets for tag in recon_dir_affixes): logger.info(f"Original recon_dir_affixes = {recon_dir_affixes}") recon_dir_affixes = expand_presets(recon_dir_affixes, recon_dir_presets) logger.info(f"Expanded recon_dir_affixes = {recon_dir_affixes}") # Attach time string if prefix_time is true or non-empty str if prefix_time is True or (isinstance(prefix_time, str) and prefix_time): time_str = get_time(prefix_time) # e.g. '20250606' parts.append(time_str) # Attach prefix (only if prefix is non-empty str) if isinstance(prefix, str) and prefix: parts.append(prefix) # Attach indices mode (optional) if "indices" in recon_dir_affixes: indices_mode = recon_params["INDICES_MODE"].get("mode") parts.append(f"{indices_mode}_N{len(indices)}") # Attach DP size and meas flip (optional) if "meas" in recon_dir_affixes: dp_size = model.get_complex_probe_view().size(-1) parts.append(f"dp{dp_size}") meas_flipT = init_params["meas_flipT"] # Attach meas flipping if meas_flipT is not None: # Note that [0,0,0] will be attached is specified for clarity flipT_str = "flipT" + "".join(str(x) for x in meas_flipT) parts.append(flipT_str) # Attach group mode and batch size (optional) if "batch" in recon_dir_affixes: group_mode = recon_params["GROUP_MODE"] batch_size = recon_params["BATCH_SIZE"].get("size") grad_accum = recon_params["BATCH_SIZE"].get("grad_accumulation", 1) batch_size *= grad_accum # Affix the effective batch size parts.append(f"{group_mode}{batch_size}") # Attach pmode (optional) if "pmode" in recon_dir_affixes: pmode = model.get_complex_probe_view().size(0) parts.append(f"p{pmode}") # Attach omode (optional) if "omode" in recon_dir_affixes: omode = model.opt_objp.size(0) parts.append(f"{omode}obj") # Attach obj Nlayer and dz (optional) if "nlayer" in recon_dir_affixes: nlayer = model.opt_objp.size(1) parts.append(f"{nlayer}slice") if nlayer != 1: slice_thickness = ( model.slice_thickness.detach().cpu().numpy() ) # This is the initialized slice thickness parts.append(f"dz{slice_thickness:.3g}") # Attach optimizer name (optional) if "optimizer" in recon_dir_affixes: optimizer_str = model.optimizer_params["name"] parts.append(f"{optimizer_str}") # Attach scheduler name (optional) if "scheduler" in recon_dir_affixes and model.scheduler_params is not None: scheduler_str = model.scheduler_params["name"] parts.append(f"{scheduler_str}") # Attach start_iter (optional) if "start_iter" in recon_dir_affixes: start_iter_map = { "probe": "ps", "obja": "oas", "objp": "ops", "probe_pos_shifts": "ss", "obj_tilts": "ts", "slice_thickness": "dzs", } for key, tag in start_iter_map.items(): start_val = model.start_iter.get(key) if start_val is not None and start_val > 1: parts.append(f"{tag}{start_val}") # Attach learning rate (optional) if "lr" in recon_dir_affixes: lr_map = { "probe": "plr", "obja": "oalr", "objp": "oplr", "probe_pos_shifts": "slr", "obj_tilts": "tlr", "slice_thickness": "dzlr", } for key, tag in lr_map.items(): lr_val = model.lr_params[key] if lr_val != 0: lr_str = format(lr_val, ".0e").replace("e-0", "e-") parts.append(f"{tag}{lr_str}") # Attach model params (optional) if "model" in recon_dir_affixes: if model.detector_blur_std is not None and model.detector_blur_std != 0: parts.append(f"dpblur{model.detector_blur_std}") # Attach constraint params (optional) if "constraint" in recon_dir_affixes: if constraint_params["kr_filter"]["start_iter"] is not None: obj_type = constraint_params["kr_filter"]["obj_type"] kr_str = {"both": "kr", "amplitude": "kra", "phase": "krp"}.get(obj_type) radius = constraint_params["kr_filter"]["radius"] parts.append(f"{kr_str}f{radius}") if constraint_params["kz_filter"]["start_iter"] is not None: obj_type = constraint_params["kz_filter"]["obj_type"] kz_str = {"both": "kz", "amplitude": "kza", "phase": "kzp"}.get(obj_type) beta = constraint_params["kz_filter"]["beta"] parts.append(f"{kz_str}f{beta}") if constraint_params["kr_thresh"]["start_iter"] is not None: obj_type = constraint_params["kr_thresh"]["obj_type"] krt_str = {"both": "krt", "amplitude": "krta", "phase": "krtp"}.get(obj_type) thresh = constraint_params["kr_thresh"]["thresh"] parts.append(f"{krt_str}{thresh}") if ( constraint_params["obj_rblur"]["start_iter"] is not None and constraint_params["obj_rblur"]["std"] != 0 ): obj_type = constraint_params["obj_rblur"]["obj_type"] obj_str = {"both": "o", "amplitude": "oa", "phase": "op"}.get(obj_type) parts.append(f"{obj_str}rblur{constraint_params['obj_rblur']['std']}") if ( constraint_params["obj_zblur"]["start_iter"] is not None and constraint_params["obj_zblur"]["std"] != 0 ): obj_type = constraint_params["obj_zblur"]["obj_type"] obj_str = {"both": "o", "amplitude": "oa", "phase": "op"}.get(obj_type) parts.append(f"{obj_str}zblur{constraint_params['obj_zblur']['std']}") if constraint_params["complex_ratio"]["start_iter"] is not None: obj_type = constraint_params["complex_ratio"]["obj_type"] obj_str = {"both": "o", "amplitude": "oa", "phase": "op"}.get(obj_type) alpha1 = round(constraint_params["complex_ratio"]["alpha1"], 2) alpha2 = round(constraint_params["complex_ratio"]["alpha2"], 2) parts.append(f"{obj_str}cplx{alpha1}_{alpha2}") if constraint_params["mirrored_amp"]["start_iter"] is not None: scale = round(constraint_params["mirrored_amp"]["scale"], 2) power = round(constraint_params["mirrored_amp"]["power"], 2) parts.append(f"mamp{scale}_{power}") if constraint_params["obj_z_recenter"]["start_iter"] is not None: parts.append("ozrec") if constraint_params["obja_thresh"]["start_iter"] is not None: parts.append(f"oathr{round(constraint_params['obja_thresh']['thresh'][0], 2)}") if constraint_params["objp_postiv"]["start_iter"] is not None: mode = constraint_params["objp_postiv"].get("mode", "clip_neg") mode_str = "s" if mode == "subtract_min" else "c" relax = constraint_params["objp_postiv"]["relax"] relax_str = "" if relax == 0 else f"{round(relax, 2)}" parts.append(f"opos{mode_str}{relax_str}") if constraint_params["tilt_smooth"]["start_iter"] is not None: parts.append(f"tsm{round(constraint_params['tilt_smooth']['std'], 2)}") if constraint_params["probe_mask_k"]["start_iter"] is not None: parts.append(f"pmk{round(constraint_params['probe_mask_k']['radius'], 2)}") if constraint_params["probe_mask_r"]["start_iter"] is not None: parts.append(f"pmr{round(constraint_params['probe_mask_r']['radius'], 2)}") # Attach loss params (optional) if "loss" in recon_dir_affixes: loss_map = { "loss_single": ("sng", 2), "loss_poissn": ("psn", 2), "loss_pacbed": ("pcb", 2), "loss_sparse": ("spr", 2), "loss_simlar": ("sml", 2), } for key, (tag, digits) in loss_map.items(): loss = loss_params.get(key, {}) if loss.get("state"): parts.append(f"{tag}{round(loss.get('weight', 0), digits)}") # Attach conv_angle (optional) if "conv_angle" in recon_dir_affixes and "probe_conv_angle" in init_params: parts.append(f"ca{init_params['probe_conv_angle']:.3g}") # Attach aberrations (optional) if "aberrations" in recon_dir_affixes and "probe_aberrations" in init_params: init_aberrations = init_params.get("probe_aberrations") if init_aberrations: for k, v in init_aberrations.items(): parts.append(f"{k}_{v:.3g}") # Note that the user values are canonicalized to Krivanek polar during load_params # Attach probe_Ls (optional, for xray only) if "Ls" in recon_dir_affixes and "probe_Ls" in init_params: init_Ls = init_params.get("probe_Ls") if init_Ls is not None: parts.append(f"Ls{init_Ls * 1e9:.0f}") # Attach probe_z_shift (optional) if "z_shift" in recon_dir_affixes and "probe_z_shift" in init_params: init_z_shift = init_params.get("probe_z_shift") if init_z_shift is not None and init_z_shift != 0: parts.append(f"z_shift{init_z_shift:.3g}") # Attach dx (optional) if "dx" in recon_dir_affixes: dx = model.dx.detach().cpu().numpy() parts.append(f"dx{dx:.4g}") # Attach scan_affine (optional) if "affine" in recon_dir_affixes: scan_affine = model.scan_affine # Note that scan_affine could be None if scan_affine is not None and not np.allclose(scan_affine, [1, 0, 0, 0]): formats = [".2f", ".2f", ".1f", ".1f"] # customize per index formatted = [format(x, fmt) for x, fmt in zip(scan_affine, formats)] affine_str = "aff" + "_".join(formatted) # (4,) parts.append(f"{affine_str}") # Attach init tilts (optional) if "tilt" in recon_dir_affixes: init_tilts = ( model.opt_obj_tilts.mean(0).detach().cpu().numpy() ) # (2,) regardless tilt_type = 'all' or 'each' if np.any(init_tilts): parts.append(f"tilt{init_tilts[0]:.2g}_{init_tilts[1]:.2g}") # Attach postfix (only if postfix is non-empty str) if isinstance(postfix, str) and postfix: parts.append(postfix) # Make output folder output_path = os.path.join(output_dir, "_".join(parts)) if parts else output_dir output_path = safe_filename(output_path) os.makedirs(output_path, exist_ok=True) logger.info(f"output_path = '{output_path}' is generated!") return output_path
[docs] def save_results(output_path, model, params, optimizer, scheduler, niter, indices, batch_losses, collate_str=''): """Exports the reconstruction model state and renders image outputs. This function acts as the main saving hub. Depending on the `save_result` and `result_modes` specifications in `params`, it saves the full state to an HDF5 file (including provenance) and renders the complex object and probe arrays out to `.tif` images with appropriate bit-depth scaling and dimensional slicing (e.g., z-stacks, sums, or cropped fields of view). Args: output_path (str): The directory where files will be written. model (PtychoModel): The reconstruction model. params (dict): The configuration dictionary dictating save preferences. optimizer (torch.optim.Optimizer): The active optimizer. scheduler (torch.optim.lr_scheduler.LRScheduler or None): The active LR scheduler, or None if not used. niter (int): The current iteration number. indices (list or numpy.ndarray): The batch indices. batch_losses (dict): The recorded loss values. collate_str (str, optional): An optional string injected into the filenames (useful for distinguishing multiple concurrent outputs). Defaults to ''. """ save_result_list = params['recon_params'].get('save_result', ['model', 'obj', 'probe']) result_modes = params['recon_params'].get('result_modes') iter_str = '_iter' + str(niter).zfill(4) if 'model' in save_result_list: hdf5_file_path = safe_filename(os.path.join(output_path, f"model{collate_str}{iter_str}.hdf5")) save_dict = make_save_dict(output_path, model, params, optimizer, scheduler, niter, indices, batch_losses) save_dict_to_hdf5(save_dict, hdf5_file_path) provenance_json_str = generate_provenance_json(current_provenance=model.recon_provenance, params=params, output_filename=hdf5_file_path) save_provenance_to_hdf5(hdf5_file_path, provenance_json_str) probe = model.get_complex_probe_view() probe_amp = probe.permute(1,0,2).flatten(1).abs().detach().cpu().numpy() # (pmode, Y, X) -> Permute (Y, pmode, X) -> Flatten (Y, pmode*X) probe_prop = model.get_propagated_probe(np.array([0])) # Use np.array([0]) instead of [0] for indices is more consistent with types and safer with torch.compile probe_prop_amp = probe_prop.permute(0,2,1,3).flatten(2).abs().detach().cpu().numpy() # (Z, pmode, Y, X) -> Permute (Z, Y, pmode, X) -> Flatten (Z, Y, pmode*X). objp = model.opt_objp.detach().cpu().numpy() obja = model.opt_obja.detach().cpu().numpy() # omode_occu = model.omode_occu # Currently not used but we'll need it when omode_occu != 'uniform' omode = model.opt_objp.size(0) zslice = model.opt_objp.size(1) crop_pos = model.crop_pos[indices].detach().cpu().numpy() + np.array(probe.shape[-2:])//2 y_min, y_max = crop_pos[:,0].min(), crop_pos[:,0].max() x_min, x_max = crop_pos[:,1].min(), crop_pos[:,1].max() for bit in result_modes['bit']: if bit == '8': bit_str = '_08bit' elif bit == '16': bit_str = '_16bit' elif bit == '32': bit_str = '_32bit' elif bit == 'raw': bit_str = '' else: bit_str = '' if 'probe' in save_result_list: imwrite(safe_filename(os.path.join(output_path, f"probe_amp{bit_str}{collate_str}{iter_str}.tif")), normalize_by_bit_depth(probe_amp, bit)) if 'probe_prop' in save_result_list: imwrite(safe_filename(os.path.join(output_path, f"probe_prop_amp{bit_str}{collate_str}{iter_str}.tif")), normalize_by_bit_depth(probe_prop_amp, bit)) for fov in result_modes['FOV']: if fov == 'crop': fov_str = '_crop' objp_crop = objp[:, :, y_min:y_max+1, x_min:x_max+1] obja_crop = obja[:, :, y_min:y_max+1, x_min:x_max+1] elif fov == 'full': fov_str = '' objp_crop = objp obja_crop = obja else: fov_str = '' objp_crop = objp obja_crop = obja postfix_str = fov_str + bit_str + collate_str + iter_str if any(keyword in save_result_list for keyword in ['obj', 'objp', 'object']): # TODO: For omode_occu != 'uniform', we should do a weighted sum across omode instead for dim in result_modes['obj_dim']: if omode == 1 and zslice == 1: if dim == 2: imwrite(safe_filename(os.path.join(output_path, f"objp{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[0,0], bit)) elif omode == 1 and zslice > 1: if dim == 3: imwrite(safe_filename(os.path.join(output_path, f"objp_zstack{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[0,:], bit)) if dim == 2: imwrite(safe_filename(os.path.join(output_path, f"objp_zsum{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[0,:].sum(0), bit)) elif omode > 1 and zslice == 1: if dim == 3: imwrite(safe_filename(os.path.join(output_path, f"objp_ostack{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,0], bit)) if dim == 2: imwrite(safe_filename(os.path.join(output_path, f"objp_omean{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,0].mean(0), bit)) imwrite(safe_filename(os.path.join(output_path, f"objp_ostd{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,0].std(0), bit)) else: if dim == 4: imwrite(safe_filename(os.path.join(output_path, f"objp_4D{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,:], bit)) if dim == 3: imwrite(safe_filename(os.path.join(output_path, f"objp_ostack_zsum{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,:].sum(1), bit)) imwrite(safe_filename(os.path.join(output_path, f"objp_omean_zstack{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,:].mean(0), bit)) if dim == 2: imwrite(safe_filename(os.path.join(output_path, f"objp_omean_zsum{postfix_str}.tif")), normalize_by_bit_depth(objp_crop[:,:].mean(0).sum(0), bit)) if any(keyword in save_result_list for keyword in ['obja']): # TODO: For omode_occu != 'uniform', we should do a weighted sum across omode instead for dim in result_modes['obj_dim']: if omode == 1 and zslice == 1: if dim == 2: imwrite(safe_filename(os.path.join(output_path, f"obja{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[0,0], bit)) elif omode == 1 and zslice > 1: if dim == 3: imwrite(safe_filename(os.path.join(output_path, f"obja_zstack{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[0,:], bit)) if dim == 2: imwrite(safe_filename(os.path.join(output_path, f"obja_zmean{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[0,:].mean(0), bit)) imwrite(safe_filename(os.path.join(output_path, f"obja_zprod{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[0,:].prod(0), bit)) elif omode > 1 and zslice == 1: if dim == 3: imwrite(safe_filename(os.path.join(output_path, f"obja_ostack{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,0], bit)) if dim == 2: imwrite(safe_filename(os.path.join(output_path, f"obja_omean{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,0].mean(0), bit)) imwrite(safe_filename(os.path.join(output_path, f"obja_ostd{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,0].std(0), bit)) else: if dim == 4: imwrite(safe_filename(os.path.join(output_path, f"obja_4D{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:], bit)) if dim == 3: imwrite(safe_filename(os.path.join(output_path, f"obja_ostack_zmean{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:].mean(1), bit)) imwrite(safe_filename(os.path.join(output_path, f"obja_ostack_zprod{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:].prod(1), bit)) imwrite(safe_filename(os.path.join(output_path, f"obja_omean_zstack{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:].mean(0), bit)) if dim == 2: imwrite(safe_filename(os.path.join(output_path, f"obja_omean_zmean{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:].mean(0).mean(0), bit)) imwrite(safe_filename(os.path.join(output_path, f"obja_omean_zprod{postfix_str}.tif")), normalize_by_bit_depth(obja_crop[:,:].mean(0).prod(0), bit))