@@ -79,7 +79,9 @@ def get_config(self) -> dict:
7979
8080 return serialize (config )
8181
82- def forward (self , data : dict [str , any ], * , stage : str = "inference" , ** kwargs ) -> dict [str , np .ndarray ]:
82+ def forward (
83+ self , data : dict [str , any ], * , stage : str = "inference" , log_det_jac : bool = False , ** kwargs
84+ ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
8385 """Apply the transforms in the forward direction.
8486
8587 Parameters
@@ -88,22 +90,33 @@ def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -
8890 The data to be transformed.
8991 stage : str, one of ["training", "validation", "inference"]
9092 The stage the function is called in.
93+ log_det_jac: bool, optional
94+ Whether to return the log determinant of the Jacobian of the transforms.
9195 **kwargs : dict
9296 Additional keyword arguments passed to each transform.
9397
9498 Returns
9599 -------
96- dict
97- The transformed data.
100+ dict | tuple[dict, dict]
101+ The transformed data or tuple of transformed data and log determinant of the Jacobian .
98102 """
99103 data = data .copy ()
104+ if not log_det_jac :
105+ for transform in self .transforms :
106+ data = transform (data , stage = stage , ** kwargs )
107+ return data
100108
109+ log_det_jac = {}
101110 for transform in self .transforms :
102- data = transform (data , stage = stage , ** kwargs )
111+ transformed_data = transform (data , stage = stage , ** kwargs )
112+ log_det_jac = transform .log_det_jac (data , log_det_jac , ** kwargs )
113+ data = transformed_data
103114
104- return data
115+ return data , log_det_jac
105116
106- def inverse (self , data : dict [str , np .ndarray ], * , stage : str = "inference" , ** kwargs ) -> dict [str , any ]:
117+ def inverse (
118+ self , data : dict [str , np .ndarray ], * , stage : str = "inference" , log_det_jac : bool = False , ** kwargs
119+ ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
107120 """Apply the transforms in the inverse direction.
108121
109122 Parameters
@@ -112,24 +125,32 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kw
112125 The data to be transformed.
113126 stage : str, one of ["training", "validation", "inference"]
114127 The stage the function is called in.
128+ log_det_jac: bool, optional
129+ Whether to return the log determinant of the Jacobian of the transforms.
115130 **kwargs : dict
116131 Additional keyword arguments passed to each transform.
117132
118133 Returns
119134 -------
120- dict
121- The transformed data.
135+ dict | tuple[dict, dict]
136+ The transformed data or tuple of transformed data and log determinant of the Jacobian .
122137 """
123138 data = data .copy ()
139+ if not log_det_jac :
140+ for transform in reversed (self .transforms ):
141+ data = transform (data , stage = stage , inverse = True , ** kwargs )
142+ return data
124143
144+ log_det_jac = {}
125145 for transform in reversed (self .transforms ):
126146 data = transform (data , stage = stage , inverse = True , ** kwargs )
147+ log_det_jac = transform .log_det_jac (data , log_det_jac , inverse = True , ** kwargs )
127148
128- return data
149+ return data , log_det_jac
129150
130151 def __call__ (
131152 self , data : Mapping [str , any ], * , inverse : bool = False , stage = "inference" , ** kwargs
132- ) -> dict [str , np .ndarray ]:
153+ ) -> dict [str , np .ndarray ] | tuple [ dict [ str , np . ndarray ], dict [ str , np . ndarray ]] :
133154 """Apply the transforms in the given direction.
134155
135156 Parameters
@@ -145,8 +166,8 @@ def __call__(
145166
146167 Returns
147168 -------
148- dict
149- The transformed data.
169+ dict | tuple[dict, dict]
170+ The transformed data or tuple of transformed data and log determinant of the Jacobian .
150171 """
151172 if inverse :
152173 return self .inverse (data , stage = stage , ** kwargs )
0 commit comments