Skip to content

Commit 7a85866

Browse files
initial implementation GraphicalSimulator
1 parent c286962 commit 7a85866

File tree

3 files changed

+167
-0
lines changed

3 files changed

+167
-0
lines changed

bayesflow/experimental/graphical_simulator/__init__.py

Whitespace-only changes.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
from .graphical_simmulator import GraphicalSimulator
3+
from bayesflow.utils import batched_call
4+
5+
6+
def test_batched_call():
7+
return batched_call(sample_fn, (10, 2), flatten=True)
8+
pass
9+
10+
11+
def sample_fn():
12+
return {"a": 3, "b": 6}
13+
14+
15+
def twolevel_simulator():
16+
def sample_hypers():
17+
hyper_mean = np.random.normal()
18+
hyper_std = np.abs(np.random.normal())
19+
20+
return {"hyper_mean": float(hyper_mean), "hyper_std": float(hyper_std)}
21+
22+
def sample_locals(hyper_mean, hyper_std):
23+
local_mean = np.random.normal(hyper_mean, hyper_std)
24+
25+
return {"local_mean": float(local_mean)}
26+
27+
def sample_shared():
28+
shared_std = np.abs(np.random.normal())
29+
30+
return {"shared_std": shared_std}
31+
32+
def sample_y(local_mean, shared_std):
33+
y = np.random.normal(local_mean, shared_std)
34+
35+
return {"y": float(y)}
36+
37+
simulator = GraphicalSimulator()
38+
simulator.add_node("hypers", sampling_fn=sample_hypers, reps=1)
39+
40+
simulator.add_node(
41+
"locals",
42+
sampling_fn=sample_locals,
43+
reps=6,
44+
)
45+
simulator.add_node("shared", sampling_fn=sample_shared, reps=1)
46+
simulator.add_node("y", sampling_fn=sample_y, reps=10)
47+
48+
simulator.add_edge("hypers", "locals")
49+
simulator.add_edge("locals", "y")
50+
simulator.add_edge("shared", "y")
51+
52+
return simulator
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import inspect
2+
import itertools
3+
from collections.abc import Callable
4+
from typing import Any, Optional
5+
6+
import networkx as nx
7+
import numpy as np
8+
9+
from bayesflow.simulators import Simulator
10+
from bayesflow.types import Shape
11+
12+
13+
class GraphicalSimulator(Simulator):
14+
"""
15+
A graph-based simulator that generates samples by traversing a DAG
16+
and calling user-defined sampling functions at each node.
17+
18+
Parameters
19+
----------
20+
meta_fn : Optional[Callable[[], dict[str, Any]]]
21+
A callable that returns a dictionary of meta data.
22+
This meta data can be used to dynamically vary the number of sampling repetitions (`reps`)
23+
for nodes added via `add_node`.
24+
"""
25+
26+
def __init__(self, meta_fn: Optional[Callable[[], dict[str, Any]]] = None, *args, **kwargs):
27+
super().__init__(*args, **kwargs)
28+
self.graph = nx.DiGraph()
29+
self.meta_fn = meta_fn
30+
31+
def add_node(self, node: str, sampling_fn: Callable[..., dict[str, Any]], reps: int | str = 1):
32+
self.graph.add_node(node, sampling_fn=sampling_fn, reps=reps)
33+
34+
def add_edge(self, from_node: str, to_node: str):
35+
self.graph.add_edge(from_node, to_node)
36+
37+
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
38+
"""
39+
Generates samples by topologically traversing the DAG.
40+
For each node, the sampling function is called based on parent values.
41+
42+
Parameters
43+
----------
44+
batch_shape : Shape
45+
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
46+
but an int can also be passed.
47+
**kwargs
48+
Unused
49+
"""
50+
_ = kwargs # Simulator class requires **kwargs, which are unused here
51+
meta_dict = self.meta_fn() if self.meta_fn else {}
52+
53+
# Initialize samples containers for each node
54+
for node in self.graph.nodes:
55+
self.graph.nodes[node]["samples"] = np.empty(batch_shape, dtype="object")
56+
57+
for batch_idx in np.ndindex(batch_shape):
58+
for node in nx.topological_sort(self.graph):
59+
node_samples = []
60+
61+
parent_nodes = list(self.graph.predecessors(node))
62+
sampling_fn = self.graph.nodes[node]["sampling_fn"]
63+
reps_field = self.graph.nodes[node]["reps"]
64+
reps = reps_field if isinstance(reps_field, int) else meta_dict[reps_field]
65+
66+
if not parent_nodes:
67+
# root node: generate independent samples
68+
node_samples = [
69+
{"__batch_idx": batch_idx, f"__{node}_idx": i} | sampling_fn() for i in range(1, reps + 1)
70+
]
71+
else:
72+
# non-root node: depends on parent samples
73+
parent_samples = [self.graph.nodes[p]["samples"][batch_idx] for p in parent_nodes]
74+
merged_dicts = merge_lists_of_dicts(parent_samples)
75+
76+
for merged in merged_dicts:
77+
index_entries = filter_indices(merged)
78+
variable_entries = filter_variables(merged)
79+
80+
node_samples.extend(
81+
[
82+
index_entries | {f"__{node}_idx": i} | call_sampling_fn(sampling_fn, variable_entries)
83+
for i in range(1, reps + 1)
84+
]
85+
)
86+
87+
self.graph.nodes[node]["samples"][batch_idx] = node_samples
88+
89+
return {"a": np.zeros(3)}
90+
91+
92+
def merge_lists_of_dicts(nested_list: list[list[dict]]) -> list[dict]:
93+
"""
94+
Merges all combinations of dictionaries from a list of lists.
95+
Equivalent to a Cartesian product of dicts, then flattening.
96+
"""
97+
98+
all_combinations = itertools.product(*nested_list)
99+
return [{k: v for d in combo for k, v in d.items()} for combo in all_combinations]
100+
101+
102+
def call_sampling_fn(sampling_fn: Callable, inputs: dict) -> dict[str, Any]:
103+
num_args = len(inspect.signature(sampling_fn).parameters)
104+
if num_args == 0:
105+
return sampling_fn()
106+
else:
107+
return sampling_fn(**inputs)
108+
109+
110+
def filter_indices(d: dict) -> dict[str, Any]:
111+
return {k: v for k, v in d.items() if k.startswith("__")}
112+
113+
114+
def filter_variables(d: dict) -> dict[str, Any]:
115+
return {k: v for k, v in d.items() if not k.startswith("__")}

0 commit comments

Comments
 (0)