Skip to content

Commit b2d94fe

Browse files
authored
Add ort-specific passes to ort_fusion (#2532)
There are specific optimization needs from ort shipping models.
1 parent 2cc2502 commit b2d94fe

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,18 @@ def optimize_for_ort(
140140
)
141141
# Apply the ORT pattern rewrite rules.
142142
rewrite(model, ORT_PATTERN_REWRITE_RULES)
143+
144+
passes = ir.passes.Sequential(
145+
# Apply the ORT optimization passes.
146+
# https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L172
147+
common_passes.ClearMetadataAndDocStringPass(),
148+
# https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L139
149+
common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1),
150+
common_passes.RemoveInitializersFromInputsPass(),
151+
common_passes.ShapeInferencePass(),
152+
common_passes.CheckerPass(),
153+
)
154+
assert passes.in_place
155+
result = passes(model)
156+
assert result.model is model
143157
return model, fusion_count

0 commit comments

Comments
 (0)