1- import inspect
21import itertools
32from collections .abc import Callable
43from typing import Any , Optional
87
98from bayesflow .simulators import Simulator
109from bayesflow .types import Shape
10+ from bayesflow .utils .decorators import allow_batch_size
1111
1212
1313class GraphicalSimulator (Simulator ):
@@ -34,6 +34,7 @@ def add_node(self, node: str, sampling_fn: Callable[..., dict[str, Any]], reps:
3434 def add_edge (self , from_node : str , to_node : str ):
3535 self .graph .add_edge (from_node , to_node )
3636
37+ @allow_batch_size
3738 def sample (self , batch_shape : Shape , ** kwargs ) -> dict [str , np .ndarray ]:
3839 """
3940 Generates samples by topologically traversing the DAG.
@@ -49,10 +50,11 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
4950 """
5051 _ = kwargs # Simulator class requires **kwargs, which are unused here
5152 meta_dict = self .meta_fn () if self .meta_fn else {}
53+ samples_by_node = {}
5254
53- # Initialize samples containers for each node
55+ # Initialize samples container for each node
5456 for node in self .graph .nodes :
55- self . graph . nodes [node ][ "samples" ] = np .empty (batch_shape , dtype = "object" )
57+ samples_by_node [node ] = np .empty (batch_shape , dtype = "object" )
5658
5759 for batch_idx in np .ndindex (batch_shape ):
5860 for node in nx .topological_sort (self .graph ):
@@ -70,46 +72,97 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
7072 ]
7173 else :
7274 # non-root node: depends on parent samples
73- parent_samples = [self . graph . nodes [ p ][ "samples" ][batch_idx ] for p in parent_nodes ]
75+ parent_samples = [samples_by_node [ p ][batch_idx ] for p in parent_nodes ]
7476 merged_dicts = merge_lists_of_dicts (parent_samples )
7577
7678 for merged in merged_dicts :
77- index_entries = filter_indices ( merged )
78- variable_entries = filter_variables ( merged )
79+ index_entries = { k : v for k , v in merged . items () if k . startswith ( "__" )}
80+ variable_entries = { k : v for k , v in merged . items () if not k . startswith ( "__" )}
7981
8082 node_samples .extend (
8183 [
82- index_entries | {f"__{ node } _idx" : i } | call_sampling_fn ( sampling_fn , variable_entries )
84+ index_entries | {f"__{ node } _idx" : i } | sampling_fn ( ** variable_entries )
8385 for i in range (1 , reps + 1 )
8486 ]
8587 )
8688
87- self . graph . nodes [node ][ "samples" ][batch_idx ] = node_samples
89+ samples_by_node [node ][batch_idx ] = node_samples
8890
89- return {"a" : np .zeros (3 )}
91+ output_dict = {}
92+ for node in nx .topological_sort (self .graph ):
93+ output_dict .update (self ._collect_output (samples_by_node [node ]))
94+
95+ return output_dict
96+
97+ def _collect_output (self , samples ):
98+ output_dict = {}
99+
100+ index_entries = [k for k in samples .flat [0 ][0 ].keys () if k .startswith ("__" )]
101+ node = index_entries [- 1 ].removeprefix ("__" ).removesuffix ("_idx" )
102+ ancestors = non_root_ancestors (self .graph , node )
103+ variable_names = self ._variable_names (samples )
104+
105+ for variable in variable_names :
106+ output_shape = self ._output_shape (samples , variable )
107+ output_dict [variable ] = np .empty (output_shape )
108+
109+ for batch_idx in np .ndindex (samples .shape ):
110+ for sample in samples [batch_idx ]:
111+ idx = tuple (
112+ [* batch_idx ]
113+ + [sample [f"__{ a } _idx" ] - 1 for a in ancestors ]
114+ + [sample [f"__{ node } _idx" ] - 1 ] # - 1 for 0-based indexing
115+ )
116+ output_dict [variable ][idx ] = sample [variable ]
117+
118+ return output_dict
119+
120+ def _variable_names (self , samples ):
121+ return [k for k in samples .flat [0 ][0 ].keys () if not k .startswith ("__" )]
122+
123+ def _output_shape (self , samples , variable ):
124+ index_entries = [k for k in samples .flat [0 ][0 ].keys () if k .startswith ("__" )]
125+ node = index_entries [- 1 ].removeprefix ("__" ).removesuffix ("_idx" )
126+
127+ # start with batch shape
128+ batch_shape = samples .shape
129+ output_shape = [* batch_shape ]
130+ ancestors = non_root_ancestors (self .graph , node )
131+
132+ # add reps of non root ancestors
133+ for ancestor in ancestors :
134+ reps = max (s [f"__{ ancestor } _idx" ] for s in samples .flat [0 ])
135+ output_shape .append (reps )
136+
137+ # add node reps
138+ if not is_root_node (self .graph , node ):
139+ node_reps = max (s [f"__{ node } _idx" ] for s in samples .flat [0 ])
140+ output_shape .append (node_reps )
141+
142+ # add variable shape
143+ variable_shape = np .atleast_1d (samples .flat [0 ][0 ][variable ]).shape
144+ output_shape .extend (variable_shape )
145+
146+ return tuple (output_shape )
147+
148+
149+ def non_root_ancestors (graph , node ):
150+ return [n for n in nx .topological_sort (graph ) if n in nx .ancestors (graph , node ) and not is_root_node (graph , n )]
151+
152+
153+ def is_root_node (graph , node ):
154+ return len (list (graph .predecessors (node ))) == 0
90155
91156
92157def merge_lists_of_dicts (nested_list : list [list [dict ]]) -> list [dict ]:
93158 """
94159 Merges all combinations of dictionaries from a list of lists.
95160 Equivalent to a Cartesian product of dicts, then flattening.
161+
162+ Examples:
163+ >>> merge_lists_of_dicts([[{"a": 1, "b": 2}], [{"c": 3}, {"d": 4}]])
164+ [{'a': 1, 'b': 2, 'c': 3}, {'a': 1, 'b': 2, 'd': 4}]
96165 """
97166
98167 all_combinations = itertools .product (* nested_list )
99168 return [{k : v for d in combo for k , v in d .items ()} for combo in all_combinations ]
100-
101-
102- def call_sampling_fn (sampling_fn : Callable , inputs : dict ) -> dict [str , Any ]:
103- num_args = len (inspect .signature (sampling_fn ).parameters )
104- if num_args == 0 :
105- return sampling_fn ()
106- else :
107- return sampling_fn (** inputs )
108-
109-
110- def filter_indices (d : dict ) -> dict [str , Any ]:
111- return {k : v for k , v in d .items () if k .startswith ("__" )}
112-
113-
114- def filter_variables (d : dict ) -> dict [str , Any ]:
115- return {k : v for k , v in d .items () if not k .startswith ("__" )}
0 commit comments