Skip to content

Commit f40e1f7

Browse files
authored
[https://nvbugs/5625972][fix] Add context manager to fix FakeTensorProp (#9047)
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent 50c4863 commit f40e1f7

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

tensorrt_llm/_torch/auto_deploy/utils/_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn as nn
88
import torch.utils._pytree as pytree
99
from torch import fx
10+
from torch._dispatch.python import enable_python_dispatcher
1011
from torch._export.utils import _detect_fake_mode_from_gm
1112
from torch._prims_common import DeviceLikeType
1213
from torch._subclasses import FakeTensor, FakeTensorMode
@@ -210,7 +211,8 @@ def _run_shape_prop_single_gm(
210211

211212
# run shape propagation if we have all the fake tensors
212213
if all(inp is not _NO_VAL for inp in inps):
213-
FakeTensorProp(gm, fake_mode).propagate(*inps)
214+
with enable_python_dispatcher():
215+
FakeTensorProp(gm, fake_mode).propagate(*inps)
214216
else:
215217
ad_logger.warning("No fake tensors and no args available for shape propagation")
216218

tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,14 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
114114
},
115115
},
116116
),
117-
pytest.param(
117+
(
118118
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
119119
{
120120
"transforms": {
121121
"insert_cached_attention": {"backend": "flashinfer"},
122122
"compile_model": {"backend": "torch-opt"},
123123
},
124124
},
125-
marks=pytest.mark.skip(reason="https://nvbugs/5625972"),
126125
),
127126
(
128127
"meta-llama/Llama-4-Scout-17B-16E-Instruct",

0 commit comments

Comments
 (0)