Skip to content

Commit d1624ee

Browse files
enable sampling_fn with no arguments for non root nodes, change output dimensionality rules for sample method
1 parent 60e589a commit d1624ee

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

bayesflow/experimental/graphical_simulator/graphical_simulator.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import itertools
23
from collections.abc import Callable
34
from typing import Any, Optional
@@ -68,7 +69,8 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
6869
if not parent_nodes:
6970
# root node: generate independent samples
7071
node_samples = [
71-
{"__batch_idx": batch_idx, f"__{node}_idx": i} | sampling_fn() for i in range(1, reps + 1)
72+
{"__batch_idx": batch_idx, f"__{node}_idx": i} | self._call_sampling_fn(sampling_fn, {})
73+
for i in range(1, reps + 1)
7274
]
7375
else:
7476
# non-root node: depends on parent samples
@@ -79,9 +81,12 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
7981
index_entries = {k: v for k, v in merged.items() if k.startswith("__")}
8082
variable_entries = {k: v for k, v in merged.items() if not k.startswith("__")}
8183

84+
sampling_fn_input = variable_entries | meta_dict
8285
node_samples.extend(
8386
[
84-
index_entries | {f"__{node}_idx": i} | sampling_fn(**variable_entries)
87+
index_entries
88+
| {f"__{node}_idx": i}
89+
| self._call_sampling_fn(sampling_fn, sampling_fn_input)
8590
for i in range(1, reps + 1)
8691
]
8792
)
@@ -92,13 +97,16 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
9297
for node in nx.topological_sort(self.graph):
9398
output_dict.update(self._collect_output(samples_by_node[node]))
9499

100+
output_dict.update(meta_dict)
101+
95102
return output_dict
96103

97104
def _collect_output(self, samples):
98105
output_dict = {}
99106

100107
index_entries = [k for k in samples.flat[0][0].keys() if k.startswith("__")]
101108
node = index_entries[-1].removeprefix("__").removesuffix("_idx")
109+
node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0])
102110
ancestors = non_root_ancestors(self.graph, node)
103111
variable_names = self._variable_names(samples)
104112

@@ -108,12 +116,13 @@ def _collect_output(self, samples):
108116

109117
for batch_idx in np.ndindex(samples.shape):
110118
for sample in samples[batch_idx]:
111-
idx = tuple(
112-
[*batch_idx]
113-
+ [sample[f"__{a}_idx"] - 1 for a in ancestors]
114-
+ [sample[f"__{node}_idx"] - 1] # - 1 for 0-based indexing
115-
)
116-
output_dict[variable][idx] = sample[variable]
119+
idx = [*batch_idx]
120+
for ancestor in ancestors:
121+
idx.append(sample[f"__{ancestor}_idx"] - 1)
122+
if not is_root_node(self.graph, node):
123+
if node_reps != 1:
124+
idx.append(sample[f"__{node}_idx"] - 1) # -1 for 0-based indexing
125+
output_dict[variable][tuple(idx)] = sample[variable]
117126

118127
return output_dict
119128

@@ -137,14 +146,22 @@ def _output_shape(self, samples, variable):
137146
# add node reps
138147
if not is_root_node(self.graph, node):
139148
node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0])
140-
output_shape.append(node_reps)
149+
if node_reps != 1:
150+
output_shape.append(node_reps)
141151

142152
# add variable shape
143153
variable_shape = np.atleast_1d(samples.flat[0][0][variable]).shape
144154
output_shape.extend(variable_shape)
145155

146156
return tuple(output_shape)
147157

158+
def _call_sampling_fn(self, sampling_fn, args):
159+
signature = inspect.signature(sampling_fn)
160+
fn_args = signature.parameters
161+
accepted_args = {k: v for k, v in args.items() if k in fn_args}
162+
163+
return sampling_fn(**accepted_args)
164+
148165

149166
def non_root_ancestors(graph, node):
150167
return [n for n in nx.topological_sort(graph) if n in nx.ancestors(graph, node) and not is_root_node(graph, n)]

0 commit comments

Comments
 (0)