Skip to content

Commit 13112bc

Browse files
authored
Pass correct training stage in compute_metrics (#534)
* Pass correct training stage in CouplingFlow.compute_metrics * Pass correct training stage in CIF and PointInferenceNetwork
1 parent f1c0c87 commit 13112bc

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

bayesflow/experimental/cif/cif.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _inverse(
9999
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
100100
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
101101

102-
elbo = self.log_prob(x, conditions=conditions)
102+
elbo = self.log_prob(x, conditions=conditions, training=stage == "training")
103103

104104
loss = -keras.ops.mean(elbo)
105105

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def compute_metrics(
183183
) -> dict[str, Tensor]:
184184
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
185185

186-
z, log_density = self(x, conditions=conditions, inverse=False, density=True)
186+
z, log_density = self(x, conditions=conditions, inverse=False, density=True, training=stage == "training")
187187
loss = weighted_mean(-log_density, sample_weight)
188188

189189
return base_metrics | {"loss": loss}

bayesflow/networks/point_inference_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def call(
145145
def compute_metrics(
146146
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
147147
) -> dict[str, Tensor]:
148-
output = self(x, conditions)
148+
output = self(x, conditions, training=stage == "training")
149149

150150
metrics = {}
151151
# calculate negative score as mean over all scores

0 commit comments

Comments
 (0)