2424)
2525
2626
27+ @torch .library .custom_op ("auto_deploy::triton_ssm_prepare_metadata" , mutates_args = ())
28+ def _triton_ssm_prepare_metadata (
29+ position_ids : torch .Tensor ,
30+ seq_len : torch .Tensor ,
31+ input_pos : torch .Tensor ,
32+ cache_loc : torch .Tensor ,
33+ pages_per_seq : torch .Tensor ,
34+ slot_idx : torch .Tensor ,
35+ page_size : int ,
36+ chunk_size : int ,
37+ ) -> List [torch .Tensor ]:
38+ """Prepare metadata for cached SSM transform.
39+
40+ Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
41+ """
42+ # Determine number of active sequences and compute seq_start boundaries
43+ seq_len_sanitized = SequenceInfo ._get_sanitized_seq_len (position_ids , seq_len )
44+ num_seq = len (seq_len_sanitized )
45+
46+ seq_start = torch .zeros_like (seq_len_sanitized )
47+ if num_seq > 1 :
48+ seq_start [1 :] = torch .cumsum (seq_len_sanitized [:- 1 ], 0 )
49+
50+ # Truncate slot indices to match active sequences
51+ slot_idx_sanitized = slot_idx [:num_seq ].clone ().to (torch .long )
52+ # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
53+ # reference implementation to support chunked prefill.
54+ use_initial_states = input_pos > 0
55+
56+ device = position_ids .device
57+
58+ chunk_indices = torch .zeros (num_seq , dtype = torch .int32 , device = device )
59+ chunk_offsets = torch .zeros (num_seq , dtype = torch .int32 , device = device )
60+ cu_seqlens = torch .zeros (num_seq + 1 , dtype = torch .int32 , device = device )
61+ _ , s = position_ids .shape [:2 ]
62+ if s > 1 :
63+ # only compute chunk indices and offsets for prefill.
64+ prefill_mask = seq_len_sanitized > 1
65+ num_prefill = int (prefill_mask .sum ().item ())
66+ num_prefill_tokens = int (seq_len_sanitized [:num_prefill ].sum ().item ())
67+ num_decode = num_seq - num_prefill
68+ cu_seqlens = torch .cat (
69+ [
70+ torch .zeros (1 , dtype = torch .int32 , device = device ),
71+ torch .cumsum (seq_len_sanitized [:num_prefill ].to (torch .int32 ), dim = 0 ),
72+ ],
73+ dim = 0 ,
74+ )
75+ chunk_indices , chunk_offsets = cu_seqlens_to_chunk_indices_offsets (cu_seqlens , chunk_size )
76+ else :
77+ num_prefill = 0
78+ num_prefill_tokens = 0
79+ num_decode = num_seq
80+ batch_info_tensor = torch .tensor (
81+ [num_prefill , num_prefill_tokens , num_decode ], dtype = torch .int32
82+ ) # host tensor
83+
84+ return (
85+ seq_len_sanitized ,
86+ seq_start ,
87+ slot_idx_sanitized ,
88+ use_initial_states ,
89+ cu_seqlens ,
90+ chunk_indices ,
91+ chunk_offsets ,
92+ batch_info_tensor ,
93+ )
94+
95+
96+ @_triton_ssm_prepare_metadata .register_fake
97+ def _triton_ssm_prepare_metadata_fake (
98+ position_ids , seq_len , input_pos , cache_loc , pages_per_seq , slot_idx , page_size , chunk_size
99+ ):
100+ # Use the same sanitization logic to determine sizes in fake mode
101+ seq_len_sanitized = SequenceInfo ._get_sanitized_seq_len (position_ids , seq_len )
102+ num_seq = len (seq_len_sanitized )
103+ return (
104+ torch .empty_like (seq_len_sanitized ),
105+ torch .empty_like (seq_len_sanitized ),
106+ torch .empty (num_seq , dtype = torch .long , device = slot_idx .device ),
107+ torch .empty (num_seq , dtype = torch .bool , device = slot_idx .device ),
108+ torch .empty (num_seq + 1 , dtype = torch .int32 , device = slot_idx .device ), # cu seqlens
109+ torch .empty (num_seq , dtype = torch .int32 , device = slot_idx .device ), # chunk indices
110+ torch .empty (num_seq , dtype = torch .int32 , device = slot_idx .device ), # chunk offsets
111+ torch .empty (2 , dtype = torch .int32 ), # batch info tensor
112+ )
113+
114+
27115@torch .library .custom_op ("auto_deploy::triton_cached_ssm" , mutates_args = {})
28116def _triton_cached_ssm (
29117 # INPUTS (dense but may be flattened across sequences)
@@ -39,6 +127,10 @@ def _triton_cached_ssm(
39127 seq_start : torch .Tensor , # [num_seq]
40128 slot_idx : torch .Tensor , # [num_seq]
41129 use_initial_states : torch .Tensor , # [num_seq]
130+ cu_seqlens : torch .Tensor , # [num_seq + 1]
131+ chunk_indices : torch .Tensor , # [num_seq + 1]
132+ chunk_offsets : torch .Tensor , # [num_seq + 1]
133+ batch_info_tensor : torch .Tensor , # [2]
42134 # CACHES
43135 ssm_state_cache : torch .Tensor , # [max_batch_size, num_heads, head_dim, ssm_state_size]
44136 # CONSTANTS
@@ -51,8 +143,7 @@ def _triton_cached_ssm(
51143 - Prefill: run one varlen combined scan over concatenated prefill tokens and update final states per slot.
52144 - Decode: batch single-token updates with selective_state_update and update states per slot.
53145 """
54- b , s = hidden_states .shape [:2 ]
55- num_seq = seq_len .shape [0 ]
146+ b , s , num_heads , head_dim = hidden_states .shape
56147 # Flatten tokens for indexing/scatter
57148 bs = b * s
58149 device = hidden_states .device
@@ -64,39 +155,23 @@ def _triton_cached_ssm(
64155 y = torch .empty_like (hidden_states , memory_format = torch .contiguous_format )
65156 y_flat = y .view (bs , * y .shape [2 :])
66157
67- num_heads = hidden_states .shape [2 ]
68- head_dim = hidden_states .shape [3 ]
69158 ssm_state_size = B .shape [3 ]
70159
71- if s == 1 :
72- num_prefill = 0
73- num_decode = num_seq
74- else :
75- prefill_mask = seq_len > 1
76- num_prefill = int (prefill_mask .sum ().item ())
77- num_decode = num_seq - num_prefill
160+ [num_prefill , num_prefill_tokens , num_decode ] = batch_info_tensor .tolist ()
78161
79162 # Prefill: concatenate tokens at the front and run combined scan
80163 if num_prefill > 0 :
81- seq_len_prefill = seq_len [:num_prefill ].to (torch .int32 )
82- total_prefill_tokens = int (seq_len_prefill .sum ().item ())
164+ seq_len_prefill = seq_len [:num_prefill ]
83165
84- hs_prefill = hs_flat [:total_prefill_tokens ].unsqueeze (0 ) # [1, S_p, H, D]
85- B_prefill = B_flat [:total_prefill_tokens ].unsqueeze (0 ) # [1, S_p, G, N]
86- C_prefill = C_flat [:total_prefill_tokens ].unsqueeze (0 ) # [1, S_p, G, N]
87- dt_prefill = dt_flat [:total_prefill_tokens ].unsqueeze (0 ) # [1, S_p, H]
166+ hs_prefill = hs_flat [:num_prefill_tokens ].unsqueeze (0 ) # [1, S_p, H, D]
167+ B_prefill = B_flat [:num_prefill_tokens ].unsqueeze (0 ) # [1, S_p, G, N]
168+ C_prefill = C_flat [:num_prefill_tokens ].unsqueeze (0 ) # [1, S_p, G, N]
169+ dt_prefill = dt_flat [:num_prefill_tokens ].unsqueeze (0 ) # [1, S_p, H]
88170
89- cu_seqlens = torch .cat (
90- [
91- torch .zeros (1 , dtype = torch .int32 , device = device ),
92- torch .cumsum (seq_len_prefill , dim = 0 ),
93- ],
94- dim = 0 ,
95- )
96171 seq_ids = torch .arange (num_prefill , device = device , dtype = torch .int32 )
97172 seq_idx_prefill = torch .repeat_interleave (seq_ids , seq_len_prefill ).view (1 , - 1 )
98173
99- initial_states = chunk_indices = chunk_offsets = None
174+ initial_states = None
100175 if torch .any (use_initial_states [:num_prefill ]):
101176 initial_states = torch .where (
102177 use_initial_states [:num_prefill , None , None , None ],
@@ -106,6 +181,11 @@ def _triton_cached_ssm(
106181 chunk_indices , chunk_offsets = cu_seqlens_to_chunk_indices_offsets (
107182 cu_seqlens , chunk_size
108183 )
184+
185+ else :
186+ chunk_indices = None
187+ chunk_offsets = None
188+
109189 y_prefill , varlen_states = mamba_chunk_scan_combined (
110190 hs_prefill ,
111191 dt_prefill ,
@@ -128,20 +208,19 @@ def _triton_cached_ssm(
128208 mamba_ssm_cache_dtype = ssm_state_cache .dtype ,
129209 )
130210
131- y_flat [:total_prefill_tokens ] = y_prefill [0 ].to (y_flat .dtype )
211+ y_flat [:num_prefill_tokens ] = y_prefill [0 ].to (y_flat .dtype )
132212 ssm_state_cache .index_copy_ (
133- 0 , slot_idx [:num_prefill ]. to ( torch . long ) , varlen_states .to (ssm_state_cache .dtype )
213+ 0 , slot_idx [:num_prefill ], varlen_states .to (ssm_state_cache .dtype )
134214 )
135215
136216 # Decode: batch single-token updates via selective_state_update
137217 if num_decode > 0 :
138- total_prefill_tokens = 0 if num_prefill == 0 else int (seq_len [:num_prefill ].sum ().item ())
139- slot_idx_decode = slot_idx [num_prefill :].to (torch .long )
218+ slot_idx_decode = slot_idx [num_prefill :]
140219
141- x_decode = hs_flat [total_prefill_tokens : total_prefill_tokens + num_decode ] # [nd, H, D]
142- B_decode = B_flat [total_prefill_tokens : total_prefill_tokens + num_decode ] # [nd, G, N]
143- C_decode = C_flat [total_prefill_tokens : total_prefill_tokens + num_decode ] # [nd, G, N]
144- dt_decode = dt_flat [total_prefill_tokens : total_prefill_tokens + num_decode ] # [nd, H]
220+ x_decode = hs_flat [num_prefill_tokens : num_prefill_tokens + num_decode ] # [nd, H, D]
221+ B_decode = B_flat [num_prefill_tokens : num_prefill_tokens + num_decode ] # [nd, G, N]
222+ C_decode = C_flat [num_prefill_tokens : num_prefill_tokens + num_decode ] # [nd, G, N]
223+ dt_decode = dt_flat [num_prefill_tokens : num_prefill_tokens + num_decode ] # [nd, H]
145224
146225 dt_hp = dt_decode [:, :, None ].expand (- 1 , num_heads , head_dim )
147226 dt_bias_hp = dt_bias [..., None ].expand (num_heads , head_dim )
@@ -162,9 +241,7 @@ def _triton_cached_ssm(
162241 state_batch_indices = slot_idx_decode ,
163242 ) # [nd, H, D]
164243
165- y_flat [total_prefill_tokens : total_prefill_tokens + num_decode ].copy_ (
166- y_dec .to (y_flat .dtype )
167- )
244+ y_flat [num_prefill_tokens : num_prefill_tokens + num_decode ].copy_ (y_dec .to (y_flat .dtype ))
168245
169246 return y
170247
@@ -184,6 +261,10 @@ def _triton_cached_ssm_fake(
184261 seq_start : torch .Tensor , # [num_seq]
185262 slot_idx : torch .Tensor , # [num_seq]
186263 use_initial_states : torch .Tensor , # [num_seq]
264+ cu_seqlens : torch .Tensor , # [num_seq + 1]
265+ chunk_indices : torch .Tensor , # [num_seq + 1]
266+ chunk_offsets : torch .Tensor , # [num_seq + 1]
267+ batch_info_tensor : torch .Tensor , # [2]
187268 # CACHES
188269 ssm_state_cache : torch .Tensor , # [max_batch_size, num_heads, head_dim, ssm_state_size]
189270 # CONSTANTS
@@ -226,8 +307,9 @@ def get_cached_attention_op(cls) -> MHACallable:
226307
227308 @classmethod
228309 def get_prepare_metadata_op (cls ) -> Tuple [PrepareMetadataCallable , int ]:
229- # Returns (seq_len, seq_start, slot_idx, use_initial_states)
230- return torch .ops .auto_deploy .torch_ssm_prepare_metadata , 4
310+ # Returns: seq_len, seq_start, slot_idx, use_initial_states,
311+ # cu_seqlens, chunk_indices, chunk_offsets, batch_info_tensor
312+ return torch .ops .auto_deploy .triton_ssm_prepare_metadata , 8
231313
232314 @classmethod
233315 def get_cache_initializers (
0 commit comments