Skip to content

Commit a322ff1

Browse files
authored
Adapter keeps track of the transform jacobians (#419)
* minimal working case (.scale) * concatenate * keep, drop, rename * scale, log, sqrt * standardize * constraint transforms * continuous approximator returns log_prob with volume correction * loop for inverse jacobian * inverse for elementwise * inverse for Transforms * raise error with numpy transform (for now) * do not fail if no transform is used * take care of log1p as well * fix filter transforms, boundary condition * add tests for adapter jacobians * document jacobian arg * jacobian -> log_det_jac * add test for inverse concatenation * fix standardize * correct nesting in map_transform
1 parent a7f9162 commit a322ff1

File tree

18 files changed

+312
-28
lines changed

18 files changed

+312
-28
lines changed

bayesflow/adapters/adapter.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def get_config(self) -> dict:
7979

8080
return serialize(config)
8181

82-
def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]:
82+
def forward(
83+
self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
84+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
8385
"""Apply the transforms in the forward direction.
8486
8587
Parameters
@@ -88,22 +90,33 @@ def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -
8890
The data to be transformed.
8991
stage : str, one of ["training", "validation", "inference"]
9092
The stage the function is called in.
93+
log_det_jac: bool, optional
94+
Whether to return the log determinant of the Jacobian of the transforms.
9195
**kwargs : dict
9296
Additional keyword arguments passed to each transform.
9397
9498
Returns
9599
-------
96-
dict
97-
The transformed data.
100+
dict | tuple[dict, dict]
101+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
98102
"""
99103
data = data.copy()
104+
if not log_det_jac:
105+
for transform in self.transforms:
106+
data = transform(data, stage=stage, **kwargs)
107+
return data
100108

109+
log_det_jac = {}
101110
for transform in self.transforms:
102-
data = transform(data, stage=stage, **kwargs)
111+
transformed_data = transform(data, stage=stage, **kwargs)
112+
log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs)
113+
data = transformed_data
103114

104-
return data
115+
return data, log_det_jac
105116

106-
def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]:
117+
def inverse(
118+
self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
119+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
107120
"""Apply the transforms in the inverse direction.
108121
109122
Parameters
@@ -112,24 +125,32 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kw
112125
The data to be transformed.
113126
stage : str, one of ["training", "validation", "inference"]
114127
The stage the function is called in.
128+
log_det_jac: bool, optional
129+
Whether to return the log determinant of the Jacobian of the transforms.
115130
**kwargs : dict
116131
Additional keyword arguments passed to each transform.
117132
118133
Returns
119134
-------
120-
dict
121-
The transformed data.
135+
dict | tuple[dict, dict]
136+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
122137
"""
123138
data = data.copy()
139+
if not log_det_jac:
140+
for transform in reversed(self.transforms):
141+
data = transform(data, stage=stage, inverse=True, **kwargs)
142+
return data
124143

144+
log_det_jac = {}
125145
for transform in reversed(self.transforms):
126146
data = transform(data, stage=stage, inverse=True, **kwargs)
147+
log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs)
127148

128-
return data
149+
return data, log_det_jac
129150

130151
def __call__(
131152
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
132-
) -> dict[str, np.ndarray]:
153+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
133154
"""Apply the transforms in the given direction.
134155
135156
Parameters
@@ -145,8 +166,8 @@ def __call__(
145166
146167
Returns
147168
-------
148-
dict
149-
The transformed data.
169+
dict | tuple[dict, dict]
170+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
150171
"""
151172
if inverse:
152173
return self.inverse(data, stage=stage, **kwargs)

bayesflow/adapters/transforms/concatenate.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,37 @@ def extra_repr(self) -> str:
115115
result += f", axis={self.axis}"
116116

117117
return result
118+
119+
def log_det_jac(
120+
self,
121+
data: dict[str, np.ndarray],
122+
log_det_jac: dict[str, np.ndarray],
123+
*,
124+
strict: bool = False,
125+
inverse: bool = False,
126+
**kwargs,
127+
) -> dict[str, np.ndarray]:
128+
# copy to avoid side effects
129+
log_det_jac = log_det_jac.copy()
130+
131+
if inverse:
132+
if log_det_jac.get(self.into) is not None:
133+
raise ValueError(
134+
"Cannot obtain an inverse Jacobian of concatenation. "
135+
"Transform your variables before you concatenate."
136+
)
137+
138+
return log_det_jac
139+
140+
required_keys = set(self.keys)
141+
available_keys = set(log_det_jac.keys())
142+
common_keys = available_keys & required_keys
143+
144+
if len(common_keys) == 0:
145+
return log_det_jac
146+
147+
parts = [log_det_jac.pop(key) for key in common_keys]
148+
149+
log_det_jac[self.into] = sum(parts)
150+
151+
return log_det_jac

bayesflow/adapters/transforms/constrain.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ def constrain(x):
8787

8888
def unconstrain(x):
8989
return inverse_sigmoid((x - lower) / (upper - lower))
90+
91+
def ldj(x):
92+
x = (x - lower) / (upper - lower)
93+
return -np.log(x) - np.log1p(-x) - np.log(upper - lower)
94+
9095
case str() as name:
9196
raise ValueError(f"Unsupported method name for double bounded constraint: '{name}'.")
9297
case other:
@@ -101,13 +106,22 @@ def constrain(x):
101106

102107
def unconstrain(x):
103108
return inverse_softplus(x - lower)
109+
110+
def ldj(x):
111+
x = x - lower
112+
return x - np.log(np.exp(x) - 1)
113+
104114
case "exp" | "log":
105115

106116
def constrain(x):
107117
return np.exp(x) + lower
108118

109119
def unconstrain(x):
110120
return np.log(x - lower)
121+
122+
def ldj(x):
123+
return -np.log(x - lower)
124+
111125
case str() as name:
112126
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
113127
case other:
@@ -122,13 +136,21 @@ def constrain(x):
122136

123137
def unconstrain(x):
124138
return -inverse_softplus(-(x - upper))
139+
140+
def ldj(x):
141+
x = -(x - upper)
142+
return x - np.log(np.exp(x) - 1)
143+
125144
case "exp" | "log":
126145

127146
def constrain(x):
128147
return -np.exp(-x) + upper
129148

130149
def unconstrain(x):
131150
return -np.log(-x + upper)
151+
152+
def ldj(x):
153+
return -np.log(-x + upper)
132154
case str() as name:
133155
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
134156
case other:
@@ -142,6 +164,7 @@ def unconstrain(x):
142164

143165
self.constrain = constrain
144166
self.unconstrain = unconstrain
167+
self.ldj = ldj
145168

146169
# do this last to avoid serialization issues
147170
match inclusive:
@@ -178,3 +201,9 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
178201
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
179202
# inverse means network space -> data space, so constrain the data
180203
return self.constrain(data)
204+
205+
def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
206+
ldj = self.ldj(data)
207+
if inverse:
208+
ldj = -ldj
209+
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))

bayesflow/adapters/transforms/drop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
4646

4747
def extra_repr(self) -> str:
4848
return "[" + ", ".join(map(repr, self.keys)) + "]"
49+
50+
def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
51+
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac)

bayesflow/adapters/transforms/elementwise_transform.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
2525

2626
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
2727
raise NotImplementedError
28+
29+
def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray | None:
30+
return None

bayesflow/adapters/transforms/filter_transform.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,35 @@ def _should_transform(self, key: str, value: np.ndarray, inverse: bool = False)
150150
return predicate(key, value, inverse=inverse)
151151

152152
def _apply_transform(self, key: str, value: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
153+
transform = self._get_transform(key)
154+
155+
return transform(value, inverse=inverse, **kwargs)
156+
157+
def _get_transform(self, key: str) -> ElementwiseTransform:
153158
if key not in self.transform_map:
154159
self.transform_map[key] = self.transform_constructor(**self.kwargs)
155160

156-
transform = self.transform_map[key]
161+
return self.transform_map[key]
157162

158-
return transform(value, inverse=inverse, **kwargs)
163+
def log_det_jac(
164+
self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs
165+
):
166+
data = data.copy()
167+
168+
if strict and self.include is not None:
169+
missing_keys = set(self.include) - set(data.keys())
170+
if missing_keys:
171+
raise KeyError(f"Missing keys from include list: {missing_keys!r}")
172+
173+
for key, value in data.items():
174+
if self._should_transform(key, value, inverse=False):
175+
transform = self._get_transform(key)
176+
ldj = transform.log_det_jac(value, **kwargs)
177+
if ldj is None:
178+
continue
179+
elif key in log_det_jac:
180+
log_det_jac[key] += ldj
181+
else:
182+
log_det_jac[key] = ldj
183+
184+
return log_det_jac

bayesflow/adapters/transforms/keep.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,6 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
5757

5858
def extra_repr(self) -> str:
5959
return "[" + ", ".join(map(repr, self.keys)) + "]"
60+
61+
def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
62+
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac)

bayesflow/adapters/transforms/log.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,12 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
3737

3838
def get_config(self) -> dict:
3939
return serialize({"p1": self.p1})
40+
41+
def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
42+
if self.p1:
43+
ldj = -np.log1p(data)
44+
else:
45+
ldj = -np.log(data)
46+
if inverse:
47+
ldj = -ldj
48+
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))

bayesflow/adapters/transforms/map_transform.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,8 @@ def get_config(self) -> dict:
4141
def forward(self, data: dict[str, np.ndarray], *, strict: bool = True, **kwargs) -> dict[str, np.ndarray]:
4242
data = data.copy()
4343

44-
required_keys = set(self.transform_map.keys())
45-
available_keys = set(data.keys())
46-
missing_keys = required_keys - available_keys
47-
48-
if strict and missing_keys:
49-
raise KeyError(f"Missing keys: {missing_keys!r}")
44+
if strict:
45+
self._check_keys(data)
5046

5147
for key, transform in self.transform_map.items():
5248
if key in data:
@@ -57,15 +53,40 @@ def forward(self, data: dict[str, np.ndarray], *, strict: bool = True, **kwargs)
5753
def inverse(self, data: dict[str, np.ndarray], *, strict: bool = False, **kwargs) -> dict[str, np.ndarray]:
5854
data = data.copy()
5955

60-
required_keys = set(self.transform_map.keys())
61-
available_keys = set(data.keys())
62-
missing_keys = required_keys - available_keys
63-
64-
if strict and missing_keys:
65-
raise KeyError(f"Missing keys: {missing_keys!r}")
56+
if strict:
57+
self._check_keys(data)
6658

6759
for key, transform in self.transform_map.items():
6860
if key in data:
6961
data[key] = transform.inverse(data[key], **kwargs)
7062

7163
return data
64+
65+
def log_det_jac(
66+
self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs
67+
) -> dict[str, np.ndarray]:
68+
data = data.copy()
69+
70+
if strict:
71+
self._check_keys(data)
72+
73+
for key, transform in self.transform_map.items():
74+
if key in data:
75+
ldj = transform.log_det_jac(data[key], **kwargs)
76+
77+
if ldj is None:
78+
continue
79+
elif key in log_det_jac:
80+
log_det_jac[key] += ldj
81+
else:
82+
log_det_jac[key] = ldj
83+
84+
return log_det_jac
85+
86+
def _check_keys(self, data: dict[str, np.ndarray]):
87+
required_keys = set(self.transform_map.keys())
88+
available_keys = set(data.keys())
89+
missing_keys = required_keys - available_keys
90+
91+
if missing_keys:
92+
raise KeyError(f"Missing keys: {missing_keys!r}")

bayesflow/adapters/transforms/numpy_transform.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,6 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
7272

7373
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
7474
return self._inverse(data)
75+
76+
def log_det_jac(self, data, inverse=False, **kwargs):
77+
raise NotImplementedError("log determinant of the Jacobian of the numpy transforms are not implemented yet")

0 commit comments

Comments
 (0)