Skip to content

Commit d8ac4fd

Browse files
allow root node repetitions
1 parent ae105f6 commit d8ac4fd

File tree

1 file changed

+30
-16
lines changed

1 file changed

+30
-16
lines changed

bayesflow/experimental/graphical_simulator/graphical_simulator.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
9393

9494
samples_by_node[node][batch_idx] = node_samples
9595

96+
# collect outputs
9697
output_dict = {}
9798
for node in nx.topological_sort(self.graph):
9899
output_dict.update(self._collect_output(samples_by_node[node]))
@@ -104,24 +105,37 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
104105
def _collect_output(self, samples):
105106
output_dict = {}
106107

108+
# retrieve node and ancestors from internal sample representation
107109
index_entries = [k for k in samples.flat[0][0].keys() if k.startswith("__")]
108110
node = index_entries[-1].removeprefix("__").removesuffix("_idx")
109-
node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0])
110-
ancestors = non_root_ancestors(self.graph, node)
111+
ancestors = sorted_ancestors(self.graph, node)
112+
113+
# build dict of node repetitions
114+
reps = {}
115+
for ancestor in ancestors:
116+
reps[ancestor] = max(s[f"__{ancestor}_idx"] for s in samples.flat[0])
117+
reps[node] = max(s[f"__{node}_idx"] for s in samples.flat[0])
118+
111119
variable_names = self._variable_names(samples)
112120

121+
# collect output for each variable
113122
for variable in variable_names:
114123
output_shape = self._output_shape(samples, variable)
115124
output_dict[variable] = np.empty(output_shape)
116125

117126
for batch_idx in np.ndindex(samples.shape):
118127
for sample in samples[batch_idx]:
119128
idx = [*batch_idx]
129+
130+
# add index elements for ancestors
120131
for ancestor in ancestors:
121-
idx.append(sample[f"__{ancestor}_idx"] - 1)
122-
if not is_root_node(self.graph, node):
123-
if node_reps != 1:
124-
idx.append(sample[f"__{node}_idx"] - 1) # -1 for 0-based indexing
132+
if reps[ancestor] != 1:
133+
idx.append(sample[f"__{ancestor}_idx"] - 1) # -1 for 0-based indexing
134+
135+
# add index elements for node
136+
if reps[node] != 1:
137+
idx.append(sample[f"__{node}_idx"] - 1) # -1 for 0-based indexing
138+
125139
output_dict[variable][tuple(idx)] = sample[variable]
126140

127141
return output_dict
@@ -136,19 +150,19 @@ def _output_shape(self, samples, variable):
136150
# start with batch shape
137151
batch_shape = samples.shape
138152
output_shape = [*batch_shape]
139-
ancestors = non_root_ancestors(self.graph, node)
153+
ancestors = sorted_ancestors(self.graph, node)
140154

141-
# add reps of non root ancestors
155+
# add ancestor reps
142156
for ancestor in ancestors:
143-
reps = max(s[f"__{ancestor}_idx"] for s in samples.flat[0])
144-
output_shape.append(reps)
145-
146-
# add node reps
147-
if not is_root_node(self.graph, node):
148-
node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0])
157+
node_reps = max(s[f"__{ancestor}_idx"] for s in samples.flat[0])
149158
if node_reps != 1:
150159
output_shape.append(node_reps)
151160

161+
# add node reps
162+
node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0])
163+
if node_reps != 1:
164+
output_shape.append(node_reps)
165+
152166
# add variable shape
153167
variable_shape = np.atleast_1d(samples.flat[0][0][variable]).shape
154168
output_shape.extend(variable_shape)
@@ -163,8 +177,8 @@ def _call_sampling_fn(self, sampling_fn, args):
163177
return sampling_fn(**accepted_args)
164178

165179

166-
def non_root_ancestors(graph, node):
167-
return [n for n in nx.topological_sort(graph) if n in nx.ancestors(graph, node) and not is_root_node(graph, n)]
180+
def sorted_ancestors(graph, node):
181+
return [n for n in nx.topological_sort(graph) if n in nx.ancestors(graph, node)]
168182

169183

170184
def is_root_node(graph, node):

0 commit comments

Comments
 (0)