1+ import inspect
12import itertools
23from collections .abc import Callable
34from typing import Any , Optional
@@ -68,7 +69,8 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
6869 if not parent_nodes :
6970 # root node: generate independent samples
7071 node_samples = [
71- {"__batch_idx" : batch_idx , f"__{ node } _idx" : i } | sampling_fn () for i in range (1 , reps + 1 )
72+ {"__batch_idx" : batch_idx , f"__{ node } _idx" : i } | self ._call_sampling_fn (sampling_fn , {})
73+ for i in range (1 , reps + 1 )
7274 ]
7375 else :
7476 # non-root node: depends on parent samples
@@ -79,9 +81,12 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
7981 index_entries = {k : v for k , v in merged .items () if k .startswith ("__" )}
8082 variable_entries = {k : v for k , v in merged .items () if not k .startswith ("__" )}
8183
84+ sampling_fn_input = variable_entries | meta_dict
8285 node_samples .extend (
8386 [
84- index_entries | {f"__{ node } _idx" : i } | sampling_fn (** variable_entries )
87+ index_entries
88+ | {f"__{ node } _idx" : i }
89+ | self ._call_sampling_fn (sampling_fn , sampling_fn_input )
8590 for i in range (1 , reps + 1 )
8691 ]
8792 )
@@ -92,13 +97,16 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
9297 for node in nx .topological_sort (self .graph ):
9398 output_dict .update (self ._collect_output (samples_by_node [node ]))
9499
100+ output_dict .update (meta_dict )
101+
95102 return output_dict
96103
97104 def _collect_output (self , samples ):
98105 output_dict = {}
99106
100107 index_entries = [k for k in samples .flat [0 ][0 ].keys () if k .startswith ("__" )]
101108 node = index_entries [- 1 ].removeprefix ("__" ).removesuffix ("_idx" )
109+ node_reps = max (s [f"__{ node } _idx" ] for s in samples .flat [0 ])
102110 ancestors = non_root_ancestors (self .graph , node )
103111 variable_names = self ._variable_names (samples )
104112
@@ -108,12 +116,13 @@ def _collect_output(self, samples):
108116
109117 for batch_idx in np .ndindex (samples .shape ):
110118 for sample in samples [batch_idx ]:
111- idx = tuple (
112- [* batch_idx ]
113- + [sample [f"__{ a } _idx" ] - 1 for a in ancestors ]
114- + [sample [f"__{ node } _idx" ] - 1 ] # - 1 for 0-based indexing
115- )
116- output_dict [variable ][idx ] = sample [variable ]
119+ idx = [* batch_idx ]
120+ for ancestor in ancestors :
121+ idx .append (sample [f"__{ ancestor } _idx" ] - 1 )
122+ if not is_root_node (self .graph , node ):
123+ if node_reps != 1 :
124+ idx .append (sample [f"__{ node } _idx" ] - 1 ) # -1 for 0-based indexing
125+ output_dict [variable ][tuple (idx )] = sample [variable ]
117126
118127 return output_dict
119128
@@ -137,14 +146,22 @@ def _output_shape(self, samples, variable):
137146 # add node reps
138147 if not is_root_node (self .graph , node ):
139148 node_reps = max (s [f"__{ node } _idx" ] for s in samples .flat [0 ])
140- output_shape .append (node_reps )
149+ if node_reps != 1 :
150+ output_shape .append (node_reps )
141151
142152 # add variable shape
143153 variable_shape = np .atleast_1d (samples .flat [0 ][0 ][variable ]).shape
144154 output_shape .extend (variable_shape )
145155
146156 return tuple (output_shape )
147157
158+ def _call_sampling_fn (self , sampling_fn , args ):
159+ signature = inspect .signature (sampling_fn )
160+ fn_args = signature .parameters
161+ accepted_args = {k : v for k , v in args .items () if k in fn_args }
162+
163+ return sampling_fn (** accepted_args )
164+
148165
149166def non_root_ancestors (graph , node ):
150167 return [n for n in nx .topological_sort (graph ) if n in nx .ancestors (graph , node ) and not is_root_node (graph , n )]
0 commit comments