Skip to content

Commit f6a1708

Browse files
authored
fix unconditional sampling in ContinuousApproximator (#548)
- batch shape was calculated from inference_conditions even if they are known to be None - add approximator test for unconditional setting
1 parent d68c9dd commit f6a1708

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def _sample(
537537
)
538538
batch_shape = keras.ops.shape(inference_conditions)[:-1]
539539
else:
540-
batch_shape = keras.ops.shape(inference_conditions)[1:-1]
540+
batch_shape = (num_samples,)
541541

542542
return self.inference_network.sample(
543543
batch_shape, conditions=inference_conditions, **filter_kwargs(kwargs, self.inference_network.sample)

tests/test_approximators/conftest.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def point_inference_network_with_multiple_parametric_scores():
6868
def point_approximator_with_single_parametric_score(adapter, point_inference_network, summary_network):
6969
from bayesflow import PointApproximator
7070

71+
if "-> 'inference_conditions'" not in str(adapter) and "-> 'summary_conditions'" not in str(adapter):
72+
pytest.skip("point approximator does not support unconditional estimation")
73+
7174
return PointApproximator(
7275
adapter=adapter,
7376
inference_network=point_inference_network,
@@ -81,6 +84,9 @@ def point_approximator_with_multiple_parametric_scores(
8184
):
8285
from bayesflow import PointApproximator
8386

87+
if "-> 'inference_conditions'" not in str(adapter) and "-> 'summary_conditions'" not in str(adapter):
88+
pytest.skip("point approximator does not support unconditional estimation")
89+
8490
return PointApproximator(
8591
adapter=adapter,
8692
inference_network=point_inference_network_with_multiple_parametric_scores,
@@ -128,7 +134,16 @@ def adapter_with_sample_weight():
128134
)
129135

130136

131-
@pytest.fixture(params=["adapter_without_sample_weight", "adapter_with_sample_weight"])
137+
@pytest.fixture()
138+
def adapter_unconditional():
139+
from bayesflow import ContinuousApproximator
140+
141+
return ContinuousApproximator.build_adapter(
142+
inference_variables=["mean", "std"],
143+
)
144+
145+
146+
@pytest.fixture(params=["adapter_unconditional", "adapter_without_sample_weight", "adapter_with_sample_weight"])
132147
def adapter(request):
133148
return request.getfixturevalue(request.param)
134149

0 commit comments

Comments
 (0)