"""
Defines available options and validation rules for the ``model_params`` dictionary.
"""
from __future__ import annotations
import pathlib
from typing import Any, Dict, Literal, Optional, Union
from pydantic import BaseModel, Field, FilePath, field_validator, model_serializer, model_validator
[docs]
class OptimizerParams(BaseModel):
model_config = {"extra": "forbid"}
name: Literal["Adadelta", "Adafactor", "Adagrad", "Adam",
"AdamW", "SparseAdam", "Adamax", "ASGD",
"LBFGS", "Muon", "Nadam", "RAdam",
"RMSprop", "Rprop", "SGD"] = Field(default="Adam", description="Optimizer name")
configs: Dict[str, Any] = Field(default_factory=dict, description="Optimizer configurations")
load_state: Optional[FilePath] = Field(
default=None, description="Path str of a PtyRAD model file to load previous optimizer state"
)
[docs]
@model_serializer
def serialize_model(self):
"""Custom serializer to convert pathlib.Path back to str."""
data = self.__dict__.copy()
if data.get('load_state') is not None and isinstance(data['load_state'], pathlib.Path):
data['load_state'] = str(data['load_state'])
return data
[docs]
class SchedulerParams(BaseModel):
model_config = {"extra": "forbid"}
name: Literal["LambdaLR", "MultiplicativeLR", "StepLR", "MultiStepLR",
"ConstantLR", "LinearLR", "ExponentialLR", "PolynomialLR",
"CosineAnnealingLR", "ReduceLROnPlateau",
"CyclicLR", "OneCycleLR", "CosineAnnealingWarmRestarts"] = Field(description="Scheduler class name from torch.optim.lr_scheduler")
configs: Dict[str, Any] = Field(default_factory=dict, description="Scheduler configurations")
load_state: Optional[FilePath] = Field(
default=None, description="Path str of a PtyRAD model file to load previous scheduler state"
)
step_unit: Literal["iter", "batch"] = Field(
default="iter",
description=(
"When to call scheduler.step(): 'iter' (once per outer iteration, default) or "
"'batch' (once per optimizer.step() call, i.e. per grad-accumulation boundary). "
"Use 'batch' for CyclicLR and OneCycleLR, which are designed to step every optimizer "
"update. ReduceLROnPlateau ignores this setting and always steps per iteration."
)
)
[docs]
@model_serializer
def serialize_model(self):
"""Custom serializer to convert pathlib.Path back to str."""
data = self.__dict__.copy()
if data.get('load_state') is not None and isinstance(data['load_state'], pathlib.Path):
data['load_state'] = str(data['load_state'])
return data
[docs]
class UpdateParams(BaseModel):
model_config = {"extra": "forbid"}
obja: Dict[str, Union[int, float, None]] = Field(
default={"start_iter": 1, "lr": 5.0e-4}, description="Object amplitude update params"
)
objp: Dict[str, Union[int, float, None]] = Field(
default={"start_iter": 1, "lr": 5.0e-4}, description="Object phase update params"
)
obj_tilts: Dict[str, Union[int, float, None]] = Field(
default={"start_iter": None, "lr": 0.0}, description="Object tilts update params"
)
slice_thickness: Dict[str, Union[int, float, None]] = Field(
default={"start_iter": None, "lr": 0.0}, description="Slice thickness update params"
)
probe: Dict[str, Union[int, float, None]] = Field(
default={"start_iter": 1, "lr": 1.0e-4}, description="Probe update params"
)
probe_pos_shifts: Dict[str, Union[int, float, None]] = Field(
default={"start_iter": 1, "lr": 5.0e-4},
description="Sub-pixel probe position shifts update params",
)
@field_validator(
"obja", "objp", "obj_tilts", "slice_thickness", "probe", "probe_pos_shifts", mode="after"
)
@classmethod
def validate_update_params(cls, v: Dict[str, Any], field) -> Dict[str, Any]:
"""Validate start_iter and lr for update parameters."""
start_iter = v.get("start_iter")
lr = v.get("lr", 0.0)
# start_iter must be None or >= 1
if not (start_iter is None or (isinstance(start_iter, int) and start_iter >= 1)):
raise ValueError(f"{field.field_name}.start_iter must be None or an integer >= 1")
# If start_iter is not None, lr must be non-zero
if start_iter is not None and lr == 0.0:
raise ValueError(f"{field.field_name}.lr must be non-zero when start_iter is not None")
# lr must be >= 0
if not (isinstance(lr, (int, float)) and lr >= 0.0):
raise ValueError(f"{field.field_name}.lr must be a non-negative number")
return v
@model_validator(mode="after")
def validate_all_start_iter(self):
"""Ensure not all start_iter are None or all > 1."""
fields = ["obja", "objp", "obj_tilts", "slice_thickness", "probe", "probe_pos_shifts"]
start_iters = [self.__dict__[field].get("start_iter") for field in fields]
# start_iter can not be all None or all > 1
if all(si is None for si in start_iters):
raise ValueError("start_iter values can not be all None")
non_none_iters = [si for si in start_iters if si is not None]
if non_none_iters and all(si > 1 for si in non_none_iters):
raise ValueError(
"Non-None start_iter values can not be all > 1"
) # Early iterations would have no gradients to work with
return self
[docs]
class ModelParams(BaseModel):
"""
optimizer configurations are specified in 'optimizer_params', see https://pytorch.org/docs/stable/optim.html for detailed information of available optimizers and configs.
update behaviors of optimizable variables (tensors) are specified in 'update_params'.
'start_iter' specifies the iteration at which the variables (tensors) can start being updated by automatic differentiation (AD)
'lr' specifies the learning rate for the variables (tensors)
Usually slower learning rate leads to better convergence/results, but is also updating slower.
The variable optimization has 2 steps, (1) calculate gradient and (2) apply update based on learning rate * gradient
'start_iter: null' will disable grad calculation and would not update the variable regardless the learning rate through out the whole reconstruction
'start_iter: N(int)' would only calculate the grad when iteration >= N, so no grad will be calculated when iteration < N
Therefore, only the variable with non-zero learning rate would be optimized when iteration > start_iter.
If you don't want/need to optimize certain parameters, set their start_iter to null AND learning rate to 0 for faster computation.
Typical learning rate is 1e-3 to 1e-4.
"""
model_config = {"extra": "forbid"}
detector_blur_std: Optional[float] = Field(
default=None,
ge=0.0,
description="Gaussian blur std for simulated diffraction patterns. unit: px (k-space)",
)
"""
This applies Gaussian blur to the forward model simulated diffraction patterns to emulate the PSF of high-energy electrons on detector for experimental data.
Typical value is 0-1 px (std) based on the acceleration voltage
"""
preload_data: Optional[bool] = Field(
default=True,
description="Boolean flag for either to preload data into device memory or not",
)
"""
type: bool.
This flag determines how the measurements data is stored and transferred to device during reconstruciton.
If true, measurement data will be fully loaded into device memory during model initialization for best performance.
However, dataset larger than device memory (i.e., GPU VRAM) would throw Out-Of-Memory error.
If 'preload_data': false, measurement data is kept on host memory (CPU RAM) and only the mini-batch is transferred to device memory in a streaming way.
This would enable reconstruction of large dataset that doesn't fit into GPU VRAM with a little data transfer overhead.
The default is 'true' for performance although the difference is negligible on demo datasets.
"""
optimizer_params: OptimizerParams = Field(
default_factory=OptimizerParams, description="Optimizer configuration"
)
"""
Support all PyTorch optimizer.
The suggested optimizer is 'Adam' with default configs (null).
You can load the previous optimizer state by passing the path of `model.hdf5` to `load_state`, this way you can continue previous reconstruciton smoothly without abrupt gradients.
(Because lots of the optimizers are adaptive and have history-dependent learning rate manipulation, so loading the optimizer state is necessary if you want to continue the previous optimization trajectory).
However, the optimizer state must be coming from previous reconstructions with the same set of optimization variables with identical size of the dimensions otherwise it won't run.
"""
scheduler_params: Optional[SchedulerParams] = Field(
default=None, description="LR scheduler configuration"
)
"""
Optional learning rate scheduler.
Set to e.g. {name: CosineAnnealingLR, configs: {T_max: 500}} to decay the learning rate during reconstruction.
Supports any class from torch.optim.lr_scheduler (https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate)
except ChainedScheduler and SequentialLR because both of them require other scheduler objects as constructor arguments,
so you will have to specify them after param loading if you want to use them.
ReduceLROnPlateau is handled automatically (mean loss is passed).
Set load_state to a model.hdf5 path to resume mid-schedule.
Note: not compatible with LBFGS optimizer (scheduler will be ignored with a warning).
"""
update_params: UpdateParams = Field(
default_factory=UpdateParams, description="Update parameters for optimizable tensors"
)
# Make explicit list so autodoc_pydantic can sort by this when go by `autodoc_pydantic_model_member_order = 'bysource'` in conf.py
__all__ = [
"ModelParams",
"OptimizerParams",
"SchedulerParams",
"UpdateParams"
]