12. JIT Compile#
Enables PyTorch’s JIT (just-in-time) compiler via compiler_configs: {enable: true}, which fuses and optimizes GPU kernels at runtime for a measured 1.3–1.9× speedup.
When to use: Production runs on Linux or macOS where the one-time compilation warmup is acceptable and maximum throughput is desired. Particularly beneficial for large reconstructions run over many iterations. If the hardware permits, it’s almost always better to run in JIT mode for significant speedup.
Tradeoffs & limitations: The first epoch incurs a compilation overhead before the speedup takes effect. On Windows, requires the triton-windows package. Speedup follows a complicated scaling law with problem size (Npix, probe modes, slice count, batch sizes) — small problems see less benefit.
1# Created with PtyRAD v1.0.0
2# Documentation: https://ptyrad.readthedocs.io/en/latest/
3# Detailed description for each option: https://ptyrad.readthedocs.io/en/latest/_autosummary/ptyrad.params.html
4
5# Switch on PyTorch JIT compilation on supported hardware for 1.3-1.9x speedup
6
7# PSO
8
9init_params:
10 # Experimental params
11 probe_kv : 300 # [kV] Acceleration voltage
12 probe_conv_angle : 21.4 # [mrad] Semi-convergence angle for probe-forming aperture
13 probe_aberrations : {'C10': 200} # [Angstrom, degree] Aberration coefficients in Krivanel polar notations. C10 = -df, and positive C10 refers to overfocus (stronger lens).
14 meas_Npix : 256 # Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
15 pos_N_scan_slow : 64 # Number of scan position along slow scan direction. Usually it's the vertical direction of acquisition GUI
16 pos_N_scan_fast : 64 # Number of scan position along fast scan direction. Usually it's the horizontal direction of acquisition GUI
17 pos_scan_step_size : 0.410 # [Angstrom] Step size between probe positions in a rectangular raster scan pattern
18 # Model complexity
19 probe_pmode_max : 4 # Maximum number of mixed probe modes
20 obj_Nlayer : 21 # Number of slices for multislice object
21 obj_slice_thickness : 10 # [Angstrom] Slice thickness (propagation distance) for multislice ptychography. Typical values are between 1 to 20 Ang.
22 # Preprocessing
23 meas_permute : null # Permute meas array with a list of ints to reorder datasets into (N_scans, ky, kx) if needed.
24 meas_reshape : null # Reshape meas array with a list of 3 ints to convert the 4D diffraction dataset (Ry,Rx,ky,kx) into 3D (N_scans,ky,kx) for PtyRAD.
25 meas_flipT : null # Flip meas orientation with a list of 3 binary booleans (0 or 1) as [flipud, fliplr, transpose]
26 meas_crop : [null,null,[68,188],[68,188]] # Crops the 4D dataset with [[scan_slow_start, scan_slow_end], [scan_fast_start, scan_fast_end], [ky_start, ky_end], [kx_start, kx_end]].
27 meas_pad : {'mode': 'on_the_fly', 'padding_type': 'power', 'target_Npix': 256, 'value': 0, 'threshold': 70} # Pads the diffraction pattern to side length = 'target_Npix' and correspondingly change the kMax, dx, Npix.
28 pos_scan_affine : null # Affine transformation [scale, asymmetry, rotation, shear] of scan patterns. e.g, [1,0,3,0], rotation and shear are in unit of degree.
29 # Input source and params
30 meas_params : {'path': 'data/PSO/sample_data_PrScO3.mat', 'key': 'dp'} # Supports EMPAD .raw, .hdf5, .mat, and .tif
31
32model_params:
33 detector_blur_std : 1 # [k-space px] Gaussian blur std of forward simulated diffraction patterns. Typical value is 0-1 px.
34
35loss_params:
36 loss_single: {'state': true, 'weight': 1.0, 'dp_pow': 0.5} # Amplitude noise model for typical dataset (dose-sufficient) under the maximum-likelihood formalism
37 loss_sparse: {'state': true, 'weight': 0.1, 'ln_order': 1} # L_n norm sparsity regularization calculated for object phase ('objp')
38
39constraint_params:
40 obj_zblur : {'start_iter': 1, 'step': 1, 'end_iter': null, 'obj_type': 'both', 'kernel_size': 5, 'std': 1} # Apply a "z-direction" 1D Gaussian blur to the object.
41 mirrored_amp : {'start_iter': 1, 'step': 1, 'end_iter': null, 'relax': 0.1, 'scale': 0.03, 'power': 4} # Apply a more flexible, ad hoc constraint for constraining amplitude using 1-scale*phase**power, which provide more arbitrary parameters to tune the constrained amplitude based on the phase.
42 obja_thresh : {'start_iter': 1, 'step': 1, 'end_iter': null, 'relax': 0, 'thresh': [0.96, 1.04]} # Thresholds the object amplitude around 1 with specified range in 'thresh'.
43 objp_postiv : {'start_iter': 1, 'step': 1, 'end_iter': null, 'relax': 0} # Apply a positivity constraint of the object phase by clipping negative values
44
45recon_params:
46 NITER: 200 # Total number of reconstruction iterations. 1 iteration means a full pass of all selected diffraction patterns.
47 BATCH_SIZE: {'size': 32, 'grad_accumulation': 1} # Number of diffraction patterns processed simultaneously to get the gradient update.
48 SAVE_ITERS: 10 # Number of completed iterations before saving the current reconstruction results (model, probe, object) and summary figures.
49 output_dir: 'output/walkthrough/12_jit_compile/'
50 recon_dir_affixes: ['minimal', 'model', 'loss', 'constraint'] # Customizable affixes of reconstruction folder name with presets like 'minimal', 'default', 'all'. See docs for 19 more detailed controls.
51 prefix_time: false # type: boolean, preset strings, and time format strings. Set to true to prepend a date str like '20240903_' in front of the reconstruction folder name
52 prefix: '' # Prefix this string to the reconstruction folder name. Note that a "_" will be automatically generated.
53 postfix: 'compile' # Postfix this string to the reconstruction folder name. Note that a "_" will be automatically generated.
54
55 ## NOTE: PyTorch provides JIT compilation that can ship 1.3-1.9x speedup on supported hardware.
56 ## Linux and macOS should support the PyTorch JIT compiler out-of-the-box.
57 ## For Windows users, please follow the instruction and download triton-windows from https://github.com/woct0rdho/triton-windows
58 ## Nte that the speedup factor may depend on the task dimension, including Npix, pmode, Nlayer, and batch sizes.
59 compiler_configs: {'enable': true} # Set to {'enable': true} to enable PyTorch JIT compilation for a 1.3-1.9x speedup on supported hardware.