Source code for ptyrad.io.dataloader

"""
Custom DataLoader class to load batched measurements either from GPU device memory or host RAM.

"""

import torch
import numpy as np
from typing import Optional, Union, List
from torch.utils.data import Dataset

[docs] class MeasDataLoader: """ Data loader for PtyRAD experimental measurements with on-the-fly processing. Handles indexed slicing of experimental pattern arrays with flexible device placement (CPU/GPU), on-demand or pre-loaded options, and optional on-the-fly padding/resampling. Args: meas_arr: np.ndarray of experimental diffraction patterns [N, H, W] preload_data: If True, load all data into device memory at initialization. If False, load on-demand per access. Default: True device: torch.device or str ('cpu', 'cuda', etc.). Default: 'cuda' dtype: torch data type for output tensors. Default: torch.float32 meas_padded: Optional np.ndarray for on-the-fly padding. Padded pattern template. meas_padded_idx: Optional tuple (pad_h1, pad_h2, pad_w1, pad_w2) for padding regions. meas_scale_factors: Optional tuple (scale_h, scale_w) for on-the-fly resampling. """ def __init__( self, meas_arr: np.ndarray, preload_data: bool = True, device: Union[str, torch.device] = 'cuda', dtype: torch.dtype = torch.float32, meas_padded: Optional[np.ndarray] = None, meas_padded_idx: Optional[tuple] = None, meas_scale_factors: Optional[tuple] = None, ): self.meas_arr = meas_arr self.device = torch.device(device) if isinstance(device, str) else device self.dtype = dtype self.preload_data = preload_data self.N_scans = len(meas_arr) # On-the-fly processing parameters self.meas_padded = torch.tensor(meas_padded, dtype=torch.float32, device=device) if meas_padded is not None else None self.meas_padded_idx = torch.tensor(meas_padded_idx, dtype=torch.int32, device=device) if meas_padded_idx is not None else None self.meas_scale_factors = meas_scale_factors if self.preload_data: # Load everything into device memory at init if not meas_arr.flags['C_CONTIGUOUS']: meas_arr = np.ascontiguousarray(meas_arr) # PyTorch can't create tensor from numpy array with negative strides, so a contiguous RAM copy is temporarily needed self.data = torch.from_numpy(meas_arr).to(device=self.device, dtype=self.dtype) else: # Keep as numpy array, load on-demand self.data = meas_arr def __len__(self) -> int: """Return the total number of diffraction patterns.""" return self.N_scans def __getitem__(self, idx: Union[int, List, np.ndarray, torch.Tensor]) -> torch.Tensor: """ Get measurement data by index or indices with optional on-the-fly processing. Args: idx: Single index (int), array of indices (list, np.ndarray), or tensor indices Returns: Tensor of experimental patterns on the specified device, with optional padding/resampling applied. """ # Convert tensor indices to numpy on CPU for slicing. This should only happen when using Accelerator and DDP with multiGPU if isinstance(idx, torch.Tensor): idx = idx.cpu().numpy() if self.preload_data: # Data already on device measurements = self.data[idx] else: # Load from numpy and convert to tensor sliced_data = np.asarray(self.data[idx]) if not sliced_data.flags['C_CONTIGUOUS']: sliced_data = np.ascontiguousarray(sliced_data) # PyTorch can't create tensor from numpy array with negative strides, so a contiguous RAM copy is needed measurements = torch.from_numpy(sliced_data).to(device=self.device, dtype=self.dtype) # Apply on-the-fly padding if configured if self.meas_padded is not None: pad_h1, pad_h2, pad_w1, pad_w2 = self.meas_padded_idx canvas = torch.zeros( (measurements.shape[0], *self.meas_padded.shape[-2:]), dtype=self.dtype, device=self.device ) canvas += self.meas_padded canvas[..., pad_h1:pad_h2, pad_w1:pad_w2] = measurements measurements = canvas # Apply on-the-fly resampling if configured if self.meas_scale_factors is not None: scale_h, scale_w = self.meas_scale_factors if scale_h != 1 or scale_w != 1: # 2D interpolate requires 4D input (N, C, H, W) measurements = torch.nn.functional.interpolate( measurements.unsqueeze(1), scale_factor=(scale_h, scale_w), mode='bilinear' ).squeeze(1) # Normalize to preserve intensity scale measurements = measurements / (scale_h * scale_w) return measurements
[docs] class IndicesDataset(Dataset): """ The Dataset class used specifically for the multiGPU mode for DDP """ def __init__(self, indices): self.indices = indices def __len__(self): return len(self.indices) def __getitem__(self, idx): return self.indices[idx]