diff --git a/tensorrt_llm/_torch/modules/layer_norm.py b/tensorrt_llm/_torch/modules/layer_norm.py index 8db0af7e97d..811067952c5 100644 --- a/tensorrt_llm/_torch/modules/layer_norm.py +++ b/tensorrt_llm/_torch/modules/layer_norm.py @@ -18,6 +18,8 @@ import torch from torch import nn +from ..utils import maybe_compile + class LayerNorm(nn.Module): """Layer normalization module with configurable weight and bias parameters. @@ -65,6 +67,7 @@ def __init__( persistent=False) self.variance_epsilon = eps + @maybe_compile(dynamic=True) def forward( self, hidden_states: torch.Tensor,