1- from collections .abc import Callable , MutableSequence , Sequence , Mapping
1+ from collections .abc import Callable , MutableSequence , Sequence
22
33import numpy as np
44
1818 Keep ,
1919 Log ,
2020 MapTransform ,
21- NNPE ,
2221 NumpyTransform ,
2322 OneHot ,
2423 Rename ,
@@ -87,16 +86,14 @@ def get_config(self) -> dict:
8786 return serialize (config )
8887
8988 def forward (
90- self , data : dict [str , any ], * , stage : str = "inference" , log_det_jac : bool = False , ** kwargs
89+ self , data : dict [str , any ], * , log_det_jac : bool = False , ** kwargs
9190 ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
9291 """Apply the transforms in the forward direction.
9392
9493 Parameters
9594 ----------
96- data : dict
95+ data : dict[str, any]
9796 The data to be transformed.
98- stage : str, one of ["training", "validation", "inference"]
99- The stage the function is called in.
10097 log_det_jac: bool, optional
10198 Whether to return the log determinant of the Jacobian of the transforms.
10299 **kwargs : dict
@@ -110,28 +107,26 @@ def forward(
110107 data = data .copy ()
111108 if not log_det_jac :
112109 for transform in self .transforms :
113- data = transform (data , stage = stage , ** kwargs )
110+ data = transform (data , ** kwargs )
114111 return data
115112
116113 log_det_jac = {}
117114 for transform in self .transforms :
118- transformed_data = transform (data , stage = stage , ** kwargs )
115+ transformed_data = transform (data , ** kwargs )
119116 log_det_jac = transform .log_det_jac (data , log_det_jac , ** kwargs )
120117 data = transformed_data
121118
122119 return data , log_det_jac
123120
124121 def inverse (
125- self , data : dict [str , np . ndarray ], * , stage : str = "inference" , log_det_jac : bool = False , ** kwargs
122+ self , data : dict [str , any ], * , log_det_jac : bool = False , ** kwargs
126123 ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
127124 """Apply the transforms in the inverse direction.
128125
129126 Parameters
130127 ----------
131- data : dict
128+ data : dict[str, any]
132129 The data to be transformed.
133- stage : str, one of ["training", "validation", "inference"]
134- The stage the function is called in.
135130 log_det_jac: bool, optional
136131 Whether to return the log determinant of the Jacobian of the transforms.
137132 **kwargs : dict
@@ -145,18 +140,18 @@ def inverse(
145140 data = data .copy ()
146141 if not log_det_jac :
147142 for transform in reversed (self .transforms ):
148- data = transform (data , stage = stage , inverse = True , ** kwargs )
143+ data = transform (data , inverse = True , ** kwargs )
149144 return data
150145
151146 log_det_jac = {}
152147 for transform in reversed (self .transforms ):
153- data = transform (data , stage = stage , inverse = True , ** kwargs )
148+ data = transform (data , inverse = True , ** kwargs )
154149 log_det_jac = transform .log_det_jac (data , log_det_jac , inverse = True , ** kwargs )
155150
156151 return data , log_det_jac
157152
158153 def __call__ (
159- self , data : Mapping [str , any ], * , inverse : bool = False , stage = "inference" , ** kwargs
154+ self , data : dict [str , any ], * , inverse : bool = False , ** kwargs
160155 ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
161156 """Apply the transforms in the given direction.
162157
@@ -166,8 +161,6 @@ def __call__(
166161 The data to be transformed.
167162 inverse : bool, optional
168163 If False, apply the forward transform, else apply the inverse transform (default False).
169- stage : str, one of ["training", "validation", "inference"]
170- The stage the function is called in.
171164 **kwargs
172165 Additional keyword arguments passed to each transform.
173166
@@ -177,9 +170,9 @@ def __call__(
177170 The transformed data or tuple of transformed data and log determinant of the Jacobian.
178171 """
179172 if inverse :
180- return self .inverse (data , stage = stage , ** kwargs )
173+ return self .inverse (data , ** kwargs )
181174
182- return self .forward (data , stage = stage , ** kwargs )
175+ return self .forward (data , ** kwargs )
183176
184177 def __repr__ (self ):
185178 result = ""
@@ -701,43 +694,6 @@ def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
701694 self .transforms .append (transform )
702695 return self
703696
704- def nnpe (
705- self ,
706- keys : str | Sequence [str ],
707- * ,
708- spike_scale : float | None = None ,
709- slab_scale : float | None = None ,
710- per_dimension : bool = True ,
711- seed : int | None = None ,
712- ):
713- """Append an :py:class:`~transforms.NNPE` transform to the adapter.
714-
715- Parameters
716- ----------
717- keys : str or Sequence of str
718- The names of the variables to transform.
719- spike_scale : float or np.ndarray or None, default=None
720- The scale of the spike (Normal) distribution. Automatically determined if None.
721- slab_scale : float or np.ndarray or None, default=None
722- The scale of the slab (Cauchy) distribution. Automatically determined if None.
723- per_dimension : bool, default=True
724- If true, noise is applied per dimension of the last axis of the input data.
725- If false, noise is applied globally.
726- seed : int or None
727- The seed for the random number generator. If None, a random seed is used.
728- """
729- if isinstance (keys , str ):
730- keys = [keys ]
731-
732- transform = MapTransform (
733- {
734- key : NNPE (spike_scale = spike_scale , slab_scale = slab_scale , per_dimension = per_dimension , seed = seed )
735- for key in keys
736- }
737- )
738- self .transforms .append (transform )
739- return self
740-
741697 def one_hot (self , keys : str | Sequence [str ], num_classes : int ):
742698 """Append a :py:class:`~transforms.OneHot` transform to the adapter.
743699
@@ -857,6 +813,8 @@ def standardize(
857813 self ,
858814 include : str | Sequence [str ] = None ,
859815 * ,
816+ mean : int | float | np .ndarray ,
817+ std : int | float | np .ndarray ,
860818 predicate : Predicate = None ,
861819 exclude : str | Sequence [str ] = None ,
862820 ** kwargs ,
@@ -865,10 +823,14 @@ def standardize(
865823
866824 Parameters
867825 ----------
868- predicate : Predicate, optional
869- Function that indicates which variables should be transformed.
870826 include : str or Sequence of str, optional
871827 Names of variables to include in the transform.
828+ mean : int or float
829+ Specifies the mean (location) of the transform.
830+ std : int or float
831+ Specifies the standard deviation (scale) of the transform.
832+ predicate : Predicate, optional
833+ Function that indicates which variables should be transformed.
872834 exclude : str or Sequence of str, optional
873835 Names of variables to exclude from the transform.
874836 **kwargs :
@@ -879,6 +841,8 @@ def standardize(
879841 predicate = predicate ,
880842 include = include ,
881843 exclude = exclude ,
844+ mean = mean ,
845+ std = std ,
882846 ** kwargs ,
883847 )
884848 self .transforms .append (transform )
0 commit comments