Skip to content

Commit e59b2b2

Browse files
rename sampling_fn argument to sample_fn in GraphicalSimulator.add_node method
1 parent 56f7681 commit e59b2b2

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

bayesflow/experimental/graphical_simulator/example_simulators.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
23
from .graphical_simulator import GraphicalSimulator
34

45

@@ -60,21 +61,21 @@ def meta_fn():
6061
simulator = GraphicalSimulator(meta_fn=meta_fn)
6162
simulator.add_node(
6263
"schools",
63-
sampling_fn=sample_school,
64+
sample_fn=sample_school,
6465
)
6566
simulator.add_node(
6667
"exams",
67-
sampling_fn=sample_exam,
68+
sample_fn=sample_exam,
6869
reps="num_exams",
6970
)
7071
simulator.add_node(
7172
"questions",
72-
sampling_fn=sample_question,
73+
sample_fn=sample_question,
7374
reps="num_questions",
7475
)
7576
simulator.add_node(
7677
"students",
77-
sampling_fn=sample_student,
78+
sample_fn=sample_student,
7879
reps="num_students",
7980
)
8081

@@ -109,8 +110,8 @@ def meta():
109110

110111
simulator = GraphicalSimulator(meta_fn=meta)
111112

112-
simulator.add_node("prior", sampling_fn=prior)
113-
simulator.add_node("likelihood", sampling_fn=likelihood)
113+
simulator.add_node("prior", sample_fn=prior)
114+
simulator.add_node("likelihood", sample_fn=likelihood)
114115

115116
simulator.add_edge("prior", "likelihood")
116117

@@ -140,15 +141,15 @@ def sample_y(local_mean, shared_std):
140141
return {"y": float(y)}
141142

142143
simulator = GraphicalSimulator()
143-
simulator.add_node("hypers", sampling_fn=sample_hypers, reps=1)
144+
simulator.add_node("hypers", sample_fn=sample_hypers, reps=5)
144145

145146
simulator.add_node(
146147
"locals",
147148
sampling_fn=sample_locals,
148149
reps=6,
149150
)
150-
simulator.add_node("shared", sampling_fn=sample_shared, reps=1)
151-
simulator.add_node("y", sampling_fn=sample_y, reps=10)
151+
simulator.add_node("shared", sample_fn=sample_shared, reps=1)
152+
simulator.add_node("y", sample_fn=sample_y, reps=10)
152153

153154
simulator.add_edge("hypers", "locals")
154155
simulator.add_edge("locals", "y")
@@ -184,19 +185,19 @@ def sample_y(level_3_mean, shared_std):
184185
return {"y": y}
185186

186187
simulator = GraphicalSimulator()
187-
simulator.add_node("level1", sampling_fn=sample_level_1)
188+
simulator.add_node("level1", sample_fn=sample_level_1)
188189
simulator.add_node(
189190
"level2",
190-
sampling_fn=sample_level_2,
191+
sample_fn=sample_level_2,
191192
reps=10,
192193
)
193194
simulator.add_node(
194195
"level3",
195-
sampling_fn=sample_level_3,
196+
sample_fn=sample_level_3,
196197
reps=20,
197198
)
198-
simulator.add_node("shared", sampling_fn=sample_shared)
199-
simulator.add_node("y", sampling_fn=sample_y, reps=10)
199+
simulator.add_node("shared", sample_fn=sample_shared)
200+
simulator.add_node("y", sample_fn=sample_y, reps=10)
200201

201202
simulator.add_edge("level1", "level2")
202203
simulator.add_edge("level2", "level3")

bayesflow/experimental/graphical_simulator/graphical_simulator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def __init__(self, meta_fn: Optional[Callable[[], dict[str, Any]]] = None, *args
2929
self.graph = nx.DiGraph()
3030
self.meta_fn = meta_fn
3131

32-
def add_node(self, node: str, sampling_fn: Callable[..., dict[str, Any]], reps: int | str = 1):
33-
self.graph.add_node(node, sampling_fn=sampling_fn, reps=reps)
32+
def add_node(self, node: str, sample_fn: Callable[..., dict[str, Any]], reps: int | str = 1):
33+
self.graph.add_node(node, sample_fn=sample_fn, reps=reps)
3434

3535
def add_edge(self, from_node: str, to_node: str):
3636
self.graph.add_edge(from_node, to_node)
@@ -62,7 +62,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
6262
node_samples = []
6363

6464
parent_nodes = list(self.graph.predecessors(node))
65-
sampling_fn = self.graph.nodes[node]["sampling_fn"]
65+
sampling_fn = self.graph.nodes[node]["sample_fn"]
6666
reps_field = self.graph.nodes[node]["reps"]
6767
reps = reps_field if isinstance(reps_field, int) else meta_dict[reps_field]
6868

0 commit comments

Comments
 (0)