Skip to content

Commit 4cc500e

Browse files
committed
Merge branch 'dev' into fork/vpratz/docs-migration-advice
2 parents 53dda82 + 42fa035 commit 4cc500e

File tree

18 files changed

+292
-14
lines changed

18 files changed

+292
-14
lines changed

.github/workflows/tests.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,11 @@ jobs:
7373
pytest -x -m "not slow"
7474
7575
- name: Run Slow Tests
76-
# run all slow tests only on manual trigger
77-
if: github.event_name == 'workflow_dispatch'
76+
# Run slow tests on manual trigger and pushes/PRs to main.
77+
# Limit to one OS and Python version to save compute.
78+
# Multiline if statements are weird, https://github.com/orgs/community/discussions/25641,
79+
# but feel free to convert it.
80+
if: ${{ ((github.event_name == 'workflow_dispatch') || (github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'pull_request' && github.base_ref == 'main')) && ((matrix.os == 'windows-latest') && (matrix.python-version == '3.10')) }}
7881
run: |
7982
pytest -m "slow"
8083

bayesflow/links/ordered.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from keras.saving import register_keras_serializable as serializable
33

44
from bayesflow.utils import layer_kwargs
5+
from bayesflow.utils.decorators import sanitize_input_shape
56

67

78
@serializable(package="links.ordered")
@@ -49,5 +50,6 @@ def call(self, inputs):
4950
x = keras.ops.concatenate([below, anchor_input, above], self.axis)
5051
return x
5152

53+
@sanitize_input_shape
5254
def compute_output_shape(self, input_shape):
5355
return input_shape

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
The type of transformation used in the coupling layers, such as "affine".
7878
Default is "affine".
7979
permutation : str or None, optional
80-
The type of permutation applied between layers. Can be "random" or None
80+
The type of permutation applied between layers. Can be "orthogonal", "random", "swap", or None
8181
(no permutation). Default is "random".
8282
use_actnorm : bool, optional
8383
Whether to apply ActNorm before each coupling layer. Default is True.

bayesflow/networks/summary_network.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def build(self, input_shape):
2121
if self.base_distribution is not None:
2222
self.base_distribution.build(keras.ops.shape(z))
2323

24+
@sanitize_input_shape
2425
def compute_output_shape(self, input_shape):
2526
return keras.ops.shape(self.call(keras.ops.zeros(input_shape)))
2627

bayesflow/networks/transformers/mab.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from bayesflow.networks import MLP
55
from bayesflow.types import Tensor
66
from bayesflow.utils import layer_kwargs
7+
from bayesflow.utils.decorators import sanitize_input_shape
78
from bayesflow.utils.serialization import serializable
89

910

@@ -122,8 +123,10 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -
122123
return out
123124

124125
# noinspection PyMethodOverriding
126+
@sanitize_input_shape
125127
def build(self, seq_x_shape, seq_y_shape):
126128
self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape))
127129

130+
@sanitize_input_shape
128131
def compute_output_shape(self, seq_x_shape, seq_y_shape):
129132
return keras.ops.shape(self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape)))

bayesflow/networks/transformers/pma.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from bayesflow.networks import MLP
55
from bayesflow.types import Tensor
66
from bayesflow.utils import layer_kwargs
7+
from bayesflow.utils.decorators import sanitize_input_shape
78
from bayesflow.utils.serialization import serializable
89

910
from .mab import MultiHeadAttentionBlock
@@ -125,5 +126,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
125126
summaries = self.mab(seed_tiled, set_x_transformed, training=training, **kwargs)
126127
return ops.reshape(summaries, (ops.shape(summaries)[0], -1))
127128

129+
@sanitize_input_shape
128130
def compute_output_shape(self, input_shape):
129131
return keras.ops.shape(self.call(keras.ops.zeros(input_shape)))

bayesflow/networks/transformers/sab.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import keras
22

33
from bayesflow.types import Tensor
4+
from bayesflow.utils.decorators import sanitize_input_shape
45
from bayesflow.utils.serialization import serializable
56

67
from .mab import MultiHeadAttentionBlock
@@ -16,6 +17,7 @@ class SetAttentionBlock(MultiHeadAttentionBlock):
1617
"""
1718

1819
# noinspection PyMethodOverriding
20+
@sanitize_input_shape
1921
def build(self, input_set_shape):
2022
self.call(keras.ops.zeros(input_set_shape))
2123

@@ -42,5 +44,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
4244
return super().call(input_set, input_set, training=training, **kwargs)
4345

4446
# noinspection PyMethodOverriding
47+
@sanitize_input_shape
4548
def compute_output_shape(self, input_set_shape):
4649
return keras.ops.shape(self.call(keras.ops.zeros(input_set_shape)))

bayesflow/utils/decorators.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def callback(x):
114114

115115

116116
def sanitize_input_shape(fn: Callable):
117-
"""Decorator to replace the first dimension in input_shape with a dummy batch size if it is None"""
117+
"""Decorator to replace the first dimension in ..._shape arguments with a dummy batch size if it is None"""
118118

119119
# The Keras functional API passes input_shape = (None, second_dim, third_dim, ...), which
120120
# causes problems when constructions like self.call(keras.ops.zeros(input_shape)) are used
@@ -126,5 +126,8 @@ def callback(input_shape: Shape) -> Shape:
126126
return tuple(input_shape)
127127
return input_shape
128128

129-
fn = argument_callback("input_shape", callback)(fn)
129+
args = inspect.getfullargspec(fn).args
130+
for arg in args:
131+
if arg.endswith("_shape"):
132+
fn = argument_callback(arg, callback)(fn)
130133
return fn

docsrc/source/about.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
About us
1+
About Us
22
========
33

44
Core maintainers

docsrc/source/conf.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,29 @@
141141
"image_light": "_static/bayesflow_hor.png",
142142
"image_dark": "_static/bayesflow_hor_dark.png",
143143
},
144-
"navbar_center": ["version-switcher", "navbar-nav"],
144+
"icon_links_label": "Icon Links",
145+
"icon_links": [
146+
{
147+
"name": "GitHub",
148+
"url": "https://github.com/bayesflow-org/bayesflow",
149+
"icon": "fa-brands fa-square-github",
150+
"type": "fontawesome",
151+
},
152+
{
153+
"name": "Discourse Forum",
154+
"url": "https://discuss.bayesflow.org/",
155+
"icon": "fa-brands fa-discourse",
156+
"type": "fontawesome",
157+
},
158+
],
159+
"navbar_align": "left",
160+
# -- Template placement in theme layouts ----------------------------------
161+
"navbar_start": ["navbar-logo"],
162+
# Note that the alignment of navbar_center is controlled by navbar_align
163+
"navbar_center": ["navbar-nav"],
164+
"navbar_end": ["theme-switcher", "navbar-icon-links", "version-switcher"],
165+
# navbar_persistent is persistent right (even when on mobiles)
166+
"navbar_persistent": ["search-button"],
145167
"switcher": {
146168
"json_url": "/versions.json",
147169
"version_match": current,

0 commit comments

Comments
 (0)