@@ -93,6 +93,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
9393
9494 samples_by_node [node ][batch_idx ] = node_samples
9595
96+ # collect outputs
9697 output_dict = {}
9798 for node in nx .topological_sort (self .graph ):
9899 output_dict .update (self ._collect_output (samples_by_node [node ]))
@@ -104,24 +105,37 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
104105 def _collect_output (self , samples ):
105106 output_dict = {}
106107
108+ # retrieve node and ancestors from internal sample representation
107109 index_entries = [k for k in samples .flat [0 ][0 ].keys () if k .startswith ("__" )]
108110 node = index_entries [- 1 ].removeprefix ("__" ).removesuffix ("_idx" )
109- node_reps = max (s [f"__{ node } _idx" ] for s in samples .flat [0 ])
110- ancestors = non_root_ancestors (self .graph , node )
111+ ancestors = sorted_ancestors (self .graph , node )
112+
113+ # build dict of node repetitions
114+ reps = {}
115+ for ancestor in ancestors :
116+ reps [ancestor ] = max (s [f"__{ ancestor } _idx" ] for s in samples .flat [0 ])
117+ reps [node ] = max (s [f"__{ node } _idx" ] for s in samples .flat [0 ])
118+
111119 variable_names = self ._variable_names (samples )
112120
121+ # collect output for each variable
113122 for variable in variable_names :
114123 output_shape = self ._output_shape (samples , variable )
115124 output_dict [variable ] = np .empty (output_shape )
116125
117126 for batch_idx in np .ndindex (samples .shape ):
118127 for sample in samples [batch_idx ]:
119128 idx = [* batch_idx ]
129+
130+ # add index elements for ancestors
120131 for ancestor in ancestors :
121- idx .append (sample [f"__{ ancestor } _idx" ] - 1 )
122- if not is_root_node (self .graph , node ):
123- if node_reps != 1 :
124- idx .append (sample [f"__{ node } _idx" ] - 1 ) # -1 for 0-based indexing
132+ if reps [ancestor ] != 1 :
133+ idx .append (sample [f"__{ ancestor } _idx" ] - 1 ) # -1 for 0-based indexing
134+
135+ # add index elements for node
136+ if reps [node ] != 1 :
137+ idx .append (sample [f"__{ node } _idx" ] - 1 ) # -1 for 0-based indexing
138+
125139 output_dict [variable ][tuple (idx )] = sample [variable ]
126140
127141 return output_dict
@@ -136,19 +150,19 @@ def _output_shape(self, samples, variable):
136150 # start with batch shape
137151 batch_shape = samples .shape
138152 output_shape = [* batch_shape ]
139- ancestors = non_root_ancestors (self .graph , node )
153+ ancestors = sorted_ancestors (self .graph , node )
140154
141- # add reps of non root ancestors
155+ # add ancestor reps
142156 for ancestor in ancestors :
143- reps = max (s [f"__{ ancestor } _idx" ] for s in samples .flat [0 ])
144- output_shape .append (reps )
145-
146- # add node reps
147- if not is_root_node (self .graph , node ):
148- node_reps = max (s [f"__{ node } _idx" ] for s in samples .flat [0 ])
157+ node_reps = max (s [f"__{ ancestor } _idx" ] for s in samples .flat [0 ])
149158 if node_reps != 1 :
150159 output_shape .append (node_reps )
151160
161+ # add node reps
162+ node_reps = max (s [f"__{ node } _idx" ] for s in samples .flat [0 ])
163+ if node_reps != 1 :
164+ output_shape .append (node_reps )
165+
152166 # add variable shape
153167 variable_shape = np .atleast_1d (samples .flat [0 ][0 ][variable ]).shape
154168 output_shape .extend (variable_shape )
@@ -163,8 +177,8 @@ def _call_sampling_fn(self, sampling_fn, args):
163177 return sampling_fn (** accepted_args )
164178
165179
166- def non_root_ancestors (graph , node ):
167- return [n for n in nx .topological_sort (graph ) if n in nx .ancestors (graph , node ) and not is_root_node ( graph , n ) ]
180+ def sorted_ancestors (graph , node ):
181+ return [n for n in nx .topological_sort (graph ) if n in nx .ancestors (graph , node )]
168182
169183
170184def is_root_node (graph , node ):
0 commit comments