Skip to content

Commit 8bd7791

Browse files
authored
[https://nvbugs/5631254][fix] avoid torch.compile for multiple times (#9135)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent e90dbaf commit 8bd7791

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensorrt_llm/_torch/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,12 @@ def maybe_compile(func=None, **compile_kwargs):
339339
"""
340340

341341
def decorator(f):
342+
compiled_func = torch.compile(f, **compile_kwargs)
342343

343344
def wrapper(*args, **kwargs):
344345
if is_piecewise_running():
345346
return f(*args, **kwargs)
346-
return torch.compile(f, **compile_kwargs)(*args, **kwargs)
347+
return compiled_func(*args, **kwargs)
347348

348349
return wrapper
349350

0 commit comments

Comments
 (0)