@@ -36,13 +36,21 @@ def test_variable_batch_size(inference_network, random_samples, random_condition
3636 else :
3737 new_conditions = keras .ops .zeros ((bs ,) + keras .ops .shape (random_conditions )[1 :])
3838
39- inference_network (new_input , conditions = new_conditions )
39+ try :
40+ inference_network (new_input , conditions = new_conditions )
41+ except NotImplementedError :
42+ # network is not invertible
43+ pass
4044 inference_network (new_input , conditions = new_conditions , inverse = True )
4145
4246
4347@pytest .mark .parametrize ("density" , [True , False ])
4448def test_output_structure (density , generative_inference_network , random_samples , random_conditions ):
45- output = generative_inference_network (random_samples , conditions = random_conditions , density = density )
49+ try :
50+ output = generative_inference_network (random_samples , conditions = random_conditions , density = density )
51+ except NotImplementedError :
52+ # network not invertible
53+ return
4654
4755 if density :
4856 assert isinstance (output , tuple )
@@ -57,9 +65,13 @@ def test_output_structure(density, generative_inference_network, random_samples,
5765
5866
5967def test_output_shape (generative_inference_network , random_samples , random_conditions ):
60- forward_output , forward_log_density = generative_inference_network (
61- random_samples , conditions = random_conditions , density = True
62- )
68+ try :
69+ forward_output , forward_log_density = generative_inference_network (
70+ random_samples , conditions = random_conditions , density = True
71+ )
72+ except NotImplementedError :
73+ # network is not invertible, not forward function available
74+ return
6375
6476 assert keras .ops .shape (forward_output ) == keras .ops .shape (random_samples )
6577 assert keras .ops .shape (forward_log_density ) == (keras .ops .shape (random_samples )[0 ],)
@@ -74,9 +86,13 @@ def test_output_shape(generative_inference_network, random_samples, random_condi
7486
7587def test_cycle_consistency (generative_inference_network , random_samples , random_conditions ):
7688 # cycle-consistency means the forward and inverse methods are inverses of each other
77- forward_output , forward_log_density = generative_inference_network (
78- random_samples , conditions = random_conditions , density = True
79- )
89+ try :
90+ forward_output , forward_log_density = generative_inference_network (
91+ random_samples , conditions = random_conditions , density = True
92+ )
93+ except NotImplementedError :
94+ # network is not invertible, cycle consistency cannot be tested.
95+ return
8096 inverse_output , inverse_log_density = generative_inference_network (
8197 forward_output , conditions = random_conditions , density = True , inverse = True
8298 )
@@ -88,7 +104,11 @@ def test_cycle_consistency(generative_inference_network, random_samples, random_
88104def test_density_numerically (generative_inference_network , random_samples , random_conditions ):
89105 from bayesflow .utils import jacobian
90106
91- output , log_density = generative_inference_network (random_samples , conditions = random_conditions , density = True )
107+ try :
108+ output , log_density = generative_inference_network (random_samples , conditions = random_conditions , density = True )
109+ except NotImplementedError :
110+ # network does not support density estimation
111+ return
92112
93113 def f (x ):
94114 return generative_inference_network (x , conditions = random_conditions )
0 commit comments