Skip to content

Commit 53491ff

Browse files
[#9023][feat] reduce AD graph optimization time for non-participating passes (#9024)
Shorten AD graph optimization by 30% (measured on Nemotron-6): A bug in the transformation interface marked all passes as not clean, regardless of what was reported by the transformation Fix how the optimization passes report the results of their actions. Many passes report that the graph is not clean even when they didn't participate in the optimization. Each graph cleaning invocation can take several seconds. Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
1 parent cdde15b commit 53491ff

File tree

15 files changed

+86
-44
lines changed

15 files changed

+86
-44
lines changed

tensorrt_llm/_torch/auto_deploy/transform/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,14 +410,14 @@ def _apply_per_gm_or_whole_model(
410410
return self._apply_to_full_model(mod, cm, factory, shared_config)
411411

412412
# just run it on first graph module we are encountering for now...
413-
info = TransformInfo()
413+
info = None
414414
for k, graph_sub in named_graphmodules(mod):
415415
graph_sub, info_apply = self._apply(graph_sub, cm, factory, shared_config)
416416
if k == "":
417417
mod = graph_sub
418418
else:
419419
mod.set_submodule(k, graph_sub)
420-
info = info & info_apply
420+
info = info & info_apply if info is not None else info_apply
421421
return mod, info
422422

423423
@final

tensorrt_llm/_torch/auto_deploy/transform/library/attention.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ def register_repeat_kv(patterns: ADPatternMatcherPass):
304304
info = TransformInfo(
305305
skipped=False,
306306
num_matches=num_kv_patterns,
307-
is_clean=False,
308-
has_valid_shapes=False,
307+
is_clean=num_kv_patterns == 0,
308+
has_valid_shapes=num_kv_patterns == 0,
309309
)
310310

311311
return gm, info
@@ -333,8 +333,8 @@ def register_eager_attention(patterns: ADPatternMatcherPass):
333333
info = TransformInfo(
334334
skipped=False,
335335
num_matches=num_eager_patterns,
336-
is_clean=False,
337-
has_valid_shapes=False,
336+
is_clean=num_eager_patterns == 0,
337+
has_valid_shapes=num_eager_patterns == 0,
338338
)
339339

340340
return gm, info
@@ -647,8 +647,8 @@ def register_sdpa_to_torch_attention(patterns: ADPatternMatcherPass):
647647
info = TransformInfo(
648648
skipped=False,
649649
num_matches=num_patterns,
650-
is_clean=False,
651-
has_valid_shapes=False,
650+
is_clean=num_patterns == 0,
651+
has_valid_shapes=num_patterns == 0,
652652
)
653653
return gm, info
654654

@@ -685,8 +685,8 @@ def register_repeat_kv_with_torch_attention(patterns: ADPatternMatcherPass):
685685
info = TransformInfo(
686686
skipped=False,
687687
num_matches=num_patterns,
688-
is_clean=False,
689-
has_valid_shapes=False,
688+
is_clean=num_patterns == 0,
689+
has_valid_shapes=num_patterns == 0,
690690
)
691691
return gm, info
692692

@@ -870,7 +870,7 @@ def _apply(
870870
info = TransformInfo(
871871
skipped=False,
872872
num_matches=num_matches,
873-
is_clean=False,
874-
has_valid_shapes=False,
873+
is_clean=num_matches == 0,
874+
has_valid_shapes=num_matches == 0,
875875
)
876876
return gm, info

tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def _apply(
5252
object.__setattr__(vr, "upper", max_total)
5353

5454
# store info object about the transform
55-
info = TransformInfo(skipped=False, num_matches=len(vrs))
55+
info = TransformInfo(
56+
skipped=False,
57+
num_matches=len(vrs),
58+
is_clean=len(vrs) == 0,
59+
has_valid_shapes=len(vrs) == 0,
60+
)
5661

5762
return gm, info

tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def _apply(
5151
num_matches += 1
5252

5353
# store info object about the transform
54-
info = TransformInfo(skipped=False, num_matches=num_matches)
54+
info = TransformInfo(
55+
skipped=False,
56+
num_matches=num_matches,
57+
is_clean=num_matches == 0,
58+
has_valid_shapes=num_matches == 0,
59+
)
5560

5661
return gm, info

tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def _apply(
4848
num_matches += 1
4949

5050
# store info object about the transform
51-
info = TransformInfo(skipped=False, num_matches=num_matches)
51+
info = TransformInfo(
52+
skipped=False,
53+
num_matches=num_matches,
54+
is_clean=num_matches == 0,
55+
has_valid_shapes=num_matches == 0,
56+
)
5257

5358
return gm, info

tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def _apply(
112112
num_matches = patterns.apply(gm.graph)
113113

114114
info = TransformInfo(
115-
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False
115+
skipped=False,
116+
num_matches=num_matches,
117+
is_clean=num_matches == 0,
118+
has_valid_shapes=num_matches == 0,
116119
)
117120
return gm, info

tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def _apply(
117117
info = TransformInfo(
118118
skipped=False,
119119
num_matches=len(nodes_to_eliminate),
120-
is_clean=False,
121-
has_valid_shapes=False,
120+
is_clean=len(nodes_to_eliminate) == 0,
121+
has_valid_shapes=len(nodes_to_eliminate) == 0,
122122
)
123123

124124
return gm, info

tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,8 @@ def _apply(
291291
info = TransformInfo(
292292
skipped=(cnt == 0),
293293
num_matches=cnt,
294-
is_clean=False,
295-
has_valid_shapes=False,
294+
is_clean=cnt == 0,
295+
has_valid_shapes=cnt == 0,
296296
)
297297
return gm, info
298298

@@ -333,7 +333,7 @@ def _apply(
333333
info = TransformInfo(
334334
skipped=(cnt == 0),
335335
num_matches=cnt,
336-
is_clean=False,
337-
has_valid_shapes=False,
336+
is_clean=(cnt == 0),
337+
has_valid_shapes=(cnt == 0),
338338
)
339339
return gm, info

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,10 @@ def _apply(
510510
num_moe_patterns += 1
511511

512512
info = TransformInfo(
513-
skipped=False, num_matches=num_moe_patterns, is_clean=False, has_valid_shapes=False
513+
skipped=False,
514+
num_matches=num_moe_patterns,
515+
is_clean=num_moe_patterns == 0,
516+
has_valid_shapes=num_moe_patterns == 0,
514517
)
515518
return gm, info
516519

@@ -754,7 +757,10 @@ def _apply(
754757
fused_key_counter = _insert_fused_moe_ops(gm)
755758

756759
info = TransformInfo(
757-
skipped=False, num_matches=fused_key_counter, is_clean=False, has_valid_shapes=False
760+
skipped=False,
761+
num_matches=fused_key_counter,
762+
is_clean=fused_key_counter == 0,
763+
has_valid_shapes=fused_key_counter == 0,
758764
)
759765
return gm, info
760766

@@ -779,7 +785,7 @@ def _apply(
779785
info = TransformInfo(
780786
skipped=(fused_key_counter == 0),
781787
num_matches=fused_key_counter,
782-
is_clean=False,
783-
has_valid_shapes=False,
788+
is_clean=fused_key_counter == 0,
789+
has_valid_shapes=fused_key_counter == 0,
784790
)
785791
return gm, info

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,10 @@ def _apply_fusion_pass(
215215

216216
torch.cuda.empty_cache()
217217
return gm, TransformInfo(
218-
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False
218+
skipped=False,
219+
num_matches=num_matches,
220+
is_clean=num_matches == 0,
221+
has_valid_shapes=num_matches == 0,
219222
)
220223

221224

@@ -252,7 +255,10 @@ def _apply(
252255
torch.cuda.empty_cache()
253256

254257
info = TransformInfo(
255-
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False
258+
skipped=False,
259+
num_matches=num_matches,
260+
is_clean=num_matches == 0,
261+
has_valid_shapes=num_matches == 0,
256262
)
257263
return gm, info
258264

0 commit comments

Comments
 (0)