Skip to content

Commit ec9cf71

Browse files
[None][feat] AutoDeploy: Perf improvement for mamba layers (#8991)
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
1 parent ebdd1cc commit ec9cf71

File tree

6 files changed

+143
-15
lines changed

6 files changed

+143
-15
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ transforms:
165165
############################################################################################
166166
# COMPILE MODEL
167167
############################################################################################
168+
fuse_causal_conv_activation:
169+
stage: compile
168170
compile_model:
169171
stage: compile
170172
run_per_gm: false

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _cuda_cached_causal_conv1d(
112112
dilation: int,
113113
groups: int,
114114
padding_mode: str,
115+
activation: Optional[str],
115116
) -> torch.Tensor:
116117
"""Flattened cached causal conv that respects slot-indexed state caches (CUDA backend).
117118
@@ -175,7 +176,7 @@ def _cuda_cached_causal_conv1d(
175176
cache_indices=cache_indices,
176177
has_initial_state=has_initial_state,
177178
conv_states=conv_state_cache,
178-
activation=None,
179+
activation=activation,
179180
pad_slot_id=PAD_SLOT_ID,
180181
) # (dim, total_prefill_tokens)
181182

@@ -185,24 +186,26 @@ def _cuda_cached_causal_conv1d(
185186

186187
# DECODE: batch update for single-token sequences
187188
if num_decode > 0:
188-
# Use true start offsets for decode tokens (tail after prefills)
189-
decode_idx = seq_start[num_prefill:].to(torch.long)
190-
x_decode = inp_flat.index_select(0, decode_idx) # [num_decode, C_in]
189+
x_decode = inp_flat[
190+
total_prefill_tokens : total_prefill_tokens + num_decode
191+
] # [num_decode, C_in]
191192

192193
y_dec = causal_conv1d_update(
193194
x_decode, # [batch, dim]
194195
conv_state_cache,
195196
w2d,
196197
bias,
197-
activation=None,
198+
activation=activation,
198199
cache_seqlens=None,
199200
conv_state_indices=slot_idx[num_prefill:].to(torch.int32),
200201
pad_slot_id=PAD_SLOT_ID,
201202
)
202203

203204
if y_dec.dim() == 3:
204205
y_dec = y_dec.squeeze(-1)
205-
y_flat.index_copy_(0, decode_idx, y_dec.to(y_flat.dtype))
206+
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
207+
y_dec.to(y_flat.dtype)
208+
)
206209

207210
# Custom op must not return an alias of any input; return a fresh tensor
208211
return y.contiguous().clone()
@@ -227,6 +230,7 @@ def _cuda_cached_causal_conv1d_fake(
227230
dilation: int,
228231
groups: int,
229232
padding_mode: str,
233+
activation: Optional[str],
230234
):
231235
return torch.empty(
232236
input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype
@@ -293,4 +297,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
293297
stride, padding, dilation, groups, padding_mode = extract_op_args(
294298
source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode"
295299
)
296-
return [stride, padding, dilation, groups, padding_mode]
300+
# None is for activation parameter, which may not exist in the source node (added by fusion later)
301+
return [stride, padding, dilation, groups, padding_mode, None]

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,4 +355,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
355355
stride, padding, dilation, groups, padding_mode = extract_op_args(
356356
source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode"
357357
)
358-
return [stride, padding, dilation, groups, padding_mode]
358+
# None is for activation parameter, which may not exist in the source node (added by fusion later)
359+
return [stride, padding, dilation, groups, padding_mode, None]

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,27 +145,26 @@ def _triton_cached_ssm(
145145

146146
dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
147147
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
148-
dt_pre = torch.nn.functional.softplus(dt_hp + dt_bias_hp.to(dtype=dt_hp.dtype))
149-
dt_pre = torch.clamp(dt_pre, time_step_limit[0], time_step_limit[1])
150148
A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size)
151149
D_full = D[..., None].expand(num_heads, head_dim)
152150

153-
dt_bias_zero = torch.zeros_like(dt_bias_hp)
154151
y_dec = selective_state_update(
155152
ssm_state_cache,
156153
x_decode,
157-
dt_pre,
154+
dt_hp,
158155
A_full,
159156
B_decode,
160157
C_decode,
161158
D=D_full,
162159
z=None,
163-
dt_bias=dt_bias_zero,
164-
dt_softplus=False,
160+
dt_bias=dt_bias_hp,
161+
dt_softplus=True,
165162
state_batch_indices=slot_idx_decode,
166163
) # [nd, H, D]
167164

168-
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode] = y_dec.to(y_flat.dtype)
165+
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
166+
y_dec.to(y_flat.dtype)
167+
)
169168

170169
return y
171170

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""Fusion transform for fusing activation functions into causal_conv1d operations."""
2+
3+
from typing import List, Optional, Tuple
4+
5+
import torch
6+
import torch.nn.functional as F
7+
from torch.fx import GraphModule, Node
8+
9+
from ...models.factory import ModelFactory
10+
from ...shim.interface import CachedSequenceInterface
11+
from ...utils.node_utils import is_op
12+
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
13+
14+
15+
def _match_causal_conv_activation_pattern(
16+
graph: GraphModule,
17+
target_op,
18+
) -> List[Tuple[Node, Node, str]]:
19+
"""
20+
Match the causal_conv + activation pattern in the graph.
21+
22+
The pattern corresponds to:
23+
conv_out = cuda_cached_causal_conv1d(...)
24+
out = activation(conv_out)
25+
26+
Args:
27+
graph: The graph module to search
28+
target_op: The target causal conv op to match
29+
30+
Returns:
31+
A list of tuples (conv_node, activation_node, activation_name) for each match
32+
"""
33+
matches = []
34+
35+
for node in graph.nodes:
36+
if not is_op(node, target_op):
37+
continue
38+
39+
# Check if this node has exactly one user and it's an activation
40+
if len(node.users) != 1:
41+
continue
42+
43+
activation_node = list(node.users.keys())[0]
44+
if activation_node.op != "call_function":
45+
continue
46+
47+
# Detect activation type
48+
activation_name: Optional[str] = None
49+
if activation_node.target in (torch.ops.aten.silu.default, F.silu):
50+
activation_name = "silu"
51+
# Can extend to support more activations here:
52+
# elif activation_node.target in (torch.ops.aten.gelu.default, F.gelu):
53+
# activation_name = "gelu"
54+
55+
if activation_name is not None:
56+
matches.append((node, activation_node, activation_name))
57+
58+
return matches
59+
60+
61+
@TransformRegistry.register("fuse_causal_conv_activation")
62+
class FuseCausalConvActivation(BaseTransform):
63+
"""Fuses activation functions into cached CUDA causal_conv1d operations.
64+
65+
This transform detects patterns like:
66+
conv_out = cuda_cached_causal_conv1d(...)
67+
out = silu(conv_out)
68+
69+
And replaces them with:
70+
out = cuda_cached_causal_conv1d(..., activation="silu")
71+
72+
This optimization allows the backend CUDA kernels to fuse the activation,
73+
reducing memory bandwidth and improving performance.
74+
75+
Note: This runs AFTER insert_cached_causal_conv, so it operates on the
76+
cached CUDA operations, not the uncached torch operations.
77+
"""
78+
79+
def _apply(
80+
self,
81+
gm: GraphModule,
82+
cm: CachedSequenceInterface,
83+
factory: ModelFactory,
84+
shared_config: SharedConfig,
85+
) -> Tuple[GraphModule, TransformInfo]:
86+
graph = gm.graph
87+
88+
# Step 1: Identify causal_conv + activation pattern
89+
matches = _match_causal_conv_activation_pattern(
90+
graph,
91+
target_op=torch.ops.auto_deploy.cuda_cached_causal_conv1d,
92+
)
93+
94+
# Step 2: Replace matched patterns with fused version
95+
for conv_node, activation_node, activation_name in matches:
96+
with graph.inserting_after(conv_node):
97+
# Create new call with fused activation
98+
# Replace the last arg (activation=None) with activation_name
99+
new_args = list(conv_node.args[:-1]) + [activation_name]
100+
fused_node = graph.call_function(
101+
torch.ops.auto_deploy.cuda_cached_causal_conv1d,
102+
args=tuple(new_args),
103+
)
104+
105+
# Replace all uses of activation_node with fused_node
106+
activation_node.replace_all_uses_with(fused_node)
107+
108+
# Remove the old nodes
109+
graph.erase_node(activation_node)
110+
graph.erase_node(conv_node)
111+
112+
gm.recompile()
113+
114+
info = TransformInfo(
115+
skipped=False,
116+
num_matches=len(matches),
117+
is_clean=False,
118+
has_valid_shapes=False,
119+
)
120+
return gm, info

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def test_generate_only_with_slot_mapping_cuda(conv_env):
8282
d,
8383
g,
8484
pm,
85+
None,
8586
)
8687

8788
assert y.shape == (batch, seq, c)

0 commit comments

Comments
 (0)