1- from collections .abc import Mapping
1+ from collections .abc import Mapping , Callable
22
33import numpy as np
44
@@ -23,8 +23,37 @@ def __init__(
2323 num_samples : int = None ,
2424 * ,
2525 stage : str = "training" ,
26+ augmentations : Mapping [str , Callable ] | Callable = None ,
2627 ** kwargs ,
2728 ):
29+ """
30+ Initialize an OfflineDataset instance for offline training with optional data augmentations.
31+
32+ Parameters
33+ ----------
34+ data : Mapping[str, np.ndarray]
35+ Pre-simulated data stored in a dictionary, where each key maps to a NumPy array.
36+ batch_size : int
37+ Number of samples per batch.
38+ adapter : Adapter or None
39+ Optional adapter to transform the batch.
40+ num_samples : int, optional
41+ Number of samples in the dataset. If None, it will be inferred from the data.
42+ stage : str, default="training"
43+ Current stage (e.g., "training", "validation", etc.) used by the adapter.
44+ augmentations : dict of str to Callable or Callable, optional
45+ Dictionary of augmentation functions to apply to each corresponding key in the batch
46+ or a function to apply to the entire batch (possibly adding new keys).
47+
48+ If you provide a dictionary of functions, each function should accept one element
49+ of your output batch and return the corresponding transformed element. Otherwise,
50+ your function should accept the entire dictionary output and return a dictionary.
51+
52+ Note - augmentations are applied before the adapter is called and are generally
53+ transforms that you only want to apply during training.
54+ **kwargs
55+ Additional keyword arguments passed to the base `PyDataset`.
56+ """
2857 super ().__init__ (** kwargs )
2958 self .batch_size = batch_size
3059 self .data = data
@@ -39,10 +68,29 @@ def __init__(
3968
4069 self .indices = np .arange (self .num_samples , dtype = "int64" )
4170
71+ self .augmentations = augmentations
72+
4273 self .shuffle ()
4374
4475 def __getitem__ (self , item : int ) -> dict [str , np .ndarray ]:
45- """Get a batch of pre-simulated data"""
76+ """
77+ Load a batch of data from disk.
78+
79+ Parameters
80+ ----------
81+ item : int
82+ Index of the batch to retrieve.
83+
84+ Returns
85+ -------
86+ dict of str to np.ndarray
87+ A batch of loaded (and optionally augmented/adapted) data.
88+
89+ Raises
90+ ------
91+ IndexError
92+ If the requested batch index is out of range.
93+ """
4694 if not 0 <= item < self .num_batches :
4795 raise IndexError (f"Index { item } is out of bounds for dataset with { self .num_batches } batches." )
4896
@@ -54,6 +102,16 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
54102 for key , value in self .data .items ()
55103 }
56104
105+ if self .augmentations is None :
106+ pass
107+ elif isinstance (self .augmentations , Mapping ):
108+ for key , fn in self .augmentations .items ():
109+ batch [key ] = fn (batch [key ])
110+ elif isinstance (self .augmentations , Callable ):
111+ batch = self .augmentations (batch )
112+ else :
113+ raise RuntimeError (f"Could not apply augmentations of type { type (self .augmentations )} ." )
114+
57115 if self .adapter is not None :
58116 batch = self .adapter (batch , stage = self .stage )
59117
0 commit comments