Skip to content

Commit 6e8037a

Browse files
suyogguptaWanli-Jiang
authored andcommitted
fix prefill
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> fix typo Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
1 parent 98b10f7 commit 6e8037a

File tree

13 files changed

+155
-55
lines changed

13 files changed

+155
-55
lines changed

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def forward(self, *args, **kwargs) -> Any:
175175

176176
# retrieve output from buffer, cut to batch size, and unflatten
177177
bs = args_batched[0].shape[0]
178-
out_flat = [o_b[:bs].detach().clone() for o_b in self._out_buffer_flat]
178+
out_flat = [o_b[:bs] for o_b in self._out_buffer_flat]
179179
return self._out_spec.unflatten(out_flat)
180180

181181

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
page_size: int = 0,
117117
max_num_tokens: Optional[int] = None,
118118
vocab_size_padded: Optional[int] = None,
119+
chunk_size: Optional[int] = None,
119120
):
120121
"""Initialize the SequenceInfo object.
121122
@@ -142,7 +143,7 @@ def __init__(
142143
self.max_batch_size = max_batch_size
143144
self.page_size = page_size if page_size > 0 else max_seq_len
144145
self.vocab_size_padded = vocab_size_padded
145-
146+
self.chunk_size = chunk_size
146147
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
147148
# (max_batch_size, max_seq_len) input in trtllm runtime.
148149
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
@@ -193,7 +194,7 @@ def __init__(
193194
"input_pos": torch.empty(self.max_batch_size, dtype=torch.int),
194195
"cache_loc": torch.empty(max_num_cache_loc_assignments, dtype=torch.int),
195196
"pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int),
196-
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.int),
197+
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.long),
197198
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
198199
"_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int),
199200
}
@@ -203,7 +204,7 @@ def __init__(
203204
# NOTE: order of keys is relevant here!
204205
self._uncached_arg_names = ("input_ids", "position_ids")
205206
self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq", "slot_idx")
206-
self._cached_constants = ("page_size",)
207+
self._cached_constants = ("page_size", "chunk_size")
207208
############################################################################################
208209

209210
# EXTRA TENSOR FIELDS ######################################################################

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def prepare_flashinfer_metadata(
162162
pages_per_seq: torch.Tensor,
163163
slot_idx: torch.Tensor,
164164
page_size: int,
165+
chunk_size: int,
165166
) -> List[torch.Tensor]:
166167
"""Prepare metadata for flashinfer attention.
167168
@@ -213,7 +214,7 @@ def prepare_flashinfer_metadata(
213214
# As SequenceInfo._get_sanitized_num_sequences could break in fake mode
214215
@prepare_flashinfer_metadata.register_fake
215216
def prepare_flashinfer_metadata_fake(
216-
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
217+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
217218
):
218219
seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
219220
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import triton
1212
import triton.language as tl
1313

14+
from tensorrt_llm._utils import nvtx_range
15+
1416

1517
@triton.jit
1618
def _write_zeros_to_output(
@@ -304,6 +306,7 @@ def _default_kernel_config(M: int, E: int, N: int, K: int, top_k: int) -> dict:
304306
}
305307

306308

309+
@nvtx_range("triton_moe_pack_routed_tokens")
307310
def _pack_routed_tokens(
308311
topk_ids: torch.Tensor,
309312
M: int,

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def cuda_causal_conv_prepare_metadata(
6161
pages_per_seq: torch.Tensor,
6262
slot_idx: torch.Tensor,
6363
page_size: int,
64+
chunk_size: int,
6465
) -> List[torch.Tensor]:
6566
"""Prepare metadata for cached causal conv (CUDA backend).
6667
@@ -81,7 +82,7 @@ def cuda_causal_conv_prepare_metadata(
8182

8283
@cuda_causal_conv_prepare_metadata.register_fake
8384
def cuda_causal_conv_prepare_metadata_fake(
84-
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
85+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
8586
):
8687
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
8788
num_seq = len(seq_len_sanitized)
@@ -182,7 +183,7 @@ def _cuda_cached_causal_conv1d(
182183

183184
# Scatter outputs back to y
184185
y_prefill = y_varlen.transpose(0, 1) # [total_prefill_tokens, C_out]
185-
y_flat[:total_prefill_tokens].copy_(y_prefill.to(y_flat.dtype))
186+
y_flat[:total_prefill_tokens].copy_(y_prefill)
186187

187188
# DECODE: batch update for single-token sequences
188189
if num_decode > 0:
@@ -203,12 +204,10 @@ def _cuda_cached_causal_conv1d(
203204

204205
if y_dec.dim() == 3:
205206
y_dec = y_dec.squeeze(-1)
206-
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
207-
y_dec.to(y_flat.dtype)
208-
)
207+
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec)
209208

210209
# Custom op must not return an alias of any input; return a fresh tensor
211-
return y.contiguous().clone()
210+
return y
212211

213212

214213
@_cuda_cached_causal_conv1d.register_fake

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def _torch_ssm_prepare_metadata(
120120
pages_per_seq: torch.Tensor,
121121
slot_idx: torch.Tensor,
122122
page_size: int,
123+
chunk_size: int,
123124
) -> List[torch.Tensor]:
124125
"""Prepare metadata for cached SSM transform.
125126
@@ -143,7 +144,7 @@ def _torch_ssm_prepare_metadata(
143144

144145
@_torch_ssm_prepare_metadata.register_fake
145146
def _torch_ssm_prepare_metadata_fake(
146-
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
147+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
147148
):
148149
# Use the same sanitization logic to determine sizes in fake mode
149150
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py

Lines changed: 120 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,94 @@
2424
)
2525

2626

27+
@torch.library.custom_op("auto_deploy::triton_ssm_prepare_metadata", mutates_args=())
28+
def _triton_ssm_prepare_metadata(
29+
position_ids: torch.Tensor,
30+
seq_len: torch.Tensor,
31+
input_pos: torch.Tensor,
32+
cache_loc: torch.Tensor,
33+
pages_per_seq: torch.Tensor,
34+
slot_idx: torch.Tensor,
35+
page_size: int,
36+
chunk_size: int,
37+
) -> List[torch.Tensor]:
38+
"""Prepare metadata for cached SSM transform.
39+
40+
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
41+
"""
42+
# Determine number of active sequences and compute seq_start boundaries
43+
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
44+
num_seq = len(seq_len_sanitized)
45+
46+
seq_start = torch.zeros_like(seq_len_sanitized)
47+
if num_seq > 1:
48+
seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0)
49+
50+
# Truncate slot indices to match active sequences
51+
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)
52+
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
53+
# reference implementation to support chunked prefill.
54+
use_initial_states = input_pos > 0
55+
56+
device = position_ids.device
57+
58+
chunk_indices = torch.zeros(num_seq, dtype=torch.int32, device=device)
59+
chunk_offsets = torch.zeros(num_seq, dtype=torch.int32, device=device)
60+
cu_seqlens = torch.zeros(num_seq + 1, dtype=torch.int32, device=device)
61+
_, s = position_ids.shape[:2]
62+
if s > 1:
63+
# only compute chunk indices and offsets for prefill.
64+
prefill_mask = seq_len_sanitized > 1
65+
num_prefill = int(prefill_mask.sum().item())
66+
num_prefill_tokens = int(seq_len_sanitized[:num_prefill].sum().item())
67+
num_decode = num_seq - num_prefill
68+
cu_seqlens = torch.cat(
69+
[
70+
torch.zeros(1, dtype=torch.int32, device=device),
71+
torch.cumsum(seq_len_sanitized[:num_prefill].to(torch.int32), dim=0),
72+
],
73+
dim=0,
74+
)
75+
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
76+
else:
77+
num_prefill = 0
78+
num_prefill_tokens = 0
79+
num_decode = num_seq
80+
batch_info_tensor = torch.tensor(
81+
[num_prefill, num_prefill_tokens, num_decode], dtype=torch.int32
82+
) # host tensor
83+
84+
return (
85+
seq_len_sanitized,
86+
seq_start,
87+
slot_idx_sanitized,
88+
use_initial_states,
89+
cu_seqlens,
90+
chunk_indices,
91+
chunk_offsets,
92+
batch_info_tensor,
93+
)
94+
95+
96+
@_triton_ssm_prepare_metadata.register_fake
97+
def _triton_ssm_prepare_metadata_fake(
98+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
99+
):
100+
# Use the same sanitization logic to determine sizes in fake mode
101+
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
102+
num_seq = len(seq_len_sanitized)
103+
return (
104+
torch.empty_like(seq_len_sanitized),
105+
torch.empty_like(seq_len_sanitized),
106+
torch.empty(num_seq, dtype=torch.long, device=slot_idx.device),
107+
torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device),
108+
torch.empty(num_seq + 1, dtype=torch.int32, device=slot_idx.device), # cu seqlens
109+
torch.empty(num_seq, dtype=torch.int32, device=slot_idx.device), # chunk indices
110+
torch.empty(num_seq, dtype=torch.int32, device=slot_idx.device), # chunk offsets
111+
torch.empty(2, dtype=torch.int32), # batch info tensor
112+
)
113+
114+
27115
@torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={})
28116
def _triton_cached_ssm(
29117
# INPUTS (dense but may be flattened across sequences)
@@ -39,6 +127,10 @@ def _triton_cached_ssm(
39127
seq_start: torch.Tensor, # [num_seq]
40128
slot_idx: torch.Tensor, # [num_seq]
41129
use_initial_states: torch.Tensor, # [num_seq]
130+
cu_seqlens: torch.Tensor, # [num_seq + 1]
131+
chunk_indices: torch.Tensor, # [num_seq + 1]
132+
chunk_offsets: torch.Tensor, # [num_seq + 1]
133+
batch_info_tensor: torch.Tensor, # [2]
42134
# CACHES
43135
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
44136
# CONSTANTS
@@ -51,8 +143,7 @@ def _triton_cached_ssm(
51143
- Prefill: run one varlen combined scan over concatenated prefill tokens and update final states per slot.
52144
- Decode: batch single-token updates with selective_state_update and update states per slot.
53145
"""
54-
b, s = hidden_states.shape[:2]
55-
num_seq = seq_len.shape[0]
146+
b, s, num_heads, head_dim = hidden_states.shape
56147
# Flatten tokens for indexing/scatter
57148
bs = b * s
58149
device = hidden_states.device
@@ -64,39 +155,23 @@ def _triton_cached_ssm(
64155
y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
65156
y_flat = y.view(bs, *y.shape[2:])
66157

67-
num_heads = hidden_states.shape[2]
68-
head_dim = hidden_states.shape[3]
69158
ssm_state_size = B.shape[3]
70159

71-
if s == 1:
72-
num_prefill = 0
73-
num_decode = num_seq
74-
else:
75-
prefill_mask = seq_len > 1
76-
num_prefill = int(prefill_mask.sum().item())
77-
num_decode = num_seq - num_prefill
160+
[num_prefill, num_prefill_tokens, num_decode] = batch_info_tensor.tolist()
78161

79162
# Prefill: concatenate tokens at the front and run combined scan
80163
if num_prefill > 0:
81-
seq_len_prefill = seq_len[:num_prefill].to(torch.int32)
82-
total_prefill_tokens = int(seq_len_prefill.sum().item())
164+
seq_len_prefill = seq_len[:num_prefill]
83165

84-
hs_prefill = hs_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
85-
B_prefill = B_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
86-
C_prefill = C_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
87-
dt_prefill = dt_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H]
166+
hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
167+
B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
168+
C_prefill = C_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
169+
dt_prefill = dt_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H]
88170

89-
cu_seqlens = torch.cat(
90-
[
91-
torch.zeros(1, dtype=torch.int32, device=device),
92-
torch.cumsum(seq_len_prefill, dim=0),
93-
],
94-
dim=0,
95-
)
96171
seq_ids = torch.arange(num_prefill, device=device, dtype=torch.int32)
97172
seq_idx_prefill = torch.repeat_interleave(seq_ids, seq_len_prefill).view(1, -1)
98173

99-
initial_states = chunk_indices = chunk_offsets = None
174+
initial_states = None
100175
if torch.any(use_initial_states[:num_prefill]):
101176
initial_states = torch.where(
102177
use_initial_states[:num_prefill, None, None, None],
@@ -106,6 +181,11 @@ def _triton_cached_ssm(
106181
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
107182
cu_seqlens, chunk_size
108183
)
184+
185+
else:
186+
chunk_indices = None
187+
chunk_offsets = None
188+
109189
y_prefill, varlen_states = mamba_chunk_scan_combined(
110190
hs_prefill,
111191
dt_prefill,
@@ -128,20 +208,19 @@ def _triton_cached_ssm(
128208
mamba_ssm_cache_dtype=ssm_state_cache.dtype,
129209
)
130210

131-
y_flat[:total_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
211+
y_flat[:num_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
132212
ssm_state_cache.index_copy_(
133-
0, slot_idx[:num_prefill].to(torch.long), varlen_states.to(ssm_state_cache.dtype)
213+
0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype)
134214
)
135215

136216
# Decode: batch single-token updates via selective_state_update
137217
if num_decode > 0:
138-
total_prefill_tokens = 0 if num_prefill == 0 else int(seq_len[:num_prefill].sum().item())
139-
slot_idx_decode = slot_idx[num_prefill:].to(torch.long)
218+
slot_idx_decode = slot_idx[num_prefill:]
140219

141-
x_decode = hs_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H, D]
142-
B_decode = B_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
143-
C_decode = C_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
144-
dt_decode = dt_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H]
220+
x_decode = hs_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, H, D]
221+
B_decode = B_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, G, N]
222+
C_decode = C_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, G, N]
223+
dt_decode = dt_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, H]
145224

146225
dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
147226
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
@@ -162,9 +241,7 @@ def _triton_cached_ssm(
162241
state_batch_indices=slot_idx_decode,
163242
) # [nd, H, D]
164243

165-
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
166-
y_dec.to(y_flat.dtype)
167-
)
244+
y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_dec.to(y_flat.dtype))
168245

169246
return y
170247

@@ -184,6 +261,10 @@ def _triton_cached_ssm_fake(
184261
seq_start: torch.Tensor, # [num_seq]
185262
slot_idx: torch.Tensor, # [num_seq]
186263
use_initial_states: torch.Tensor, # [num_seq]
264+
cu_seqlens: torch.Tensor, # [num_seq + 1]
265+
chunk_indices: torch.Tensor, # [num_seq + 1]
266+
chunk_offsets: torch.Tensor, # [num_seq + 1]
267+
batch_info_tensor: torch.Tensor, # [2]
187268
# CACHES
188269
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
189270
# CONSTANTS
@@ -226,8 +307,9 @@ def get_cached_attention_op(cls) -> MHACallable:
226307

227308
@classmethod
228309
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
229-
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
230-
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4
310+
# Returns: seq_len, seq_start, slot_idx, use_initial_states,
311+
# cu_seqlens, chunk_indices, chunk_offsets, batch_info_tensor
312+
return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 8
231313

232314
@classmethod
233315
def get_cache_initializers(

tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def prepare_fused_mla_metadata(
182182
pages_per_seq: torch.Tensor,
183183
slot_idx: torch.Tensor,
184184
page_size: int,
185+
chunk_size: int,
185186
) -> List[torch.Tensor]:
186187
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
187188
seq_start = torch.zeros_like(seq_len[:num_seq])
@@ -196,7 +197,7 @@ def prepare_fused_mla_metadata(
196197

197198
@prepare_fused_mla_metadata.register_fake
198199
def prepare_fused_mla_metadata_fake(
199-
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
200+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
200201
):
201202
return (
202203
torch.empty_like(seq_len),

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def torch_backend_prepare_metadata(
363363
pages_per_seq: torch.Tensor,
364364
slot_idx: torch.Tensor,
365365
page_size: int,
366+
chunk_size: int,
366367
) -> List[torch.Tensor]:
367368
"""Prepare metadata for torch backend attention (similar to triton backend)."""
368369
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)

0 commit comments

Comments
 (0)