@@ -381,7 +381,8 @@ def _create_mock_metadata(request_ids,
381381 cache_manager ,
382382 num_ctx_tokens ,
383383 num_tokens ,
384- indexer_max_chunk_size = 8194 ):
384+ indexer_max_chunk_size = 8194 ,
385+ max_draft_tokens = 0 ):
385386 """Helper to create mock metadata for testing."""
386387
387388 class MockKVCacheParams :
@@ -396,6 +397,7 @@ def __init__(self):
396397 self .request_ids = request_ids
397398 self .num_contexts = num_contexts
398399 self .num_generations = num_generations
400+ self .max_draft_tokens = max_draft_tokens
399401 # Keep seq_lens on CPU for split_prefill_chunks and other CPU operations
400402 # CUDA kernels will convert to CUDA as needed
401403 self .seq_lens = seq_lens .cpu () if seq_lens .is_cuda else seq_lens
@@ -826,6 +828,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
826828 cache_manager = cache_manager ,
827829 num_ctx_tokens = total_context_tokens ,
828830 num_tokens = total_context_tokens ,
831+ max_draft_tokens = next_n - 1 ,
829832 )
830833 Indexer .prepare (metadata_context )
831834
@@ -851,6 +854,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
851854 cache_manager = cache_manager ,
852855 num_ctx_tokens = 0 ,
853856 num_tokens = batch_size * num_gen_tokens ,
857+ max_draft_tokens = next_n - 1 ,
854858 )
855859 Indexer .prepare (metadata_gen )
856860
@@ -1418,6 +1422,7 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
14181422 cache_manager = cache_manager ,
14191423 num_ctx_tokens = total_context_tokens ,
14201424 num_tokens = total_context_tokens ,
1425+ max_draft_tokens = next_n - 1 ,
14211426 )
14221427 Indexer .prepare (metadata_context )
14231428 indexer ._update_k_cache (k_context_fp8 , k_context_scale , metadata_context )
@@ -1450,16 +1455,24 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
14501455 cache_manager = cache_manager ,
14511456 num_ctx_tokens = 0 ,
14521457 num_tokens = num_gen_tokens ,
1458+ max_draft_tokens = next_n - 1 ,
14531459 )
14541460 Indexer .prepare (metadata_gen_write )
14551461 indexer ._update_k_cache (k_fp8 , k_scale , metadata_gen_write )
14561462
14571463 # Test with custom CUDA kernel
1458- metadata_custom = _create_mock_metadata (request_ids , batch_size , 0 ,
1459- batch_size , seq_lens .clone (),
1464+ metadata_custom = _create_mock_metadata (request_ids ,
1465+ batch_size ,
1466+ 0 ,
1467+ batch_size ,
1468+ seq_lens .clone (),
14601469 final_lens .clone (),
1461- num_cached_tokens , cache_manager , 0 ,
1462- num_gen_tokens , max_model_len )
1470+ num_cached_tokens ,
1471+ cache_manager ,
1472+ 0 ,
1473+ num_gen_tokens ,
1474+ max_model_len ,
1475+ max_draft_tokens = next_n - 1 )
14631476
14641477 Indexer .prepare (metadata_custom )
14651478 indexer ._update_k_cache (k_fp8 , k_scale , metadata_custom )
@@ -1476,11 +1489,18 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
14761489 pytest .skip (f"Custom topk not available: { e } " )
14771490
14781491 # Test with PyTorch fallback
1479- metadata_fallback = _create_mock_metadata (request_ids , batch_size , 0 ,
1480- batch_size , seq_lens .clone (),
1492+ metadata_fallback = _create_mock_metadata (request_ids ,
1493+ batch_size ,
1494+ 0 ,
1495+ batch_size ,
1496+ seq_lens .clone (),
14811497 final_lens .clone (),
1482- num_cached_tokens , cache_manager ,
1483- 0 , num_gen_tokens , max_model_len )
1498+ num_cached_tokens ,
1499+ cache_manager ,
1500+ 0 ,
1501+ num_gen_tokens ,
1502+ max_model_len ,
1503+ max_draft_tokens = next_n - 1 )
14841504
14851505 Indexer .prepare (metadata_fallback )
14861506 indexer ._update_k_cache (k_fp8 , k_scale , metadata_fallback )
0 commit comments