Skip to content

Commit 3f91e7a

Browse files
samples method of GraphicalSimulator now returns a dict of appropriately shaped numpy arrays
1 parent 7a85866 commit 3f91e7a

File tree

2 files changed

+79
-36
lines changed

2 files changed

+79
-36
lines changed

bayesflow/experimental/graphical_simulator/example_simulators.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,5 @@
11
import numpy as np
2-
from .graphical_simmulator import GraphicalSimulator
3-
from bayesflow.utils import batched_call
4-
5-
6-
def test_batched_call():
7-
return batched_call(sample_fn, (10, 2), flatten=True)
8-
pass
9-
10-
11-
def sample_fn():
12-
return {"a": 3, "b": 6}
2+
from .graphical_simulator import GraphicalSimulator
133

144

155
def twolevel_simulator():

bayesflow/experimental/graphical_simulator/graphical_simulator.py

Lines changed: 78 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import inspect
21
import itertools
32
from collections.abc import Callable
43
from typing import Any, Optional
@@ -8,6 +7,7 @@
87

98
from bayesflow.simulators import Simulator
109
from bayesflow.types import Shape
10+
from bayesflow.utils.decorators import allow_batch_size
1111

1212

1313
class GraphicalSimulator(Simulator):
@@ -34,6 +34,7 @@ def add_node(self, node: str, sampling_fn: Callable[..., dict[str, Any]], reps:
3434
def add_edge(self, from_node: str, to_node: str):
3535
self.graph.add_edge(from_node, to_node)
3636

37+
@allow_batch_size
3738
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
3839
"""
3940
Generates samples by topologically traversing the DAG.
@@ -49,10 +50,11 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
4950
"""
5051
_ = kwargs # Simulator class requires **kwargs, which are unused here
5152
meta_dict = self.meta_fn() if self.meta_fn else {}
53+
samples_by_node = {}
5254

53-
# Initialize samples containers for each node
55+
# Initialize samples container for each node
5456
for node in self.graph.nodes:
55-
self.graph.nodes[node]["samples"] = np.empty(batch_shape, dtype="object")
57+
samples_by_node[node] = np.empty(batch_shape, dtype="object")
5658

5759
for batch_idx in np.ndindex(batch_shape):
5860
for node in nx.topological_sort(self.graph):
@@ -70,46 +72,97 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
7072
]
7173
else:
7274
# non-root node: depends on parent samples
73-
parent_samples = [self.graph.nodes[p]["samples"][batch_idx] for p in parent_nodes]
75+
parent_samples = [samples_by_node[p][batch_idx] for p in parent_nodes]
7476
merged_dicts = merge_lists_of_dicts(parent_samples)
7577

7678
for merged in merged_dicts:
77-
index_entries = filter_indices(merged)
78-
variable_entries = filter_variables(merged)
79+
index_entries = {k: v for k, v in merged.items() if k.startswith("__")}
80+
variable_entries = {k: v for k, v in merged.items() if not k.startswith("__")}
7981

8082
node_samples.extend(
8183
[
82-
index_entries | {f"__{node}_idx": i} | call_sampling_fn(sampling_fn, variable_entries)
84+
index_entries | {f"__{node}_idx": i} | sampling_fn(**variable_entries)
8385
for i in range(1, reps + 1)
8486
]
8587
)
8688

87-
self.graph.nodes[node]["samples"][batch_idx] = node_samples
89+
samples_by_node[node][batch_idx] = node_samples
8890

89-
return {"a": np.zeros(3)}
91+
output_dict = {}
92+
for node in nx.topological_sort(self.graph):
93+
output_dict.update(self._collect_output(samples_by_node[node]))
94+
95+
return output_dict
96+
97+
def _collect_output(self, samples):
98+
output_dict = {}
99+
100+
index_entries = [k for k in samples.flat[0][0].keys() if k.startswith("__")]
101+
node = index_entries[-1].removeprefix("__").removesuffix("_idx")
102+
ancestors = non_root_ancestors(self.graph, node)
103+
variable_names = self._variable_names(samples)
104+
105+
for variable in variable_names:
106+
output_shape = self._output_shape(samples, variable)
107+
output_dict[variable] = np.empty(output_shape)
108+
109+
for batch_idx in np.ndindex(samples.shape):
110+
for sample in samples[batch_idx]:
111+
idx = tuple(
112+
[*batch_idx]
113+
+ [sample[f"__{a}_idx"] - 1 for a in ancestors]
114+
+ [sample[f"__{node}_idx"] - 1] # - 1 for 0-based indexing
115+
)
116+
output_dict[variable][idx] = sample[variable]
117+
118+
return output_dict
119+
120+
def _variable_names(self, samples):
121+
return [k for k in samples.flat[0][0].keys() if not k.startswith("__")]
122+
123+
def _output_shape(self, samples, variable):
124+
index_entries = [k for k in samples.flat[0][0].keys() if k.startswith("__")]
125+
node = index_entries[-1].removeprefix("__").removesuffix("_idx")
126+
127+
# start with batch shape
128+
batch_shape = samples.shape
129+
output_shape = [*batch_shape]
130+
ancestors = non_root_ancestors(self.graph, node)
131+
132+
# add reps of non root ancestors
133+
for ancestor in ancestors:
134+
reps = max(s[f"__{ancestor}_idx"] for s in samples.flat[0])
135+
output_shape.append(reps)
136+
137+
# add node reps
138+
if not is_root_node(self.graph, node):
139+
node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0])
140+
output_shape.append(node_reps)
141+
142+
# add variable shape
143+
variable_shape = np.atleast_1d(samples.flat[0][0][variable]).shape
144+
output_shape.extend(variable_shape)
145+
146+
return tuple(output_shape)
147+
148+
149+
def non_root_ancestors(graph, node):
150+
return [n for n in nx.topological_sort(graph) if n in nx.ancestors(graph, node) and not is_root_node(graph, n)]
151+
152+
153+
def is_root_node(graph, node):
154+
return len(list(graph.predecessors(node))) == 0
90155

91156

92157
def merge_lists_of_dicts(nested_list: list[list[dict]]) -> list[dict]:
93158
"""
94159
Merges all combinations of dictionaries from a list of lists.
95160
Equivalent to a Cartesian product of dicts, then flattening.
161+
162+
Examples:
163+
>>> merge_lists_of_dicts([[{"a": 1, "b": 2}], [{"c": 3}, {"d": 4}]])
164+
[{'a': 1, 'b': 2, 'c': 3}, {'a': 1, 'b': 2, 'd': 4}]
96165
"""
97166

98167
all_combinations = itertools.product(*nested_list)
99168
return [{k: v for d in combo for k, v in d.items()} for combo in all_combinations]
100-
101-
102-
def call_sampling_fn(sampling_fn: Callable, inputs: dict) -> dict[str, Any]:
103-
num_args = len(inspect.signature(sampling_fn).parameters)
104-
if num_args == 0:
105-
return sampling_fn()
106-
else:
107-
return sampling_fn(**inputs)
108-
109-
110-
def filter_indices(d: dict) -> dict[str, Any]:
111-
return {k: v for k, v in d.items() if k.startswith("__")}
112-
113-
114-
def filter_variables(d: dict) -> dict[str, Any]:
115-
return {k: v for k, v in d.items() if not k.startswith("__")}

0 commit comments

Comments
 (0)