Skip to content

Commit 49df731

Browse files
ZhangGe6yuxianq
andauthored
[#6507][fix] Fix precision issue due to KV layout mismatch for split/concat kernels (#6917)
Signed-off-by: ZhangGe6 <sjtu.zg123@gmail.com> Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
1 parent 4fd93bd commit 49df731

File tree

6 files changed

+59
-19
lines changed

6 files changed

+59
-19
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ class FlashInferWrappers:
5656
class FlashInferAttentionMetadata(AttentionMetadata):
5757
workspace_buffer: Optional[torch.Tensor] = None
5858

59-
kv_layout: Literal["NHD", "HND"] = "NHD"
59+
# cache concat/split kernels when using PD disaggregation
60+
# expects KV cache in [max_num_pages, 2, num_kv_heads, page_size, head_dim] layout,
61+
# so set kv_layout as "HND" here
62+
kv_layout: Literal["NHD", "HND"] = "HND"
6063

6164
paged_kv_indptr_decode: torch.Tensor = field(init=False)
6265
paged_kv_indptr_prefill: torch.Tensor = field(init=False)
@@ -506,7 +509,8 @@ def forward_impl(
506509
q = q.view(-1, self.num_heads, self.head_dim)
507510

508511
# Key and Value
509-
kv_cache = metadata.kv_cache_manager.get_buffers(self.layer_idx)
512+
kv_cache = metadata.kv_cache_manager.get_buffers(
513+
self.layer_idx, kv_layout=metadata.kv_layout)
510514

511515
if k is not None and v is not None:
512516
k = k.view(-1, self.num_kv_heads, self.head_dim)

tensorrt_llm/_torch/attention_backend/star_flashinfer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ def forward(self,
331331
num_ctx_tokens = metadata.num_ctx_tokens
332332
num_qry_tokens = metadata.num_qry_tokens
333333

334-
kv_cache = metadata.kv_cache_manager.get_buffers(self.layer_idx)
334+
kv_cache = metadata.kv_cache_manager.get_buffers(
335+
self.layer_idx, kv_layout=metadata.kv_layout)
335336
if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant(
336337
):
337338
qc = self.quant_config

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -813,16 +813,43 @@ def get_num_available_tokens(self, max_num_draft_tokens: int = 0) -> int:
813813
return (self.get_num_free_blocks() * self.tokens_per_block -
814814
self.num_extra_kv_tokens - max_num_draft_tokens)
815815

816-
def get_buffers(self, layer_idx: int) -> Optional[torch.Tensor]:
816+
def get_buffers(self,
817+
layer_idx: int,
818+
kv_layout: str = "NHD") -> Optional[torch.Tensor]:
819+
''' Slice KV tensor for a specified layer and reshape it.
820+
821+
1. Slice:
822+
[max_num_pages, num_layers, kv_factor, page_size * num_kv_heads * head_dim] ->
823+
[max_num_pages, kv_factor, page_size * num_kv_heads * head_dim]
824+
825+
2. Reshape:
826+
kv_layout = "NHD" -> [max_num_pages, kv_factor, page_size, num_kv_heads, head_dim]
827+
kv_layout = "HND" -> [max_num_pages, kv_factor, num_kv_heads, page_size, head_dim]
828+
829+
Note that different attention backend/implementation can have different KV layouts,
830+
"kv_layout" should be set accordingly to avoid surprises.
831+
'''
817832
layer_offset = self.layer_offsets[layer_idx]
818833
result = self.impl.get_primary_pool_data(layer_offset)
819-
return result.reshape(
820-
result.shape[0],
821-
self.kv_factor,
822-
self.tokens_per_block,
823-
self.num_kv_heads_per_layer[layer_offset],
824-
self.head_dim,
825-
)
834+
835+
assert kv_layout in ["NHD",
836+
"HND"], f"Unsupported kv_layout: {kv_layout}"
837+
if kv_layout == "NHD":
838+
return result.reshape(
839+
result.shape[0],
840+
self.kv_factor,
841+
self.tokens_per_block,
842+
self.num_kv_heads_per_layer[layer_offset],
843+
self.head_dim,
844+
)
845+
else:
846+
return result.reshape(
847+
result.shape[0],
848+
self.kv_factor,
849+
self.num_kv_heads_per_layer[layer_offset],
850+
self.tokens_per_block,
851+
self.head_dim,
852+
)
826853

827854
def get_indexer_k_cache_pool_data(self, layer_idx: int) -> torch.Tensor:
828855
result = self.impl.get_indexer_k_cache_pool_data(layer_idx)

tests/unittest/_torch/attention/test_attention.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,13 @@ def test_attention_backend(s: Scenario):
438438
flashinfer_kv_cache = torch.randn(num_layers,
439439
s.max_num_pages,
440440
2,
441-
page_size,
442441
num_kv_heads,
442+
page_size,
443443
head_dim,
444444
device="cuda").to(s.kvcache_dtype)
445-
ref_kv_cache = flashinfer_kv_cache.transpose(1, 2).contiguous().view(
446-
num_layers, 2, batch_size, kv_cache_len, num_kv_heads, head_dim)
445+
ref_kv_cache = flashinfer_kv_cache.transpose(1, 2).transpose(
446+
3, 4).contiguous().view(num_layers, 2, batch_size, kv_cache_len,
447+
num_kv_heads, head_dim)
447448
kv = torch.randn(num_layers,
448449
2,
449450
nnz_kv,
@@ -588,12 +589,13 @@ def test_attention_backend_ifb(s: PagedScenario):
588589
flashinfer_kv_cache = torch.randn(num_layers,
589590
s.max_num_pages,
590591
2,
591-
page_size,
592592
num_kv_heads,
593+
page_size,
593594
head_dim,
594595
device="cuda").to(s.kvcache_dtype)
595-
ref_kv_cache = flashinfer_kv_cache.transpose(1, 2).contiguous().view(
596-
num_layers, 2, batch_size, kv_cache_len, num_kv_heads, head_dim)
596+
ref_kv_cache = flashinfer_kv_cache.transpose(1, 2).transpose(
597+
3, 4).contiguous().view(num_layers, 2, batch_size, kv_cache_len,
598+
num_kv_heads, head_dim)
597599
vanilla_kv_cache = ref_kv_cache.transpose(1, 2).contiguous()
598600
kv = torch.randn(num_layers,
599601
2,

tests/unittest/_torch/attention/test_flashinfer_attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,10 @@ def test_flashinfer_attention(self, scenario: Scenario):
227227
sum(context_sequence_lengths) + num_gens)
228228

229229
# validate kv cache was updated expectedly
230-
cache_buf = kv_cache_manager.get_buffers(flashinfer_attn.layer_idx)
230+
cache_buf = kv_cache_manager.get_buffers(
231+
flashinfer_attn.layer_idx, kv_layout=attn_metadata.kv_layout)
232+
if attn_metadata.kv_layout == "HND":
233+
cache_buf = cache_buf.transpose(2, 3).contiguous()
231234
assert cache_buf is not None
232235
num_kv_heads = cache_buf.size(-2)
233236

tests/unittest/_torch/attention/test_flashinfer_star_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,10 @@ def test_flashinfer_star_attention(self, scenario: Scenario):
312312
num_gens)
313313

314314
# validate kv cache was updated expectedly
315-
cache_buf = kv_cache_manager.get_buffers(star_attn.layer_idx)
315+
cache_buf = kv_cache_manager.get_buffers(
316+
star_attn.layer_idx, kv_layout=attn_metadata.kv_layout)
317+
if attn_metadata.kv_layout == "HND":
318+
cache_buf = cache_buf.transpose(2, 3).contiguous()
316319
assert cache_buf is not None
317320
num_kv_heads = cache_buf.size(-2)
318321

0 commit comments

Comments
 (0)