Skip to content

Commit cfe1c4c

Browse files
committed
fix test_dsa_indexer.
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
1 parent 3c63484 commit cfe1c4c

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

tests/unittest/_torch/attention/sparse/test_dsa_indexer.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@ def _create_mock_metadata(request_ids,
381381
cache_manager,
382382
num_ctx_tokens,
383383
num_tokens,
384-
indexer_max_chunk_size=8194):
384+
indexer_max_chunk_size=8194,
385+
max_draft_tokens=0):
385386
"""Helper to create mock metadata for testing."""
386387

387388
class MockKVCacheParams:
@@ -396,6 +397,7 @@ def __init__(self):
396397
self.request_ids = request_ids
397398
self.num_contexts = num_contexts
398399
self.num_generations = num_generations
400+
self.max_draft_tokens = max_draft_tokens
399401
# Keep seq_lens on CPU for split_prefill_chunks and other CPU operations
400402
# CUDA kernels will convert to CUDA as needed
401403
self.seq_lens = seq_lens.cpu() if seq_lens.is_cuda else seq_lens
@@ -826,6 +828,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
826828
cache_manager=cache_manager,
827829
num_ctx_tokens=total_context_tokens,
828830
num_tokens=total_context_tokens,
831+
max_draft_tokens=next_n - 1,
829832
)
830833
Indexer.prepare(metadata_context)
831834

@@ -851,6 +854,7 @@ def test_indexer_decode_with_paged_kv_cache(batch_size, next_n):
851854
cache_manager=cache_manager,
852855
num_ctx_tokens=0,
853856
num_tokens=batch_size * num_gen_tokens,
857+
max_draft_tokens=next_n - 1,
854858
)
855859
Indexer.prepare(metadata_gen)
856860

@@ -1418,6 +1422,7 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
14181422
cache_manager=cache_manager,
14191423
num_ctx_tokens=total_context_tokens,
14201424
num_tokens=total_context_tokens,
1425+
max_draft_tokens=next_n - 1,
14211426
)
14221427
Indexer.prepare(metadata_context)
14231428
indexer._update_k_cache(k_context_fp8, k_context_scale, metadata_context)
@@ -1450,16 +1455,24 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
14501455
cache_manager=cache_manager,
14511456
num_ctx_tokens=0,
14521457
num_tokens=num_gen_tokens,
1458+
max_draft_tokens=next_n - 1,
14531459
)
14541460
Indexer.prepare(metadata_gen_write)
14551461
indexer._update_k_cache(k_fp8, k_scale, metadata_gen_write)
14561462

14571463
# Test with custom CUDA kernel
1458-
metadata_custom = _create_mock_metadata(request_ids, batch_size, 0,
1459-
batch_size, seq_lens.clone(),
1464+
metadata_custom = _create_mock_metadata(request_ids,
1465+
batch_size,
1466+
0,
1467+
batch_size,
1468+
seq_lens.clone(),
14601469
final_lens.clone(),
1461-
num_cached_tokens, cache_manager, 0,
1462-
num_gen_tokens, max_model_len)
1470+
num_cached_tokens,
1471+
cache_manager,
1472+
0,
1473+
num_gen_tokens,
1474+
max_model_len,
1475+
max_draft_tokens=next_n - 1)
14631476

14641477
Indexer.prepare(metadata_custom)
14651478
indexer._update_k_cache(k_fp8, k_scale, metadata_custom)
@@ -1476,11 +1489,18 @@ def test_indexer_decode_custom_vs_fallback(batch_size, next_n, index_topk,
14761489
pytest.skip(f"Custom topk not available: {e}")
14771490

14781491
# Test with PyTorch fallback
1479-
metadata_fallback = _create_mock_metadata(request_ids, batch_size, 0,
1480-
batch_size, seq_lens.clone(),
1492+
metadata_fallback = _create_mock_metadata(request_ids,
1493+
batch_size,
1494+
0,
1495+
batch_size,
1496+
seq_lens.clone(),
14811497
final_lens.clone(),
1482-
num_cached_tokens, cache_manager,
1483-
0, num_gen_tokens, max_model_len)
1498+
num_cached_tokens,
1499+
cache_manager,
1500+
0,
1501+
num_gen_tokens,
1502+
max_model_len,
1503+
max_draft_tokens=next_n - 1)
14841504

14851505
Indexer.prepare(metadata_fallback)
14861506
indexer._update_k_cache(k_fp8, k_scale, metadata_fallback)

0 commit comments

Comments
 (0)