|
13 | 13 | from vllm.platforms import current_platform |
14 | 14 | from vllm.utils.mem_utils import get_max_shared_memory_bytes |
15 | 15 |
|
16 | | -if not current_platform.is_rocm(): |
17 | | - from xformers import ops as xops |
18 | | - from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask |
19 | | - |
20 | | - from tests.kernels.utils import make_alibi_bias |
21 | | - |
22 | 16 | FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 |
23 | 17 | # This will change depending on the compute capability. |
24 | 18 | # - 512 as a buffer |
@@ -448,129 +442,6 @@ def ref_multi_query_kv_attention( |
448 | 442 | return torch.cat(ref_outputs, dim=0) |
449 | 443 |
|
450 | 444 |
|
451 | | -@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) |
452 | | -@pytest.mark.parametrize("num_heads", NUM_HEADS) |
453 | | -@pytest.mark.parametrize("head_size", HEAD_SIZES) |
454 | | -@pytest.mark.parametrize("dtype", DTYPES) |
455 | | -@pytest.mark.parametrize("seed", SEEDS) |
456 | | -@pytest.mark.parametrize("device", CUDA_DEVICES) |
457 | | -@pytest.mark.skipif( |
458 | | - current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." |
459 | | -) |
460 | | -@torch.inference_mode() |
461 | | -def test_multi_query_kv_attention( |
462 | | - num_seqs: int, |
463 | | - num_heads: tuple[int, int], |
464 | | - head_size: int, |
465 | | - dtype: torch.dtype, |
466 | | - seed: int, |
467 | | - device: str, |
468 | | - use_alibi: bool = False, |
469 | | -) -> None: |
470 | | - current_platform.seed_everything(seed) |
471 | | - torch.set_default_device(device) |
472 | | - # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. |
473 | | - # As the xformers library is already tested with its own tests, we can use |
474 | | - # a smaller MAX_SEQ_LEN here. |
475 | | - max_len = min(MAX_SEQ_LEN, 4096) |
476 | | - seq_lens = random.sample(range(1, max_len), num_seqs) |
477 | | - num_tokens = sum(seq_lens) |
478 | | - |
479 | | - scale = float(1.0 / (head_size**0.5)) |
480 | | - num_query_heads, num_kv_heads = num_heads |
481 | | - qkv = torch.empty( |
482 | | - num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype |
483 | | - ) |
484 | | - qkv.uniform_(-scale, scale) |
485 | | - query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1) |
486 | | - |
487 | | - num_queries_per_kv = num_query_heads // num_kv_heads |
488 | | - if num_queries_per_kv > 1: |
489 | | - # Handle MQA and GQA |
490 | | - key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) |
491 | | - value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) |
492 | | - alibi_bias = None |
493 | | - if use_alibi: |
494 | | - alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) |
495 | | - attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) |
496 | | - output = torch.empty_like(query) |
497 | | - start = 0 |
498 | | - # Dynamic sequence length not supported with custom attn_bias. |
499 | | - for i, seq_len in enumerate(seq_lens): |
500 | | - end = start + seq_len |
501 | | - out = xops.memory_efficient_attention_forward( |
502 | | - query[None, start:end], |
503 | | - key[None, start:end], |
504 | | - value[None, start:end], |
505 | | - attn_bias=attn_bias[i], |
506 | | - p=0.0, |
507 | | - scale=scale, |
508 | | - ) |
509 | | - output[start:end].copy_(out.view_as(query[start:end])) |
510 | | - start += seq_len |
511 | | - # xformers.AttentionBias to Tensor for use in reference impl. |
512 | | - alibi_bias = [ |
513 | | - b.materialize((1, num_query_heads, i, i), device=device).squeeze() |
514 | | - for b, i in zip(attn_bias, seq_lens) |
515 | | - ] |
516 | | - else: |
517 | | - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) |
518 | | - output = xops.memory_efficient_attention_forward( |
519 | | - query.unsqueeze(0), |
520 | | - key.unsqueeze(0), |
521 | | - value.unsqueeze(0), |
522 | | - attn_bias=attn_bias, |
523 | | - p=0.0, |
524 | | - scale=scale, |
525 | | - ) |
526 | | - output = output.squeeze(0) |
527 | | - |
528 | | - cu_seq_lens = [0] |
529 | | - for seq_len in seq_lens: |
530 | | - cu_seq_lens.append(cu_seq_lens[-1] + seq_len) |
531 | | - ref_output = ref_multi_query_kv_attention( |
532 | | - cu_seq_lens, |
533 | | - query, |
534 | | - key, |
535 | | - value, |
536 | | - scale, |
537 | | - alibi_bias, |
538 | | - dtype, |
539 | | - ) |
540 | | - atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 |
541 | | - rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 |
542 | | - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) |
543 | | - |
544 | | - |
545 | | -@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) |
546 | | -@pytest.mark.parametrize("num_heads", NUM_HEADS) |
547 | | -@pytest.mark.parametrize("head_size", [64]) |
548 | | -@pytest.mark.parametrize("dtype", DTYPES) |
549 | | -@pytest.mark.parametrize("seed", SEEDS) |
550 | | -@pytest.mark.parametrize("device", CUDA_DEVICES) |
551 | | -@pytest.mark.skipif( |
552 | | - current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." |
553 | | -) |
554 | | -@torch.inference_mode() |
555 | | -def test_multi_query_kv_attention_with_alibi( |
556 | | - num_seqs: int, |
557 | | - num_heads: tuple[int, int], |
558 | | - head_size: int, |
559 | | - dtype: torch.dtype, |
560 | | - seed: int, |
561 | | - device: str, |
562 | | -) -> None: |
563 | | - return test_multi_query_kv_attention( |
564 | | - num_seqs, |
565 | | - num_heads, |
566 | | - head_size, |
567 | | - dtype, |
568 | | - seed, |
569 | | - device, |
570 | | - use_alibi=True, |
571 | | - ) |
572 | | - |
573 | | - |
574 | 445 | @pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention]) |
575 | 446 | def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None: |
576 | 447 | head_size = 64 |
|
0 commit comments