Skip to content

Commit 7ea287f

Browse files
authored
More tests (#437)
* fix docs of coupling flow * add additional tests
1 parent 8ac8aa3 commit 7ea287f

File tree

3 files changed

+203
-1
lines changed

3 files changed

+203
-1
lines changed

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
The type of transformation used in the coupling layers, such as "affine".
7878
Default is "affine".
7979
permutation : str or None, optional
80-
The type of permutation applied between layers. Can be "random" or None
80+
The type of permutation applied between layers. Can be "orthogonal", "random", "swap", or None
8181
(no permutation). Default is "random".
8282
use_actnorm : bool, optional
8383
Whether to apply ActNorm before each coupling layer. Default is True.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import pytest
2+
import keras
3+
import numpy as np
4+
5+
from bayesflow.networks.coupling_flow.permutations import (
6+
FixedPermutation,
7+
OrthogonalPermutation,
8+
RandomPermutation,
9+
Swap,
10+
)
11+
12+
13+
@pytest.fixture(params=[FixedPermutation, OrthogonalPermutation, RandomPermutation, Swap])
14+
def permutation_class(request):
15+
return request.param
16+
17+
18+
@pytest.fixture
19+
def input_tensor():
20+
return keras.random.normal((2, 5))
21+
22+
23+
def test_fixed_permutation_build_and_call():
24+
# Since FixedPermutation is abstract, create a subclass for testing build.
25+
class TestPerm(FixedPermutation):
26+
def build(self, xz_shape, **kwargs):
27+
length = xz_shape[-1]
28+
self.forward_indices = keras.ops.arange(length - 1, -1, -1)
29+
self.inverse_indices = keras.ops.arange(length - 1, -1, -1)
30+
31+
layer = TestPerm()
32+
input_shape = (2, 4)
33+
layer.build(input_shape)
34+
35+
x = keras.ops.convert_to_tensor(np.arange(8).reshape(input_shape).astype("float32"))
36+
z, log_det = layer(x, inverse=False)
37+
x_inv, log_det_inv = layer(z, inverse=True)
38+
39+
# Check shape preservation
40+
assert z.shape == x.shape
41+
assert x_inv.shape == x.shape
42+
# Forward then inverse recovers input
43+
np.testing.assert_allclose(keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(x), atol=1e-5)
44+
# log_det values should be zero tensors with the correct shape
45+
assert tuple(log_det.shape) == input_shape[:-1]
46+
assert tuple(log_det_inv.shape) == input_shape[:-1]
47+
48+
49+
def test_orthogonal_permutation_build_and_call(input_tensor):
50+
layer = OrthogonalPermutation()
51+
input_shape = keras.ops.shape(input_tensor)
52+
layer.build(input_shape)
53+
54+
z, log_det = layer(input_tensor)
55+
x_inv, log_det_inv = layer(z, inverse=True)
56+
57+
# Check output shapes
58+
assert z.shape == input_tensor.shape
59+
assert x_inv.shape == input_tensor.shape
60+
61+
# Forward + inverse should approximately recover input (allow some numeric tolerance)
62+
np.testing.assert_allclose(
63+
keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(input_tensor), rtol=1e-5, atol=1e-5
64+
)
65+
66+
# log_det should be scalar or batched scalar
67+
if len(log_det.shape) > 0:
68+
assert log_det.shape[0] == input_tensor.shape[0] # batch dim
69+
else:
70+
assert log_det.shape == ()
71+
72+
# log_det_inv should be negative of log_det (det(inv) = 1/det)
73+
log_det_np = keras.ops.convert_to_numpy(log_det)
74+
log_det_inv_np = keras.ops.convert_to_numpy(log_det_inv)
75+
np.testing.assert_allclose(log_det_inv_np, -log_det_np, rtol=1e-5, atol=1e-5)
76+
77+
78+
def test_random_permutation_build_and_call(input_tensor):
79+
layer = RandomPermutation()
80+
input_shape = keras.ops.shape(input_tensor)
81+
layer.build(input_shape)
82+
83+
# Assert forward_indices and inverse_indices are set and consistent
84+
fwd = keras.ops.convert_to_numpy(layer.forward_indices)
85+
inv = keras.ops.convert_to_numpy(layer.inverse_indices)
86+
# Applying inv on fwd must yield ordered indices
87+
reordered = fwd[inv]
88+
np.testing.assert_array_equal(np.arange(len(fwd)), reordered)
89+
90+
z, log_det = layer(input_tensor)
91+
x_inv, log_det_inv = layer(z, inverse=True)
92+
93+
assert z.shape == input_tensor.shape
94+
assert x_inv.shape == input_tensor.shape
95+
np.testing.assert_allclose(keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(input_tensor), atol=1e-5)
96+
assert tuple(log_det.shape) == input_shape[:-1]
97+
assert tuple(log_det_inv.shape) == input_shape[:-1]
98+
99+
100+
def test_swap_build_and_call(input_tensor):
101+
layer = Swap()
102+
input_shape = keras.ops.shape(input_tensor)
103+
layer.build(input_shape)
104+
105+
fwd = keras.ops.convert_to_numpy(layer.forward_indices)
106+
inv = keras.ops.convert_to_numpy(layer.inverse_indices)
107+
reordered = fwd[inv]
108+
np.testing.assert_array_equal(np.arange(len(fwd)), reordered)
109+
110+
z, log_det = layer(input_tensor)
111+
x_inv, log_det_inv = layer(z, inverse=True)
112+
113+
assert z.shape == input_tensor.shape
114+
assert x_inv.shape == input_tensor.shape
115+
np.testing.assert_allclose(keras.ops.convert_to_numpy(x_inv), keras.ops.convert_to_numpy(input_tensor), atol=1e-5)
116+
assert tuple(log_det.shape) == input_shape[:-1]
117+
assert tuple(log_det_inv.shape) == input_shape[:-1]
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pytest
2+
import keras
3+
4+
from bayesflow.networks.embeddings import (
5+
FourierEmbedding,
6+
RecurrentEmbedding,
7+
Time2Vec,
8+
)
9+
10+
11+
def test_fourier_embedding_output_shape_and_type():
12+
embed_dim = 8
13+
batch_size = 4
14+
15+
emb_layer = FourierEmbedding(embed_dim=embed_dim, include_identity=True)
16+
# use keras.ops.zeros with shape (batch_size, 1) and float32 dtype
17+
t = keras.ops.zeros((batch_size, 1), dtype="float32")
18+
19+
emb = emb_layer(t)
20+
# Expected shape is (batch_size, embed_dim + 1) if include_identity else (batch_size, embed_dim)
21+
expected_dim = embed_dim + 1
22+
assert emb.shape[0] == batch_size
23+
assert emb.shape[1] == expected_dim
24+
# Check type - it should be a Keras tensor, convert to numpy for checking
25+
np_emb = keras.ops.convert_to_numpy(emb)
26+
assert np_emb.shape == (batch_size, expected_dim)
27+
28+
29+
def test_fourier_embedding_without_identity():
30+
embed_dim = 8
31+
batch_size = 3
32+
33+
emb_layer = FourierEmbedding(embed_dim=embed_dim, include_identity=False)
34+
t = keras.ops.zeros((batch_size, 1), dtype="float32")
35+
36+
emb = emb_layer(t)
37+
expected_dim = embed_dim
38+
assert emb.shape[0] == batch_size
39+
assert emb.shape[1] == expected_dim
40+
41+
42+
def test_fourier_embedding_raises_for_odd_embed_dim():
43+
with pytest.raises(ValueError):
44+
FourierEmbedding(embed_dim=7)
45+
46+
47+
def test_recurrent_embedding_lstm_and_gru_shapes():
48+
batch_size = 2
49+
seq_len = 5
50+
dim = 3
51+
embed_dim = 6
52+
53+
# Dummy input
54+
x = keras.ops.zeros((batch_size, seq_len, dim), dtype="float32")
55+
56+
# lstm
57+
lstm_layer = RecurrentEmbedding(embed_dim=embed_dim, embedding="lstm")
58+
emb_lstm = lstm_layer(x)
59+
# Check the concatenated shape: last dimension = original dim + embed_dim
60+
assert emb_lstm.shape == (batch_size, seq_len, dim + embed_dim)
61+
62+
# gru
63+
gru_layer = RecurrentEmbedding(embed_dim=embed_dim, embedding="gru")
64+
emb_gru = gru_layer(x)
65+
assert emb_gru.shape == (batch_size, seq_len, dim + embed_dim)
66+
67+
68+
def test_recurrent_embedding_raises_unknown_embedding():
69+
with pytest.raises(ValueError):
70+
RecurrentEmbedding(embed_dim=4, embedding="unknown")
71+
72+
73+
def test_time2vec_shapes_and_output():
74+
batch_size = 3
75+
seq_len = 7
76+
dim = 2
77+
num_periodic_features = 4
78+
79+
x = keras.ops.zeros((batch_size, seq_len, dim), dtype="float32")
80+
time2vec_layer = Time2Vec(num_periodic_features=num_periodic_features)
81+
82+
emb = time2vec_layer(x)
83+
# The last dimension should be dim + num_periodic_features + 1 (trend + periodic)
84+
expected_dim = dim + num_periodic_features + 1
85+
assert emb.shape == (batch_size, seq_len, expected_dim)

0 commit comments

Comments
 (0)