Source code for ptyrad.analysis.extract

"""
Pure extraction functions for PtyRAD ``model.hdf5`` outputs.

Every public extraction worker accepts either the full loaded dict (from
:func:`ptyrad.io.load.load_ptyrad`) or pre-sliced sub-dicts
(``opt_tensors``, ``model_attrs``) so callers can use these helpers
without instantiating :class:`ptyrad.analysis.Analyzer`.

Geometry / FOV helpers (``get_probe_center_positions``,
``get_scanned_fov_bbox``, ``apply_fov``) are re-exported on purpose:
they encode the single most error-prone convention in the saved file
(``crop_pos`` is the top-left of the probe window, not the center).
The public position extractor exposes both conventions explicitly via
``target='probe'`` and ``target='crop'``.
"""

from __future__ import annotations

import os
from typing import Any, Literal, Mapping

import numpy as np

from ptyrad.io.hierarchy import get_nested, list_nested_keys


# ---------------------------------------------------------------------------
# Geometry / FOV helpers (single source of truth)
# ---------------------------------------------------------------------------


def _add_shifts(pos: np.ndarray, shifts: np.ndarray | None) -> np.ndarray:
    """Add ``shifts`` to ``pos`` after a shape check. Returns ``pos`` unchanged if shifts is None."""
    if shifts is None:
        return pos
    if shifts.shape != pos.shape:
        raise ValueError(
            f"probe_pos_shifts shape {shifts.shape} does not match positions shape {pos.shape}"
        )
    return pos + shifts


[docs] def get_probe_center_positions( crop_pos: np.ndarray, probe_shape: tuple[int, int], probe_pos_shifts: np.ndarray | None = None, ) -> np.ndarray: """Convert top-left ``crop_pos`` to probe-center positions. The ``crop_pos`` stored in ``model.hdf5`` is the top-left corner of the probe window in object pixels — that is the form consumed by :meth:`PtychoModel.get_obj_patches`, not the visual probe center. The probe center sits ``(Hp // 2, Wp // 2)`` further in. Parameters ---------- crop_pos : array-like, shape (N, 2), int Top-left ``[y, x]`` of each probe window in object pixels. probe_shape : tuple Probe shape; only the last two dims ``(Hp, Wp)`` are used. probe_pos_shifts : array-like, shape (N, 2), optional Optimized sub-pixel offsets; added to the center positions when given. Returns ------- np.ndarray Shape ``(N, 2)``, ``float32``, ``[y, x]`` order, in object pixels. """ crop_pos = np.asarray(crop_pos) Hp, Wp = int(probe_shape[-2]), int(probe_shape[-1]) offset = np.array([Hp // 2, Wp // 2], dtype=np.float32) pos = crop_pos.astype(np.float32) + offset shifts_arr = ( np.asarray(probe_pos_shifts, dtype=np.float32) if probe_pos_shifts is not None else None ) return _add_shifts(pos, shifts_arr)
[docs] def get_scanned_fov_bbox( crop_pos: np.ndarray, probe_shape: tuple[int, int] ) -> tuple[int, int, int, int]: """Inclusive ``(y_min, y_max, x_min, x_max)`` of probe-center positions. The bbox is computed from the integer probe-center positions (``crop_pos + probe_shape // 2``) without ``probe_pos_shifts`` — matching the bbox math in :func:`ptyrad.io.save.save_results` so the result lines up exactly with the saved ``*_crop`` TIFFs (sliced as ``[y_min:y_max + 1, x_min:x_max + 1]``). """ centers = get_probe_center_positions(crop_pos, probe_shape).astype(np.int64) y_min, y_max = int(centers[:, 0].min()), int(centers[:, 0].max()) x_min, x_max = int(centers[:, 1].min()), int(centers[:, 1].max()) return y_min, y_max, x_min, x_max
[docs] def apply_fov(arr: np.ndarray, bbox: tuple[int, int, int, int] | None) -> np.ndarray: """Crop the last two dims of ``arr`` to an inclusive ``bbox``. ``bbox`` is ``(y_min, y_max, x_min, x_max)`` and both bounds are inclusive. Returns ``arr`` unchanged when ``bbox is None``. """ if bbox is None: return arr y_min, y_max, x_min, x_max = bbox return arr[..., y_min : y_max + 1, x_min : x_max + 1]
# --------------------------------------------------------------------------- # Internal helpers (private to this module) # --------------------------------------------------------------------------- def _resolve_opt_tensors(data_or_opt: Mapping[str, Any]) -> Mapping[str, Any]: """Accept either the full data dict or the ``optimizable_tensors`` sub-dict.""" if "optimizable_tensors" in data_or_opt: return data_or_opt["optimizable_tensors"] return data_or_opt def _resolve_model_attrs( data_or_opt: Mapping[str, Any], model_attrs: Mapping[str, Any] | None ) -> Mapping[str, Any] | None: if model_attrs is not None: return model_attrs if isinstance(data_or_opt, Mapping) and "model_attributes" in data_or_opt: return data_or_opt["model_attributes"] return None def _probe_shape_from( opt_tensors: Mapping[str, Any], probe_shape: tuple[int, int] | None ) -> tuple[int, int]: if probe_shape is not None: return tuple(probe_shape) # type: ignore[return-value] probe = opt_tensors.get("probe") if probe is None: raise KeyError("opt_tensors has no 'probe' to infer probe_shape from.") return tuple(probe.shape[-2:]) # type: ignore[return-value] def _crop_bbox( opt_tensors: Mapping[str, Any], model_attrs: Mapping[str, Any] | None, probe_shape: tuple[int, int] | None, ) -> tuple[int, int, int, int]: if model_attrs is None or "crop_pos" not in model_attrs: raise ValueError( "fov='crop' requires model_attrs with 'crop_pos'. " "Pass the full data dict or supply model_attrs explicitly." ) return get_scanned_fov_bbox( np.asarray(model_attrs["crop_pos"]), _probe_shape_from(opt_tensors, probe_shape), ) def _to_torch(arr: np.ndarray, device: str): import torch return torch.as_tensor(arr, device=device) # --------------------------------------------------------------------------- # Object extractors # ---------------------------------------------------------------------------
[docs] def extract_object( data_or_opt: Mapping[str, Any], *, fov: Literal["full", "crop"] = "full", model_attrs: Mapping[str, Any] | None = None, probe_shape: tuple[int, int] | None = None, as_torch: bool = False, device: str = "cpu", ): """Return the complex object ``obja * exp(1j * objp)``. Shape ``(omode, Nz, Ny, Nx)``, ``complex64``. The 4D shape is preserved even for single-slice runs (``Nz == 1``) so downstream code can branch on ``Nz`` directly; squeeze yourself if you need 2D. ``fov='crop'`` clips to the scanned-FOV bbox (see :func:`get_scanned_fov_bbox`); it requires ``crop_pos`` from ``model_attrs``. ``probe_shape`` is auto-inferred from ``opt_tensors['probe']`` and only needs to be passed when callers have stripped the probe out of ``opt_tensors``. """ opt = _resolve_opt_tensors(data_or_opt) obja = np.asarray(opt["obja"]) objp = np.asarray(opt["objp"]) obj = (obja * np.exp(1j * objp)).astype(np.complex64) if fov == "crop": attrs = _resolve_model_attrs(data_or_opt, model_attrs) obj = apply_fov(obj, _crop_bbox(opt, attrs, probe_shape)) elif fov != "full": raise ValueError(f"fov must be 'full' or 'crop', got {fov!r}") return _to_torch(obj, device) if as_torch else obj
[docs] def extract_object_amplitude( data_or_opt: Mapping[str, Any], *, fov: Literal["full", "crop"] = "full", model_attrs: Mapping[str, Any] | None = None, probe_shape: tuple[int, int] | None = None, as_torch: bool = False, device: str = "cpu", ): """Return the saved object amplitude ``obja`` (no recomposition). Shape ``(omode, Nz, Ny, Nx)``, ``float32``. ``fov`` behaves the same as :func:`extract_object`. """ opt = _resolve_opt_tensors(data_or_opt) arr = np.asarray(opt["obja"], dtype=np.float32) if fov == "crop": attrs = _resolve_model_attrs(data_or_opt, model_attrs) arr = apply_fov(arr, _crop_bbox(opt, attrs, probe_shape)) elif fov != "full": raise ValueError(f"fov must be 'full' or 'crop', got {fov!r}") return _to_torch(arr, device) if as_torch else arr
[docs] def extract_object_phase( data_or_opt: Mapping[str, Any], *, fov: Literal["full", "crop"] = "full", model_attrs: Mapping[str, Any] | None = None, probe_shape: tuple[int, int] | None = None, as_torch: bool = False, device: str = "cpu", ): """Return the saved object phase ``objp`` (no recomposition). Shape ``(omode, Nz, Ny, Nx)``, ``float32``. ``fov`` behaves the same as :func:`extract_object`. Sign convention follows :class:`PtychoModel`: the complex object is ``obja * exp(1j * objp)``; no sign flip is applied here. """ opt = _resolve_opt_tensors(data_or_opt) arr = np.asarray(opt["objp"], dtype=np.float32) if fov == "crop": attrs = _resolve_model_attrs(data_or_opt, model_attrs) arr = apply_fov(arr, _crop_bbox(opt, attrs, probe_shape)) elif fov != "full": raise ValueError(f"fov must be 'full' or 'crop', got {fov!r}") return _to_torch(arr, device) if as_torch else arr
# --------------------------------------------------------------------------- # Probe extractor # ---------------------------------------------------------------------------
[docs] def extract_probe( data_or_opt: Mapping[str, Any], *, space: Literal["real", "fourier"] = "real", as_torch: bool = False, device: str = "cpu", ): """Return the probe modes. Shape ``(pmode, Ny, Nx)``, ``complex64``. ``space='real'`` returns the stored real-space complex wavefunction (the form held in ``optimizable_tensors['probe']`` after the ``view_as_complex`` post-process at save time). ``space='fourier'`` returns ``fftshift(fft2(ifftshift(probe), norm='ortho'))`` along the last two axes. The fftshift sandwich matches :func:`ptyrad.plotting.plot_probe_modes` exactly so amplitudes and phases line up between getters and plotters; in particular the pre-fftshift to the corner avoids the checkerboard-phase artifact that a plain ``fft2`` would produce. """ opt = _resolve_opt_tensors(data_or_opt) probe = np.asarray(opt["probe"]).astype(np.complex64) if space == "real": out = probe elif space == "fourier": from numpy.fft import fft2, fftshift, ifftshift out = fftshift( fft2(ifftshift(probe, axes=(-2, -1)), norm="ortho"), axes=(-2, -1), ).astype(np.complex64) else: raise ValueError(f"space must be 'real' or 'fourier', got {space!r}") return _to_torch(out, device) if as_torch else out
# --------------------------------------------------------------------------- # Position extractor # ---------------------------------------------------------------------------
[docs] def extract_probe_positions( data_or_opt: Mapping[str, Any], model_attrs: Mapping[str, Any] | None = None, *, fov: Literal["full", "crop"] = "full", target: Literal["probe", "crop"] = "probe", units: Literal["px", "pixel", "Ang"] = "px", include_sub_px_shifts: bool = True, as_torch: bool = False, device: str = "cpu", ): """Return positions, shape ``(N, 2)`` as ``[y, x]``, ``float32``. Parameters ---------- data_or_opt Full data dict from :func:`ptyrad.io.load.load_ptyrad`, or the ``optimizable_tensors`` sub-dict. ``model_attrs`` must be supplied when the latter is passed. model_attrs Optional explicit ``model_attributes`` sub-dict. Must contain ``crop_pos`` (always) and ``dx`` (only when ``units='Ang'``). fov ``'full'`` returns positions in the full saved-object coordinate frame. ``'crop'`` subtracts the **probe-center** scanned-FOV bbox top-left (see :func:`get_scanned_fov_bbox`) so that, regardless of ``target``, the returned positions live in the same coordinate frame as the array returned by ``extract_object(fov='crop')``. Because that cropped FOV is anchored on probe centers, ``target='crop'`` can produce negative local coordinates near the top/left edge. target ``'probe'`` (default) returns probe-center positions (``crop_pos + (Hp // 2, Wp // 2)``). ``'crop'`` returns the top-left crop-window positions used by :meth:`PtychoModel.get_obj_patches`. The two differ by exactly ``(Hp // 2, Wp // 2)`` and that offset survives ``fov='crop'`` because the same bbox is subtracted in both cases. units ``'px'`` / ``'pixel'`` returns object-space pixels. ``'Ang'`` multiplies by ``model_attrs['dx']`` to return Ångströms. include_sub_px_shifts When ``True`` (default), add ``opt_tensors['probe_pos_shifts']`` if present. These are the optimized sub-pixel offsets relative to the integer ``crop_pos`` grid. Older saves without the key silently fall back to zero shifts. """ opt = _resolve_opt_tensors(data_or_opt) attrs = _resolve_model_attrs(data_or_opt, model_attrs) if attrs is None or "crop_pos" not in attrs: raise ValueError( "extract_probe_positions requires model_attrs with 'crop_pos'. " "Pass the full data dict or supply model_attrs explicitly." ) crop_pos = np.asarray(attrs["crop_pos"]) probe_shape = _probe_shape_from(opt, None) shifts = ( np.asarray(opt["probe_pos_shifts"], dtype=np.float32) if include_sub_px_shifts and "probe_pos_shifts" in opt else None ) if target == "probe": pos = get_probe_center_positions(crop_pos, probe_shape, shifts) elif target == "crop": pos = _add_shifts(crop_pos.astype(np.float32), shifts) else: raise ValueError(f"target must be 'probe' or 'crop', got {target!r}") if fov == "crop": y_min, _, x_min, _ = get_scanned_fov_bbox(crop_pos, probe_shape) pos = pos - np.array([y_min, x_min], dtype=np.float32) elif fov != "full": raise ValueError(f"fov must be 'full' or 'crop', got {fov!r}") if units == "Ang": dx = attrs.get("dx") if dx is None: raise ValueError("units='Ang' requires 'dx' in model_attrs.") pos = pos * float(np.asarray(dx)) elif units not in ("px", "pixel"): raise ValueError(f"units must be 'px', 'pixel', or 'Ang', got {units!r}") return _to_torch(pos, device) if as_torch else pos
# --------------------------------------------------------------------------- # Misc extractors # ---------------------------------------------------------------------------
[docs] def extract_loss_curves(data: Mapping[str, Any]) -> dict[str, Any]: """Pull the loss-history fields from a loaded data dict. Returns a ``dict`` with these keys (missing fields become ``None``): - ``loss_iters`` — ``(niter, 2)`` ``float64`` ndarray of ``(iter_number, total_loss)`` pairs. - ``batch_losses`` — ``dict[str, list[float]]`` of per-loss batch values from the most recent iteration. - ``avg_losses`` — ``dict[str, float]`` of averaged batch losses. - ``niter`` — final iteration number (``int``). """ return { "loss_iters": data.get("loss_iters"), "batch_losses": data.get("batch_losses"), "avg_losses": data.get("avg_losses"), "niter": data.get("niter"), }
[docs] def extract_provenance(path: str | os.PathLike) -> dict | None: """Read the ``provenance_json`` root attribute from an HDF5 file. The provenance JSON is stored as a root-level HDF5 attribute, not as a dataset, so :func:`ptyrad.io.load.load_ptyrad` silently drops it. This helper opens the file directly via :func:`ptyrad.io.provenance.load_provenance_from_h5` and returns the parsed dict (with keys like ``'probe'``, ``'obj'``, ``'pos'``, ``'tilt'``). Returns ``None`` when the attribute is missing or the JSON parse fails. """ from ptyrad.io.provenance import load_provenance_from_h5 prov = load_provenance_from_h5(str(path)) return prov if prov else None
[docs] def extract_keys(data: Mapping[str, Any], delimiter: str = ".") -> list[str]: """Return a flat dotted-key listing of a loaded data dict. Thin wrapper over :func:`ptyrad.io.hierarchy.list_nested_keys` that accepts a nested ``dict`` (the form returned by ``load_ptyrad``) and yields keys like ``'optimizable_tensors.probe'``, ``'model_attributes.crop_pos'``. """ return list_nested_keys(data, delimiter=delimiter)
# Re-export ``get_nested`` so callers can resolve arbitrary paths without # pulling in the ``io.hierarchy`` module separately. __all__ = [ "apply_fov", "extract_keys", "extract_loss_curves", "extract_object", "extract_object_amplitude", "extract_object_phase", "extract_probe", "extract_probe_positions", "extract_provenance", "get_nested", "get_probe_center_positions", "get_scanned_fov_bbox", ]