"""
Aberrations class for aberration coefficients parsing and conversion
"""
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional, Tuple
import numpy as np
# ------------------------------------------------------------
# Specification & Constants
# ------------------------------------------------------------
ABERRATION_SPEC = {
(1, 0): ("C1", 1.0, "Defocus (C10 = -df)"),
(1, 2): ("A1", 1.0, "2-fold astigmatism"),
(2, 1): ("B2", 1/3.0, "Axial coma"),
(2, 3): ("A2", 1.0, "3-fold astigmatism"),
(3, 0): ("C3", 1.0, "Spherical aberration"),
(3, 2): ("S3", 1/4.0, "Axial star aberration"),
(3, 4): ("A3", 1.0, "4-fold astigmatism"),
(4, 1): ("B4", 1/4.0, "Axial coma(4th)"),
(4, 3): ("D4", 1/4.0, "3-lobe aberration"),
(4, 5): ("A4", 1.0, "5-fold astigmatism"),
(5, 0): ("C5", 1.0, "Spherical aberration (5th)"),
(5, 2): ("S5", 1/6.0, "Axial star aberration(5th)"),
(5, 4): ("R5", 1/6.0, "4-lobe aberration"),
(5, 6): ("A5", 1.0, "6-fold astigmatism"),
}
KRIVANEK_TO_HAIDER = {nm: (h, s) for nm, (h, s, _) in ABERRATION_SPEC.items()}
HAIDER_TO_KRIVANEK = {}
for nm, (h, s, _) in ABERRATION_SPEC.items():
HAIDER_TO_KRIVANEK[h] = (1/s, nm)
if nm[1] > 0:
HAIDER_TO_KRIVANEK[f"{h}phi"] = (1.0, nm) # Angle doesn't need scaling
# Aliases: (target_key, scale_factor)
# defocus convention: positive defocus = underfocus = negative C10
ALIASES = {
"defocus": ("C10", -1.0),
"Cs": ("C30", 1.0),
}
# ------------------------------------------------------------
# Internal Structures
# ------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class ParsedKey:
"""The resolved address and context for a user key."""
nm: Tuple[int, int]
param: Literal["magnitude", "angle", "a", "b"]
scale: float
[docs]
@dataclass(frozen=True)
class AberrationValue:
"""Intermediate storage for export."""
nm: Tuple[int, int]
mag: float
angle: float
# ------------------------------------------------------------
# Main Class
# ------------------------------------------------------------
[docs]
class Aberrations:
"""
Handles probe aberrations with support for Krivanek/Haider notations.
Internal state is always Polar Krivanek (Magnitude [Å], Angle [deg]).
Note that for round lens aberrations (m=0),
the internal storage is always (mag, 0.0) for type uniformity.
However, setting or getting "components" from these round lens aberrations
is strictly forbidden for users. For example, 'phi30', 'C10a' are not allowed.
"""
def __init__(self, data: Optional[Dict] = None):
self._data: Dict[Tuple[int, int], Tuple[float, float]] = {}
if data is not None:
self._parse_and_normalize(data)
# ============================================================
# Public API
# ============================================================
def __getitem__(self, key: str) -> float:
pk = self._parse_external_key(key)
if pk.nm not in self._data:
raise KeyError(f"Aberration {key} (Order {pk.nm}) not set.")
mag, ang = self._data[pk.nm]
# Apply Inverse Scale (User Units = Internal / Scale)
scale = 1.0 / pk.scale
if pk.param == "magnitude":
return mag * scale
if pk.param == "angle":
return ang # Angles are not scaled
if pk.param in ["a", "b"]:
a, b = self._polar_to_cartesian(mag, ang, pk.nm[1])
return (a if pk.param == "a" else b) * scale
def __setitem__(self, key: str, value: float):
pk = self._parse_external_key(key)
value = float(value) * pk.scale
# --- Path 1: Round Aberrations (m=0) ---
if pk.nm[1] == 0:
# This angle=0 is only for internal representation
# Setting 'phi10' or 'C10a' will be blocked by _parse_external_key
# Exporting m=0 aberrations would also only have a single value across 'polar', 'complex', 'cartesian'.
self._data[pk.nm] = (value, 0.0)
return
# --- Path 2: Non-Round Aberrations (m > 0) ---
curr_mag, curr_ang = self._data.get(pk.nm, (0.0, 0.0))
if pk.param == "magnitude":
self._data[pk.nm] = (value, curr_ang)
return
if pk.param == "angle":
self._data[pk.nm] = (curr_mag, value)
return
if pk.param in ["a", "b"]:
a, b = self._polar_to_cartesian(curr_mag, curr_ang, pk.nm[1])
if pk.param == "a":
self._data[pk.nm] = self._cartesian_to_polar(value, b, pk.nm[1])
else:
self._data[pk.nm] = self._cartesian_to_polar(a, value, pk.nm[1])
return
[docs]
def get_coefficients(self, style: Literal["polar", "cartesian", "complex"] = "cartesian") -> Dict[Tuple[int, int], Any]:
"""Get raw coefficients for computation (nested dictionary)."""
return self.export(notation="krivanek", style=style, layout="nested", round_decimals=None)
[docs]
def get_haider(self, decimals=3) -> Dict[str, float]:
"""Get flattened dictionary in Haider notation."""
return self.export(notation="haider", style="polar", layout="flat", round_decimals=decimals)
[docs]
def get_krivanek_polar(self, decimals=3) -> Dict[str, float]:
"""Export Krivanek Polar (Flat dictionary)."""
return self.export(notation='krivanek', style='polar', layout='flat', round_decimals=decimals)
[docs]
def get_krivanek_cartesian(self, decimals=3) -> Dict[str, float]:
"""Export Krivanek Cartesian (Flat dictionary)."""
return self.export(notation='krivanek', style='cartesian', layout='flat', round_decimals=decimals)
# ============================================================
# Export Engine
# ============================================================
[docs]
def export(self,
notation: Literal["krivanek", "haider"] = "krivanek",
style: Literal["polar", "cartesian", "complex"] = "polar",
layout: Literal["flat", "nested"] = "flat",
round_decimals: Optional[int] = 3) -> Dict:
values = self._collect_values()
result = {}
for v in values:
# 1. Map & Scale
name, mag, ang, m = self._map_notation(v, notation)
if name is None:
continue # Skip if not in notation
# 2. Format Value
payload = self._format_style(mag, ang, m, style, round_decimals)
# 3. Apply Layout
self._apply_layout(result, v.nm, name, payload, notation, style, layout)
return result
def _collect_values(self):
""" Return a sorted list of aberration values """
return [AberrationValue(nm, mag, ang) for nm, (mag, ang) in sorted(self._data.items())]
def _map_notation(self, v: AberrationValue, notation: str):
""" Convert the names and values between Haider and Krivanek notations """
n, m = v.nm
if notation == "haider":
if v.nm not in KRIVANEK_TO_HAIDER:
return None, None, None, None
name, scale = KRIVANEK_TO_HAIDER[v.nm]
return name, v.mag * scale, v.angle, m
# Default Krivanek
return f"C{n}{m}", v.mag, v.angle, m
def _format_style(self, mag, ang, m, style, decimals):
""" Convert aberration values to 'polar', 'cartesian', and 'complex' format """
if style == "complex":
rad = np.radians(ang)
return mag * np.exp(1j * m * rad)
if style == "cartesian":
a, b = self._polar_to_cartesian(mag, ang, m)
if decimals is not None:
a, b = round(a, decimals), round(b, decimals)
if m == 0:
return a if decimals is None else round(a, decimals)
return {"a": a, "b": b}
# Polar
if decimals is not None:
mag, ang = round(mag, decimals), round(ang, decimals)
if m == 0:
return mag
return {"mag": mag, "phi": ang}
def _apply_layout(self, result, nm, base, payload, notation, style, layout):
""" Construct output dict as 'nested' or 'flat' layout """
if layout == "nested":
result[nm] = payload
return
# Flat Layout
if not isinstance(payload, dict):
# Scalar payload (m=0 or complex)
result[base] = self._to_native(payload)
return
for k, v in payload.items():
# Suffix logic
if style == "cartesian":
key = f"{base}{k}"
else: # polar
if k == "mag":
key = base
else: # phi
key = f"{base}phi" if notation == "haider" else f"phi{nm[0]}{nm[1]}"
result[key] = self._to_native(v)
# ============================================================
# Parsing Logic
# ============================================================
def _parse_and_normalize(self, raw: Dict):
store = defaultdict(dict) # Temporary dict to stage the intermediate items
for key, val in raw.items():
# Tuple Keys (Canonical)
if isinstance(key, tuple):
n, m = key
if not self._is_valid_nm(n, m):
raise ValueError(f"Invalid order {key}")
self._parse_tuple_val(store, key, val)
continue
# String Keys (User Input)
pk = self._parse_external_key(key)
val = float(val) * pk.scale
self._assign(store, pk.nm, pk.param, val)
# Finalize (Store -> Internal State)
for nm, params in store.items():
self._finalize_term(nm, params)
def _finalize_term(self, nm, params):
m = nm[1]
# Conflict Check
has_polar = "magnitude" in params or "angle" in params
has_cart = "a" in params or "b" in params
if has_polar and has_cart:
raise ValueError(f"Conflicting notation for term {nm}")
if m == 0:
mag = params.get("magnitude", params.get("a", 0.0))
if mag != 0:
self._data[nm] = (mag, 0.0)
return
if has_polar:
mag = params.get("magnitude", 0.0)
ang = params.get("angle", 0.0)
if mag != 0:
self._data[nm] = (mag, ang)
else:
a = params.get("a", 0.0)
b = params.get("b", 0.0)
mag, ang = self._cartesian_to_polar(a, b, m)
if mag != 0:
self._data[nm] = (mag, ang)
def _parse_external_key(self, key: str) -> ParsedKey:
scale = 1.0
raw = key
# Resolve Alias/Haider
if raw in ALIASES:
raw, s = ALIASES[raw]
scale *= s
elif raw in HAIDER_TO_KRIVANEK:
s, nm = HAIDER_TO_KRIVANEK[raw]
if raw.endswith("phi"):
# Strict check: Haider angle aliases (if any exist for m=0)
if nm[1] == 0:
raise ValueError(f"Invalid key '{key}': Round aberration {nm} has no angle.")
return ParsedKey(nm, "angle", scale)
return ParsedKey(nm, "magnitude", scale * s)
# Parse Standard Keys
if raw.startswith("phi"):
nm = self._parse_nm(raw[3:])
if nm[1] == 0:
raise ValueError(f"Invalid key '{key}': Round aberration C{nm[0]}0 has no angle.")
return ParsedKey(nm, "angle", scale)
if raw.startswith("C"):
body = raw[1:]
# 1. Identify Component
if body.endswith("a"):
param = "a"
nm_str = body[:-1]
elif body.endswith("b"):
param = "b"
nm_str = body[:-1]
else:
param = "magnitude"
nm_str = body
# 2. Parse Order
nm = self._parse_nm(nm_str)
# 3. STRICT VALIDATION for Round Lenses (m=0)
if nm[1] == 0:
if param in ("a", "b"):
raise ValueError(f"Invalid key '{key}': Round aberration C{nm[0]}0 is a scalar. Use 'C{nm[0]}0', not 'C{nm[0]}0{param}'.")
return ParsedKey(nm, param, scale)
raise KeyError(f"Unknown key: {key}")
def _parse_nm(self, s: str) -> Tuple[int, int]:
if len(s) >= 2 and s.isdigit():
n, m = int(s[0]), int(s[1:])
if self._is_valid_nm(n, m):
return (n, m)
raise KeyError(f"Invalid order string: {s}")
def _parse_tuple_val(self, store, nm, val):
if isinstance(val, complex):
self._assign(store, nm, "a", val.real)
self._assign(store, nm, "b", val.imag)
elif isinstance(val, dict):
for k, v in val.items():
target = None
if k in ("a", "b"):
target = k
elif k in ("mag", "magnitude"):
target = "magnitude"
elif k in ("phi", "angle"):
target = "angle"
if target:
self._assign(store, nm, target, v)
elif isinstance(val, (int, float, np.number)):
if nm[1] != 0:
raise ValueError(f"Ambiguous scalar for non-round term {nm}")
self._assign(store, nm, "magnitude", float(val))
else:
raise TypeError(f"Invalid value type for {nm}: {type(val)}")
def _assign(self, store, nm, param, val):
if param in store[nm]:
raise ValueError(f"Conflicting value for {nm} parameter '{param}'")
store[nm][param] = float(val)
# ============================================================
# Helpers
# ============================================================
def _polar_to_cartesian(self, mag, ang, m):
r = np.radians(ang)
return mag * np.cos(m * r), mag * np.sin(m * r)
def _cartesian_to_polar(self, a, b, m):
mag = np.sqrt(a * a + b * b)
if mag == 0:
return 0.0, 0.0
return mag, np.degrees(np.arctan2(b, a) / m)
@staticmethod
def _is_valid_nm(n: int, m: int) -> bool:
if n < 1 or m < 0 or m > n + 1:
return False
return (m % 2) == (0 if n % 2 == 1 else 1)
@staticmethod
def _to_native(val):
if isinstance(val, np.number):
return val.item()
return val
# ============================================================
# Human Readability
# ============================================================
[docs]
def pretty_print(self) -> str:
"""
Generate a formatted table of current aberrations.
"""
if not self._data:
return "Aberrations(Empty)"
rows = []
max_mag_len = 9 # Start with min width for "Magnitude" header
for (n, m), (mag, ang) in sorted(self._data.items()):
# 1. Krivanek Label
kriv_label = f"C{n},{m}"
# 2. Haider Label (e.g. "3*B2")
haider_label = "-"
desc = ""
if (n, m) in ABERRATION_SPEC:
code, scale, name = ABERRATION_SPEC[(n, m)]
desc = name
# Format: "3*B2" if scale is 1/3, "B2" if scale is 1
if scale == 1.0:
haider_label = code
else:
# Calculate inverse factor (e.g. 1 / 0.3333 = 3.0)
inv_scale = 1.0 / scale
# Check if integer (with tolerance for float math)
if abs(inv_scale - round(inv_scale)) < 1e-5:
factor = int(round(inv_scale))
haider_label = f"{factor}*{code}"
else:
# Fallback for non-integers (e.g. 1.5*C1)
haider_label = f"{inv_scale:.2g}*{code}"
# 3. Format Magnitude (Track max width)
mag_str = f"{mag:.4f}"
max_mag_len = max(max_mag_len, len(mag_str))
# 4. Format Angle
if m == 0:
ang_str = "-"
else:
ang_str = f"{ang:.2f}"
rows.append((kriv_label, haider_label, mag_str, ang_str, desc))
# --- Table Construction ---
# Dynamic padding for Magnitude column
col_mag_width = max_mag_len + 2
# Format: Krivanek | Haider | Mag | Angle | Desc
# Fixed widths for labels, dynamic for values
row_fmt = f"{{:<8}} {{:<10}} {{:<{col_mag_width}}} {{:<10}} {{}}"
header = row_fmt.format("Krivanek", "Haider", "Magnitude", "Angle (°)", "Description")
separator = "-" * (8 + 10 + col_mag_width + 10 + 40) # Rough total length
lines = [header, separator]
for r in rows:
lines.append(row_fmt.format(*r))
return "\n".join(lines)
def __str__(self) -> str:
"""Return the pretty-printed table for print(model)."""
return self.pretty_print()
def __repr__(self) -> str:
"""Concise debug representation."""
return f"<Aberrations: {len(self._data)} terms set>"