12. JIT Compile

12. JIT Compile#

Description of this walkthrough goes here.

 1# Created with PtyRAD 0.1.0b13
 2# Documentation: https://ptyrad.readthedocs.io/en/latest/
 3# Detailed description for each option: https://ptyrad.readthedocs.io/en/latest/_autosummary/ptyrad.params.html
 4
 5# Switch on PyTorch JIT compilation on supported hardware for 1.3-1.9x speedup
 6
 7# PSO
 8
 9init_params:
10    # Experimental params
11    probe_kv               : 300 # [kV] Acceleration voltage
12    probe_conv_angle       : 21.4 # [mrad] Semi-convergence angle for probe-forming aperture
13    probe_aberrations      : {'C10': 200} # [Angstrom, degree] Aberration coefficients in Krivanel polar notations. C10 = -df, and positive C10 refers to overfocus (stronger lens).  
14    meas_Npix              : 256 # Detector pixel number, EMPAD is 128. Only supports square detector for simplicity
15    pos_N_scan_slow        : 64 # Number of scan position along slow scan direction. Usually it's the vertical direction of acquisition GUI
16    pos_N_scan_fast        : 64 # Number of scan position along fast scan direction. Usually it's the horizontal direction of acquisition GUI
17    pos_scan_step_size     : 0.410 # [Angstrom] Step size between probe positions in a rectangular raster scan pattern
18    # Model complexity
19    probe_pmode_max        : 4 # Maximum number of mixed probe modes
20    obj_Nlayer             : 21 # Number of slices for multislice object
21    obj_slice_thickness    : 10 # [Angstrom] Slice thickness (propagation distance) for multislice ptychography. Typical values are between 1 to 20 Ang.
22    # Preprocessing
23    meas_permute           : null # Permute meas array with a list of ints to reorder datasets into (N_scans, ky, kx) if needed.
24    meas_reshape           : null # Reshape meas array with a list of 3 ints to convert the 4D diffraction dataset (Ry,Rx,ky,kx) into 3D (N_scans,ky,kx) for PtyRAD.
25    meas_flipT             : null # Flip meas orientation with a list of 3 binary booleans (0 or 1) as [flipud, fliplr, transpose]
26    meas_crop              : [null,null,[68,188],[68,188]] # Crops the 4D dataset with [[scan_slow_start, scan_slow_end], [scan_fast_start, scan_fast_end], [ky_start, ky_end], [kx_start, kx_end]]. 
27    meas_pad               : {'mode': 'on_the_fly', 'padding_type': 'power', 'target_Npix': 256, 'value': 0, 'threshold': 70} # Pads the diffraction pattern to side length = 'target_Npix' and correspondingly change the kMax, dx, Npix.
28    pos_scan_affine        : null # Affine transformation [scale, asymmetry, rotation, shear] of scan patterns. e.g, [1,0,3,0], rotation and shear are in unit of degree.
29    # Input source and params
30    meas_params            : {'path': 'data/PSO/sample_data_PrScO3.mat', 'key': 'dp'} # Supports EMPAD .raw, .hdf5, .mat, and .tif
31
32model_params:
33    detector_blur_std   : 1 # [k-space px] Gaussian blur std of forward simulated diffraction patterns. Typical value is 0-1 px.
34
35loss_params:
36    loss_single: {'state': true, 'weight': 1.0, 'dp_pow': 0.5} # Amplitude noise model for typical dataset (dose-sufficient) under the maximum-likelihood formalism
37    loss_sparse: {'state': true, 'weight': 0.1, 'ln_order': 1} # L_n norm sparsity regularization calculated for object phase ('objp')
38
39constraint_params:
40    obj_zblur     : {'start_iter': 1,    'step': 1, 'end_iter': null, 'obj_type': 'both', 'kernel_size': 5, 'std': 1} # Apply a "z-direction" 1D Gaussian blur to the object.
41    mirrored_amp  : {'start_iter': 1,    'step': 1, 'end_iter': null, 'relax': 0.1, 'scale': 0.03, 'power': 4} # Apply a more flexible, ad hoc constraint for constraining amplitude using 1-scale*phase**power, which provide more arbitrary parameters to tune the constrained amplitude based on the phase.
42    obja_thresh   : {'start_iter': 1,    'step': 1, 'end_iter': null, 'relax': 0, 'thresh': [0.96, 1.04]} # Thresholds the object amplitude around 1 with specified range in 'thresh'.
43    objp_postiv   : {'start_iter': 1,    'step': 1, 'end_iter': null, 'relax': 0} # Apply a positivity constraint of the object phase by clipping negative values
44
45recon_params:
46    NITER: 200 # Total number of reconstruction iterations. 1 iteration means a full pass of all selected diffraction patterns.
47    BATCH_SIZE: {'size': 32, 'grad_accumulation': 1} # Number of diffraction patterns processed simultaneously to get the gradient update.
48    SAVE_ITERS: 10 # Number of completed iterations before saving the current reconstruction results (model, probe, object) and summary figures.
49    output_dir: 'output/walkthrough/12_jit_compile/'
50    recon_dir_affixes: ['minimal', 'model', 'loss', 'constraint'] # Customizable affixes of reconstruction folder name with presets like 'minimal', 'default', 'all'. See docs for 19 more detailed controls.
51    prefix_time: false # type: boolean, preset strings, and time format strings. Set to true to prepend a date str like '20240903_' in front of the reconstruction folder name
52    prefix: '' # Prefix this string to the reconstruction folder name. Note that a "_" will be automatically generated.
53    postfix: 'compile' # Postfix this string to the reconstruction folder name. Note that a "_" will be automatically generated.
54    
55    ## NOTE: PyTorch provides JIT compilation that can ship 1.3-1.9x speedup on supported hardware.
56    ##       Linux and macOS should support the PyTorch JIT compiler out-of-the-box. 
57    ##       For Windows users, please follow the instruction and download triton-windows from https://github.com/woct0rdho/triton-windows
58    ##       Nte that the speedup factor may depend on the task dimension, including Npix, pmode, Nlayer, and batch sizes.
59    compiler_configs: {'enable': true} # Set to {'enable': true} to enable PyTorch JIT compilation for a 1.3-1.9x speedup on supported hardware.