"""
ConvergenceMonitor: periodic convergence metric tracking for optimizable tensors.
"""
import logging
from typing import Optional
import torch
from ptyrad.core.functional import approx_torch_quantile
logger = logging.getLogger(__name__)
[docs]
class ConvergenceMonitor:
"""
Tracks convergence of optimizable tensors during ptychographic reconstruction.
Takes periodic snapshots of tracked tensors and computes the iter-to-iter change
(relative to the previous snapshot) at each snapshot.
Tracked tensors: ``obja``, ``objp``, ``probe``, ``probe_pos_shifts``.
``slice_thickness`` and ``obj_tilts`` are excluded — they are already tracked every iteration
via ``model.dz_iters`` and ``model.avg_tilt_iters`` and fed directly to the dashboard.
For ``obja`` and ``objp``, metrics are computed on the ROI crop (scanned area bounding box)
only. ``obja`` is transformed as ``1 - obja`` so vacuum → 0 and material → >0. Two scalars
are stored per tensor per step using percentile masking on the current snapshot: a background
metric (pixels below ``p_low``) and a signal metric (pixels above ``p_high``). The results
are stored under keys ``obja_bg``, ``obja_fg``, ``objp_bg``, ``objp_fg``.
For ``probe``, the fractional intensity change (``sum|ΔI| / sum(I_prev)``) of mode-summed probe intensity is tracked.
For ``probe_pos_shifts``, the RMS displacement change in Å is tracked.
Results are stored in ``model.convergence_iters`` as a dict of lists of 2-tuples
``(niter, value)``.
Args:
params: Parsed ``ConvergenceMonitorParams`` dict (with keys ``tensors``,
``every_n_iters``, ``percentile_range``).
model: ``PtychoModel`` instance. An initial snapshot is taken during ``__init__``
so the baseline is the state before the first optimizer update.
"""
# Maps tensor name → metric type used by _compute_metric (obja/objp use _compute_bg_fg_metric)
_METRIC_TYPE = {
"probe": "norm_l1",
"probe_pos_shifts": "rms",
}
def __init__(self, params: dict, model) -> None:
self._tensors: list = list(params["tensors"])
self._every_n: Optional[int] = params.get("every_n_iters")
self._percentile_range: list = list(params.get("percentile_range", [15.0, 85.0]))
self._dx: float = float(model.dx) # pixel size [Å]; used to convert probe_pos_shifts to Å
# Precompute scanned-area ROI from all scan positions — matches save.py crop convention.
# crop_pos stores top-left corners of probe patches; adding probe_half gives center positions.
with torch.no_grad():
probe_half = torch.tensor(model.get_complex_probe_view().shape[-2:]).cpu() // 2
centers = model.crop_pos.cpu() + probe_half # (N_scans, 2)
self._y_min = int(centers[:, 0].min().item())
self._y_max = int(centers[:, 0].max().item())
self._x_min = int(centers[:, 1].min().item())
self._x_max = int(centers[:, 1].max().item())
self._prev: dict = {}
for name in self._tensors:
snaps = self._snapshot(model, name)
for key, tensor in snaps.items():
self._prev[key] = tensor.clone()
logger.info(
f"### Creating ConvergenceMonitor with {params} ### ")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def should_step(self, niter: int, save_iters: Optional[int]) -> bool:
"""Return True if a convergence snapshot should be taken at ``niter``."""
if self._every_n is not None:
return niter % self._every_n == 0
if save_iters is not None:
return niter % save_iters == 0
return False
[docs]
def step(self, model, niter: int) -> None:
"""Compute and record convergence metrics for all tracked tensors."""
for name in self._tensors:
snaps = self._snapshot(model, name)
for key, current in snaps.items():
if key in ("obja", "objp"):
bg_change, fg_change = self._compute_bg_fg_metric(
current, self._prev[key], self._percentile_range
)
model.convergence_iters[f"{key}_bg"].append((niter, bg_change))
model.convergence_iters[f"{key}_fg"].append((niter, fg_change))
else:
metric_type = self._METRIC_TYPE[key]
iter_change = self._compute_metric(current, self._prev[key], metric_type)
if key == "probe_pos_shifts":
iter_change *= self._dx
model.convergence_iters[key].append((niter, iter_change))
self._prev[key] = current.clone()
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _snapshot(self, model, name: str) -> dict:
"""
Return a dict of detached CPU float32 tensors for the given parameter name.
For 'obja': returns ``{"obja": 1 - roi}`` where roi is the scanned-area crop of obja.
For 'objp': returns ``{"objp": roi}`` (scanned-area crop, no transform).
For 'probe': returns ``{"probe": probe_intensity}``.
For others: returns ``{name: tensor}``.
"""
if name == "obja":
obj = model.optimizable_tensors["obja"].detach().float()
roi = obj[:, :, self._y_min - 1:self._y_max, self._x_min - 1:self._x_max]
return {"obja": (1.0 - roi).cpu()}
if name == "objp":
obj = model.optimizable_tensors["objp"].detach().float()
roi = obj[:, :, self._y_min - 1:self._y_max, self._x_min - 1:self._x_max]
return {"objp": roi.cpu()}
if name == "probe":
probe_c = model.get_complex_probe_view().detach()
probe_int = probe_c.abs().square().sum(0).cpu()
return {"probe": probe_int}
return {name: model.optimizable_tensors[name].detach().float().cpu()}
@staticmethod
def _compute_metric(current: torch.Tensor, reference: torch.Tensor, metric_type: str) -> float:
"""Compute a scalar convergence metric between current and reference tensors."""
diff = current - reference
if metric_type == "norm_l1":
# Fractional intensity change: total absolute change as fraction of total reference intensity
return (diff.abs().sum() / (reference.sum() + 1e-8)).item()
if metric_type == "rms":
# Per-position RMS displacement magnitude (for (N, 2) probe_pos_shifts tensors)
return diff.pow(2).sum(dim=-1).mean().sqrt().item()
raise ValueError(f"Unknown metric_type: {metric_type!r}")
@staticmethod
def _compute_bg_fg_metric(
current: torch.Tensor,
reference: torch.Tensor,
percentile_range: list,
) -> tuple:
"""Compute background and signal mean absolute change using percentile masking.
Percentiles are computed on the flattened current snapshot. Background mask selects
pixels below p_low (vacuum region); signal mask selects pixels above p_high (material).
Returns (bg_change, fg_change) as floats.
"""
flat_curr = current.flatten()
p_lo = approx_torch_quantile(flat_curr, percentile_range[0] / 100.0).item()
p_hi = approx_torch_quantile(flat_curr, percentile_range[1] / 100.0).item()
flat_diff = (current - reference).abs().flatten()
bg_mask = flat_curr < p_lo
fg_mask = flat_curr > p_hi
bg_change = flat_diff[bg_mask].mean().item() if bg_mask.any() else 0.0
fg_change = flat_diff[fg_mask].mean().item() if fg_mask.any() else 0.0
return bg_change, fg_change
[docs]
def create_convergence_monitor(convergence_monitor_params, model) -> Optional[ConvergenceMonitor]:
"""
Factory that returns a ``ConvergenceMonitor`` when ``convergence_monitor_params`` is not None.
Args:
convergence_monitor_params: Parsed dict from ``ReconParams.convergence_monitor``,
or ``None`` to disable monitoring.
model: ``PtychoModel`` instance used for the initial snapshot.
Returns:
A configured ``ConvergenceMonitor``, or None if params is None.
"""
if convergence_monitor_params is None:
return None
return ConvergenceMonitor(convergence_monitor_params, model)