Skip to content

Commit d016612

Browse files
committed
allow removing nodes w/o i or o
1 parent 702e4eb commit d016612

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

hls4ml/model/graph.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,6 @@ def remove_node(self, node, rewire=True):
533533
inputs = [inp for inp in node.inputs if inp]
534534
outputs = [outp for outp in node.outputs if outp]
535535

536-
inp_var = node.get_input_variable()
537-
out_var = node.get_output_variable()
538-
539-
assert np.prod(inp_var.shape) == np.prod(out_var.shape), f'Input and output shapes do not match for {node.name}'
540-
541536
if len(inputs) > 1 or len(outputs) > 1:
542537
raise Exception('Cannot delete a node with multiple inputs/outputs')
543538

@@ -547,6 +542,14 @@ def remove_node(self, node, rewire=True):
547542
self.outputs = [inputs[0] if name == node.name else name for name in self.outputs]
548543

549544
if len(outputs) == 1 and len(inputs) == 1:
545+
inp_var = node.get_input_variable()
546+
out_var = node.get_output_variable()
547+
548+
# fmt: off
549+
assert (np.prod(inp_var.shape) == np.prod(out_var.shape)), \
550+
f'Input and output shapes do not match for {node.name}: {inp_var.shape} -> {out_var.shape}'
551+
# fmt: on
552+
550553
next_nodes = [x for x in self.graph.values() if node.outputs[0] in x.inputs]
551554
for next_node in next_nodes:
552555
# Connect inputs -> next

0 commit comments

Comments
 (0)