Skip to content

Commit 0b81173

Browse files
authored
[TRTLLM-9259][perf] Use torch.compile to fuse copy + layernorm within the LayerNorm module (#9052)
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
1 parent aca5609 commit 0b81173

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tensorrt_llm/_torch/modules/layer_norm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch
1919
from torch import nn
2020

21+
from ..utils import maybe_compile
22+
2123

2224
class LayerNorm(nn.Module):
2325
"""Layer normalization module with configurable weight and bias parameters.
@@ -65,6 +67,7 @@ def __init__(
6567
persistent=False)
6668
self.variance_epsilon = eps
6769

70+
@maybe_compile(dynamic=True)
6871
def forward(
6972
self,
7073
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)