Skip to content

Commit 3846705

Browse files
authored
Clear initializers in constant folding pass (#2668)
Clear unused initializers on the fly to prevent memory usage jump due to intermediate folded tensors. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent ee9a6e8 commit 3846705

File tree

3 files changed

+53
-7
lines changed

3 files changed

+53
-7
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.5.5
1+
0.5.6

onnxscript/optimizer/_constant_folding.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,10 +1256,20 @@ def replace_node(
12561256

12571257
# Record the names of the values that has contributed to the replacement
12581258
_record_contributing_values(node, replacement)
1259+
1260+
# Obtain the list of non-None inputs to the node before it is cleared by
1261+
# replace_nodes_and_values to check for unused initializers later.
1262+
node_inputs = [v for v in node.inputs if v is not None]
1263+
12591264
ir.convenience.replace_nodes_and_values(
12601265
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
12611266
)
12621267

1268+
if isinstance(root, ir.Graph):
1269+
# The old node should now be detached from the graph
1270+
assert node.graph is None
1271+
_clear_unused_initializers(node_inputs)
1272+
12631273
self._modified = True
12641274

12651275
# TODO: what about new opset_imports?
@@ -1336,6 +1346,19 @@ def _sym_value_can_replace_graph_output(
13361346
return True
13371347

13381348

1349+
def _clear_unused_initializers(values: Sequence[ir.Value]) -> None:
1350+
# Detach all inputs to the node, then check for unused initializers
1351+
for value in values:
1352+
if value is None or not value.is_initializer():
1353+
continue
1354+
1355+
if not value.uses():
1356+
assert value.is_initializer()
1357+
assert value.graph is not None
1358+
assert value.name is not None
1359+
value.graph.initializers.pop(value.name)
1360+
1361+
13391362
@dataclasses.dataclass
13401363
class FoldConstantsResult(ir.passes.PassResult):
13411364
symbolic_value_map: dict[ir.Value, SymbolicValue]

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,20 @@
1414

1515

1616
class FoldConstantsTest(unittest.TestCase):
17-
def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs):
17+
def _fold(
18+
self,
19+
model: ir.Model | str,
20+
onnx_shape_inference: bool = False,
21+
dce: bool = True,
22+
**kwargs,
23+
):
1824
if isinstance(model, str):
1925
model = ir.from_onnx_text(model)
2026
_constant_folding.fold_constants(
2127
model, onnx_shape_inference=onnx_shape_inference, **kwargs
2228
)
23-
optimizer.remove_unused_nodes(model)
29+
if dce:
30+
optimizer.remove_unused_nodes(model)
2431
# Ensure the model is valid after optimization
2532
onnx.checker.check_model(ir.serde.serialize_model(model))
2633
return model
@@ -50,9 +57,16 @@ def test_fold_cast_like(self):
5057
}
5158
"""
5259

53-
optimized = self._fold(model)
54-
self.assertEqual(len(optimized.graph), 1)
60+
optimized = self._fold(model, dce=False)
5561
self.assertIn("four", optimized.graph.initializers)
62+
np.testing.assert_equal(
63+
optimized.graph.initializers["four"].const_value, np.array(4.0)
64+
)
65+
# Intermediates should be removed
66+
self.assertNotIn("two_float", optimized.graph.initializers)
67+
68+
optimized = self._fold(model, dce=True)
69+
self.assertEqual(len(optimized.graph), 1)
5670

5771
def test_fold_shape(self):
5872
model = """
@@ -66,9 +80,18 @@ def test_fold_shape(self):
6680
}
6781
"""
6882

69-
optimized = self._fold(model)
70-
self.assertEqual(len(optimized.graph), 1)
83+
optimized = self._fold(model, dce=False)
7184
self.assertIn("four", optimized.graph.initializers)
85+
np.testing.assert_equal(
86+
optimized.graph.initializers["four"].const_value, np.array(4.0)
87+
)
88+
# Intermediates should be removed
89+
self.assertNotIn("two_float", optimized.graph.initializers)
90+
self.assertNotIn("rank", optimized.graph.initializers)
91+
self.assertNotIn("shape", optimized.graph.initializers)
92+
93+
optimized = self._fold(model, dce=True)
94+
self.assertEqual(len(optimized.graph), 1)
7295

7396
def test_fold_shape_slice(self):
7497
model = """

0 commit comments

Comments
 (0)