From 528dd4509d5f5afc877b850d746ec496a1960faa Mon Sep 17 00:00:00 2001 From: MingYang119 Date: Sat, 29 Nov 2025 16:30:05 +0800 Subject: [PATCH] add lightning_indexer and sparse_flash_attention --- .github/Dockerfile.buildwheel | 2 +- .../_e2e_nightly_single_node_models.yaml | 11 +- .github/workflows/_e2e_test.yaml | 21 +- .github/workflows/format_pr_body.yaml | 2 +- .github/workflows/image_310p_openeuler.yml | 1 - .github/workflows/image_310p_ubuntu.yml | 1 - .github/workflows/image_a3_openeuler.yml | 1 - .github/workflows/image_a3_ubuntu.yml | 1 - .github/workflows/image_openeuler.yml | 1 - .github/workflows/image_ubuntu.yml | 1 - .github/workflows/pre-commit.yml | 2 +- .github/workflows/release_code.yml | 2 +- .github/workflows/release_whl.yml | 3 +- .github/workflows/vllm_ascend_test_310p.yaml | 1 - .../vllm_ascend_test_nightly_a2.yaml | 11 +- .../vllm_ascend_test_nightly_a3.yaml | 3 + .gitignore | 4 + .pre-commit-config.yaml | 2 +- CMakeLists.txt | 35 +- Dockerfile | 2 +- Dockerfile.310p | 2 +- Dockerfile.310p.openEuler | 2 +- Dockerfile.a3 | 2 +- Dockerfile.a3.openEuler | 2 +- Dockerfile.openEuler | 2 +- README.md | 2 +- README.zh.md | 2 +- .../op_host/batch_matmul_transpose.h | 123 ++ csrc/batch_matmul_transpose/op_host/common.h | 57 + .../op_host/common_tiling.h | 239 +++ .../op_host/tiling/tiling_data.cpp | 155 ++ .../op_host/tiling/tiling_data.h | 90 + .../batch_matmul_transpose_kernel.cpp | 824 ++++++++ csrc/build_aclnn.sh | 4 +- csrc/kernels/math_utils.h | 15 + csrc/lightning_indexer/op_host/CMakeLists.txt | 42 + .../op_host/lightning_indexer_def.cpp | 72 + .../op_host/lightning_indexer_proto.cpp | 96 + .../op_host/lightning_indexer_tiling.cpp | 733 +++++++ .../op_host/lightning_indexer_tiling.h | 222 ++ .../op_kernel/lightning_indexer.cpp | 58 + .../op_kernel/lightning_indexer_common.h | 142 ++ .../op_kernel/lightning_indexer_kernel.h | 646 ++++++ .../lightning_indexer_service_cube.h | 421 ++++ .../lightning_indexer_service_vector.h | 613 ++++++ .../lightning_indexer_template_tiling_key.h | 69 + .../op_kernel/lightning_indexer_vector.h | 391 ++++ csrc/ops.h | 9 + .../op_host/CMakeLists.txt | 39 + .../op_host/sparse_flash_attention_def.cpp | 90 + .../op_host/sparse_flash_attention_proto.cpp | 48 + .../op_host/sparse_flash_attention_tiling.cpp | 1876 +++++++++++++++++ .../op_host/sparse_flash_attention_tiling.h | 591 ++++++ .../op_kernel/sparse_flash_attention.cpp | 53 + .../op_kernel/sparse_flash_attention_common.h | 198 ++ .../sparse_flash_attention_kernel_mla.h | 987 +++++++++ .../sparse_flash_attention_service_cube_mla.h | 1125 ++++++++++ ...parse_flash_attention_service_vector_mla.h | 1377 ++++++++++++ ...arse_flash_attention_template_tiling_key.h | 57 + csrc/torch_binding.cpp | 152 ++ csrc/torch_binding_meta.cpp | 71 + docs/source/tutorials/DeepSeek-V3.2-Exp.md | 70 +- docs/source/tutorials/Qwen2.5-Omni.md | 206 ++ docs/source/tutorials/index.md | 1 + docs/source/tutorials/multi_npu_qwen3_next.md | 4 +- examples/offline_data_parallel.py | 9 + pyproject.toml | 4 +- requirements-dev.txt | 3 +- requirements.txt | 4 +- setup.py | 12 +- .../test_offline_inference_parallel_310p.py | 3 - tests/e2e/conftest.py | 8 +- .../models/configs/ERNIE-4.5-21B-A3B-PT.yaml | 9 + ...rnVL3_5-8B.yaml => InternVL3_5-8B-hf.yaml} | 0 tests/e2e/models/configs/Molmo-7B-D-0924.yaml | 13 + tests/e2e/models/configs/accuracy.txt | 8 +- tests/e2e/models/configs/gemma-2-9b-it.yaml | 11 + tests/e2e/models/configs/gemma-3-4b-it.yaml | 13 + tests/e2e/models/configs/internlm-7b.yaml | 13 + tests/e2e/models/configs/llava-1.5-7b-hf.yaml | 11 + tests/e2e/models/test_lm_eval_correctness.py | 3 +- tests/e2e/multicard/test_data_parallel.py | 13 +- tests/e2e/multicard/test_data_parallel_tp2.py | 52 + tests/e2e/multicard/test_expert_parallel.py | 21 +- tests/e2e/multicard/test_external_launcher.py | 10 +- .../multicard/test_fused_moe_allgather_ep.py | 16 +- .../test_offline_inference_distributed.py | 14 +- tests/e2e/multicard/test_prefix_caching.py | 65 - tests/e2e/multicard/test_qwen3_next.py | 6 - tests/e2e/multicard/test_shared_expert_dp.py | 93 + .../e2e/multicard/test_torchair_graph_mode.py | 13 +- tests/e2e/multicard/test_weight_loader.py | 4 +- .../test_mtpx_deepseek_r1_0528_w8a8.py | 6 +- ...test_prefix_cache_deepseek_r1_0528_w8a8.py | 3 - .../test_prefix_cache_qwen3_32b_int8.py | 7 +- .../test_qwen3_32b_int8_a3_feature_stack3.py | 3 +- .../models/test_deepseek_r1_0528_w8a8.py | 10 +- .../models/test_deepseek_r1_w8a8_eplb.py | 3 - .../models/test_deepseek_v3_2_exp_w8a8.py | 3 +- tests/e2e/nightly/models/test_glm4_5.py | 111 + .../e2e/nightly/models/test_qwen2_5_vl_32b.py | 5 +- .../models/test_qwen3_235b_a22b_w8a8_eplb.py | 6 +- .../nightly/models/test_qwen3_235b_w8a8.py | 6 - tests/e2e/nightly/models/test_qwq_32b.py | 2 - .../models/DeepSeek-R1-W8A8-A2-torchair.yaml | 4 +- .../config/models/DeepSeek-R1-W8A8-A2.yaml | 4 +- .../config/models/DeepSeek-R1-W8A8-EPLB.yaml | 8 +- .../config/models/DeepSeek-R1-W8A8.yaml | 8 +- .../config/models/DeepSeek-V3_2-Exp-bf16.yaml | 8 +- tests/e2e/nightly/multi_node/scripts/run.sh | 4 +- .../ops/test_batch_matmul_transpose.py | 141 ++ .../spec_decode_v1/test_v1_mtp_correctness.py | 41 +- .../spec_decode_v1/test_v1_spec_decode.py | 87 +- tests/e2e/singlecard/test_ascend_scheduler.py | 52 - tests/e2e/singlecard/test_bge_model.py | 2 +- tests/e2e/singlecard/test_chunked.py | 82 - tests/e2e/singlecard/test_embedding.py | 2 +- .../e2e/singlecard/test_embedding_aclgraph.py | 4 +- tests/e2e/singlecard/test_vlm.py | 35 - tests/ut/attention/test_attention_v1.py | 246 +-- tests/ut/distributed/test_parallel_state.py | 4 +- tests/ut/models/test_mla.py | 27 +- tests/ut/models/test_qwen2_vl.py | 200 -- tests/ut/ops/test_layernorm.py | 7 +- tests/ut/ops/test_moe_comm_method.py | 4 +- typos.toml | 2 +- vllm_ascend/ascend_config.py | 4 + vllm_ascend/attention/attention_v1.py | 321 +-- vllm_ascend/attention/mla_v1.py | 23 +- vllm_ascend/attention/sfa_v1.py | 8 +- .../kvpool/ascend_store_connector.py | 2 - vllm_ascend/distributed/kvpool/config_data.py | 10 + vllm_ascend/distributed/kvpool/kv_transfer.py | 51 +- .../distributed/kvpool/pool_scheduler.py | 9 + vllm_ascend/distributed/kvpool/pool_worker.py | 48 +- .../llmdatadist_c_mgr_connector.py | 1 + vllm_ascend/distributed/parallel_state.py | 33 +- vllm_ascend/eplb/adaptor/vllm_adaptor.py | 53 +- vllm_ascend/models/__init__.py | 19 - vllm_ascend/models/layers/__init__.py | 0 vllm_ascend/models/qwen2_vl.py | 373 ---- vllm_ascend/models/qwen3_next.py | 981 --------- vllm_ascend/models/qwen3_next_mtp.py | 109 - vllm_ascend/models/qwen3_vl.py | 264 --- vllm_ascend/ops/fused_moe/fused_moe.py | 43 +- vllm_ascend/ops/fused_moe/moe_comm_method.py | 8 +- vllm_ascend/ops/fused_moe/moe_mlp.py | 81 +- vllm_ascend/ops/layernorm.py | 7 +- vllm_ascend/ops/linear.py | 2 +- vllm_ascend/{models/layers => ops}/mla.py | 0 vllm_ascend/ops/register_custom_ops.py | 36 +- vllm_ascend/patch/platform/__init__.py | 1 - vllm_ascend/patch/platform/patch_config.py | 6 + .../patch/platform/patch_distributed.py | 22 - .../platform/patch_dynamo_vllm_backend.py | 16 - vllm_ascend/patch/worker/__init__.py | 2 + .../patch/worker/patch_qwen2_5_omni.py | 72 + vllm_ascend/patch/worker/patch_qwen2_5_vl.py | 227 +- vllm_ascend/patch/worker/patch_qwen3_vl.py | 251 +++ vllm_ascend/platform.py | 4 +- vllm_ascend/quantization/quant_config.py | 9 +- vllm_ascend/quantization/utils.py | 6 + vllm_ascend/quantization/w4a8_dynamic.py | 8 +- vllm_ascend/quantization/w8a8.py | 18 +- vllm_ascend/quantization/w8a8_dynamic.py | 64 +- vllm_ascend/quantization/w8a8_pdmix.py | 70 + vllm_ascend/sample/rejection_sampler.py | 329 ++- vllm_ascend/spec_decode/__init__.py | 3 + vllm_ascend/spec_decode/eagle_proposer.py | 4 +- vllm_ascend/spec_decode/interface.py | 3 +- vllm_ascend/spec_decode/mtp_proposer.py | 68 +- vllm_ascend/spec_decode/ngram_proposer.py | 3 +- vllm_ascend/spec_decode/suffix_proposer.py | 45 + .../torchair/ops/torchair_fused_moe.py | 19 +- vllm_ascend/torchair/torchair_mtp_proposer.py | 4 +- vllm_ascend/utils.py | 12 +- vllm_ascend/worker/model_runner_v1.py | 61 +- vllm_ascend/worker/worker_v1.py | 18 +- 178 files changed, 14286 insertions(+), 3200 deletions(-) create mode 100644 csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h create mode 100644 csrc/batch_matmul_transpose/op_host/common.h create mode 100644 csrc/batch_matmul_transpose/op_host/common_tiling.h create mode 100644 csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp create mode 100644 csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h create mode 100644 csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp create mode 100644 csrc/kernels/math_utils.h create mode 100644 csrc/lightning_indexer/op_host/CMakeLists.txt create mode 100644 csrc/lightning_indexer/op_host/lightning_indexer_def.cpp create mode 100644 csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp create mode 100644 csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp create mode 100644 csrc/lightning_indexer/op_host/lightning_indexer_tiling.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer.cpp create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_common.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_service_cube.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_service_vector.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_template_tiling_key.h create mode 100644 csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h create mode 100644 csrc/sparse_flash_attention/op_host/CMakeLists.txt create mode 100644 csrc/sparse_flash_attention/op_host/sparse_flash_attention_def.cpp create mode 100644 csrc/sparse_flash_attention/op_host/sparse_flash_attention_proto.cpp create mode 100644 csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.cpp create mode 100644 csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.h create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention.cpp create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_common.h create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_kernel_mla.h create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_cube_mla.h create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_vector_mla.h create mode 100644 csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_template_tiling_key.h create mode 100644 docs/source/tutorials/Qwen2.5-Omni.md create mode 100644 tests/e2e/models/configs/ERNIE-4.5-21B-A3B-PT.yaml rename tests/e2e/models/configs/{InternVL3_5-8B.yaml => InternVL3_5-8B-hf.yaml} (100%) create mode 100644 tests/e2e/models/configs/Molmo-7B-D-0924.yaml create mode 100644 tests/e2e/models/configs/gemma-2-9b-it.yaml create mode 100644 tests/e2e/models/configs/gemma-3-4b-it.yaml create mode 100644 tests/e2e/models/configs/internlm-7b.yaml create mode 100644 tests/e2e/models/configs/llava-1.5-7b-hf.yaml create mode 100644 tests/e2e/multicard/test_data_parallel_tp2.py create mode 100644 tests/e2e/multicard/test_shared_expert_dp.py create mode 100644 tests/e2e/nightly/models/test_glm4_5.py create mode 100644 tests/e2e/nightly/ops/test_batch_matmul_transpose.py delete mode 100644 tests/e2e/singlecard/test_chunked.py delete mode 100644 tests/ut/models/test_qwen2_vl.py delete mode 100644 vllm_ascend/models/layers/__init__.py delete mode 100644 vllm_ascend/models/qwen2_vl.py delete mode 100644 vllm_ascend/models/qwen3_next.py delete mode 100644 vllm_ascend/models/qwen3_next_mtp.py delete mode 100644 vllm_ascend/models/qwen3_vl.py rename vllm_ascend/{models/layers => ops}/mla.py (100%) delete mode 100644 vllm_ascend/patch/platform/patch_dynamo_vllm_backend.py create mode 100644 vllm_ascend/patch/worker/patch_qwen2_5_omni.py create mode 100644 vllm_ascend/patch/worker/patch_qwen3_vl.py create mode 100644 vllm_ascend/quantization/w8a8_pdmix.py create mode 100644 vllm_ascend/spec_decode/suffix_proposer.py diff --git a/.github/Dockerfile.buildwheel b/.github/Dockerfile.buildwheel index abfd3b8de24..3374e8b9453 100644 --- a/.github/Dockerfile.buildwheel +++ b/.github/Dockerfile.buildwheel @@ -18,7 +18,7 @@ ARG PY_VERSION=3.11 FROM quay.io/ascend/manylinux:8.3.rc2-910b-manylinux_2_28-py${PY_VERSION} ARG COMPILE_CUSTOM_KERNELS=1 -ARG SOC_VERSION +ARG SOC_VERSION="ascend910b1" # Define environments ENV DEBIAN_FRONTEND=noninteractive diff --git a/.github/workflows/_e2e_nightly_single_node_models.yaml b/.github/workflows/_e2e_nightly_single_node_models.yaml index 1ce99fe3666..c587722bf85 100644 --- a/.github/workflows/_e2e_nightly_single_node_models.yaml +++ b/.github/workflows/_e2e_nightly_single_node_models.yaml @@ -59,7 +59,7 @@ jobs: name: ${{inputs.model_list}} accuracy test runs-on: ${{ inputs.runner }} container: - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-910b-ubuntu22.04-py3.11 + image: "${{ inputs.image }}" env: VLLM_USE_MODELSCOPE: True GHA_VLLM_ASCEND_VERSION: ${{ inputs.vllm-ascend }} @@ -109,7 +109,13 @@ jobs: shell: bash -l {0} run: | . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh - python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev20250914-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" + python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl" + + - name: Install tensorflow (for Molmo-7B-D-0924) + if: ${{ inputs.runner == 'linux-aarch64-a2-1' && contains(inputs.model_list, 'Molmo-7B-D-0924') }} + shell: bash -l {0} + run: | + pip install tensorflow --no-cache-dir - name: Resolve vllm-ascend version run: | @@ -172,6 +178,7 @@ jobs: id: report env: VLLM_WORKER_MULTIPROC_METHOD: spawn + HF_DATASETS_OFFLINE: True VLLM_USE_MODELSCOPE: True VLLM_CI_RUNNER: ${{ inputs.runner }} VLLM_VERSION: ${{ env.GHA_VLLM_VERSION }} diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 6906930ac61..c7e883a0375 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -94,11 +94,11 @@ jobs: pytest -sv tests/e2e/singlecard/test_ascend_scheduler.py pytest -sv tests/e2e/singlecard/test_bge_model.py pytest -sv tests/e2e/singlecard/test_camem.py - pytest -sv tests/e2e/singlecard/test_chunked.py pytest -sv tests/e2e/singlecard/test_embedding.py # pytest -sv tests/e2e/singlecard/test_embedding_aclgraph.py pytest -sv tests/e2e/singlecard/test_guided_decoding.py - pytest -sv tests/e2e/singlecard/test_ilama_lora.py + # torch 2.8 doesn't work with lora, fix me + #pytest -sv tests/e2e/singlecard/test_ilama_lora.py pytest -sv tests/e2e/singlecard/test_profile_execute_duration.py pytest -sv tests/e2e/singlecard/test_quantization.py pytest -sv tests/e2e/singlecard/test_sampler.py @@ -188,7 +188,8 @@ jobs: pytest -sv tests/e2e/multicard/test_external_launcher.py pytest -sv tests/e2e/multicard/test_single_request_aclgraph.py pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py - pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py + # torch 2.8 doesn't work with lora, fix me + #pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py # To avoid oom, we need to run the test in a single process. pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ @@ -266,17 +267,17 @@ jobs: VLLM_WORKER_MULTIPROC_METHOD: spawn VLLM_USE_MODELSCOPE: True run: | - pytest -sv \ - tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe \ - tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC - # tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP \ - # tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_W8A8_WITH_EP + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC + # pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP + # pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_W8A8_WITH_EP + pytest -sv tests/e2e/multicard/test_data_parallel_tp2.py - name: Install Ascend toolkit & triton_ascend (for Qwen3-Next-80B-A3B-Instruct) shell: bash -l {0} run: | . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh - python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27.whl" + python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl" - name: Run vllm-project/vllm-ascend Qwen3 Next test working-directory: ./vllm-ascend @@ -286,4 +287,4 @@ jobs: VLLM_USE_MODELSCOPE: True run: | . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh - pytest -sv tests/e2e/multicard/test_qwen3_next.py + #pytest -sv tests/e2e/multicard/test_qwen3_next.py diff --git a/.github/workflows/format_pr_body.yaml b/.github/workflows/format_pr_body.yaml index 71235e7b10e..58f0222bdfe 100644 --- a/.github/workflows/format_pr_body.yaml +++ b/.github/workflows/format_pr_body.yaml @@ -43,7 +43,7 @@ jobs: uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.2.2 - name: Set up Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 - name: Get vLLM release version run: | diff --git a/.github/workflows/image_310p_openeuler.yml b/.github/workflows/image_310p_openeuler.yml index 5a34889cf7f..b033cb47f8c 100644 --- a/.github/workflows/image_310p_openeuler.yml +++ b/.github/workflows/image_310p_openeuler.yml @@ -132,5 +132,4 @@ jobs: file: Dockerfile.310p.openEuler build-args: | PIP_INDEX_URL=https://pypi.org/simple - SOC_VERSION=ascend310p1 provenance: false diff --git a/.github/workflows/image_310p_ubuntu.yml b/.github/workflows/image_310p_ubuntu.yml index 56aafcf86e1..ddac1c1f3a3 100644 --- a/.github/workflows/image_310p_ubuntu.yml +++ b/.github/workflows/image_310p_ubuntu.yml @@ -128,5 +128,4 @@ jobs: tags: ${{ steps.meta.outputs.tags }} build-args: | PIP_INDEX_URL=https://pypi.org/simple - SOC_VERSION=ascend310p1 provenance: false \ No newline at end of file diff --git a/.github/workflows/image_a3_openeuler.yml b/.github/workflows/image_a3_openeuler.yml index b1c5772718f..6524c9e0b2e 100644 --- a/.github/workflows/image_a3_openeuler.yml +++ b/.github/workflows/image_a3_openeuler.yml @@ -131,6 +131,5 @@ jobs: file: Dockerfile.a3.openEuler build-args: | PIP_INDEX_URL=https://pypi.org/simple - SOC_VERSION=ascend910_9391 provenance: false diff --git a/.github/workflows/image_a3_ubuntu.yml b/.github/workflows/image_a3_ubuntu.yml index 473df8e51ca..baaab8da50c 100644 --- a/.github/workflows/image_a3_ubuntu.yml +++ b/.github/workflows/image_a3_ubuntu.yml @@ -127,6 +127,5 @@ jobs: tags: ${{ steps.meta.outputs.tags }} build-args: | PIP_INDEX_URL=https://pypi.org/simple - SOC_VERSION=ascend910_9391 provenance: false diff --git a/.github/workflows/image_openeuler.yml b/.github/workflows/image_openeuler.yml index 29ccb848085..ead1467d20a 100644 --- a/.github/workflows/image_openeuler.yml +++ b/.github/workflows/image_openeuler.yml @@ -131,5 +131,4 @@ jobs: file: Dockerfile.openEuler build-args: | PIP_INDEX_URL=https://pypi.org/simple - SOC_VERSION=ascend910b1 provenance: false diff --git a/.github/workflows/image_ubuntu.yml b/.github/workflows/image_ubuntu.yml index ab321304169..15960137d80 100644 --- a/.github/workflows/image_ubuntu.yml +++ b/.github/workflows/image_ubuntu.yml @@ -128,5 +128,4 @@ jobs: tags: ${{ steps.meta.outputs.tags }} build-args: | PIP_INDEX_URL=https://pypi.org/simple - SOC_VERSION=ascend910b1 provenance: false diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 27dadd6d7dc..212ee553df3 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -16,7 +16,7 @@ jobs: steps: - name: Checkout vllm-project/vllm-ascend repo uses: actions/checkout@v6 - - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: "3.11" - run: echo "::add-matcher::.github/workflows/matchers/actionlint.json" diff --git a/.github/workflows/release_code.yml b/.github/workflows/release_code.yml index be8b85f1c00..afda093e01f 100644 --- a/.github/workflows/release_code.yml +++ b/.github/workflows/release_code.yml @@ -50,7 +50,7 @@ jobs: lscpu - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/release_whl.yml b/.github/workflows/release_whl.yml index b095e696e84..f8a73ab3b5c 100644 --- a/.github/workflows/release_whl.yml +++ b/.github/workflows/release_whl.yml @@ -69,7 +69,6 @@ jobs: ls docker build -f ./.github/Dockerfile.buildwheel \ --build-arg PY_VERSION=${{ matrix.python-version }} \ - --build-arg SOC_VERSION=ascend910b1 \ -t wheel:v1 . docker run --rm \ -u $(id -u):$(id -g) \ @@ -80,7 +79,7 @@ jobs: - name: Set up Python ${{ matrix.python-version }} if: startsWith(github.ref, 'refs/tags/') - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/vllm_ascend_test_310p.yaml b/.github/workflows/vllm_ascend_test_310p.yaml index 9e14ddfb621..d2a0ff0e799 100644 --- a/.github/workflows/vllm_ascend_test_310p.yaml +++ b/.github/workflows/vllm_ascend_test_310p.yaml @@ -100,7 +100,6 @@ jobs: run: | export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib - export SOC_VERSION=ASCEND310P3 pip install -r requirements-dev.txt pip install -v -e . diff --git a/.github/workflows/vllm_ascend_test_nightly_a2.yaml b/.github/workflows/vllm_ascend_test_nightly_a2.yaml index aaa0e1afcf6..54e33b48508 100644 --- a/.github/workflows/vllm_ascend_test_nightly_a2.yaml +++ b/.github/workflows/vllm_ascend_test_nightly_a2.yaml @@ -114,6 +114,15 @@ jobs: - Qwen3-VL-8B-Instruct - Qwen2.5-Omni-7B - Meta-Llama-3.1-8B-Instruct + - os: linux-aarch64-a2-1 + model_list: + - ERNIE-4.5-21B-A3B-PT + - gemma-2-9b-it + - gemma-3-4b-it + - internlm-7b + - InternVL3_5-8B-hf + - llava-1.5-7b-hf + - Molmo-7B-D-0924 - os: linux-aarch64-a2-2 model_list: - Qwen3-30B-A3B @@ -128,5 +137,5 @@ jobs: vllm: v0.11.2 runner: ${{ matrix.test_config.os }} model_list: ${{ toJson(matrix.test_config.model_list) }} - image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-910b-ubuntu22.04-py3.11 + image: 'swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc2-910b-ubuntu22.04-py3.11' upload: false diff --git a/.github/workflows/vllm_ascend_test_nightly_a3.yaml b/.github/workflows/vllm_ascend_test_nightly_a3.yaml index 4abbdef4208..d0dc99c2ffa 100644 --- a/.github/workflows/vllm_ascend_test_nightly_a3.yaml +++ b/.github/workflows/vllm_ascend_test_nightly_a3.yaml @@ -134,6 +134,9 @@ jobs: - name: deepseek3_2-exp-w8a8 os: linux-aarch64-a3-16 tests: tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py + - name: glm-4-5 + os: linux-aarch64-a3-16 + tests: tests/e2e/nightly/models/test_glm4_5.py uses: ./.github/workflows/_e2e_nightly_single_node.yaml with: vllm: v0.11.2 diff --git a/.gitignore b/.gitignore index efbd43aef36..0341715eb71 100644 --- a/.gitignore +++ b/.gitignore @@ -203,5 +203,9 @@ kernel_meta/ # benchmark results generated by run-performance-benchmarks.sh /benchmarks/results/ +# _cann_ops_custom generated by build_aclnn.sh +/vllm_ascend/_cann_ops_custom/* +!/vllm_ascend/_cann_ops_custom/.gitkeep + # generated by CANN fusion_result.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2f42d5b159..4440bb5b75d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: args: [ --toml, pyproject.toml, '--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml', - '-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND' + '-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND,ND' ] additional_dependencies: - tomli diff --git a/CMakeLists.txt b/CMakeLists.txt index f0136bc48e0..8868c5f59ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,9 +22,9 @@ find_package(Torch REQUIRED) run_python(TORCH_VERSION "import torch; print(torch.__version__)" "Failed to locate torch path") -# check torch version is 2.7.1 -if(NOT ${TORCH_VERSION} VERSION_EQUAL "2.7.1") - message(FATAL_ERROR "Expected PyTorch version 2.7.1, but found ${TORCH_VERSION}") +# check torch version is 2.8.0 +if(NOT ${TORCH_VERSION} VERSION_EQUAL "2.8.0") + message(FATAL_ERROR "Expected PyTorch version 2.8.0, but found ${TORCH_VERSION}") endif() set(RUN_MODE "npu" CACHE STRING "cpu/sim/npu") @@ -55,16 +55,36 @@ include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/*.cpp) -ascendc_library(vllm_ascend_kernels SHARED +set(VLLM_ASCEND_CUSTOM_OP ${KERNEL_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp +) + +set(VLLM_ASCEND_CUSTOM_OP_EXCLUDE + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp +) + +if(SOC_VERSION STREQUAL "ASCEND310P3") + list(REMOVE_ITEM VLLM_ASCEND_CUSTOM_OP ${VLLM_ASCEND_CUSTOM_OP_EXCLUDE}) +endif() + +ascendc_library(vllm_ascend_kernels SHARED + ${VLLM_ASCEND_CUSTOM_OP} ) message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}") -file(GLOB VLLM_ASCEND_SRC -${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp -${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp) +if(SOC_VERSION STREQUAL "ASCEND310P3") + file(GLOB VLLM_ASCEND_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp) +else() + file(GLOB VLLM_ASCEND_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp) +endif() include_directories( ${pybind11_INCLUDE_DIRS} @@ -74,6 +94,7 @@ include_directories( ${ASCEND_HOME_PATH}/include ${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform ${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_host ) set( diff --git a/Dockerfile b/Dockerfile index cc5605ee0bf..2ac67a4b8f0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,7 +20,7 @@ FROM quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 ARG MOONCAKE_TAG="v0.3.7.post2" -ARG SOC_VERSION +ARG SOC_VERSION="ascend910b1" # Define environments ENV DEBIAN_FRONTEND=noninteractive diff --git a/Dockerfile.310p b/Dockerfile.310p index 9d2032631c2..8063c8b1695 100644 --- a/Dockerfile.310p +++ b/Dockerfile.310p @@ -19,7 +19,7 @@ FROM quay.io/ascend/cann:8.3.rc2-310p-ubuntu22.04-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 -ARG SOC_VERSION +ARG SOC_VERSION="ascend310p1" # Define environments ENV DEBIAN_FRONTEND=noninteractive diff --git a/Dockerfile.310p.openEuler b/Dockerfile.310p.openEuler index 659a56c6f7c..866ae19f3cf 100644 --- a/Dockerfile.310p.openEuler +++ b/Dockerfile.310p.openEuler @@ -19,7 +19,7 @@ FROM quay.io/ascend/cann:8.3.rc2-310p-openeuler24.03-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 -ARG SOC_VERSION +ARG SOC_VERSION="ascend310p1" ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} ENV SOC_VERSION=$SOC_VERSION diff --git a/Dockerfile.a3 b/Dockerfile.a3 index de6f1a5aefa..dbd839940aa 100644 --- a/Dockerfile.a3 +++ b/Dockerfile.a3 @@ -20,7 +20,7 @@ FROM quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 ARG MOONCAKE_TAG=v0.3.7.post2 -ARG SOC_VERSION +ARG SOC_VERSION="ascend910_9391" COPY . /vllm-workspace/vllm-ascend/ # Define environments diff --git a/Dockerfile.a3.openEuler b/Dockerfile.a3.openEuler index 7761f341f91..d287dc4d9bb 100644 --- a/Dockerfile.a3.openEuler +++ b/Dockerfile.a3.openEuler @@ -20,7 +20,7 @@ FROM quay.io/ascend/cann:8.3.rc2-a3-openeuler24.03-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 ARG MOONCAKE_TAG="v0.3.7.post2" -ARG SOC_VERSION +ARG SOC_VERSION="ascend910_9391" ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} ENV SOC_VERSION=$SOC_VERSION diff --git a/Dockerfile.openEuler b/Dockerfile.openEuler index 9666dee487a..c1bd0362533 100644 --- a/Dockerfile.openEuler +++ b/Dockerfile.openEuler @@ -20,7 +20,7 @@ FROM quay.io/ascend/cann:8.3.rc2-910b-openeuler24.03-py3.11 ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" ARG COMPILE_CUSTOM_KERNELS=1 ARG MOONCAKE_TAG="v0.3.7.post2" -ARG SOC_VERSION +ARG SOC_VERSION="ascend910b1" ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} ENV SOC_VERSION=$SOC_VERSION diff --git a/README.md b/README.md index 0c3c27b135d..31adb9a01ea 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ By using vLLM Ascend plugin, popular open-source models, including Transformer-l - Software: * Python >= 3.10, < 3.12 * CANN >= 8.3.rc1 (Ascend HDK version refers to [here](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/releasenote/releasenote_0000.html)) - * PyTorch == 2.7.1, torch-npu == 2.7.1 + * PyTorch == 2.8.0, torch-npu == 2.8.0 * vLLM (the same version as vllm-ascend) ## Getting Started diff --git a/README.zh.md b/README.zh.md index 516c23a9afc..58d669bd9e2 100644 --- a/README.zh.md +++ b/README.zh.md @@ -44,7 +44,7 @@ vLLM 昇腾插件 (`vllm-ascend`) 是一个由社区维护的让vLLM在Ascend NP - 软件: * Python >= 3.10, < 3.12 * CANN >= 8.3.rc1 (Ascend HDK 版本参考[这里](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/releasenote/releasenote_0000.html)) - * PyTorch == 2.7.1, torch-npu == 2.7.1 + * PyTorch == 2.8.0, torch-npu == 2.8.0 * vLLM (与vllm-ascend版本一致) ## 开始使用 diff --git a/csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h b/csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h new file mode 100644 index 00000000000..597545872c3 --- /dev/null +++ b/csrc/batch_matmul_transpose/op_host/batch_matmul_transpose.h @@ -0,0 +1,123 @@ +#include +#include +#include "acl/acl.h" +#include "kernel_tiling/kernel_tiling.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/tiling_data.h" +#include "common_tiling.h" + + +namespace bmm_trans { +using namespace pp_matmul; + +std::unordered_map quantModeMap = { + {"per_channel_symm", 0}, + {"per_channel_asymm", 1}, + {"per_token_symm", 2}, +}; + +std::unordered_map formatModeMap = { + {"ND", 0}, + {"NZ", 1}, +}; + +std::unordered_map atType2tensorDType = { + {at::ScalarType::BFloat16, TensorDType::TENSOR_DTYPE_BF16}, + {at::ScalarType::Half, TensorDType::TENSOR_DTYPE_FLOAT16}}; + +// batch size -> memory index +constexpr uint32_t MAX_CAPTURE_NUM = 1024; + +template +inline int GetModeVal(const MapType &mode_map, c10::optional mode_opt, c10::string_view default_mode, + const char *mode_name) +{ + std::string modeStr(mode_name); + c10::string_view mode_str = mode_opt.value_or(default_mode); + auto it = mode_map.find(mode_str); + // if input mode is unsupported, use default value + TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str)); + return it->second; +} + +std::tuple batch_matmul_transpose_tiling(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, + c10::optional format_mode, + c10::optional quant_mode) +{ + auto tensorAShape = tensor_a.sizes(); + auto tensorBShape = tensor_b.sizes(); + auto tensorCShape = tensor_c.sizes(); + uint32_t n; + uint32_t block_dim; + + //auto &platform = PlatformInfo::Instance(); + HardwareInfo hwInfo; + std::map dTypeMap = {{at::ScalarType::Half, 2.0}, {at::ScalarType::BFloat16, 2.0}}; + + at::ScalarType aType = tensor_a.scalar_type(); + at::ScalarType bType = tensor_b.scalar_type(); + at::ScalarType cType = tensor_c.scalar_type(); + TORCH_CHECK(aType == bType && bType == cType, "tensor type is not the same"); + TORCH_CHECK((aType == at::ScalarType::BFloat16) || (aType == at::ScalarType::Half), + "tensor type only support half or bf16"); + + TensorFormat formatMode = static_cast(GetModeVal(formatModeMap, format_mode, "ND", "format_mode")); + MatMul::QuantMode quantMode = + static_cast(GetModeVal(quantModeMap, quant_mode, "per_channel_symm", "quant_mode")); + + TORCH_CHECK(tensorAShape.size() == 3, "batch size is not same between srcTensor and dstTensor"); + if (formatMode == TensorFormat::TENSOR_FORMAT_ND) { + TORCH_CHECK(tensorBShape.size() == 3, "tensor shape should be dim3 in ND format"); + TORCH_CHECK(tensorAShape[2] == tensorBShape[1], "tensor shape is wrong"); + n = tensorBShape[2]; + } else { + TORCH_CHECK(tensorBShape.size() == 4, "tensor shape should be dim4 in nz format"); + TORCH_CHECK(tensorAShape[2] == tensorBShape[2], "tensor shape is wrong"); + n = tensorBShape[1] * tensorBShape[3]; + } + TORCH_CHECK(tensorAShape[1] == tensorBShape[0], "tensor shape is wrong"); + + OpShape opShape = {.batchSize = static_cast(tensorAShape[1]), + .m = static_cast(tensorAShape[0]), + .k = static_cast(tensorAShape[2]), + .n = n}; + pp_matmul::PpMatmulTilingData matmulTilingData = { + .opShape = opShape, + }; + auto dType = atType2tensorDType[aType]; + MatMulInfo mmInfo = {.batchSize = opShape.batchSize, + .m = opShape.m, + .k = opShape.k, + .n = opShape.n, + .dtypeA = dType, + .dtypeB = dType, + .dtypeC = dType, + .formatB = formatMode, + .mmType = MatMul::MatMulType::MATMUL_EIN_SUM, + .inDtype = dTypeMap[aType], + .outDtype = dTypeMap[cType], + .quantMode = quantMode}; + GetPpMatmulTiling(mmInfo, hwInfo, block_dim, matmulTilingData); + host_utils::PpMatmulTilingCheck(matmulTilingData); + + // tiling + int32_t batchIdx = opShape.m - 1; + uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData); + static auto global_tiling_data = at::empty( + {tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device())); + if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) { + aclrtMemcpy(global_tiling_data.data_ptr() + (tilingSize * batchIdx), tilingSize, &matmulTilingData, + tilingSize, ACL_MEMCPY_HOST_TO_DEVICE); + } else { + // Handle the case where batchIdx is out of range + TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx); + } + at::Tensor tiling_tensor = + at::from_blob(global_tiling_data.data_ptr() + (tilingSize * batchIdx), tilingSize, at::kByte); + + return std::make_tuple(tiling_tensor, block_dim); + +} + +} + diff --git a/csrc/batch_matmul_transpose/op_host/common.h b/csrc/batch_matmul_transpose/op_host/common.h new file mode 100644 index 00000000000..82abd10e955 --- /dev/null +++ b/csrc/batch_matmul_transpose/op_host/common.h @@ -0,0 +1,57 @@ + +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_COMMON_H +#define UTILS_COMMON_H + +namespace host_utils { + +constexpr uint32_t BLK_SIZE_ALIN_FOR_INT64 = 4; +constexpr uint32_t BLK_SIZE_ALIN_FOR_INT32 = 8; + +inline uint64_t alinInt64Count(uint64_t count) +{ + return (count + BLK_SIZE_ALIN_FOR_INT64 - 1) / BLK_SIZE_ALIN_FOR_INT64 * BLK_SIZE_ALIN_FOR_INT64; +} + +inline uint64_t alinInt32Count(uint64_t count) +{ + return (count + BLK_SIZE_ALIN_FOR_INT32 - 1) / BLK_SIZE_ALIN_FOR_INT32 * BLK_SIZE_ALIN_FOR_INT32; +} + +template +inline T CeilDiv(const T dividend, const T divisor) +{ + if (divisor == 0) { + return UINT32_MAX; + } + return (dividend + divisor - 1) / divisor; +} + +template +inline T RoundUp(const T val, const T align = 16) +{ + if (align == 0 || val + align - 1 < val) { + return 0; + } + return (val + align - 1) / align * align; +} + +template +inline T RoundDown(const T val, const T align = 16) +{ + if (align == 0) { + return 0; + } + return val / align * align; +} +} // namespace host_utils +#endif // UTILS_COMMON_H diff --git a/csrc/batch_matmul_transpose/op_host/common_tiling.h b/csrc/batch_matmul_transpose/op_host/common_tiling.h new file mode 100644 index 00000000000..4fac5c5bfa5 --- /dev/null +++ b/csrc/batch_matmul_transpose/op_host/common_tiling.h @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COMMMON_TILING_H +#define COMMMON_TILING_H + +#include +#include +#include "common.h" +#include "tiling/platform/platform_ascendc.h" + +namespace host_utils { + +constexpr uint32_t FP16_SIZE = 2; +constexpr uint32_t FP32_SIZE = 4; +constexpr uint32_t BLOCK_SIZE = 16; +constexpr uint32_t BLOCK_SIZE_INT8_K = 32; +constexpr uint32_t BASE_BLOCK_STEP = 2; +constexpr uint32_t AXES_ALIGN_SIZE = 512; +constexpr uint32_t AXES_ALIGN_SIZE_INT8 = 256; +constexpr uint32_t ND_SHAPE_SIZE = 2; +constexpr uint32_t NZ_SHAPE_SIZE = 4; +constexpr uint32_t CUBE_BLOCK_SIZE = 256; +constexpr uint32_t CUBE_BLOCK_SIZE_INT8 = 512; +constexpr uint32_t L1AB_PINGPONG_BUFFER_LEN = 262144; +constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN_INT8 = 131072 * 2; // 256 KB +constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN_FP16 = 131072; // 128 KB +constexpr uint32_t L1AB_PINGPONG_BUFFER_LEN_INT8_SPARSE = 160 * 1024; +constexpr uint32_t UB_LIMIT_SIZE_910A = 128 * 1024; + +enum class PlatformType { ASCEND_310P, ASCEND_910A, ASCEND_910B, ASCEND_910C, PLATFORM_INVALID }; + +struct PlatformInfo { +public: + static const PlatformInfo &Instance() + { + static PlatformInfo platformInfo; + return platformInfo; + } + + PlatformType socType; + uint32_t coreNum; + uint32_t coreNumAic; + uint32_t coreNumAiv; + uint64_t ubSize; + uint64_t l1Size; + uint64_t l2Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; + +private: + PlatformInfo() + { + auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance(); + // TODO Hard coding set to 910_93xx, parse using aclrtGetSocName is better + socType = PlatformType::ASCEND_910C; + coreNum = ascendcPlatform->GetCoreNum(); + coreNumAic = ascendcPlatform->GetCoreNumAic(); + coreNumAiv = ascendcPlatform->GetCoreNumAiv(); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L1, l1Size); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L2, l2Size); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, l0aSize); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, l0bSize); + ascendcPlatform->GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, l0cSize); + } + + PlatformInfo(const PlatformInfo &) = delete; + PlatformInfo &operator=(const PlatformInfo &) = delete; + PlatformInfo(PlatformInfo &&) = delete; + PlatformInfo &operator=(PlatformInfo &&) = delete; +}; + +inline __attribute__((always_inline)) uint32_t GetN0TilingLimit(bool compressFlag, uint32_t tilingN, + const PlatformType &platformType) +{ + if (compressFlag) { + return std::min(tilingN * BLOCK_SIZE, AXES_ALIGN_SIZE_INT8); + } else { + return (platformType == PlatformType::ASCEND_310P || platformType == PlatformType::ASCEND_910A) + ? AXES_ALIGN_SIZE + : AXES_ALIGN_SIZE_INT8; + } +} + +template +inline __attribute__((always_inline)) uint32_t GetN0TilingInit(const OpShareType &opShape, bool compressFlag, + uint32_t tilingN) +{ + const uint32_t rnd = 16; + return compressFlag + ? ((tilingN * BLOCK_SIZE > opShape.n) ? RoundUp(opShape.n, rnd) : tilingN * BLOCK_SIZE) + : BLOCK_SIZE; +} + +template +inline __attribute__((always_inline)) bool IsExceedTilingLimit(uint32_t axes0, uint32_t priAxes0, + uint32_t n0TilingLimit, PlatformType platformType, + uint32_t basicBlockSize) +{ + return (PRI_FLAG && axes0 > n0TilingLimit) || (!PRI_FLAG && priAxes0 > n0TilingLimit) || + (platformType == PlatformType::ASCEND_910A && basicBlockSize > UB_LIMIT_SIZE_910A); +} + +template +inline __attribute__((always_inline)) void SetOpShapeAxesInfo(OpShareType &opShape, uint32_t priAxes0, uint32_t axes0) +{ + opShape.m0 = PRI_FLAG ? priAxes0 : axes0; + opShape.n0 = PRI_FLAG ? axes0 : priAxes0; +} + +template +inline __attribute__((always_inline)) float CostFunc(const HardwareType &hwInfor, OpShapeType &shape) +{ + float aCoef = 1; + float bCoef = 1; + float bwCoef = static_cast(hwInfor.l2BandWidth) / static_cast(hwInfor.hbmBandWidth); + uint32_t mLoop = CeilDiv(shape.m, shape.m0); + uint32_t nLoop = CeilDiv(shape.n, shape.n0); + if (mLoop == 0 || nLoop == 0) { + return 1; + } + uint32_t coreNeed = shape.batchSize * mLoop * nLoop; + uint32_t blockDim = std::min(coreNeed, hwInfor.coreNum); + uint32_t mOnce = blockDim < nLoop ? shape.m0 : blockDim / nLoop * shape.m0; + uint32_t nOnce = blockDim < nLoop ? hwInfor.coreNum * shape.n0 : shape.n; + if (mOnce * shape.k * FP16_SIZE > hwInfor.l2Size) { + aCoef = bwCoef; + } + if (nOnce * shape.k * FP16_SIZE > hwInfor.l2Size) { + bCoef = bwCoef; + } + return 1 / (aCoef * static_cast(shape.n0)) + 1 / (bCoef * static_cast(shape.m0)); +} + +template +void TilingFunc(OpShareType &opShape, TilingType &tilingParam, const HardwareType &hwInfor, + const MatMulInfoType &mmInfo, bool compressFlag = false, const uint32_t tilingN = 1) +{ + float costMin = 1; + const float CONST_2 = 2.0; + const uint32_t ROUND_CONST_16 = 16; + uint32_t roundBase = static_cast( + pow(2, ceil(log(CeilDiv(PRI_FLAG ? opShape.n : opShape.m, ROUND_CONST_16)))) * ROUND_CONST_16); + uint32_t priAxes = RoundUp(PRI_FLAG ? opShape.m : opShape.n, ROUND_CONST_16); + uint32_t axes = RoundUp(PRI_FLAG ? opShape.n : opShape.m, roundBase); + float axes0Max = static_cast(AXES_ALIGN_SIZE) / mmInfo.inDtype; + auto platformType = PlatformInfo::Instance().socType; + if (mmInfo.isInt8 && (platformType == PlatformType::ASCEND_310P || platformType == PlatformType::ASCEND_910A)) { + axes0Max /= CONST_2; + } + + uint32_t n0TilingInit = GetN0TilingInit(opShape, compressFlag, tilingN); + uint32_t n0TilingLimit = GetN0TilingLimit(compressFlag, tilingN, platformType); + uint32_t priAxes0Init = PRI_FLAG ? BLOCK_SIZE : n0TilingInit; + uint32_t axes0Init = PRI_FLAG ? n0TilingInit : BLOCK_SIZE; + for (uint32_t priAxes0 = priAxes0Init; priAxes0 <= priAxes && priAxes0 <= axes0Max; priAxes0 *= BASE_BLOCK_STEP) { + for (uint32_t axes0 = axes0Init; axes0 <= axes && axes0 <= axes0Max; axes0 *= BASE_BLOCK_STEP) { + uint32_t basicBlockSize = priAxes0 * axes0 * FP32_SIZE; + if (basicBlockSize > hwInfor.l0cSize) { + continue; + } + if (mmInfo.isInt8 && + IsExceedTilingLimit(axes0, priAxes0, n0TilingLimit, platformType, basicBlockSize)) { + continue; + } + SetOpShapeAxesInfo(opShape, priAxes0, axes0); + float cost = CostFunc(hwInfor, opShape); + if (cost >= costMin) { + continue; + } + costMin = cost; + if constexpr (std::is_same::value) { + tilingParam.SetBaseOp(hwInfor.coreNum, opShape.m0, opShape.n0, mmInfo); + } else { + tilingParam.SetBaseOp(hwInfor.coreNum, opShape.m0, opShape.n0); + } + } + } +} + +template +uint32_t Swizzl(PpTilingDataType &tilingData) +{ + uint32_t swizzlDirect = 0; + uint32_t swizzlCount = 1; + float m0 = tilingData.opShape.m0; + float n0 = tilingData.opShape.n0; + float m = tilingData.opShape.m; + float k = tilingData.opShape.k; + float n = tilingData.opShape.n; + float mincost = m * k + k * n; + + for (uint32_t i = 1; i <= tilingData.blockDim; ++i) { + int c = static_cast((tilingData.blockDim + i - 1) / i); + float cost; + // B0 + A < A0 + B + if (i * n0 + m < m0 * c + n) { + swizzlDirect = 1; // Nz + cost = n0 * i + m0 * c; + if (cost <= mincost) { + mincost = cost; + swizzlCount = i; + } + } else { + swizzlDirect = 0; // Zn + cost = m0 * i + n0 * c; + if (cost < mincost) { + mincost = cost; + swizzlCount = i; + } + } + } + tilingData.swizzlDirect = swizzlDirect; + tilingData.swizzlCount = swizzlCount; + return swizzlDirect; +} + +template +inline __attribute__((always_inline)) void PpMatmulTilingCheck(const PpTilingDataType &tilingData) +{ + TORCH_CHECK(tilingData.opShape.m0 > 0, "m0 is invalid"); + TORCH_CHECK(tilingData.opShape.k0 > 0, "k0 is invalid"); + TORCH_CHECK(tilingData.opShape.n0 > 0, "n0 is invalid"); + TORCH_CHECK(tilingData.mLoop > 0, "mLoop is invalid"); + TORCH_CHECK(tilingData.kLoop > 0, "kLoop is invalid"); + TORCH_CHECK(tilingData.nLoop > 0, "nLoop is invalid"); + TORCH_CHECK(tilingData.blockDim > 0, "nLoop is invalid"); +} +} // namespace host_utils +#endif diff --git a/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp b/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp new file mode 100644 index 00000000000..ac8e047caec --- /dev/null +++ b/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp @@ -0,0 +1,155 @@ +#include +#include "tiling_data.h" +#include "common.h" +#include "common_tiling.h" + +namespace pp_matmul { + +constexpr uint32_t L1_DESCALE_BUFFER_LEN_MAX = 6144; +constexpr uint32_t CONST_3 = 3; +constexpr uint32_t CONST_4 = 4; +constexpr uint32_t CONST_16 = 16; +constexpr uint32_t CONST_32 = 32; +constexpr uint32_t CONST_256 = 256; +constexpr uint32_t CONST_512 = 512; + +const std::map G_DTYPE_MAP = {{TensorDType::TENSOR_DTYPE_FLOAT16, 1u}, + {TensorDType::TENSOR_DTYPE_BF16, 2u}}; +const std::map G_FORMAT_MAP = {{TensorFormat::TENSOR_FORMAT_ND, 0u}, + {TensorFormat::TENSOR_FORMAT_NZ, 1u}}; +using MmType = MatMul::MatMulType; +using QmType = MatMul::QuantMode; +using namespace host_utils; + +bool IsI8Bf16Kernel(const MatMulInfo &mmInfo) +{ + bool isI8Bf16 = mmInfo.isInt8 && mmInfo.dtypeC == TensorDType::TENSOR_DTYPE_BF16; + bool isI8Fp16 = mmInfo.isInt8 && mmInfo.dtypeC == TensorDType::TENSOR_DTYPE_FLOAT16 && + mmInfo.quantMode == QmType::PER_TOKEN_SYMM; + return isI8Bf16 || isI8Fp16; +} + +HardwareInfo::HardwareInfo() +{ + auto &platform = PlatformInfo::Instance(); + coreNum = platform.coreNumAic; + l2Size = platform.l2Size; + l1Size = platform.l1Size; + l0aSize = platform.l0aSize; + l0bSize = platform.l0bSize; + l0cSize = platform.l0cSize; + hbmBandWidth = 1; + l2BandWidth = 5; // 5x faster than hbm. +} + +void PpMatmulTilingData::SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n) +{ + opShape.batchSize = batchSize; + opShape.m = m; + opShape.k = k; + opShape.n = n; +} + +void PpMatmulTilingData::SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo) +{ + opShape.m0 = mBase; + opShape.n0 = nBase; + mLoop = CeilDiv(opShape.m, opShape.m0); + nLoop = CeilDiv(opShape.n, opShape.n0); + coreLoop = opShape.batchSize * mLoop * nLoop; + + if (mLoop == 1 && mmInfo.transB && coreLoop % coreNum < coreNum / CONST_4 * CONST_3) { + mBase = RoundUp(opShape.m, CONST_16); + opShape.m0 = mBase; + uint32_t maxN0 = PlatformInfo::Instance().l0cSize / (mBase * sizeof(float)); + if (mmInfo.isInt8 || mmInfo.mmType == MmType::MATMUL_WITH_BIAS) { + maxN0 = maxN0 < CONST_256 ? maxN0 : CONST_256; + } + uint32_t x = CeilDiv(opShape.n, coreNum); + uint32_t y = CeilDiv(x, maxN0); + nBase = RoundUp(CeilDiv(x, y), CONST_16); + uint32_t rqdL0CSize = mBase * nBase * sizeof(float); + if (rqdL0CSize < PlatformInfo::Instance().l0cSize && + (mBase + nBase) * CONST_256 * sizeof(uint16_t) < L1AB_PINGPONG_BUFFER_LEN) { + opShape.n0 = nBase; + nLoop = CeilDiv(opShape.n, opShape.n0); + coreLoop = opShape.batchSize * nLoop; + } + } + blockDim = std::min(coreLoop, coreNum); +} + +// transA transB quantMode [dtype] format +void PpMatmulTilingData::SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK) +{ + if (mmInfo.mmType == MmType::MATMUL_ACCUM_ATOMIC || mmInfo.mmType == MmType::MATMUL_WITH_BIAS || + mmInfo.mmType == MmType::MATMUL_EIN_SUM || mmInfo.mmType == MmType::MATMUL_DEQUANT || IsI8Bf16Kernel(mmInfo)) { + // SwizzleDir[1] TransA[1] TransB[1] DtypeA[3] DtypeB[3] DtypeC[3] FormatA[1] FormatB[1] FormatC[1] WithBias[1] + tilingKey = swizzleDirect; + tilingKey = (tilingKey << 1) + static_cast(mmInfo.transA); + tilingKey = (tilingKey << 1) + static_cast(mmInfo.transB); + tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeA); // 3bit for dtypeA. + tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeB); // 3bit for dtypeB. + tilingKey = (tilingKey << 3) + G_DTYPE_MAP.at(mmInfo.dtypeC); // 3bit for dtypeC. + tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatA); + tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatB); + tilingKey = (tilingKey << 1) + G_FORMAT_MAP.at(mmInfo.formatC); + tilingKey = (tilingKey << 1) + static_cast(mmInfo.biasFlag); + } else { + tilingKey = swizzleDirect; + tilingKey = (tilingKey << 1) + static_cast(mmInfo.transA); + tilingKey = (tilingKey << 1) + static_cast(mmInfo.transB); + tilingKey = (tilingKey << 1) + static_cast(mmInfo.isInt8); + tilingKey = (tilingKey << 1) + static_cast(mmInfo.biasFlag); + tilingKey = (tilingKey << 1) + enSplitK; + } +} + +uint32_t PpMatmulTilingData::End(const MatMulInfo &mmInfo) +{ + uint32_t cubeBlockSize = mmInfo.isInt8 ? CUBE_BLOCK_SIZE_INT8 : CUBE_BLOCK_SIZE; + uint32_t kBlockSize = mmInfo.isInt8 ? BLOCK_SIZE_INT8_K : BLOCK_SIZE; + uint32_t scaleBlockSize = mmInfo.isInt8 ? L1_DESCALE_BUFFER_LEN_MAX : 0; + uint32_t shapeSum = opShape.m0 + opShape.n0; + if (mmInfo.isInt8 && (mmInfo.transA || !mmInfo.transB)) { + shapeSum = RoundUp(opShape.m0, CONST_32) + RoundUp(opShape.n0, CONST_32); + } + uint32_t k0Max = shapeSum == 0 + ? L1AB_PINGPONG_BUFFER_LEN + : static_cast(static_cast(L1AB_PINGPONG_BUFFER_LEN - scaleBlockSize) / + (shapeSum * mmInfo.inDtype)); + if (mmInfo.mmType == MatMul::MatMulType::MATMUL_WITH_BIAS) { + uint32_t l1AbSize = L1AB_PINGPONG_BUFFER_LEN - opShape.n0 * sizeof(float); + k0Max = l1AbSize / (shapeSum * mmInfo.inDtype); + } + + opShape.k0 = + k0Max < cubeBlockSize ? RoundDown(k0Max, kBlockSize) : RoundDown(k0Max, cubeBlockSize); + if (opShape.k0 > CONST_512) { + opShape.k0 = RoundDown(opShape.k0, CONST_512); + } + kLoop = CeilDiv(opShape.k, opShape.k0); + return blockDim; +} + +void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim, + PpMatmulTilingData &tilingData) +{ + OpShape opShape; + opShape.batchSize = mmInfo.batchSize; + opShape.m = mmInfo.m; + opShape.n = mmInfo.n; + opShape.k = mmInfo.k; + tilingData.opShape = opShape; + tilingData.quantMode = static_cast(mmInfo.quantMode); + tilingData.SetTilingKey(mmInfo, 0, 0); // init tilingkey with transA transB. + if (opShape.m < opShape.n) { + TilingFunc(opShape, tilingData, hwInfo, mmInfo); + } else { + TilingFunc(opShape, tilingData, hwInfo, mmInfo); + } + uint32_t direct = Swizzl(tilingData); + blockDim = tilingData.End(mmInfo); + tilingData.SetTilingKey(mmInfo, direct, 0); +} +} // namespace pp_matmul diff --git a/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h b/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h new file mode 100644 index 00000000000..6713091e1d4 --- /dev/null +++ b/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.h @@ -0,0 +1,90 @@ +#ifndef PP_MATMUL_TILING_DATA +#define PP_MATMUL_TILING_DATA +#include + +namespace pp_matmul { +struct MatMul { + enum class MatMulType : uint32_t { + MATMUL_DEFAULT = 0, // C = op(A) * op(B) + MATMUL_DEQUANT, // + MATMUL_ACCUM_ATOMIC, // C += op(A) * op(B) + MATMUL_WITH_BIAS, // C = op(A) * op(B) + Bias, where Bias is a vector. + MATMUL_EIN_SUM + }; + enum class QuantMode : uint32_t { PER_CHANNEL_SYMM = 0, PER_CHANNEL_ASYMM, PER_TOKEN_SYMM }; +}; + +enum class TensorDType : uint32_t { TENSOR_DTYPE_FLOAT16 = 0, TENSOR_DTYPE_BF16 }; + +enum class TensorFormat : uint32_t { TENSOR_FORMAT_ND = 0, TENSOR_FORMAT_NZ }; + +struct MatMulInfo { + uint32_t batchSize{0}; + uint32_t m{0}; // actual input m + uint32_t k{0}; // actual input k + uint32_t n{0}; // actual input n + TensorDType dtypeA{TensorDType::TENSOR_DTYPE_FLOAT16}; + TensorDType dtypeB{TensorDType::TENSOR_DTYPE_FLOAT16}; + TensorDType dtypeC{TensorDType::TENSOR_DTYPE_FLOAT16}; + TensorFormat formatA{TensorFormat::TENSOR_FORMAT_ND}; + TensorFormat formatB{TensorFormat::TENSOR_FORMAT_ND}; + TensorFormat formatC{TensorFormat::TENSOR_FORMAT_ND}; + MatMul::MatMulType mmType{MatMul::MatMulType::MATMUL_DEFAULT}; + bool transA{0}; // false: 0, true: 1 + bool transB{0}; // false: 0, true: 1 + bool biasFlag{0}; // false: 0, true: 1 + bool isInt8{0}; // false: 0, true: 1 + float inDtype{0}; + float outDtype{0}; + MatMul::QuantMode quantMode{MatMul::QuantMode::PER_CHANNEL_SYMM}; +}; + +struct OpShape { + uint32_t batchSize{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; +}; + +struct HardwareInfo { + uint32_t coreNum{0}; + uint32_t l2Size{0}; + uint32_t l1Size{0}; + uint32_t l0aSize{0}; + uint32_t l0bSize{0}; + uint32_t l0cSize{0}; + uint32_t hbmBandWidth{0}; + uint32_t l2BandWidth{0}; + + HardwareInfo(); +}; + +#pragma pack(push, 1) +struct PpMatmulTilingData { + OpShape opShape{}; + uint32_t mLoop{1}; + uint32_t kLoop{1}; + uint32_t nLoop{1}; + uint32_t coreLoop{1}; + uint32_t swizzlCount{1}; + uint32_t tilingKey{0}; + uint32_t blockDim{1}; + uint32_t swizzlDirect{0}; + uint32_t splitk{0}; + uint32_t enShuffleK{0}; + uint32_t quantMode{0}; + + void SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n); + void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo); + void SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK); + uint32_t End(const MatMulInfo &mmInfo); +}; +#pragma pack(pop) + +void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim, + PpMatmulTilingData &tilingData); +} // namespace pp_matmul +#endif diff --git a/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp b/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp new file mode 100644 index 00000000000..81d987bae62 --- /dev/null +++ b/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp @@ -0,0 +1,824 @@ +// Adapted from +// https://gitee.com/ascend/ascend-transformer-boost +// +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +// This file is a part of the CANN Open Software. +// Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// + +#define __aicore__ [aicore] +#include "kernel_operator.h" +#include "../op_host/tiling/tiling_data.h" +#include "../../mla_preprocess/op_kernel/kernel/common.h" +#include "../../mla_preprocess/op_kernel/kernel/hardware.h" +#include "../../mla_preprocess/op_kernel/kernel/mma.h" +#include "../../mla_preprocess/op_kernel/kernel/utils.h" +#include "../../mla_preprocess/op_kernel/kernel/iterator.h" +#include "../../kernels/math_utils.h" + +constexpr uint32_t L0_PINGPONG_BUFFER_LEN = 16384; +constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 131072; +constexpr uint32_t CONST_16 = 16; +constexpr uint32_t CONST_256 = 256; +constexpr uint64_t ND2NZ_STRIDE_LIMIT = 65536; +constexpr uint64_t BLOCK_SIZE_16 = 16; +constexpr uint64_t CONST_16UL = 16; +constexpr uint64_t CONST_256UL = 256; + +struct MatCoord { + uint64_t m{0}; + uint64_t k{0}; + uint64_t n{0}; +}; + +using namespace device_utils; + +template +class PpMatmulEinSum +{ + using LocalTensor = AscendC::LocalTensor; + template + using CopyGmToCbuf = gm_to_l1; + using LoadCbufToCa = l1_to_l0_a; + using LoadCbufToCb = l1_to_l0_b; + using Mad = mmad; + using CopyCcToGm = l0c_to_gm; + +public: + __aicore__ explicit PpMatmulEinSum(){}; + + __aicore__ __force_inline__ void Init(__gm__ uint8_t *__restrict__ a, __gm__ uint8_t *__restrict__ b, + __gm__ uint8_t *__restrict__ c, __gm__ uint8_t *__restrict__ tiling_data) + { + gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(a)); + gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(b)); + gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(c)); + auto gm_tiling_data = reinterpret_cast<__gm__ pp_matmul::PpMatmulTilingData *>(tiling_data); + + batch_size = gm_tiling_data->opShape.batchSize; + m = gm_tiling_data->opShape.m; + k = gm_tiling_data->opShape.k; + n = gm_tiling_data->opShape.n; + m0 = gm_tiling_data->opShape.m0; + k0 = gm_tiling_data->opShape.k0; + n0 = gm_tiling_data->opShape.n0; + tdim.m = gm_tiling_data->mLoop; + tdim.k = gm_tiling_data->kLoop; + tdim.n = gm_tiling_data->nLoop; + core_loop = gm_tiling_data->coreLoop; + swizzle_cnt = gm_tiling_data->swizzlCount; + en_shuffle_k = gm_tiling_data->enShuffleK; + + AsdopsBuffer buf; + l1_base_a = buf.template GetBuffer(0); + l1_base_b = buf.template GetBuffer( + RoundUp(m0 * k0 * sizeof(InDtype), CONST_256UL)); + l0a_base = buf.template GetBuffer(0); + l0b_base = buf.template GetBuffer(0); + num_core = AscendC::GetBlockNum(); + core_idx = AscendC::GetBlockIdx(); + ping_flag = 1; + } + + __aicore__ __force_inline__ void GetBlockIdx(uint64_t index, MatCoord &tidx) + { + uint64_t in_batch_idx = index % (tdim.m * tdim.n); + if constexpr (SwizzleDirect == 0) { // Zn + uint64_t tile_block_loop = (tdim.m + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.n); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.n); + + uint64_t n_row = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_row = tdim.m - swizzle_cnt * tile_block_idx; + } + tidx.m = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row; + tidx.n = in_tile_block_idx / n_row; + if (tile_block_idx % 2 != 0) { + tidx.n = tdim.n - tidx.n - 1; + } + } else if constexpr (SwizzleDirect == 1) { // Nz + uint64_t tile_block_loop = (tdim.n + swizzle_cnt - 1) / swizzle_cnt; + uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.m); + uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.m); + + uint64_t n_col = swizzle_cnt; + if (tile_block_idx == tile_block_loop - 1) { + n_col = tdim.n - swizzle_cnt * tile_block_idx; + } + tidx.m = in_tile_block_idx / n_col; + tidx.n = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col; + if (tile_block_idx % 2 != 0) { + tidx.m = tdim.m - tidx.m - 1; + } + } + } + + __aicore__ __force_inline__ void Process() + { + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + + for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += num_core) { + uint64_t batch_idx = loop_idx / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBlockIdx(loop_idx, tidx); + uint64_t offset_a = 0, offset_b = 0, offset_a_next = 0, offset_b_next = 0; + uint64_t offset_c = tidx.m * m0 * batch_size * n + batch_idx * n + tidx.n * n0; + uint64_t m_actual = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round = RoundUp(m_actual); + uint64_t n_round = RoundUp(n_actual); + uint64_t mn_max = m_round > n_round ? m_round : n_round; + uint64_t k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / CONST_16 * CONST_16; + uint64_t shuffle_k = en_shuffle_k ? (core_idx % tdim.k) : 0; + if (TA) { + offset_a = shuffle_k * k0 * m * batch_size + batch_idx * m + tidx.m * m0; + } else { + offset_a = tidx.m * m0 * batch_size * k + batch_idx * k + shuffle_k * k0; + } + + if (TB) { + if constexpr (FormatB != DataFormat::NZ) { + offset_b = batch_idx * k * n + tidx.n * n0 * k + shuffle_k * k0; + } else { + offset_b = batch_idx * RoundUp(k) * RoundUp(n) + + tidx.n * n0 * BLOCK_SIZE_16 + shuffle_k * k0 * RoundUp(n); + } + } else { + if constexpr (FormatB != DataFormat::NZ) { + offset_b = batch_idx * k * n + shuffle_k * k0 * n + tidx.n * n0; + } else { + offset_b = batch_idx * RoundUp(k) * RoundUp(n) + + shuffle_k * k0 * BLOCK_SIZE_16 + tidx.n * n0 * RoundUp(k); + } + } + + uint64_t k_actual = (shuffle_k == tdim.k - 1) ? k - shuffle_k * k0 : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l0a_buf = ping_flag ? l0a_base : l0a_base[L0_PINGPONG_BUFFER_LEN]; + LocalTensor l0b_buf = ping_flag ? l0b_base : l0b_base[L0_PINGPONG_BUFFER_LEN]; + event_t event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (loop_idx == core_idx) { + WAIT_FLAG(MTE1, MTE2, event_id); + // *** load matrix A to L1 + if ((m == 1) || (m_actual == 1 && !TA)) { + CopyGmToCbuf(l1_buf_a, // dst + gm_a[offset_a], // src + 1, // nTileActual + 16, // nTileCeil + 1, // nVal + k_actual, // kTileActual + k_round, // kTileCeil + k); // dVal + } else { + if (TA) { + CopyGmToCbuf(l1_buf_a, // dst + gm_a[offset_a], // src + k_actual, // nTileActual + k_round, // nTileCeil + k, // nVal + m_actual, // dTileActual + m_round, // dTileCeil + m * batch_size); // dVal + } else { + CopyGmToCbuf(l1_buf_a, // dst + gm_a[offset_a], // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k * batch_size); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id); + // *** load matrix B to L1 + wait_flag(PIPE_MTE1, PIPE_MTE2, event_id + 2); + if constexpr (FormatB != DataFormat::NZ) { + if (TB) { + CopyGmToCbuf(l1_buf_b, // dst + gm_b[offset_b], // src + n_actual, // nTileActual + n_round, // nTileCeil + n, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(l1_buf_b, // dst + gm_b[offset_b], // src + k_actual, // nTileActual + k_round, // nTileCeil + k, // nVal + n_actual, // dTileActual + n_round, // dTileCeil + n); // dVal + } + } else { + if (TB) { + CopyGmToCbuf(l1_buf_b, // dst + gm_b[offset_b], // src + n_actual, // nTileActual + n_round, // nTileCeil + RoundUp(n), // nVal + k_actual, // dTileActual + k_round, // dTileCeil + RoundUp(k)); // dVal + } else { + CopyGmToCbuf(l1_buf_b, // dst + gm_b[offset_b], // src + k_actual, // nTileActual + k_round, // nTileCeil + RoundUp(k), // nVal + n_actual, // dTileActual + n_round, // dTileCeil + RoundUp(n)); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id + 2); + } + + for (tidx.k = 0; tidx.k < tdim.k; ++tidx.k) { + shuffle_k = en_shuffle_k ? (tidx.k + core_idx) % tdim.k : tidx.k; + uint64_t k_actual = (shuffle_k == (tdim.k - 1)) ? (k - shuffle_k * k0) : k0; + uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; + fdim.k = (k_actual + k_part_len - 1) / k_part_len; + + LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (tidx.k < tdim.k - 1) { + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + tidx.k + 1) % tdim.k : (tidx.k + 1); + if (TA) { + offset_a_next = shuffle_k_next * k0 * m * batch_size + batch_idx * m + tidx.m * m0; + } else { + offset_a_next = tidx.m * m0 * batch_size * k + batch_idx * k + shuffle_k_next * k0; + } + + if (TB) { + if constexpr (FormatB != DataFormat::NZ) { + offset_b_next = batch_idx * k * n + tidx.n * n0 * k + shuffle_k_next * k0; + } else { + offset_b_next = + batch_idx * RoundUp(k) * RoundUp(n) + + tidx.n * n0 * BLOCK_SIZE_16 + shuffle_k_next * k0 * RoundUp(n); + } + } else { + if constexpr (FormatB != DataFormat::NZ) { + offset_b_next = batch_idx * k * n + shuffle_k_next * k0 * n + tidx.n * n0; + } else { + offset_b_next = + batch_idx * RoundUp(k) * RoundUp(n) + + shuffle_k_next * k0 * BLOCK_SIZE_16 + tidx.n * n0 * RoundUp(k); + } + } + + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + WAIT_FLAG(MTE1, MTE2, event_id_next); + // *** load matrix A to L1 + if ((m == 1) || (m_actual == 1 && !TA)) { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual_next, // kTileActual + k_round_next, // kTileCeil + k); // dVal + } else { + if (TA) { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + k, // nVal + m_actual, // dTileActual + m_round, // dTileCeil + m * batch_size); // dVal + } else { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + k * batch_size); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id_next); + + // *** load matrix B to L1 + wait_flag(PIPE_MTE1, PIPE_MTE2, event_id_next + 2); + if constexpr (FormatB != DataFormat::NZ) { + if (TB) { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + n_actual, // nTileActual + n_round, // nTileCeil + n, // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + k, // nVal + n_actual, // dTileActual + n_round, // dTileCeil + n); // dVal + } + } else { + if (TB) { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + n_actual, // nTileActual + n_round, // nTileCeil + RoundUp(n), // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + RoundUp(k)); // dVal + } else { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + RoundUp(k), // nVal + n_actual, // dTileActual + n_round, // dTileCeil + RoundUp(n)); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + if (tidx.k == tdim.k - 1 && loop_idx + num_core < core_loop) { + uint64_t b_idx_next = (loop_idx + num_core) / tdim.n / tdim.m; + MatCoord tidx{0}; + GetBlockIdx(loop_idx + num_core, tidx); + uint64_t shuffle_k_next = en_shuffle_k ? (core_idx % tdim.k) : 0; + uint64_t m_actual_next = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0; + uint64_t n_actual_next = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0; + uint64_t m_round_next = (m_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t n_round_next = (n_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0; + uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16; + if (TA) { + offset_a_next = shuffle_k_next * k0 * m * batch_size + b_idx_next * m + tidx.m * m0; + } else { + offset_a_next = tidx.m * m0 * batch_size * k + b_idx_next * k + shuffle_k_next * k0; + } + + if (TB) { + if constexpr (FormatB != DataFormat::NZ) { + offset_b_next = b_idx_next * k * n + tidx.n * n0 * k + shuffle_k_next * k0; + } else { + offset_b_next = + b_idx_next * RoundUp(k) * RoundUp(n) + + tidx.n * n0 * BLOCK_SIZE_16 + shuffle_k_next * k0 * RoundUp(n); + } + } else { + if constexpr (FormatB != DataFormat::NZ) { + offset_b_next = b_idx_next * k * n + shuffle_k_next * k0 * n + tidx.n * n0; + } else { + offset_b_next = + b_idx_next * RoundUp(k) * RoundUp(n) + + shuffle_k_next * k0 * BLOCK_SIZE_16 + tidx.n * n0 * RoundUp(k); + } + } + + LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]; + LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN]; + event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + WAIT_FLAG(MTE1, MTE2, event_id_next); + // *** load matrix A to L1 + if (m == 1 || m_actual_next == 1 && !TA) { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + m_actual_next, // nTileActual + m_round_next, // nTileCeil + m, // nVal + k_actual_next, // kTileActual + k_round_next, // kTileCeil + k); // dVal + } else { + if (TA) { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + k, // nVal + m_actual_next, // dTileActual + m_round_next, // dTileCeil + m * batch_size); // dVal + } else { + CopyGmToCbuf(l1_buf_a_next, // dst + gm_a[offset_a_next], // src + m_actual_next, // nTileActual + m_round_next, // nTileCeil + m, // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + k * batch_size); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id_next); + + // *** load matrix B to L1 + wait_flag(PIPE_MTE1, PIPE_MTE2, event_id_next + 2); + if constexpr (FormatB != DataFormat::NZ) { + if (TB) { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + n_actual_next, // nTileActual + n_round_next, // nTileCeil + n, // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + k); // dVal + } else { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + k, // nVal + n_actual_next, // dTileActual + n_round_next, // dTileCeil + n); // dVal + } + } else { + if (TB) { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + n_actual_next, // nTileActual + n_round_next, // nTileCeil + RoundUp(n), // nVal + k_actual_next, // dTileActual + k_round_next, // dTileCeil + RoundUp(k)); // dVal + } else { + CopyGmToCbuf(l1_buf_b_next, // dst + gm_b[offset_b_next], // src + k_actual_next, // nTileActual + k_round_next, // nTileCeil + RoundUp(k), // nVal + n_actual_next, // dTileActual + n_round_next, // dTileCeil + RoundUp(n)); // dVal + } + } + SET_FLAG(MTE2, MTE1, event_id_next + 2); + } + + MatCoord fidx{0}; + for (fidx.k = 0; fidx.k < fdim.k; ++fidx.k) { + uint32_t k0_round = (fidx.k < fdim.k - 1) ? k_part_len : k_round - fidx.k * k_part_len; + uint32_t k0_actual = (fidx.k < fdim.k - 1) ? k_part_len : k_actual - fidx.k * k_part_len; + + auto mte1_mad_ping_flag = 1 - fidx.k % 2; + auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; + auto l0a_buf = l0a_base[(fidx.k % 2) * L0_PINGPONG_BUFFER_LEN]; + auto l0b_buf = l0b_base[(fidx.k % 2) * L0_PINGPONG_BUFFER_LEN]; + + // *** load matrix A from L1 to L0A + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id); + } + WAIT_FLAG(M, MTE1, mte1_mad_event_id); + if ((m == 1) || (m_actual == 1 && !TA)) { + l1_to_l0_a( + l0a_buf, // dst + l1_buf_a[fidx.k * k_part_len], // src + 0, // mTileCeil + CeilDiv(k0_round), // kPartCeil + 0, // mSrcStride + 1, // kSrcStride + 0, // mDstStride + 0); // kDstStride + } else { + if (TA) { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[fidx.k * k_part_len * CONST_16], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + k_round / CONST_16, // mSrcStride + 1, // kSrcStride + k0_round / CONST_16, // mDstStride + 1); // kDstStride + } else { + LoadCbufToCa(l0a_buf, // l0Tensor + l1_buf_a[fidx.k * k_part_len * m_round], // l1Tensor + m_round, // mTileCeil + k0_round, // kPartCeil + 1, // mSrcStride + m_round / CONST_16, // kSrcStride + k0_round / CONST_16, // mDstStride + 1); // kDstStride + } + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id); + } + + // *** load matrix B from L1 to L0B + if (fidx.k == 0) { + WAIT_FLAG(MTE2, MTE1, event_id + 2); + } + if (TB) { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * n_round], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + 1, // nSrcStride + n_round / CONST_16, // kSrcStride + 1, // nDstStride + k0_round / CONST_16); // kDstStride + } else { + LoadCbufToCb(l0b_buf, // l0Tensor + l1_buf_b[fidx.k * k_part_len * CONST_16], // l1Tensor + n_round, // nTileCeil + k0_round, // kPartCeil + k_round / CONST_16, // nSrcStride + 1, // kSrcStride + 1, // nDstStride + n_round / CONST_16); // kDstStride + } + if (fidx.k == fdim.k - 1) { + SET_FLAG(MTE1, MTE2, event_id + 2); + } + + SET_FLAG(MTE1, M, mte1_mad_event_id); + WAIT_FLAG(MTE1, M, mte1_mad_event_id); + + bool init_c = (tidx.k == 0 && fidx.k == 0); + if (init_c) { + WAIT_FLAG(FIX, M, EVENT_ID0); + } + + if (m != 1 && m_actual == 1 && TA) { + Mad(l0c_buf, // c + l0a_buf, // a + l0b_buf, // b + CONST_16, // mTileActual + n_actual, // nTileActual + k0_actual, // kTileActual + init_c); // initC + } else { + Mad(l0c_buf, // c + l0a_buf, // a + l0b_buf, // b + m_actual, // mTileActual + n_actual, // nTileActual + k0_actual, // kTileActual + init_c); // initC + } + + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, mte1_mad_event_id); + } + + ping_flag = 1 - ping_flag; + } + + SET_FLAG(M, FIX, EVENT_ID0); + WAIT_FLAG(M, FIX, EVENT_ID0); + + // copy from L0C to gm + CopyCcToGm(gm_c[offset_c], // dst + l0c_buf, // src + m_actual, // mTileActual + n_actual, // nTileActual + m_round, // mTileCeil + n * batch_size); // nActual + SET_FLAG(FIX, M, EVENT_ID0); + } + + WAIT_FLAG(M, MTE1, EVENT_ID0); + WAIT_FLAG(M, MTE1, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID0); + WAIT_FLAG(MTE1, MTE2, EVENT_ID1); + WAIT_FLAG(MTE1, MTE2, EVENT_ID2); + WAIT_FLAG(MTE1, MTE2, EVENT_ID3); + WAIT_FLAG(FIX, M, EVENT_ID0); + PIPE_BARRIER(ALL); + } + +private: + AscendC::GlobalTensor gm_a; + AscendC::GlobalTensor gm_b; + AscendC::GlobalTensor gm_c; + AscendC::LocalTensor l1_base_a; + AscendC::LocalTensor l1_base_b; + AscendC::LocalTensor l0a_base; + AscendC::LocalTensor l0b_base; + AscendC::LocalTensor l0c_buf; + + uint32_t num_core{0}; + uint32_t batch_size{0}; + uint32_t m{0}; + uint32_t k{0}; + uint32_t n{0}; + uint32_t m0{0}; + uint32_t k0{0}; + uint32_t n0{0}; + MatCoord tdim{0}; + MatCoord fdim{0}; + uint32_t core_loop{0}; + uint32_t swizzle_cnt{1}; + uint32_t core_idx{0}; + uint32_t en_shuffle_k{0}; + uint32_t ping_flag{0}; +}; + +extern "C" __global__ __aicore__ void batch_matmul_transpose(GM_ADDR gm_a, GM_ADDR gm_b, GM_ADDR gm_c, + GM_ADDR gm_tiling_data) +{ + PpMatmulEinSum<0, false, false, half, half, DataFormat::ND> + einsum_0_n_fp16_nd; // swizzleDir[0] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<1, false, false, half, half, DataFormat::ND> + einsum_1_n_fp16_nd; // swizzleDir[1] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<0, false, true, half, half, DataFormat::ND> + einsum_0_t_fp16_nd; // swizzleDir[0] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<1, false, true, half, half, DataFormat::ND> + einsum_1_t_fp16_nd; // swizzleDir[1] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<0, false, false, __bf16, __bf16, DataFormat::ND> + einsum_0_n_bf16_nd; // swizzleDir[0] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<1, false, false, __bf16, __bf16, DataFormat::ND> + einsum_1_n_bf16_nd; // swizzleDir[1] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<0, false, true, __bf16, __bf16, DataFormat::ND> + einsum_0_t_bf16_nd; // swizzleDir[0] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[0] + PpMatmulEinSum<1, false, true, __bf16, __bf16, DataFormat::ND> + einsum_1_t_bf16_nd; // swizzleDir[1] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[0] + + PpMatmulEinSum<0, false, false, half, half, DataFormat::NZ> + einsum_0_n_fp16_nz; // swizzleDir[0] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<1, false, false, half, half, DataFormat::NZ> + einsum_1_n_fp16_nz; // swizzleDir[1] transA[0] transB[0] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<0, false, true, half, half, DataFormat::NZ> + einsum_0_t_fp16_nz; // swizzleDir[0] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<1, false, true, half, half, DataFormat::NZ> + einsum_1_t_fp16_nz; // swizzleDir[1] transA[0] transB[1] DtypeA[001] DtypeB[001] DtypeC[001] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<0, false, false, __bf16, __bf16, DataFormat::NZ> + einsum_0_n_bf16_nz; // swizzleDir[0] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<1, false, false, __bf16, __bf16, DataFormat::NZ> + einsum_1_n_bf16_nz; // swizzleDir[1] transA[0] transB[0] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<0, false, true, __bf16, __bf16, DataFormat::NZ> + einsum_0_t_bf16_nz; // swizzleDir[0] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[1] + PpMatmulEinSum<1, false, true, __bf16, __bf16, DataFormat::NZ> + einsum_1_t_bf16_nz; // swizzleDir[1] transA[0] transB[1] DtypeA[010] DtypeB[010] DtypeC[010] DataFormatA[0] + // DataFormatB[1] + + SetPadding((uint64_t)0); + SetNdpara(1, 0, 0); + SetAtomicnone(); + + // get tiling args + auto tiling_data = reinterpret_cast<__gm__ pp_matmul::PpMatmulTilingData *>(gm_tiling_data); + uint32_t masked_key = tiling_data->tilingKey >> 2; + + switch (masked_key) { + case 0b00000100100100: + case 0b01000100100100: + einsum_0_n_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_n_fp16_nd.Process(); + break; + case 0b00100100100100: + case 0b01100100100100: + einsum_0_t_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_t_fp16_nd.Process(); + break; + case 0b10000100100100: + case 0b11000100100100: + einsum_1_n_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_n_fp16_nd.Process(); + break; + case 0b10100100100100: + case 0b11100100100100: + einsum_1_t_fp16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_t_fp16_nd.Process(); + break; + case 0b00001001001000: + case 0b01001001001000: + einsum_0_n_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_n_bf16_nd.Process(); + break; + case 0b00101001001000: + case 0b01101001001000: + einsum_0_t_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_t_bf16_nd.Process(); + break; + case 0b10001001001000: + case 0b11001001001000: + einsum_1_n_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_n_bf16_nd.Process(); + break; + case 0b10101001001000: + case 0b11101001001000: + einsum_1_t_bf16_nd.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_t_bf16_nd.Process(); + break; + + case 0b00000100100101: + case 0b01000100100101: + einsum_0_n_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_n_fp16_nz.Process(); + break; + case 0b00100100100101: + case 0b01100100100101: + einsum_0_t_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_t_fp16_nz.Process(); + break; + case 0b10000100100101: + case 0b11000100100101: + einsum_1_n_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_n_fp16_nz.Process(); + break; + case 0b10100100100101: + case 0b11100100100101: + einsum_1_t_fp16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_t_fp16_nz.Process(); + break; + case 0b00001001001001: + case 0b01001001001001: + einsum_0_n_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_n_bf16_nz.Process(); + break; + case 0b00101001001001: + case 0b01101001001001: + einsum_0_t_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_0_t_bf16_nz.Process(); + break; + case 0b10001001001001: + case 0b11001001001001: + einsum_1_n_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_n_bf16_nz.Process(); + break; + case 0b10101001001001: + case 0b11101001001001: + einsum_1_t_bf16_nz.Init(gm_a, gm_b, gm_c, gm_tiling_data); + einsum_1_t_bf16_nz.Process(); + break; + default: + break; + } +} + + +namespace vllm_ascend { + +extern void batch_matmul_transpose_impl( + void* stream, + void* gm_a, + void* gm_b, + void* gm_c, + void* gm_tiling_data, + const uint32_t block_dim) +{ + batch_matmul_transpose<<>>( + gm_a, + gm_b, + gm_c, + gm_tiling_data); +} + +} \ No newline at end of file diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 9dba287e3ae..0fa1a6ae366 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -11,11 +11,11 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then exit 0 elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then # ASCEND910B (A2) series - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention" SOC_ARG="ascend910_93" else # others diff --git a/csrc/kernels/math_utils.h b/csrc/kernels/math_utils.h new file mode 100644 index 00000000000..62b46921c14 --- /dev/null +++ b/csrc/kernels/math_utils.h @@ -0,0 +1,15 @@ +#ifndef KERNEL_MATH_UTILS_H +#define KERNEL_MATH_UTILS_H +#include + +namespace device_utils { + +template +__aicore__ __force_inline__ T RoundUp(const T &val) +{ + return (val + roundVal - 1) / roundVal * roundVal; +} + +}; // namespace device_utils + +#endif diff --git a/csrc/lightning_indexer/op_host/CMakeLists.txt b/csrc/lightning_indexer/op_host/CMakeLists.txt new file mode 100644 index 00000000000..7922ba8e429 --- /dev/null +++ b/csrc/lightning_indexer/op_host/CMakeLists.txt @@ -0,0 +1,42 @@ +# This program is free software, you can redistribute it and/or modify it. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME LightningIndexer + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror + -mllvm -cce-aicore-hoist-movemask=false + --op_relocatable_kernel_binary=true +) + +set(lightning_indexer_depends transformer/attention/lightning_indexer PARENT_SCOPE) + +target_sources(op_host_aclnn PRIVATE + lightning_indexer_def.cpp +) + +target_sources(optiling PRIVATE + lightning_indexer_tiling.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(opmaster_ct PRIVATE + lightning_indexer_tiling.cpp + ) +endif () + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + lightning_indexer_proto.cpp +) + diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_def.cpp b/csrc/lightning_indexer/op_host/lightning_indexer_def.cpp new file mode 100644 index 00000000000..262efe2da94 --- /dev/null +++ b/csrc/lightning_indexer/op_host/lightning_indexer_def.cpp @@ -0,0 +1,72 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_def.cpp + * \brief + */ +#include +#include "register/op_def_registry.h" + +namespace ops { +class LightningIndexer : public OpDef { +public: + explicit LightningIndexer(const char *name) : OpDef(name) + { + this->Input("query") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("key") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("weights") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("actual_seq_lengths_query") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("actual_seq_lengths_key") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("block_table") + .ParamType(OPTIONAL) + .DataTypeList({ge::DT_INT32}) + .FormatList({ge::FORMAT_ND}) + .AutoContiguous(); + this->Output("sparse_indices").ParamType(REQUIRED).DataTypeList({ge::DT_INT32}).FormatList({ge::FORMAT_ND}); + this->Attr("layout_query").AttrType(OPTIONAL).String("BSND"); + this->Attr("layout_key").AttrType(OPTIONAL).String("PA_BSND"); + this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048); // 2048:默认值,筛选前2048 + this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3:默认值,只计算下三角 + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false"); + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + } +}; +OP_ADD(LightningIndexer); +} // namespace ops \ No newline at end of file diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp b/csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp new file mode 100644 index 00000000000..cc1a793e4bf --- /dev/null +++ b/csrc/lightning_indexer/op_host/lightning_indexer_proto.cpp @@ -0,0 +1,96 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_proto.cpp + * \brief + */ +#include +#include +#include "error/ops_error.h" + + +using namespace ge; + +namespace ops { +constexpr uint32_t QUERY_INDEX = 0; +constexpr uint32_t KEY_INDEX = 1; +constexpr uint32_t ACTUAL_SEQ_K_INDEX = 4; +constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 0; +constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 1; +constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 2; + +static ge::graphStatus InferShapeLightningIndexer(gert::InferShapeContext *context) +{ + OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferShapeContext is nullptr!"), + return ge::GRAPH_FAILED); + const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX); + OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED); + const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX); + OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED); + gert::Shape *outShape = context->GetOutputShape(0); + + auto attrs = context->GetAttrs(); + OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED); + const char *inputLayoutQueryPtr = attrs->GetAttrPointer(ATTR_QUERY_LAYOUT_INDEX); + OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED); + const char *inputLayoutKeyPtr = attrs->GetAttrPointer(ATTR_KEY_LAYOUT_INDEX); + OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED); + const int64_t *seleced_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX); + OPS_LOG_E_IF_NULL(context, seleced_count, return ge::GRAPH_FAILED); + std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr); + std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr); + OPS_ERR_IF( + inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND", + OPS_LOG_E(context, "The attr layout_query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str()), + return ge::GRAPH_FAILED); + + outShape->SetDimNum(queryShape->GetDimNum()); + if (inputLayoutQueryPtrStr == "BSND") { + OPS_ERR_IF( + queryShape->GetDimNum() != 4, + OPS_LOG_E(context, "Layout BSND, queryDims (%zu) must be 4!", queryShape->GetDimNum()), + return ge::GRAPH_FAILED); + outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B + outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S + outShape->SetDim(2, keyShape->GetDim(2)); // 2:Dim N + outShape->SetDim(3, *seleced_count); // 3:Dim K + } else { + OPS_ERR_IF( + queryShape->GetDimNum() != 3, + OPS_LOG_E(context, "Layout TND, queryDims (%zu) must be 3!", queryShape->GetDimNum()), + return ge::GRAPH_FAILED); + outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T + int32_t nDimIndex = (inputLayoutKeyPtrStr == "PA_BSND") ? 2 : 1; // 2:Key Dim N + outShape->SetDim(1, keyShape->GetDim(nDimIndex)); // 1:Dim N + outShape->SetDim(2, *seleced_count); // 2:Dim K + } + OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferShape end."); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataTypeLightningIndexer(gert::InferDataTypeContext *context) +{ + OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferDataTypeContext is nullptr!"), + return ge::GRAPH_FAILED); + OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexer InferDataType impl."); + // default set q's dtype as fia's output type + ge::DataType outputType = ge::DT_INT32; + // attention_out, outidx:0 + context->SetOutputDataType(0, outputType); + OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferDataType end."); + return GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(LightningIndexer) + .InferShape(InferShapeLightningIndexer) + .InferDataType(InferDataTypeLightningIndexer); +} // namespace ops diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp b/csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp new file mode 100644 index 00000000000..2a9655b77d3 --- /dev/null +++ b/csrc/lightning_indexer/op_host/lightning_indexer_tiling.cpp @@ -0,0 +1,733 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_tiling.cpp + * \brief + */ + +#include "lightning_indexer_tiling.h" +#include "../op_kernel/lightning_indexer_template_tiling_key.h" + +using namespace ge; +using namespace AscendC; +using std::map; +using std::string; +namespace optiling { +// --------------------------LIInfoParser类成员函数定义------------------------------------- +ge::graphStatus LIInfoParser::CheckRequiredInOutExistence() const +{ + OPS_ERR_IF(opParamInfo_.query.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor query is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.query.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor query is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor k is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor k is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.weights.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor value is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.weights.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor value is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.attenOut.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor output is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.attenOut.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor output is nullptr"), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::CheckRequiredAttrExistence() const +{ + OPS_ERR_IF(opParamInfo_.layOut == nullptr, OPS_LOG_E(opName_, "attr layout_query is nullptr"), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(opParamInfo_.layOutKey == nullptr, OPS_LOG_E(opName_, "attr layout_key is nullptr"), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(opParamInfo_.sparseCount == nullptr, OPS_LOG_E(opName_, "attr sparse_count is nullptr"), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(opParamInfo_.sparseMode == nullptr, OPS_LOG_E(opName_, "attr sparse_mode is nullptr"), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::CheckRequiredParaExistence() const +{ + if (CheckRequiredInOutExistence() != ge::GRAPH_SUCCESS || CheckRequiredAttrExistence() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetOpName() +{ + if (context_->GetNodeName() == nullptr) { + OPS_LOG_E("LightningIndexer", "opName got from TilingContext is nullptr"); + return ge::GRAPH_FAILED; + } + opName_ = context_->GetNodeName(); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetNpuInfo() +{ + platformInfo_ = context_->GetPlatformInfo(); + OPS_ERR_IF(platformInfo_ == nullptr, OPS_LOG_E(opName_, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo_); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); + OPS_ERR_IF(aicNum == 0 || aivNum == 0, OPS_LOG_E(opName_, "num of core obtained is 0."), return GRAPH_FAILED); + + socVersion_ = ascendcPlatform.GetSocVersion(); + if ((socVersion_ != platform_ascendc::SocVersion::ASCEND910B) && + (socVersion_ != platform_ascendc::SocVersion::ASCEND910_93)) { + OPS_LOG_E(opName_, "SOC Version[%d] is not support.", (int32_t)socVersion_); + return GRAPH_FAILED; + } + OPS_ERR_IF(context_->GetWorkspaceSizes(1) == nullptr, OPS_LOG_E(opName_, "workSpaceSize got from ge is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(context_->GetRawTilingData() == nullptr, + OPS_LOG_E(context_->GetNodeName(), "RawTilingData got from GE context is nullptr."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +void LIInfoParser::GetOptionalInputParaInfo() +{ + opParamInfo_.actualSeqLengthsQ.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_Q_INDEX); + opParamInfo_.actualSeqLengthsQ.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_Q_INDEX); + opParamInfo_.actualSeqLengths.tensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_K_INDEX); + opParamInfo_.actualSeqLengths.desc = context_->GetOptionalInputDesc(ACTUAL_SEQ_K_INDEX); + opParamInfo_.blockTable.tensor = context_->GetOptionalInputTensor(BLOCK_TABLE_INDEX); + opParamInfo_.blockTable.desc = context_->GetOptionalInputDesc(BLOCK_TABLE_INDEX); +} + +void LIInfoParser::GetInputParaInfo() +{ + opParamInfo_.query.desc = context_->GetInputDesc(QUERY_INDEX); + opParamInfo_.query.shape = context_->GetInputShape(QUERY_INDEX); + opParamInfo_.key.desc = context_->GetInputDesc(KEY_INDEX); + opParamInfo_.key.shape = context_->GetInputShape(KEY_INDEX); + opParamInfo_.weights.desc = context_->GetInputDesc(WEIGTHS_INDEX); + opParamInfo_.weights.shape = context_->GetInputShape(WEIGTHS_INDEX); + GetOptionalInputParaInfo(); +} + +void LIInfoParser::GetOutputParaInfo() +{ + opParamInfo_.attenOut.desc = context_->GetOutputDesc(LIGHTNING_INDEXER); + opParamInfo_.attenOut.shape = context_->GetOutputShape(LIGHTNING_INDEXER); +} + +ge::graphStatus LIInfoParser::GetAndCheckAttrParaInfo() +{ + auto attrs = context_->GetAttrs(); + OPS_ERR_IF(attrs == nullptr, OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "attrs got from ge is nullptr"), + return ge::GRAPH_FAILED); + + OPS_LOG_I(context_->GetNodeName(), "GetAndCheckAttrParaInfo start"); + opParamInfo_.layOut = attrs->GetStr(ATTR_QUERY_LAYOUT_INDEX); + opParamInfo_.layOutKey = attrs->GetStr(ATTR_KEY_LAYOUT_INDEX); + opParamInfo_.sparseCount = attrs->GetAttrPointer(ATTR_SPARSE_COUNT_INDEX); + opParamInfo_.sparseMode = attrs->GetAttrPointer(ATTR_SPARSE_MODE_INDEX); + + if (opParamInfo_.layOut != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "layout_query is:%s", opParamInfo_.layOut); + } + if (opParamInfo_.layOutKey != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "layout_key is:%s", opParamInfo_.layOutKey); + } + if (opParamInfo_.sparseCount != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "selscted count is:%d", *opParamInfo_.sparseCount); + } + if (opParamInfo_.sparseMode != nullptr) { + OPS_LOG_I(context_->GetNodeName(), "sparse mode is:%d", *opParamInfo_.sparseMode); + } + OPS_LOG_I(context_->GetNodeName(), "GetAndCheckAttrParaInfo end"); + + OPS_ERR_IF( + ((std::string(opParamInfo_.layOutKey) != "PA_BSND") + && (std::string(opParamInfo_.layOut) != std::string(opParamInfo_.layOutKey))), + OPS_LOG_E(opName_, "under non-PA conditions, layout_query and layout_key should be equal."), + return ge::GRAPH_FAILED); + OPS_ERR_IF( + ((std::string(opParamInfo_.layOutKey) != "PA_BSND") && (std::string(opParamInfo_.layOutKey) != "BSND") + && (std::string(opParamInfo_.layOutKey) != "TND")), + OPS_LOG_E(opName_, "input attr layout_key only supported PA_BSND, BSND or TND"), return ge::GRAPH_FAILED); + OPS_ERR_IF(((std::string(opParamInfo_.layOut) != "BSND") && (std::string(opParamInfo_.layOut) != "TND")), + OPS_LOG_E(opName_, "input attr layout_query only supported BSND or TND."), return ge::GRAPH_FAILED); + OPS_ERR_IF(!((*opParamInfo_.sparseCount > 0) && (*opParamInfo_.sparseCount <= SPARSE_LIMIT)), + OPS_LOG_E(opName_, "input attr sparse_count must > 0 and <= 2048."), return ge::GRAPH_FAILED); + OPS_ERR_IF(!((*opParamInfo_.sparseMode == 0) || (*opParamInfo_.sparseMode == SPARSE_MODE_LOWER)), + OPS_LOG_E(opName_, "input attr sparse_mode only supported 0 or 3."), return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetOpParaInfo() +{ + GetInputParaInfo(); + GetOutputParaInfo(); + if (ge::GRAPH_SUCCESS != GetAndCheckAttrParaInfo()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetAndCheckInOutDataType() +{ + inputQType_ = opParamInfo_.query.desc->GetDataType(); + inputKType_ = opParamInfo_.key.desc->GetDataType(); + weightsType_ = opParamInfo_.weights.desc->GetDataType(); + outputType_ = opParamInfo_.attenOut.desc->GetDataType(); + + bool inDTypeAllEqual = (inputQType_ == inputKType_) && (inputKType_ == weightsType_); + OPS_ERR_IF(!inDTypeAllEqual, + OPS_LOG_E(opName_, "The data types of the input query, key, and weights must be the same."), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(((inputQType_ != ge::DT_FLOAT16) && (inputQType_ != ge::DT_BF16)), + OPS_LOG_E(opName_, "The data types of the input query, key, and weights must be float16 or bfloat16."), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(outputType_ != ge::DT_INT32, + OPS_LOG_E(opName_, "The data types of the output sparse_indices must be int32."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetQueryKeyAndOutLayout() +{ + // 获取query,key的Layout基准值 + const map layoutMap = { + {"BSND", DataLayout::BSND}, + {"TND", DataLayout::TND}, + {"PA_BSND", DataLayout::BnBsND} + }; + + std::string layout(opParamInfo_.layOut); + auto it = layoutMap.find(layout); + if (it != layoutMap.end()) { + qLayout_ = it->second; + } + + std::string layoutKey(opParamInfo_.layOutKey); + auto itKey = layoutMap.find(layoutKey); + if (itKey != layoutMap.end()) { + kLayout_ = itKey->second; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetAndCheckOptionalInput() +{ + if (kLayout_ == DataLayout::BnBsND) { + OPS_ERR_IF(opParamInfo_.blockTable.tensor == nullptr, + OPS_LOG_E(opName_, "key layout only supported PA_BSND, input block_table must not be null"), + return ge::GRAPH_FAILED); + OPS_ERR_IF( + opParamInfo_.actualSeqLengths.tensor == nullptr, + OPS_LOG_E(opName_, "key layout only supported PA_BSND, input actual_seq_lengths_key must not be null"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.blockTable.desc->GetDataType() != ge::DT_INT32, + OPS_LOG_E(opName_, "input block_table data type only support int32"), return ge::GRAPH_FAILED); + } else if (kLayout_ == DataLayout::TND) { + OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor == nullptr, + OPS_LOG_E(opName_, "when layout_key is TND, input actual_seq_lengths_key must not be null"), + return ge::GRAPH_FAILED); + } + OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor != nullptr && + opParamInfo_.actualSeqLengths.desc->GetDataType() != ge::DT_INT32, + OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.actualSeqLengths.tensor != nullptr && + opParamInfo_.actualSeqLengths.desc->GetDataType() != ge::DT_INT32, + OPS_LOG_E(opName_, "input actual_seq_lengths_key data type only support int32"), + return ge::GRAPH_FAILED); + if (qLayout_ == DataLayout::TND) { + OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor == nullptr, + OPS_LOG_E(opName_, "when layout_query is TND, input actual_seq_lengths_query must not be null"), + return ge::GRAPH_FAILED); + } + OPS_ERR_IF(opParamInfo_.actualSeqLengthsQ.tensor != nullptr && + opParamInfo_.actualSeqLengthsQ.desc->GetDataType() != ge::DT_INT32, + OPS_LOG_E(opName_, "input actual_seq_lengths_query data type only support int32"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(kLayout_ != DataLayout::BnBsND && opParamInfo_.blockTable.tensor != nullptr, + OPS_LOG_E(opName_, "when key layout is not PA_BSND, input block_table must be null"), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::CheckShapeDim() +{ + OPS_ERR_IF((opParamInfo_.blockTable.tensor != nullptr) && + (opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum() != DIM_NUM_TWO), + OPS_LOG_E(opName_, "the dim num of block_table's shape should be 2"), return ge::GRAPH_FAILED); + + uint32_t kShapeDim = opParamInfo_.key.shape->GetStorageShape().GetDimNum(); + uint32_t qShapeDim = opParamInfo_.query.shape->GetStorageShape().GetDimNum(); + uint32_t weightsShapeDim = opParamInfo_.weights.shape->GetStorageShape().GetDimNum(); + uint32_t outShapeDim = opParamInfo_.attenOut.shape->GetStorageShape().GetDimNum(); + uint32_t qExpectShapeDim = DIM_NUM_FOUR; + uint32_t kExpectShapeDim = DIM_NUM_FOUR; + if (qLayout_ == DataLayout::TND) { + qExpectShapeDim = DIM_NUM_THREE; + } + if (kLayout_ == DataLayout::TND) { + kExpectShapeDim = DIM_NUM_THREE; + } + OPS_ERR_IF(kShapeDim != kExpectShapeDim, + OPS_LOG_E(opName_, "the dim num of key's shape should be %u, but now is %u", kExpectShapeDim, kShapeDim), + return ge::GRAPH_FAILED); + OPS_ERR_IF(qShapeDim != qExpectShapeDim, + OPS_LOG_E(opName_, "the dim num of query's shape should be %u, but now is %u", + qExpectShapeDim, qShapeDim), + return ge::GRAPH_FAILED); + OPS_ERR_IF(outShapeDim != qExpectShapeDim, + OPS_LOG_E(opName_, "the dim num of sparse_indices's shape should be %u, but now is %u", + qExpectShapeDim, outShapeDim), + return ge::GRAPH_FAILED); + OPS_ERR_IF(!(weightsShapeDim == qExpectShapeDim - 1), + OPS_LOG_E(opName_, "the dim num of weights's shape should be %u, but now is %u", qExpectShapeDim - 1, + weightsShapeDim), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetN1Size() +{ + if (qLayout_ == DataLayout::BSND) { + n1Size_ = static_cast(opParamInfo_.query.shape->GetStorageShape().GetDim(DIM_IDX_TWO)); + } else { + // TND + n1Size_ = static_cast(opParamInfo_.query.shape->GetStorageShape().GetDim(1)); + } + OPS_LOG_I(context_->GetNodeName(), "n1Size is %d", n1Size_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + const std::string &actualSeqLenName) +{ + size = static_cast(tensor->GetShapeSize()); + if (size <= 0) { + OPS_LOG_E(opName_, "%s's shape size is %u, it should be greater than 0.", actualSeqLenName.c_str(), size); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetAndCheckN2Size() +{ + uint32_t n2Index = (kLayout_ == DataLayout::TND) ? DIM_IDX_ONE : DIM_IDX_TWO; + n2Size_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(n2Index)); + OPS_LOG_I(context_->GetNodeName(), "n2Size_ is %d", n2Size_); + OPS_ERR_IF(n2Size_ != 1, OPS_LOG_E(opName_, "key shape[%u] is numhead, only support 1.", n2Index), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetGSize() +{ + if (n1Size_ % n2Size_ != 0) { + OPS_LOG_E(opName_, "input query's head_num %u can not be a multiple of key's head_num %u.", n1Size_, n2Size_); + return ge::GRAPH_FAILED; + } + gSize_ = n1Size_ / n2Size_; + OPS_ERR_IF(gSize_ != 64, OPS_LOG_E(opName_, "N1 is %u, N2 is %u, N1 divided by N2 must equal 64.", + n1Size_, n2Size_), return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetBatchSize() +{ + // 获取B基准值 + // 1、非TND/NTD时, 以query的batch_size维度为基准; + // 2、TND/NTD时, actual_seq_lens_q必须传入, 以actual_seq_lens_q数组的长度为B轴大小 + if ((qLayout_ == DataLayout::TND)) { + return GetActualSeqLenSize(bSize_, opParamInfo_.actualSeqLengthsQ.tensor, "input actual_seq_lengths_query"); + } else { // BSND + bSize_ = opParamInfo_.query.shape->GetStorageShape().GetDim(0); + return ge::GRAPH_SUCCESS; + } +} + +ge::graphStatus LIInfoParser::GetHeadDim() +{ + // 以query的D维度为基准 + uint32_t dIndex = DIM_IDX_TWO; + // 根据layout确定D维度在shape中的位置 + switch (qLayout_) { + case DataLayout::TND: + // TND格式: [Total, N, D] -> D是第2维(索引2) + dIndex = DIM_IDX_TWO; + break; + case DataLayout::BSND: + // BSND格式: [Batch, SeqLen, N, D] -> D是第3维(索引3) + dIndex = DIM_IDX_THREE; + break; + default: + OPS_LOG_E(opName_, "unsupported layout for getting head dim."); + return ge::GRAPH_FAILED; + } + headDim_ = opParamInfo_.query.shape->GetStorageShape().GetDim(dIndex); + OPS_ERR_IF(headDim_ != HEAD_DIM_LIMIT, OPS_LOG_E(opName_, "input query's last dim head_dim only support 128."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetS1Size() +{ + if (qLayout_ == DataLayout::BSND) { + s1Size_ = opParamInfo_.query.shape->GetStorageShape().GetDim(1); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetAndCheckBlockSize() +{ + blockSize_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(1)); + OPS_LOG_I(context_->GetNodeName(), "blockSize_ is %d", blockSize_); + + OPS_ERR_IF(((blockSize_ % 16 != 0) || (blockSize_ == 0) || (blockSize_ > 1024)), + OPS_LOG_E(opName_, "input key's block_size must be a multiple of 16 and belong to (0, 1024]."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::CheckBlockCount() +{ + int32_t blockCount_ = static_cast(opParamInfo_.key.shape->GetStorageShape().GetDim(0)); + OPS_ERR_IF((blockCount_ == 0), + OPS_LOG_E(opName_, "input key's block_count cannot be 0."), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetS2SizeForPageAttention() +{ + if (GetAndCheckBlockSize() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + if (CheckBlockCount() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + maxBlockNumPerBatch_ = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1); + s2Size_ = maxBlockNumPerBatch_ * blockSize_; + OPS_LOG_I(context_->GetNodeName(), "maxBlockNumPerBatch_ is %d, blockSize_ is %d, s2Size_ is %d", + maxBlockNumPerBatch_, blockSize_, s2Size_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::GetS2Size() +{ + // 获取S2基准值 + // 1、BATCH_CONTINUOUS时, 从key的S轴获取 + // 3、PAGE_ATTENTION时, S2 = block_table.dim1 * block_size + if (kLayout_ == DataLayout::BnBsND) { + return GetS2SizeForPageAttention(); + } else if (kLayout_ == DataLayout::TND) { + s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(0); + } else if (kLayout_ == DataLayout::BSND) { + s2Size_ = opParamInfo_.key.shape->GetStorageShape().GetDim(1); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::ValidateInputShapesMatchQTnd() +{ + // -----------------------check BatchSize------------------- + // bSize_ 来源于act_seq_q + if (kLayout_ == DataLayout::TND) { + OPS_ERR_IF( + (opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_), + OPS_LOG_E(opName_, + "TND case input actual_seq_lengths_query, actual_seq_lengths_key are %u, %ld respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize()), + return ge::GRAPH_FAILED); + } else { // kLayout_ PA_BSND + OPS_ERR_IF( + (opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_) || + (opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_), + OPS_LOG_E( + opName_, + "TND case input actual_seq_lengths_query, actual_seq_lengths_key, block_table dim 0 are %u, %ld, %ld respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize(), + opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + } + // -----------------------check T------------------- + uint32_t qTsize = opParamInfo_.query.shape->GetStorageShape().GetDim(0); + OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != qTsize) || + (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != qTsize), + OPS_LOG_E(opName_, "TND case input query, weights, sparse_indices dim 0 are %u, %ld, %ld respectively, they must be same.", + qTsize, opParamInfo_.weights.shape->GetStorageShape().GetDim(0), + opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::ValidateInputShapesMatchQBsnd() +{ + // -----------------------check BatchSize------------------- + // bSize_ 来源于query + if (kLayout_ == DataLayout::BnBsND) { + OPS_ERR_IF((opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0) != bSize_) || + (opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_), + OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_key, block_table dim 0 are %u, %ld, %ld respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize(), + opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + } else if (kLayout_ == DataLayout::BSND) { + OPS_ERR_IF(opParamInfo_.key.shape->GetStorageShape().GetDim(0) != bSize_, + OPS_LOG_E(opName_, "BSND case input query, key dim 0 are %u, %ld respectively, they must be same.", + bSize_, opParamInfo_.key.shape->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + OPS_ERR_IF((opParamInfo_.actualSeqLengths.tensor != nullptr) && + (opParamInfo_.actualSeqLengths.tensor->GetShapeSize() != bSize_), + OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_key dim 0 are %u, %ld respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengths.tensor->GetShapeSize()), + return ge::GRAPH_FAILED); + } + OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(0) != bSize_) || + (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0) != bSize_), + OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 0 are %u, %ld, %ld respectively, they must be same.", + bSize_, opParamInfo_.weights.shape->GetStorageShape().GetDim(0), + opParamInfo_.attenOut.shape->GetStorageShape().GetDim(0)), + return ge::GRAPH_FAILED); + OPS_ERR_IF((opParamInfo_.actualSeqLengthsQ.tensor != nullptr) && + (opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize() != bSize_), + OPS_LOG_E(opName_, "BSND case input query, actual_seq_lengths_query dim 0 are %u, %ld respectively, they must be same.", + bSize_, opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize()), + return ge::GRAPH_FAILED); + // -----------------------check S1------------------- + OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(1) != s1Size_) || + (opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1) != s1Size_), + OPS_LOG_E(opName_, "BSND case input query, weight, sparse_indices dim 1 are %u, %ld, %ld, they must be same.", + s1Size_, opParamInfo_.weights.shape->GetStorageShape().GetDim(1), + opParamInfo_.attenOut.shape->GetStorageShape().GetDim(1)), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus LIInfoParser::ValidateInputShapesMatch() +{ + /* + TND: + query [T,N1,D], + key [BlockNum,BlockSize,N2,D], + weight [T,N1], + block_table [BatchSize, BatchMaxBlockNum], + act_seq_k [BatchSize] + act_seq_q [BatchSize], + out [T,N2,topk] + ---------------------- + BSND: + query [BatchSize,S1,N1,D], + key [BlockNum,BlockSize,N2,D], + weight [BatchSize,S1,N1], + block_table [BatchSize, BatchMaxBlockNum], + act_seq_k [BatchSize] + act_seq_q [BatchSize] 可选 + out [BatchSize,S1,N2,topk] + */ + uint32_t queryWeightsN1Dim = 1; + uint32_t outN2Dim = 1; + if (qLayout_ == DataLayout::TND) { + if (ValidateInputShapesMatchQTnd() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + } else { + if (ValidateInputShapesMatchQBsnd() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + queryWeightsN1Dim = DIM_IDX_TWO; + outN2Dim = DIM_IDX_TWO; + } + // -----------------------check N1------------------- + OPS_ERR_IF((opParamInfo_.weights.shape->GetStorageShape().GetDim(queryWeightsN1Dim) != n1Size_), + OPS_LOG_E(opName_, "input query, weight shape dim N1 must be same."), return ge::GRAPH_FAILED); + // -----------------------check D------------------- + uint32_t keyDDim = kLayout_ == DataLayout::TND ? DIM_IDX_TWO : DIM_IDX_THREE; + OPS_ERR_IF((opParamInfo_.key.shape->GetStorageShape().GetDim(keyDDim) != headDim_), + OPS_LOG_E(opName_, "input query, key shape last dim must be same."), return ge::GRAPH_FAILED); + // -----------------------check N2------------------- + OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim) != n2Size_), + OPS_LOG_E(opName_, "input query and output sparse_indices shape n2 dim must be same."), + return ge::GRAPH_FAILED); + // -----------------------check sparse_count------------------- + OPS_ERR_IF((opParamInfo_.attenOut.shape->GetStorageShape().GetDim(outN2Dim + 1) != *opParamInfo_.sparseCount), + OPS_LOG_E(opName_, "output sparse_indices shape last dim must be same as attr sparse_count."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +void LIInfoParser::GenerateInfo(LITilingInfo &liInfo) +{ + liInfo.opName = opName_; + liInfo.platformInfo = platformInfo_; + liInfo.opParamInfo = opParamInfo_; + liInfo.socVersion = socVersion_; + + liInfo.bSize = bSize_; + liInfo.n1Size = n1Size_; + liInfo.n2Size = n2Size_; + liInfo.s1Size = s1Size_; + liInfo.s2Size = s2Size_; + liInfo.gSize = gSize_; + + liInfo.inputQType = inputQType_; + liInfo.inputKType = inputKType_; + liInfo.outputType = outputType_; + + liInfo.blockSize = blockSize_; + liInfo.maxBlockNumPerBatch = maxBlockNumPerBatch_; + + std::string layOutKeyStr(opParamInfo_.layOutKey); + liInfo.pageAttentionFlag = layOutKeyStr == "PA_BSND" ? true : false; + liInfo.sparseMode = *opParamInfo_.sparseMode; + liInfo.sparseCount = *opParamInfo_.sparseCount; + + liInfo.inputQLayout = qLayout_; + liInfo.inputKLayout = kLayout_; +} + +ge::graphStatus LIInfoParser::ParseAndCheck(LITilingInfo &liInfo) +{ + if (ge::GRAPH_SUCCESS != GetOpName() || ge::GRAPH_SUCCESS != GetNpuInfo() || ge::GRAPH_SUCCESS != GetOpParaInfo() || + ge::GRAPH_SUCCESS != CheckRequiredParaExistence()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != GetAndCheckInOutDataType() || ge::GRAPH_SUCCESS != GetQueryKeyAndOutLayout() || + ge::GRAPH_SUCCESS != GetAndCheckOptionalInput()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != CheckShapeDim() || ge::GRAPH_SUCCESS != GetN1Size() || + ge::GRAPH_SUCCESS != GetAndCheckN2Size() || ge::GRAPH_SUCCESS != GetGSize()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != GetBatchSize() || ge::GRAPH_SUCCESS != GetS1Size() || ge::GRAPH_SUCCESS != GetHeadDim() || + ge::GRAPH_SUCCESS != GetS2Size()) { + return ge::GRAPH_FAILED; + } + if (ge::GRAPH_SUCCESS != ValidateInputShapesMatch()) { + return ge::GRAPH_FAILED; + } + + GenerateInfo(liInfo); + + return ge::GRAPH_SUCCESS; +} + +// --------------------------TilingPrepare函数定义------------------------------------- +static ge::graphStatus TilingPrepareForLightningIndexer(gert::TilingParseContext * /* context */) +{ + return ge::GRAPH_SUCCESS; +} + +// --------------------------LightningIndexerTiling类成员函数定义----------------------- +ge::graphStatus LightningIndexerTiling::DoTiling(LITilingInfo *tilingInfo) +{ + // -------------set blockdim----------------- + auto ascendcPlatform = platform_ascendc::PlatformAscendC(tilingInfo->platformInfo); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); + uint32_t blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum); + context_->SetBlockDim(blockDim); + + // -------------set workspacesize----------------- + constexpr uint32_t MM1_RES_ELEM_SIZE = 4; // 4: fp32 + constexpr uint32_t DOUBLE_BUFFER = 2; // 双Buffer + constexpr uint32_t M_BASE_SIZE = 512; // m轴基本块大小 + constexpr uint32_t S2_BASE_SIZE = 512; // S2轴基本块大小 + constexpr uint32_t V1_RES_ELEM_SIZE = 4; // 4: int32 + constexpr uint32_t V1_RES_ELEM_TYPE = 2; // 保留Index和Value 2种数据 + constexpr uint32_t V1_DECODE_PARAM_ELEM_SIZE = 8; // 8: int64 + constexpr uint32_t V1_DECODE_PARAM_NUM = 16; // Decode参数个数 + constexpr uint32_t V1_DECODE_DATA_NUM = 2; // Decode每个核需要存储头和尾部两块数据 + constexpr uint32_t S1_BASE_SIZE = 8; // S1轴基本块的大小 + constexpr uint32_t TOPK_MAX_SIZE = 2048; // TopK选取个数 + uint32_t workspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize(); + // 主流程需Workspace大小 + uint32_t mm1ResSize = M_BASE_SIZE * S2_BASE_SIZE; + workspaceSize += mm1ResSize * MM1_RES_ELEM_SIZE * DOUBLE_BUFFER * aicNum; + // Decode流程(LD)需要Workspace大小 + // 临时存储Decode中间结果大小: 2(头/尾)*8(s1Base)*2(idx/value)*2048(K)*sizeof(int32)*24=6M + workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_RES_ELEM_TYPE * TOPK_MAX_SIZE * V1_RES_ELEM_SIZE * aicNum; + // 临时存储Decode中间参数信息大小: 2(头/尾)*8(s1Base)*16(paramNum)*sizeof(int64_t)*24=48k + workspaceSize += V1_DECODE_DATA_NUM * S1_BASE_SIZE * V1_DECODE_PARAM_NUM * V1_DECODE_PARAM_ELEM_SIZE * aicNum; + size_t *workSpaces = context_->GetWorkspaceSizes(1); + workSpaces[0] = workspaceSize; + + // -------------set tilingdata----------------- + tilingData_.set_bSize(tilingInfo->bSize); + tilingData_.set_s2Size(tilingInfo->s2Size); + tilingData_.set_s1Size(tilingInfo->s1Size); + tilingData_.set_sparseCount(tilingInfo->sparseCount); + tilingData_.set_gSize(tilingInfo->gSize); + tilingData_.set_blockSize(tilingInfo->blockSize); + tilingData_.set_maxBlockNumPerBatch(tilingInfo->maxBlockNumPerBatch); + tilingData_.set_sparseMode(tilingInfo->sparseMode); + tilingData_.set_usedCoreNum(blockDim); + tilingData_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity()); + context_->GetRawTilingData()->SetDataSize(tilingData_.GetDataSize()); + + // -------------set tilingkey----------------- + // DT_Q, DT_KV, DT_OUT, PAGE_ATTENTION, FLASH_DECODE, LAYOUT_T, KV_LAYOUT_T + uint32_t inputQType = static_cast(tilingInfo->inputQType); + uint32_t inputKType = static_cast(tilingInfo->inputKType); + uint32_t outputType = static_cast(tilingInfo->outputType); + uint32_t pageAttentionFlag = static_cast(tilingInfo->pageAttentionFlag); + uint32_t inputQLayout = static_cast(tilingInfo->inputQLayout); + uint32_t inputKLayout = static_cast(tilingInfo->inputKLayout); + uint32_t tilingKey = + GET_TPL_TILING_KEY(inputQType, inputKType, outputType, pageAttentionFlag, inputQLayout, inputKLayout); + context_->SetTilingKey(tilingKey); + + return ge::GRAPH_SUCCESS; +} + +// --------------------------Tiling函数定义--------------------------- +ge::graphStatus TilingForLightningIndexer(gert::TilingContext *context) +{ + OPS_ERR_IF(context == nullptr, OPS_REPORT_VECTOR_INNER_ERR("LightningIndexer", "Tiling context is null."), + return ge::GRAPH_FAILED); + LITilingInfo liInfo; + LIInfoParser LIInfoParser(context); + if (LIInfoParser.ParseAndCheck(liInfo) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + LightningIndexerTiling liTiling(context); + return liTiling.DoTiling(&liInfo); +} + +// --------------------------Tiling函数及TilingPrepare函数注册-------- +IMPL_OP_OPTILING(LightningIndexer) + .Tiling(TilingForLightningIndexer) + .TilingParse(TilingPrepareForLightningIndexer); + +} // namespace optiling diff --git a/csrc/lightning_indexer/op_host/lightning_indexer_tiling.h b/csrc/lightning_indexer/op_host/lightning_indexer_tiling.h new file mode 100644 index 00000000000..79db8c9328a --- /dev/null +++ b/csrc/lightning_indexer/op_host/lightning_indexer_tiling.h @@ -0,0 +1,222 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_tiling.h + * \brief + */ + +#ifndef LIGHTNING_INDEXER_TILING_H_ +#define LIGHTNING_INDEXER_TILING_H_ + +#include "exe_graph/runtime/tiling_context.h" +#include "tiling/platform/platform_ascendc.h" +#include "register/op_def_registry.h" +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" +#include "error/ops_error.h" +#include "platform/platform_info.h" + +namespace optiling { +// ------------------公共定义-------------------------- +struct TilingRequiredParaInfo { + const gert::CompileTimeTensorDesc *desc; + const gert::StorageShape *shape; +}; + +struct TilingOptionalParaInfo { + const gert::CompileTimeTensorDesc *desc; + const gert::Tensor *tensor; +}; + +enum class DataLayout : uint32_t { + BSND = 0, + TND = 1, + BnBsND = 2 +}; + +// ------------------算子原型索引常量定义---------------- +// Inputs Index +constexpr uint32_t QUERY_INDEX = 0; +constexpr uint32_t KEY_INDEX = 1; +constexpr uint32_t WEIGTHS_INDEX = 2; +constexpr uint32_t ACTUAL_SEQ_Q_INDEX = 3; +constexpr uint32_t ACTUAL_SEQ_K_INDEX = 4; +constexpr uint32_t BLOCK_TABLE_INDEX = 5; +constexpr uint32_t LIGHTNING_INDEXER = 0; +// Attributes Index +constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 0; +constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 1; +constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 2; +constexpr uint32_t ATTR_SPARSE_MODE_INDEX = 3; +// Dim Index +constexpr uint32_t DIM_IDX_ONE = 1; +constexpr uint32_t DIM_IDX_TWO = 2; +constexpr uint32_t DIM_IDX_THREE = 3; +// Dim Num +constexpr uint32_t DIM_NUM_TWO = 2; +constexpr uint32_t DIM_NUM_THREE = 3; +constexpr uint32_t DIM_NUM_FOUR = 4; +// 入参限制常量 +constexpr uint32_t HEAD_DIM_LIMIT = 128; +constexpr uint32_t SPARSE_LIMIT = 2048; +constexpr uint32_t SPARSE_MODE_LOWER = 3; + +// -----------算子TilingData定义--------------- +BEGIN_TILING_DATA_DEF(LITilingData) +TILING_DATA_FIELD_DEF(uint32_t, bSize) +TILING_DATA_FIELD_DEF(uint32_t, n2Size) +TILING_DATA_FIELD_DEF(uint32_t, gSize) +TILING_DATA_FIELD_DEF(uint32_t, s1Size) +TILING_DATA_FIELD_DEF(uint32_t, s2Size) +TILING_DATA_FIELD_DEF(uint32_t, sparseCount) +TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum) +TILING_DATA_FIELD_DEF(uint32_t, blockSize) +TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch) +TILING_DATA_FIELD_DEF(uint32_t, sparseMode) +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(LightningIndexer, LITilingData) + +// -----------算子CompileInfo定义------------------- +struct LICompileInfo {}; + +// -----------算子Tiling入参结构体定义--------------- +struct LiParaInfo { + TilingRequiredParaInfo query = {nullptr, nullptr}; + TilingRequiredParaInfo key = {nullptr, nullptr}; + TilingRequiredParaInfo weights = {nullptr, nullptr}; + TilingOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr}; + TilingOptionalParaInfo actualSeqLengths = {nullptr, nullptr}; + TilingOptionalParaInfo blockTable = {nullptr, nullptr}; + TilingRequiredParaInfo attenOut = {nullptr, nullptr}; + + const char *layOut = nullptr; + const char *layOutKey = nullptr; + const int32_t *blockSize = nullptr; + const int32_t *sparseMode = nullptr; + const int32_t *sparseCount = nullptr; +}; + +// -----------算子Tiling入参信息类--------------- +class LITilingInfo { +public: + const char *opName = nullptr; + fe::PlatFormInfos *platformInfo = nullptr; + LiParaInfo opParamInfo; + // Base Param + platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B; + uint32_t bSize = 0; + uint32_t n1Size = 0; + uint32_t n2Size = 0; + uint32_t s1Size = 0; + int64_t s2Size = 0; + uint32_t qkHeadDim = 0; + uint32_t gSize = 0; + // PageAttention + bool pageAttentionFlag = false; + int32_t blockSize = 0; + uint32_t maxBlockNumPerBatch = 0; + // Mask + int32_t sparseMode = 0; + // Others Flag + uint32_t sparseCount = 0; + // DType + ge::DataType inputQType = ge::DT_FLOAT16; + ge::DataType inputKType = ge::DT_FLOAT16; + ge::DataType outputType = ge::DT_INT32; + // Layout + DataLayout inputQLayout = DataLayout::BSND; + DataLayout inputKLayout = DataLayout::BnBsND; +}; + +// -----------算子Tiling入参信息解析及Check类--------------- +class LIInfoParser { +public: + explicit LIInfoParser(gert::TilingContext *context) : context_(context) + { + } + ~LIInfoParser() = default; + + ge::graphStatus CheckRequiredInOutExistence() const; + ge::graphStatus CheckRequiredAttrExistence() const; + ge::graphStatus CheckRequiredParaExistence() const; + ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + const std::string &actualSeqLenName); + ge::graphStatus GetOpName(); + ge::graphStatus GetNpuInfo(); + void GetOptionalInputParaInfo(); + void GetInputParaInfo(); + void GetOutputParaInfo(); + ge::graphStatus GetAndCheckAttrParaInfo(); + ge::graphStatus GetOpParaInfo(); + ge::graphStatus ValidateInputShapesMatchQBsnd(); + ge::graphStatus ValidateInputShapesMatchQTnd(); + ge::graphStatus ValidateInputShapesMatch(); + ge::graphStatus GetAndCheckInOutDataType(); + ge::graphStatus GetBatchSize(); + ge::graphStatus GetHeadDim(); + ge::graphStatus GetS1Size(); + ge::graphStatus GetAndCheckOptionalInput(); + ge::graphStatus CheckShapeDim(); + ge::graphStatus GetAndCheckBlockSize(); + ge::graphStatus CheckBlockCount(); + ge::graphStatus GetS2SizeForPageAttention(); + ge::graphStatus GetS2Size(); + ge::graphStatus GetQueryKeyAndOutLayout(); + ge::graphStatus GetN1Size(); + ge::graphStatus GetAndCheckN2Size(); + ge::graphStatus GetGSize(); + ge::graphStatus GetAttenMaskInfo(); + ge::graphStatus GetActualSeqInfo(); + void GenerateInfo(LITilingInfo &liInfo); + ge::graphStatus ParseAndCheck(LITilingInfo &liInfo); + +public: + gert::TilingContext *context_ = nullptr; + const char *opName_; + fe::PlatFormInfos *platformInfo_; + LiParaInfo opParamInfo_; + + // BaseParams + uint32_t bSize_ = 0; + uint32_t n1Size_ = 0; + uint32_t n2Size_ = 0; + uint32_t gSize_ = 0; + uint32_t s1Size_ = 0; + int64_t s2Size_ = 0; + uint32_t headDim_ = 0; + // Layout + DataLayout qLayout_ = DataLayout::BSND; + DataLayout kLayout_ = DataLayout::BnBsND; + // PageAttention + uint32_t maxBlockNumPerBatch_ = 0; + int32_t blockSize_ = 0; + platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B; + ge::DataType inputQType_ = ge::DT_FLOAT16; + ge::DataType inputKType_ = ge::DT_FLOAT16; + ge::DataType weightsType_ = ge::DT_FLOAT16; + ge::DataType blockTableType_ = ge::DT_FLOAT16; + ge::DataType inputKRopeType_ = ge::DT_FLOAT16; + ge::DataType outputType_ = ge::DT_FLOAT16; +}; + +// ---------------算子Tiling类--------------- +class LightningIndexerTiling { +public: + explicit LightningIndexerTiling(gert::TilingContext *context) : context_(context){}; + ge::graphStatus DoTiling(LITilingInfo *tilingInfo); + +private: + gert::TilingContext *context_ = nullptr; + LITilingData tilingData_; +}; + +} // namespace optiling +#endif // LIGHTNING_INDEXER_TILING_H_ \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer.cpp b/csrc/lightning_indexer/op_kernel/lightning_indexer.cpp new file mode 100644 index 00000000000..fefa72e618e --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer.cpp @@ -0,0 +1,58 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer.cpp + * \brief + */ + +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "lightning_indexer_template_tiling_key.h" +#include "lightning_indexer_kernel.h" + +using namespace LIKernel; + +#define INVOKE_LI_NO_KFC_OP_IMPL(templateClass, ...) \ + do { \ + templateClass> op; \ + LI_COPY_TILING_DATA(LITilingData, tiling); \ + op.Init(query, key, weights, actualSeqLengthsQ, actualSeqLengths, blocktable, sparseIndices, user, \ + tiling_data, &tPipe); \ + op.Process(); \ + } while (0) + +#define LI_COPY_TILING_DATA(tilingDataStruct, tiling) \ + GET_TILING_DATA_WITH_STRUCT(tilingDataStruct, tiling_data_in, tiling); \ + const tilingDataStruct *__restrict tiling_data = &tiling_data_in; + + +template +__global__ __aicore__ void lightning_indexer(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, + __gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths, + __gm__ uint8_t *blocktable, __gm__ uint8_t *sparseIndices, + __gm__ uint8_t *workspace, __gm__ uint8_t *tiling) +{ +#if (__CCE_AICORE__ == 310) || (defined __DAV_310R6__) || (__CCE_AICORE__ == 200) + +#else + TPipe tPipe; + __gm__ uint8_t *user = GetUserWorkspace(workspace); + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); + + if constexpr (DT_Q == LI_TPL_FP16 && DT_K == LI_TPL_FP16 && DT_OUT == LI_TPL_INT32) { + INVOKE_LI_NO_KFC_OP_IMPL(LIPreload, half, half, int32_t, PAGE_ATTENTION, + LI_LAYOUT(LAYOUT_T), LI_LAYOUT(K_LAYOUT_T)); + } else { + INVOKE_LI_NO_KFC_OP_IMPL(LIPreload, bfloat16_t, bfloat16_t, int32_t, PAGE_ATTENTION, + LI_LAYOUT(LAYOUT_T), LI_LAYOUT(K_LAYOUT_T)); + } +#endif +} diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_common.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_common.h new file mode 100644 index 00000000000..eb4086fb080 --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_common.h @@ -0,0 +1,142 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_common.h + * \brief + */ +#ifndef LIGHTNING_INDEXER_COMMON_H +#define LIGHTNING_INDEXER_COMMON_H + +namespace LICommon { + +// 与tiling的layout保持一致 +enum class LI_LAYOUT { + BSND = 0, + TND = 1, + PA_BSND = 2 +}; + +template +struct LIType { + using queryType = Q_T; + using keyType = K_T; + using outputType = OUT_T; + static constexpr bool pageAttention = PAGE_ATTENTION; + static constexpr LI_LAYOUT layout = LAYOUT_T; + static constexpr LI_LAYOUT keyLayout = K_LAYOUT_T; +}; + +struct RunInfo { + uint32_t loop; + uint32_t bN2Idx; + uint32_t bIdx; + uint32_t n2Idx = 0; + uint32_t gS1Idx; + uint32_t s2Idx; + + uint32_t actS1Size = 1; + uint32_t actS2Size = 1; + uint32_t actMBaseSize; + uint32_t actualSingleProcessSInnerSize; + uint32_t actualSingleProcessSInnerSizeAlign; + + uint64_t tensorQueryOffset; + uint64_t tensorKeyOffset; + uint64_t tensorWeightsOffset; + uint64_t indiceOutOffset; + + bool isFirstS2InnerLoop; + bool isLastS2InnerLoop; + bool isAllLoopEnd = false; +}; + +struct ConstInfo { + // CUBE与VEC核间同步的模式 + static constexpr uint32_t FIA_SYNC_MODE2 = 2; + // BUFFER的字节数 + static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32; + static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64; + static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256; + static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512; + static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024; + static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048; + static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096; + static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192; + static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384; + static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768; + // 无效索引 + static constexpr int INVALID_IDX = -1; + + // CUBE和VEC的核间同步EventID + uint32_t syncC1V1 = 0U; + uint32_t syncV1C1 = 0U; + + // 基本块大小 + uint32_t mBaseSize = 1ULL; + uint32_t s1BaseSize = 1ULL; + uint32_t s2BaseSize = 1ULL; + + uint64_t batchSize = 0ULL; + uint64_t gSize = 0ULL; + uint64_t qHeadNum = 0ULL; + uint64_t kHeadNum; + uint64_t headDim; + uint64_t sparseCount; // topK选取大小 + uint64_t kSeqSize = 0ULL; // kv最大S长度 + uint64_t qSeqSize = 1ULL; // q最大S长度 + uint32_t kCacheBlockSize = 0; // PA场景的block size + uint32_t maxBlockNumPerBatch = 0; // PA场景的最大单batch block number + LI_LAYOUT outputLayout; // 输出的格式 + bool attenMaskFlag = false; + + uint32_t actualLenQDims = 0U; // query的actualSeqLength 的维度 + uint32_t actualLenDims = 0U; // KV 的actualSeqLength 的维度 + bool isAccumSeqS1 = false; // 是否累加模式 + bool isAccumSeqS2 = false; // 是否累加模式 +}; + +struct SplitCoreInfo { + uint32_t s2Start = 0U; // S2的起始位置 + uint32_t s2End = 0U; // S2循环index上限 + uint32_t bN2Start = 0U; + uint32_t bN2End = 0U; + uint32_t gS1Start = 0U; + uint32_t gS1End = 0U; + bool isLD = false; // 当前核是否需要进行Decode归约任务 +}; + +template +__aicore__ inline T Align(T num, T rnd) +{ + return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd) * (rnd))); +} + +template +__aicore__ inline T1 Min(T1 a, T2 b) +{ + return (a > b) ? (b) : (a); +} + +template +__aicore__ inline T1 Max(T1 a, T2 b) +{ + return (a > b) ? (a) : (b); +} + +template +__aicore__ inline T CeilDiv(T num, T rnd) +{ + return (((rnd) == 0) ? 0 : (((num) + (rnd)-1) / (rnd))); +} +} // namespace LICommon + +#endif // LIGHTNING_INDEXER_COMMON_H \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h new file mode 100644 index 00000000000..62bb913b572 --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_kernel.h @@ -0,0 +1,646 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_kernel.h + * \brief + */ + +#ifndef LIGHTNING_INDEXER_KERNEL_H +#define LIGHTNING_INDEXER_KERNEL_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "lightning_indexer_common.h" +#include "lightning_indexer_service_vector.h" +#include "lightning_indexer_service_cube.h" + +namespace LIKernel { +using namespace LICommon; +using namespace LIServiceVec; +using namespace matmul; +using AscendC::CacheMode; +using AscendC::CrossCoreSetFlag; +using AscendC::CrossCoreWaitFlag; + +// 由于S2循环前,RunInfo还没有赋值,使用TempLoopInfo临时存放B、N、S1轴相关的信息;同时减少重复计算 +struct TempLoopInfo { + uint32_t bN2Idx = 0; + uint32_t bIdx = 0U; + uint32_t n2Idx = 0U; + uint32_t gS1Idx = 0U; + uint32_t gS1LoopEnd = 0U; // gS1方向循环的结束Idx + uint32_t s2LoopEnd = 0U; // S2方向循环的结束Idx + uint32_t actS1Size = 1ULL; // 当前Batch循环处理的S1轴的实际大小 + uint32_t actS2Size = 0ULL; + bool curActSeqLenIsZero = false; + bool needDealActS1LessThanS1 = false; // S1的实际长度小于shape的S1长度时,是否需要清理输出 + uint32_t actMBaseSize = 0U; // m轴(gS1)方向实际大小 + uint32_t mBasicSizeTail = 0U; // gS1方向循环的尾基本块大小 + uint32_t s2BasicSizeTail = 0U; // S2方向循环的尾基本块大小 +}; + +template +class LIPreload { +public: + __aicore__ inline LIPreload(){}; + __aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, + __gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths, + __gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices, __gm__ uint8_t *workspace, + const LITilingData *__restrict tiling, TPipe *tPipe); + __aicore__ inline void Process(); + + // =================================类型定义区================================= + using Q_T = typename LIT::queryType; + using K_T = typename LIT::keyType; + using OUT_T = typename LIT::outputType; + static constexpr bool PAGE_ATTENTION = LIT::pageAttention; + static constexpr LI_LAYOUT LAYOUT_T = LIT::layout; + static constexpr LI_LAYOUT K_LAYOUT_T = LIT::keyLayout; + + using MM1_OUT_T = float; + + LIMatmul matmulService; + LIVector vectorService; + + // =================================常量区================================= + static constexpr uint32_t SYNC_C1_V1_FLAG = 4; + static constexpr uint32_t SYNC_V1_C1_FLAG = 5; + + static constexpr uint32_t M_BASE_SIZE = 512; + static constexpr uint32_t S2_BASE_SIZE = 512; + static constexpr uint32_t HEAD_DIM = 128; + static constexpr uint32_t K_HEAD_NUM = 1; + static constexpr uint32_t GM_ALIGN_BYTES = 512; + + static constexpr int64_t LD_PREFETCH_LEN = 2; + // for workspace double + static constexpr uint32_t WS_DOBULE = 2; + +protected: + TPipe *pipe = nullptr; + + // offset + uint64_t queryCoreOffset = 0ULL; + uint64_t keyCoreOffset = 0ULL; + uint64_t weightsCoreOffset = 0ULL; + uint64_t indiceOutCoreOffset = 0ULL; + + // ================================Global Buffer区================================= + GlobalTensor queryGm; + GlobalTensor keyGm; + GlobalTensor weightsGm; + + GlobalTensor indiceOutGm; + GlobalTensor blockTableGm; + + GlobalTensor actualSeqLengthsGmQ; + GlobalTensor actualSeqLengthsGm; + // workspace + GlobalTensor mm1ResGm; // 存放S + GlobalTensor vec1ResGm; // 存放TopK计算中间结果 + GlobalTensor vec1ParamGm; // 存放LD参数信息 + + // ================================类成员变量==================================== + // aic、aiv核信息 + uint32_t tmpBlockIdx = 0U; + uint32_t aiCoreIdx = 0U; + uint32_t usedCoreNum = 0U; + + LICommon::ConstInfo constInfo{}; + TempLoopInfo tempLoopInfo{}; + LICommon::SplitCoreInfo splitCoreInfo{}; + + // ================================Init functions================================== + __aicore__ inline void InitTilingData(const LITilingData *__restrict tilingData); + __aicore__ inline void InitBuffers(); + __aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths); + // ================================Split Core================================ + __aicore__ inline void SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LICommon::SplitCoreInfo &info); + __aicore__ inline uint32_t GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, uint32_t actS2Size); + __aicore__ inline uint32_t GetTotalBaseBlockNum(); + // ================================Process functions================================ + __aicore__ inline void ProcessMain(); + __aicore__ inline void ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, LICommon::RunInfo &runInfo); + __aicore__ inline void ProcessDecode(); + __aicore__ inline void ProcessInvalid(); + // ================================Params Calc===================================== + __aicore__ inline void CalcGS1LoopParams(uint32_t bN2Idx); + __aicore__ inline void GetBN2Idx(uint32_t bN2Idx); + __aicore__ inline uint32_t GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq, + GlobalTensor &actualSeqLengthsGm, uint32_t defaultSeqLen); + __aicore__ inline void GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size); + __aicore__ inline void CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx); + __aicore__ inline void CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LICommon::RunInfo &runInfo); + __aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start); +}; + +template +__aicore__ inline void LIPreload::InitTilingData(const LITilingData *__restrict tilingData) +{ + usedCoreNum = tilingData->usedCoreNum; + constInfo.batchSize = tilingData->bSize; + constInfo.qHeadNum = constInfo.gSize = tilingData->gSize; + constInfo.kSeqSize = tilingData->s2Size; + constInfo.qSeqSize = tilingData->s1Size; + constInfo.attenMaskFlag = (tilingData->sparseMode == 3); + constInfo.kCacheBlockSize = tilingData->blockSize; + constInfo.maxBlockNumPerBatch = tilingData->maxBlockNumPerBatch; + constInfo.sparseCount = tilingData->sparseCount; + constInfo.outputLayout = LAYOUT_T; // 输出和输入形状一致 + if (LAYOUT_T == LI_LAYOUT::TND) { + constInfo.isAccumSeqS1 = true; + } + if (K_LAYOUT_T == LI_LAYOUT::TND) { + constInfo.isAccumSeqS2 = true; + } + + constInfo.kHeadNum = K_HEAD_NUM; + constInfo.headDim = HEAD_DIM; + + constInfo.mBaseSize = M_BASE_SIZE; + constInfo.s2BaseSize = S2_BASE_SIZE; + constInfo.s1BaseSize = (constInfo.mBaseSize + constInfo.gSize - 1) / constInfo.gSize; +} + +template +__aicore__ inline void LIPreload::InitBuffers() +{ + if ASCEND_IS_AIV { + vectorService.InitBuffers(pipe); + } else { + matmulService.InitBuffers(pipe); + } +} + +template +__aicore__ inline void LIPreload::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, + __gm__ uint8_t *actualSeqLengths) +{ + if (actualSeqLengthsQ == nullptr) { + constInfo.actualLenQDims = 0; + } else { + constInfo.actualLenQDims = constInfo.batchSize; + actualSeqLengthsGmQ.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengthsQ, constInfo.actualLenQDims); + } + if (actualSeqLengths == nullptr) { + constInfo.actualLenDims = 0; + } else { + constInfo.actualLenDims = constInfo.batchSize; + actualSeqLengthsGm.SetGlobalBuffer((__gm__ uint32_t *)actualSeqLengths, constInfo.actualLenDims); + } +} + +template +__aicore__ inline uint32_t LIPreload::GetActualSeqLen(uint32_t bIdx, uint32_t actualLenDims, bool isAccumSeq, + GlobalTensor &actualSeqLengthsGm, + uint32_t defaultSeqLen) +{ + if (actualLenDims == 0) { + return defaultSeqLen; + } else if (isAccumSeq && bIdx > 0) { + return actualSeqLengthsGm.GetValue(bIdx) - actualSeqLengthsGm.GetValue(bIdx - 1); + } else { + return actualSeqLengthsGm.GetValue(bIdx); + } +} + +template +__aicore__ inline void LIPreload::GetS1S2ActualSeqLen(uint32_t bIdx, uint32_t &actS1Size, uint32_t &actS2Size) +{ + actS1Size = GetActualSeqLen(bIdx, constInfo.actualLenQDims, constInfo.isAccumSeqS1, actualSeqLengthsGmQ, + constInfo.qSeqSize); + actS2Size = + GetActualSeqLen(bIdx, constInfo.actualLenDims, constInfo.isAccumSeqS2, actualSeqLengthsGm, constInfo.kSeqSize); +} + +template +__aicore__ inline uint32_t LIPreload::GetS2BaseBlockNumOnMask(uint32_t s1gIdx, uint32_t actS1Size, + uint32_t actS2Size) +{ + if (actS2Size == 0) { + return 0; + } + uint32_t s1Offset = constInfo.s1BaseSize * s1gIdx; + int32_t validS2LenBase = static_cast(actS2Size) - static_cast(actS1Size); + int32_t validS2Len = s1Offset + validS2LenBase + constInfo.s1BaseSize; + validS2Len = Min(validS2Len, static_cast(actS2Size)); + validS2Len = Max(validS2Len, 1); + return (validS2Len + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; +} + +template +__aicore__ inline uint32_t LIPreload::GetTotalBaseBlockNum() +{ + uint32_t totalBlockNum = 0; + uint32_t actS1Size, actS2Size; + uint32_t s1GBaseNum, s2BaseNum; + for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) { + GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size); + s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize); + if (!constInfo.attenMaskFlag) { + s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize); + totalBlockNum += s1GBaseNum * s2BaseNum * constInfo.kHeadNum; + continue; + } + for (uint32_t s1gIdx = 0; s1gIdx < s1GBaseNum; s1gIdx++) { + s2BaseNum = GetS2BaseBlockNumOnMask(s1gIdx, actS1Size, actS2Size); + totalBlockNum += s2BaseNum * constInfo.kHeadNum; + } + } + return totalBlockNum; +} + +// 多核版本,双闭区间 +template +__aicore__ void inline LIPreload::SplitCore(uint32_t curCoreIdx, uint32_t &coreNum, LICommon::SplitCoreInfo &info) +{ + // 计算每个核最少处理的块数, 剩余的部分前面的核每个核多处理一块 + uint32_t totalBlockNum = GetTotalBaseBlockNum(); + uint32_t minBlockPerCore = totalBlockNum / coreNum; + uint32_t deal1MoreBlockCoreNum = totalBlockNum % coreNum; + uint32_t coreIdx = 0; + uint32_t lastGS1RemainBlockCnt = 0; + uint32_t coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore; + coreNum = minBlockPerCore == 0 ? deal1MoreBlockCoreNum : coreNum; + + bool findLastCoreEnd = true; + uint32_t actS1Size, actS2Size; + uint32_t s1GBaseNum, s2BaseNum; + for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kHeadNum; bN2Idx++) { + uint32_t bIdx = bN2Idx / constInfo.kHeadNum; + if (bN2Idx % constInfo.kHeadNum == 0) { + GetS1S2ActualSeqLen(bIdx, actS1Size, actS2Size); + s1GBaseNum = CeilDiv(actS1Size, constInfo.s1BaseSize); + s2BaseNum = CeilDiv(actS2Size, constInfo.s2BaseSize); + } + if constexpr (LAYOUT_T == LI_LAYOUT::BSND) { + if (findLastCoreEnd && (s1GBaseNum == 0U || s2BaseNum == 0U)) { + info.bN2Start = bN2Idx; + info.gS1Start = 0; + info.s2Start = 0; + findLastCoreEnd = false; + } + } + for (uint32_t gS1Idx = 0; gS1Idx < s1GBaseNum; gS1Idx++) { + if (constInfo.attenMaskFlag) { + s2BaseNum = GetS2BaseBlockNumOnMask(gS1Idx, actS1Size, actS2Size); + } + if (findLastCoreEnd && s2BaseNum == 0U) { + info.bN2Start = bN2Idx; + info.gS1Start = gS1Idx; + info.s2Start = 0; + findLastCoreEnd = false; + } + for (uint32_t s2Idx = 0; s2Idx < s2BaseNum;) { + if (findLastCoreEnd) { + info.bN2Start = bN2Idx; + info.gS1Start = gS1Idx; + info.s2Start = s2Idx; + findLastCoreEnd = false; + } + uint32_t s2RemainBaseNum = s2BaseNum - s2Idx; + if (lastGS1RemainBlockCnt + s2RemainBaseNum >= coreDealBlockCnt) { + info.bN2End = bN2Idx; + info.gS1End = gS1Idx; + info.s2End = s2Idx + coreDealBlockCnt - lastGS1RemainBlockCnt - 1; + + if (coreIdx == curCoreIdx) { + // S2被切N核,那么只有第一个核需要处理LD,其他核不用 + if (s2Idx == 0 && info.s2End + 1 < s2BaseNum) { + info.isLD = true; + } + // 最后一个核处理的不是最后一个Batch,表明后面的Batch为空块(S2=0), 调整终点坐标以便清理输出 + if (coreIdx == coreNum - 1 && info.bN2End != constInfo.batchSize -1) { + info.bN2End = constInfo.batchSize -1; + info.gS1End = 0; + info.s2End = 0; + } + return; + } + coreIdx++; + findLastCoreEnd = true; + s2Idx = info.s2End + 1; + lastGS1RemainBlockCnt = 0; + coreDealBlockCnt = coreIdx < deal1MoreBlockCoreNum ? minBlockPerCore + 1 : minBlockPerCore; + } else { + lastGS1RemainBlockCnt += s2RemainBaseNum; + break; + } + } + } + } +} + +template +__aicore__ inline void LIPreload::DealActSeqLenIsZero(uint32_t bIdx, uint32_t n2Idx, uint32_t s1Start) +{ + if ASCEND_IS_AIV { + if (constInfo.outputLayout == LI_LAYOUT::TND) { + uint32_t tSize = actualSeqLengthsGmQ.GetValue(constInfo.batchSize - 1); + uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsGmQ.GetValue(bIdx - 1); + uint32_t s1Count = tempLoopInfo.actS1Size; + + for (uint32_t s1Idx = s1Start; s1Idx < s1Count; s1Idx++) { + uint64_t indiceOutOffset = + (tBase + s1Idx) * constInfo.kHeadNum * constInfo.sparseCount + // T轴、s1轴偏移 + n2Idx * constInfo.sparseCount; // N2轴偏移 + vectorService.CleanInvalidOutput(indiceOutOffset); + } + } else if (constInfo.outputLayout == LI_LAYOUT::BSND) { + for (uint32_t s1Idx = s1Start; s1Idx < constInfo.qSeqSize; s1Idx++) { + // B,S1,N2,K + uint64_t indiceOutOffset = bIdx * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount + + s1Idx * constInfo.kHeadNum * constInfo.sparseCount + // B轴、S1轴偏移 + n2Idx * constInfo.sparseCount; // N2轴偏移 + vectorService.CleanInvalidOutput(indiceOutOffset); + } + } + } +} + +template +__aicore__ inline void LIPreload::Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *weights, + __gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths, + __gm__ uint8_t *blockTable, __gm__ uint8_t *sparseIndices, + __gm__ uint8_t *workspace, const LITilingData *__restrict tiling, + TPipe *tPipe) +{ + if ASCEND_IS_AIV { + tmpBlockIdx = GetBlockIdx(); // vec:0-47 + aiCoreIdx = tmpBlockIdx / 2; + } else { + tmpBlockIdx = GetBlockIdx(); // cube:0-23 + aiCoreIdx = tmpBlockIdx; + } + + InitTilingData(tiling); + InitActualSeqLen(actualSeqLengthsQ, actualSeqLengths); + + // 计算分核 + SplitCore(aiCoreIdx, usedCoreNum, splitCoreInfo); + + pipe = tPipe; + // workspace 内存排布 + // |mm1ResGm(存S)|vec1ResGm(存LD中间结果)|vec1ParamGm(存LD参数) + // |Core0_mm1ResDB0-Core0_mm1ResDB1-Core1_mm1ResDB0....Core23_mm1ResDB0-Core23_mm1ResDB1|Core0_vec1Res... + uint64_t offset = 0; + + // mm1开DoubleBuffer + uint64_t singleCoreMm1ResSize = WS_DOBULE * constInfo.mBaseSize * constInfo.s2BaseSize * sizeof(MM1_OUT_T); + mm1ResGm.SetGlobalBuffer((__gm__ MM1_OUT_T *)(workspace + offset + aiCoreIdx * singleCoreMm1ResSize)); + offset += GetBlockNum() * singleCoreMm1ResSize; + + // ld流程需要ws大小: [aicnum, 2, CeilDiv(constInfo.mBaseSize, constInfo.gSize), topkOut_*2] + // (aic, 8, 2, 2, 2048) + // (aic, s1_cube, 头尾, idx/value, K) + vec1ResGm.SetGlobalBuffer((__gm__ float *)(workspace + offset)); + offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * WS_DOBULE * BASE_TOPK * sizeof(float); + + // (aic, 8, 2, 16) + // (aic, s1_cube, 头尾,16ele) + vec1ParamGm.SetGlobalBuffer((__gm__ int64_t *)(workspace + offset)); + offset += GetBlockNum() * constInfo.s1BaseSize * WS_DOBULE * LD_PARAM_NUM * sizeof(int64_t); + + if ASCEND_IS_AIV { + vectorService.InitParams(constInfo, tiling); + indiceOutGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices); + weightsGm.SetGlobalBuffer((__gm__ K_T *)weights); + vectorService.InitVec1GlobalTensor(mm1ResGm, vec1ResGm, vec1ParamGm, weightsGm, indiceOutGm); + } else { + matmulService.InitParams(constInfo); + queryGm.SetGlobalBuffer((__gm__ Q_T *)query); + if constexpr (PAGE_ATTENTION) { + blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable); + } + keyGm.SetGlobalBuffer((__gm__ K_T *)key); + matmulService.InitMm1GlobalTensor(blockTableGm, keyGm, queryGm, mm1ResGm); + } + InitBuffers(); +} + +template +__aicore__ inline void LIPreload::GetBN2Idx(uint32_t bN2Idx) +{ + tempLoopInfo.bN2Idx = bN2Idx; + tempLoopInfo.bIdx = bN2Idx / constInfo.kHeadNum; + tempLoopInfo.n2Idx = bN2Idx % constInfo.kHeadNum; +} + +template +__aicore__ inline void LIPreload::CalcS2LoopParams(uint32_t bN2LoopIdx, uint32_t gS1LoopIdx) +{ + tempLoopInfo.gS1Idx = gS1LoopIdx; + tempLoopInfo.actMBaseSize = constInfo.mBaseSize; + uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx * constInfo.mBaseSize; + if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) { + tempLoopInfo.actMBaseSize = tempLoopInfo.mBasicSizeTail; + } + + bool isEnd = (bN2LoopIdx == splitCoreInfo.bN2End) && (gS1LoopIdx == splitCoreInfo.gS1End); + uint32_t s2BlockNum; + if (constInfo.attenMaskFlag) { + s2BlockNum = GetS2BaseBlockNumOnMask(gS1LoopIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size); + } else { + s2BlockNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; + } + tempLoopInfo.s2LoopEnd = isEnd ? splitCoreInfo.s2End : s2BlockNum - 1; +} + +template +__aicore__ inline void LIPreload::CalcGS1LoopParams(uint32_t bN2LoopIdx) +{ + GetBN2Idx(bN2LoopIdx); + GetS1S2ActualSeqLen(tempLoopInfo.bIdx, tempLoopInfo.actS1Size, tempLoopInfo.actS2Size); + if ((tempLoopInfo.actS2Size == 0) || (tempLoopInfo.actS1Size == 0)) { + tempLoopInfo.curActSeqLenIsZero = true; + return; + } + tempLoopInfo.curActSeqLenIsZero = false; + tempLoopInfo.s2BasicSizeTail = tempLoopInfo.actS2Size % constInfo.s2BaseSize; + tempLoopInfo.s2BasicSizeTail = + (tempLoopInfo.s2BasicSizeTail == 0) ? constInfo.s2BaseSize : tempLoopInfo.s2BasicSizeTail; + tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize; + tempLoopInfo.mBasicSizeTail = + (tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail; + + uint32_t gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize; + tempLoopInfo.gS1LoopEnd = (bN2LoopIdx == splitCoreInfo.bN2End) ? splitCoreInfo.gS1End : gS1SplitNum - 1; + if constexpr (LAYOUT_T == LI_LAYOUT::BSND) { + if (tempLoopInfo.gS1LoopEnd == gS1SplitNum - 1 && constInfo.qSeqSize > tempLoopInfo.actS1Size) { + tempLoopInfo.needDealActS1LessThanS1 = true; + } + } +} + +template +__aicore__ inline void LIPreload::CalcRunInfo(uint32_t loop, uint32_t s2LoopIdx, LICommon::RunInfo &runInfo) +{ + runInfo.loop = loop; + runInfo.bIdx = tempLoopInfo.bIdx; + runInfo.gS1Idx = tempLoopInfo.gS1Idx; + runInfo.s2Idx = s2LoopIdx; + runInfo.bN2Idx = tempLoopInfo.bN2Idx; + + runInfo.actS1Size = tempLoopInfo.actS1Size; + runInfo.actS2Size = tempLoopInfo.actS2Size; + // 计算实际基本块size + runInfo.actMBaseSize = tempLoopInfo.actMBaseSize; + runInfo.actualSingleProcessSInnerSize = constInfo.s2BaseSize; + uint32_t s2SplitNum = (tempLoopInfo.actS2Size + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; + if (runInfo.s2Idx == s2SplitNum - 1) { + runInfo.actualSingleProcessSInnerSize = tempLoopInfo.s2BasicSizeTail; + } + runInfo.actualSingleProcessSInnerSizeAlign = + LICommon::Align((uint32_t)runInfo.actualSingleProcessSInnerSize, LICommon::ConstInfo::BUFFER_SIZE_BYTE_32B); + + runInfo.isFirstS2InnerLoop = s2LoopIdx == splitCoreInfo.s2Start; + runInfo.isLastS2InnerLoop = s2LoopIdx == tempLoopInfo.s2LoopEnd; + runInfo.isAllLoopEnd = (runInfo.bN2Idx == splitCoreInfo.bN2End) && (runInfo.gS1Idx == splitCoreInfo.gS1End) && + (runInfo.s2Idx == splitCoreInfo.s2End); + + if (runInfo.isFirstS2InnerLoop) { + uint64_t actualSeqQPrefixSum; + uint64_t actualSeqKPrefixSum; + if constexpr (LAYOUT_T == LI_LAYOUT::TND) { + actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGmQ.GetValue(runInfo.bIdx - 1); + actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsGm.GetValue(runInfo.bIdx - 1); + } else { // BSND + actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.qSeqSize; + actualSeqKPrefixSum = (runInfo.bIdx <= 0) ? 0 : runInfo.bIdx * constInfo.kSeqSize; + } + uint64_t tndBIdxOffset = actualSeqQPrefixSum * constInfo.qHeadNum * constInfo.headDim; + uint64_t tndKeyBIdxOffset = actualSeqKPrefixSum * constInfo.kHeadNum * constInfo.headDim; + // B,S1,N1(N2,G),D + queryCoreOffset = tndBIdxOffset + runInfo.gS1Idx * constInfo.mBaseSize * constInfo.headDim; + keyCoreOffset = tndKeyBIdxOffset + runInfo.n2Idx * constInfo.headDim; + // B,S1,N1(N2,G)/T,N1(N2,G) + weightsCoreOffset = actualSeqQPrefixSum * constInfo.qHeadNum + runInfo.n2Idx * constInfo.gSize; + // B,S1,N2,k/T,N2,k + indiceOutCoreOffset = actualSeqQPrefixSum * constInfo.kHeadNum * constInfo.sparseCount + + runInfo.n2Idx * constInfo.sparseCount; + } + runInfo.tensorQueryOffset = queryCoreOffset; + runInfo.tensorKeyOffset = keyCoreOffset + runInfo.s2Idx * constInfo.s2BaseSize * constInfo.kHeadNum + * constInfo.headDim; + runInfo.tensorWeightsOffset = weightsCoreOffset; + runInfo.indiceOutOffset = indiceOutCoreOffset; +} + +template +__aicore__ inline void LIPreload::Process() +{ + if (usedCoreNum == 0) { + // 没有计算任务,直接清理输出 + ProcessInvalid(); + return; + } + ProcessMain(); + ProcessDecode(); +} + +template +__aicore__ inline void LIPreload::ProcessInvalid() +{ + if ASCEND_IS_AIV { + uint32_t aivCoreNum = GetBlockNum() * 2; // 2 means c:v = 1:2 + uint64_t totalOutputSize = + constInfo.batchSize * constInfo.qSeqSize * constInfo.kHeadNum * constInfo.sparseCount; + uint64_t singleCoreSize = + LICommon::Align((totalOutputSize + aivCoreNum - 1) / aivCoreNum, GM_ALIGN_BYTES / sizeof(OUT_T)); + uint64_t baseSize = tmpBlockIdx * singleCoreSize; + if (baseSize < totalOutputSize) { + uint64_t dealSize = + (baseSize + singleCoreSize > totalOutputSize) ? singleCoreSize : totalOutputSize - baseSize; + GlobalTensor output = indiceOutGm[baseSize]; + AscendC::InitGlobalMemory(output, dealSize, constInfo.INVALID_IDX); + } + } +} + +template +__aicore__ inline void LIPreload::ProcessMain() +{ + if (aiCoreIdx >= usedCoreNum) { + // 无任务核直接返回 + return; + } + + if ASCEND_IS_AIV { + vectorService.AllocEventID(); + CrossCoreSetFlag(constInfo.syncV1C1); + CrossCoreSetFlag(constInfo.syncV1C1); + } else { + matmulService.AllocEventID(); + } + + LICommon::RunInfo runInfo; + uint32_t gloop = 0; + for (uint32_t bN2LoopIdx = splitCoreInfo.bN2Start; bN2LoopIdx <= splitCoreInfo.bN2End; bN2LoopIdx++) { + CalcGS1LoopParams(bN2LoopIdx); + if (tempLoopInfo.curActSeqLenIsZero) { + DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, 0U); + continue; + } + for (uint32_t gS1LoopIdx = splitCoreInfo.gS1Start; gS1LoopIdx <= tempLoopInfo.gS1LoopEnd; gS1LoopIdx++) { + CalcS2LoopParams(bN2LoopIdx, gS1LoopIdx); + for (int s2LoopIdx = splitCoreInfo.s2Start; s2LoopIdx <= tempLoopInfo.s2LoopEnd; s2LoopIdx++) { + ProcessBaseBlock(gloop, s2LoopIdx, runInfo); + ++gloop; + } + splitCoreInfo.s2Start = 0; + } + if (tempLoopInfo.needDealActS1LessThanS1) { + DealActSeqLenIsZero(tempLoopInfo.bIdx, tempLoopInfo.n2Idx, tempLoopInfo.actS1Size); + } + splitCoreInfo.gS1Start = 0; + } + + if ASCEND_IS_AIV { + vectorService.FreeEventID(); + } else { + matmulService.FreeEventID(); + CrossCoreWaitFlag(constInfo.syncV1C1); + CrossCoreWaitFlag(constInfo.syncV1C1); + } +} + +template +__aicore__ inline void LIPreload::ProcessBaseBlock(uint32_t loop, uint64_t s2LoopIdx, LICommon::RunInfo &runInfo) +{ + CalcRunInfo(loop, s2LoopIdx, runInfo); + if ASCEND_IS_AIC { + CrossCoreWaitFlag(constInfo.syncV1C1); + matmulService.ComputeMm1(runInfo); + CrossCoreSetFlag(constInfo.syncC1V1); + } else { + CrossCoreWaitFlag(constInfo.syncC1V1); + vectorService.ProcessVec(runInfo); + CrossCoreSetFlag(constInfo.syncV1C1); + } +} + +template +__aicore__ inline void LIPreload::ProcessDecode() +{ + if ASCEND_IS_AIV { + vectorService.InitLDBuffers(pipe); + ICachePreLoad(LD_PREFETCH_LEN); + SyncAll(); + if (splitCoreInfo.isLD) { + vectorService.ProcessLD(); + } + } +} +} // namespace LIKernel +#endif // LIGHTNING_INDEXER_KERNEL_H \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_service_cube.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_service_cube.h new file mode 100644 index 00000000000..4a86389d416 --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_service_cube.h @@ -0,0 +1,421 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_service_cube.h + * \brief use 5 buffer for matmul l1, better pipeline + */ +#ifndef LIGHTNING_INDEXER_SERVICE_CUBE_H +#define LIGHTNING_INDEXER_SERVICE_CUBE_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "lightning_indexer_common.h" + +namespace LIKernel { +using namespace LICommon; +template +class LIMatmul { +public: + using Q_T = typename LIT::queryType; + using K_T = typename LIT::keyType; + + __aicore__ inline LIMatmul(){}; + __aicore__ inline void InitBuffers(TPipe *pipe); + __aicore__ inline void InitMm1GlobalTensor(const GlobalTensor &blkTableGm, const GlobalTensor &keyGm, + const GlobalTensor &queryGm, const GlobalTensor &mm1ResGm); + __aicore__ inline void InitParams(const ConstInfo &constInfo); + __aicore__ inline void AllocEventID(); + __aicore__ inline void FreeEventID(); + __aicore__ inline void ComputeMm1(const LICommon::RunInfo &runInfo); + + static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; // isSetFMatrix isSetPadding; + static constexpr uint64_t KEY_BUF_NUM = 3; + static constexpr uint64_t QUERY_BUF_NUM = 2; + static constexpr uint64_t L0_BUF_NUM = 2; + + static constexpr uint32_t KEY_MTE1_MTE2_EVENT = EVENT_ID2; + static constexpr uint32_t QUERY_MTE1_MTE2_EVENT = EVENT_ID5; // KEY_MTE1_MTE2_EVENT + KEY_BUF_NUM; + static constexpr uint32_t M_MTE1_EVENT = EVENT_ID3; + + static constexpr uint32_t MTE2_MTE1_EVENT = EVENT_ID2; + static constexpr uint32_t MTE1_M_EVENT = EVENT_ID2; + + static constexpr uint64_t M_BASIC_BLOCK = 256; + static constexpr uint64_t D_BASIC_BLOCK = 128; + static constexpr uint64_t S2_BASIC_BLOCK = 256; + + static constexpr uint64_t M_BASIC_BLOCK_L0 = 128; + static constexpr uint64_t D_BASIC_BLOCK_L0 = 128; + static constexpr uint64_t S2_BASIC_BLOCK_L0 = 128; + + static constexpr uint64_t QUERY_BUFFER_OFFSET = M_BASIC_BLOCK * D_BASIC_BLOCK; + static constexpr uint64_t KEY_BUFFER_OFFSET = S2_BASIC_BLOCK * D_BASIC_BLOCK; + static constexpr uint64_t L0AB_BUFFER_OFFSET = M_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0; + static constexpr uint64_t L0C_BUFFER_OFFSET = M_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0; + +protected: + __aicore__ inline void Fixp(uint64_t s1gGmOffset, uint64_t s2GmOffset, uint64_t s1gL0RealSize, + uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo); + __aicore__ inline void ComuteL0c(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo); + __aicore__ inline void LoadKeyToL0b(uint64_t s2L0Offset, uint64_t s2L1RealSize, uint64_t s2L0RealSize, + const LICommon::RunInfo &runInfo); + __aicore__ inline void LoadQueryToL0a(uint64_t s1gL1Offset, uint64_t s1gL0Offset, uint64_t s1gL1RealSize, + uint64_t s1gL0RealSize, const LICommon::RunInfo &runInfo); + __aicore__ inline void QueryNd2Nz(uint64_t s1gL1RealSize, uint64_t s1gL1Offset, const LICommon::RunInfo &runInfo); + __aicore__ inline void KeyNd2Nz(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo); + __aicore__ inline void KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, const LICommon::RunInfo &runInfo); + GlobalTensor blkTableGm_; + GlobalTensor keyGm_; + GlobalTensor queryGm_; + GlobalTensor mm1ResGm_; + + TBuf bufQL1_; + LocalTensor queryL1_; + TBuf bufKeyL1_; + LocalTensor keyL1_; + + TBuf bufQL0_; + LocalTensor queryL0_; + TBuf bufKeyL0_; + LocalTensor keyL0_; + + TBuf bufL0C_; + LocalTensor cL0_; + + uint64_t keyL1BufIdx_ = 0; + uint64_t queryL1Mte2BufIdx_ = 0; + uint64_t queryL1Mte1BufIdx_ = 0; + uint64_t l0BufIdx_ = 0; + + ConstInfo constInfo_; + +private: + static constexpr bool PAGE_ATTENTION = LIT::pageAttention; +}; + +template +__aicore__ inline void LIMatmul::InitParams(const ConstInfo &constInfo) +{ + constInfo_ = constInfo; +} + +template +__aicore__ inline void LIMatmul::InitBuffers(TPipe *pipe) +{ + pipe->InitBuffer(bufQL1_, QUERY_BUF_NUM * M_BASIC_BLOCK * D_BASIC_BLOCK * sizeof(Q_T)); + queryL1_ = bufQL1_.Get(); + pipe->InitBuffer(bufKeyL1_, KEY_BUF_NUM * S2_BASIC_BLOCK * D_BASIC_BLOCK * sizeof(K_T)); + keyL1_ = bufKeyL1_.Get(); + + pipe->InitBuffer(bufQL0_, L0_BUF_NUM * M_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 * sizeof(Q_T)); + queryL0_ = bufQL0_.Get(); + pipe->InitBuffer(bufKeyL0_, L0_BUF_NUM * D_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0 * sizeof(K_T)); + keyL0_ = bufKeyL0_.Get(); + + pipe->InitBuffer(bufL0C_, L0_BUF_NUM * M_BASIC_BLOCK_L0 * S2_BASIC_BLOCK_L0 * sizeof(float)); + cL0_ = bufL0C_.Get(); +} + +template +__aicore__ inline void +LIMatmul::InitMm1GlobalTensor(const GlobalTensor &blkTableGm, const GlobalTensor &keyGm, + const GlobalTensor &queryGm, const GlobalTensor &mm1ResGm) +{ + blkTableGm_ = blkTableGm; + keyGm_ = keyGm; + queryGm_ = queryGm; + mm1ResGm_ = mm1ResGm; +} + +template +__aicore__ inline void LIMatmul::ComputeMm1(const LICommon::RunInfo &runInfo) +{ + uint64_t s2GmBaseOffset = runInfo.s2Idx * constInfo_.s2BaseSize; + uint64_t s1gProcessSize = runInfo.actMBaseSize; + uint64_t s2ProcessSize = runInfo.actualSingleProcessSInnerSize; + for (uint64_t s2GmOffset = 0; s2GmOffset < s2ProcessSize; s2GmOffset += S2_BASIC_BLOCK) { + WaitFlag(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % KEY_BUF_NUM); + uint64_t s2L1RealSize = + s2GmOffset + S2_BASIC_BLOCK > s2ProcessSize ? s2ProcessSize - s2GmOffset : S2_BASIC_BLOCK; + if (PAGE_ATTENTION) { + KeyNd2NzForPA(s2L1RealSize, s2GmBaseOffset + s2GmOffset, runInfo); + }else { + KeyNd2Nz(s2L1RealSize, s2GmOffset, runInfo); + } + + SetFlag(MTE2_MTE1_EVENT); + WaitFlag(MTE2_MTE1_EVENT); + // s1gProcessSize当前必定不会超过2倍的s1g basic block + for (uint64_t s1gGmOffset = 0; s1gGmOffset < s1gProcessSize; s1gGmOffset += M_BASIC_BLOCK) { + uint64_t s1gL1RealSize = + s1gGmOffset + M_BASIC_BLOCK > s1gProcessSize ? s1gProcessSize - s1gGmOffset : M_BASIC_BLOCK; + if (runInfo.isFirstS2InnerLoop && s2GmOffset == 0) { + queryL1Mte2BufIdx_++; + queryL1Mte1BufIdx_ = queryL1Mte2BufIdx_; + WaitFlag(QUERY_MTE1_MTE2_EVENT + queryL1Mte2BufIdx_ % QUERY_BUF_NUM); + QueryNd2Nz(s1gL1RealSize, s1gGmOffset, runInfo); + SetFlag(MTE2_MTE1_EVENT); + WaitFlag(MTE2_MTE1_EVENT); + } else { + queryL1Mte1BufIdx_ = + queryL1Mte2BufIdx_ - (CeilDiv(s1gProcessSize, M_BASIC_BLOCK) - 1 - (s1gGmOffset > 0)); + } + for (uint64_t s2L1Offset = 0; s2L1Offset < s2L1RealSize; s2L1Offset += S2_BASIC_BLOCK_L0) { + uint64_t s2L0RealSize = + s2L1Offset + S2_BASIC_BLOCK_L0 > s2L1RealSize ? s2L1RealSize - s2L1Offset : S2_BASIC_BLOCK_L0; + for (uint64_t s1gL1Offset = 0; s1gL1Offset < s1gL1RealSize; s1gL1Offset += M_BASIC_BLOCK_L0) { + WaitFlag(M_MTE1_EVENT + l0BufIdx_ % L0_BUF_NUM); + uint64_t s1gL0RealSize = + s1gL1Offset + M_BASIC_BLOCK_L0 > s1gL1RealSize ? s1gL1RealSize - s1gL1Offset : M_BASIC_BLOCK_L0; + LoadQueryToL0a(s1gGmOffset, s1gL1Offset, s1gL1RealSize, s1gL0RealSize, runInfo); + LoadKeyToL0b(s2L1Offset, s2L1RealSize, s2L0RealSize, runInfo); + + SetFlag(MTE1_M_EVENT); + WaitFlag(MTE1_M_EVENT); + + ComuteL0c(s1gL0RealSize, s2L0RealSize, runInfo); + + SetFlag(M_MTE1_EVENT + l0BufIdx_ % L0_BUF_NUM); + + Fixp(s1gGmOffset + s1gL1Offset, s2GmOffset + s2L1Offset, s1gL0RealSize, s2L0RealSize, runInfo); + l0BufIdx_++; + } + } + if (s2GmOffset + S2_BASIC_BLOCK >= s2ProcessSize && runInfo.isLastS2InnerLoop) { + SetFlag(QUERY_MTE1_MTE2_EVENT + queryL1Mte1BufIdx_ % QUERY_BUF_NUM); + } + } + + SetFlag(KEY_MTE1_MTE2_EVENT + keyL1BufIdx_ % KEY_BUF_NUM); + keyL1BufIdx_++; + } +} + +template +__aicore__ inline void LIMatmul::KeyNd2Nz(uint64_t s2L1RealSize, uint64_t s2GmOffset, + const LICommon::RunInfo &runInfo) +{ + uint64_t s2L1Offset = 0; + while (s2L1Offset < s2L1RealSize) { + uint64_t keyGmOffset = runInfo.tensorKeyOffset + (s2GmOffset + s2L1Offset) * constInfo_.headDim; + // 搬运按照S2_BASIC_BLOCK_L0*D_BASIC_BLOCK_L0的方式在l1上排布, 方便后续mte1 + // 根据s2的offset判断当前属于前一个L0分型还是后一个L0分型,暂时只支持两个分型 + uint64_t s2Mte2Size = (s2L1RealSize <= S2_BASIC_BLOCK_L0 || s2L1Offset >= S2_BASIC_BLOCK_L0) ? + s2L1RealSize - s2L1Offset : + S2_BASIC_BLOCK_L0 - s2L1Offset; + + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = s2Mte2Size; // 行数 + nd2nzPara.dValue = constInfo_.headDim; + nd2nzPara.srcDValue = constInfo_.headDim; + nd2nzPara.dstNzC0Stride = s2L1Offset >= S2_BASIC_BLOCK_L0 ? + CeilAlign(s2L1RealSize - S2_BASIC_BLOCK_L0, (uint64_t)BLOCK_CUBE) : + (s2L1RealSize > S2_BASIC_BLOCK_L0 ? + S2_BASIC_BLOCK_L0 : + CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE)); // 对齐到16 单位block + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET + + (s2L1Offset >= S2_BASIC_BLOCK_L0 ? + S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 + (s2L1Offset - S2_BASIC_BLOCK_L0) * BLOCK_CUBE : + s2L1Offset * BLOCK_CUBE)], + keyGm_[keyGmOffset], nd2nzPara); + + s2L1Offset += s2Mte2Size; + } +} + +// blkNum, blkSize, N2, D +template +__aicore__ inline void LIMatmul::KeyNd2NzForPA(uint64_t s2L1RealSize, uint64_t s2GmOffset, + const LICommon::RunInfo &runInfo) +{ + uint64_t s2L1Offset = 0; + while (s2L1Offset < s2L1RealSize) { + uint64_t s2BlkId = (s2L1Offset + s2GmOffset) / constInfo_.kCacheBlockSize; + uint64_t s2BlkOffset = (s2L1Offset + s2GmOffset) % constInfo_.kCacheBlockSize; + uint64_t keyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo_.maxBlockNumPerBatch + s2BlkId) * + constInfo_.kCacheBlockSize * constInfo_.kHeadNum * constInfo_.headDim + + s2BlkOffset * constInfo_.headDim; + // 搬运按照S2_BASIC_BLOCK_L0*D_BASIC_BLOCK_L0的方式在l1上排布, 方便后续mte1 + // 根据s2的offset判断当前属于前一个L0分型还是后一个L0分型,暂时只支持两个分型 + uint64_t s2Mte2Size = (s2L1RealSize <= S2_BASIC_BLOCK_L0 || s2L1Offset >= S2_BASIC_BLOCK_L0) ? + s2L1RealSize - s2L1Offset : + S2_BASIC_BLOCK_L0 - s2L1Offset; + s2Mte2Size = s2BlkOffset + s2Mte2Size >= constInfo_.kCacheBlockSize ? constInfo_.kCacheBlockSize - s2BlkOffset : + s2Mte2Size; + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = s2Mte2Size; // 行数 + nd2nzPara.dValue = constInfo_.headDim; + nd2nzPara.srcDValue = constInfo_.headDim; + nd2nzPara.dstNzC0Stride = s2L1Offset >= S2_BASIC_BLOCK_L0 ? + CeilAlign(s2L1RealSize - S2_BASIC_BLOCK_L0, (uint64_t)BLOCK_CUBE) : + (s2L1RealSize > S2_BASIC_BLOCK_L0 ? + S2_BASIC_BLOCK_L0 : + CeilAlign(s2L1RealSize, (uint64_t)BLOCK_CUBE)); // 对齐到16 单位block + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET + + (s2L1Offset >= S2_BASIC_BLOCK_L0 ? + S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 + (s2L1Offset - S2_BASIC_BLOCK_L0) * BLOCK_CUBE : + s2L1Offset * BLOCK_CUBE)], + keyGm_[keyGmOffset], nd2nzPara); + + s2L1Offset += s2Mte2Size; + } +} + +// batch, s1, n2, g, d +template +__aicore__ inline void LIMatmul::QueryNd2Nz(uint64_t s1gL1RealSize, uint64_t s1gGmOffset, + const LICommon::RunInfo &runInfo) +{ + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = s1gL1RealSize; // 行数 + nd2nzPara.dValue = constInfo_.headDim; + nd2nzPara.srcDValue = constInfo_.headDim; + nd2nzPara.dstNzC0Stride = CeilAlign(s1gL1RealSize, (uint64_t)BLOCK_CUBE); // 对齐到16 单位block + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + // 默认一块buf最多放两份 + DataCopy(queryL1_[(queryL1Mte2BufIdx_ % QUERY_BUF_NUM) * QUERY_BUFFER_OFFSET], + queryGm_[runInfo.tensorQueryOffset + s1gGmOffset * constInfo_.headDim], nd2nzPara); +} + +template +__aicore__ inline void LIMatmul::LoadQueryToL0a(uint64_t s1gGmOffset, uint64_t s1gL1Offset, uint64_t s1gL1RealSize, + uint64_t s1gL0RealSize, const LICommon::RunInfo &runInfo) +{ + LoadData3DParamsV2 loadData3DParams; + // SetFmatrixParams + loadData3DParams.l1H = CeilDiv(s1gL1RealSize, BLOCK_CUBE); // Hin=M1=8 + loadData3DParams.l1W = BLOCK_CUBE; // Win=M0 + loadData3DParams.channelSize = constInfo_.headDim; // Cin=K + + loadData3DParams.padList[0] = 0; + loadData3DParams.padList[1] = 0; + loadData3DParams.padList[2] = 0; + loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果 + + // SetLoadToA0Params + loadData3DParams.mExtension = CeilAlign(s1gL0RealSize, BLOCK_CUBE); // M height维度目的 + loadData3DParams.kExtension = constInfo_.headDim; // K width维度目的 + loadData3DParams.mStartPt = s1gL1Offset; + loadData3DParams.kStartPt = 0; + loadData3DParams.strideW = 1; + loadData3DParams.strideH = 1; + loadData3DParams.filterW = 1; + loadData3DParams.filterSizeW = (1 >> 8) & 255; + loadData3DParams.filterH = 1; + loadData3DParams.filterSizeH = (1 >> 8) & 255; + loadData3DParams.dilationFilterW = 1; + loadData3DParams.dilationFilterH = 1; + loadData3DParams.enTranspose = 0; + loadData3DParams.fMatrixCtrl = 0; + + LoadData(queryL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], + queryL1_[(queryL1Mte1BufIdx_ % QUERY_BUF_NUM) * QUERY_BUFFER_OFFSET], + loadData3DParams); +} + +template +__aicore__ inline void LIMatmul::LoadKeyToL0b(uint64_t s2L1Offset, uint64_t s2L1RealSize, uint64_t s2L0RealSize, + const LICommon::RunInfo &runInfo) +{ + uint64_t keyL1Offset = s2L1Offset >= S2_BASIC_BLOCK_L0 ? S2_BASIC_BLOCK_L0 * D_BASIC_BLOCK_L0 : 0; + LoadData2DParams loadData2DParams; + loadData2DParams.startIndex = 0; + loadData2DParams.repeatTimes = CeilDiv(s2L0RealSize, BLOCK_CUBE) * CeilDiv(constInfo_.headDim, BLOCK_CUBE); + loadData2DParams.srcStride = 1; + loadData2DParams.dstGap = 0; + loadData2DParams.ifTranspose = false; + LoadData(keyL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], + keyL1_[(keyL1BufIdx_ % KEY_BUF_NUM) * KEY_BUFFER_OFFSET + keyL1Offset], loadData2DParams); +} + +template +__aicore__ inline void LIMatmul::ComuteL0c(uint64_t s1gL0RealSize, uint64_t s2L0RealSize, + const LICommon::RunInfo &runInfo) +{ + MmadParams mmadParams; + mmadParams.m = CeilAlign(s1gL0RealSize, BLOCK_CUBE); + mmadParams.n = s2L0RealSize; + mmadParams.k = constInfo_.headDim; + mmadParams.cmatrixInitVal = true; + mmadParams.cmatrixSource = false; + mmadParams.unitFlag = 0b11; + Mmad(cL0_[(l0BufIdx_ % L0_BUF_NUM) * L0C_BUFFER_OFFSET], queryL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], + keyL0_[(l0BufIdx_ % L0_BUF_NUM) * L0AB_BUFFER_OFFSET], mmadParams); + if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) { + PipeBarrier(); + } +} + +template +__aicore__ inline void LIMatmul::Fixp(uint64_t s1gGmOffset, uint64_t s2GmOffset, uint64_t s1gL0RealSize, + uint64_t s2L0RealSize, const LICommon::RunInfo &runInfo) +{ + AscendC::DataCopyCO12DstParams intriParams; + intriParams.mSize = CeilAlign(s1gL0RealSize, BLOCK_CUBE); + intriParams.nSize = s2L0RealSize; + intriParams.dstStride = runInfo.actualSingleProcessSInnerSizeAlign; + intriParams.srcStride = CeilAlign(s1gL0RealSize, BLOCK_CUBE); + // set mode according to dtype + intriParams.quantPre = QuantMode_t::NoQuant; + intriParams.nz2ndEn = true; + intriParams.unitFlag = 0b11; // 3 unitflag + intriParams.reluPre = 1; + AscendC::SetFixpipeNz2ndFlag(1, 1, 1); + AscendC::DataCopy(mm1ResGm_[(runInfo.loop % 2) * constInfo_.mBaseSize * constInfo_.s2BaseSize + + s1gGmOffset * intriParams.dstStride + s2GmOffset], + cL0_[(l0BufIdx_ % L0_BUF_NUM) * L0C_BUFFER_OFFSET], intriParams); +} + +template +__aicore__ inline void LIMatmul::AllocEventID() +{ + SetMMLayoutTransform(true); + SetFlag(KEY_MTE1_MTE2_EVENT + 0); + SetFlag(KEY_MTE1_MTE2_EVENT + 1); + SetFlag(KEY_MTE1_MTE2_EVENT + 2); + + SetFlag(QUERY_MTE1_MTE2_EVENT + 0); + SetFlag(QUERY_MTE1_MTE2_EVENT + 1); + + SetFlag(M_MTE1_EVENT + 0); + SetFlag(M_MTE1_EVENT + 1); +} + +template +__aicore__ inline void LIMatmul::FreeEventID() +{ + SetMMLayoutTransform(false); + WaitFlag(KEY_MTE1_MTE2_EVENT + 0); + WaitFlag(KEY_MTE1_MTE2_EVENT + 1); + WaitFlag(KEY_MTE1_MTE2_EVENT + 2); + + WaitFlag(QUERY_MTE1_MTE2_EVENT + 0); + WaitFlag(QUERY_MTE1_MTE2_EVENT + 1); + + WaitFlag(M_MTE1_EVENT + 0); + WaitFlag(M_MTE1_EVENT + 1); +} +} // namespace LIKernel +#endif \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_service_vector.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_service_vector.h new file mode 100644 index 00000000000..df7dcdaafd6 --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_service_vector.h @@ -0,0 +1,613 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_service_vector.h + * \brief + */ +#ifndef LIGHTNING_INDEXER_SERVICE_VECTOR_H +#define LIGHTNING_INDEXER_SERVICE_VECTOR_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "lightning_indexer_common.h" +#include "lightning_indexer_vector.h" + +namespace LIKernel { +using namespace LICommon; +using namespace LIServiceVec; +constexpr uint32_t BASE_TOPK = 2048; +constexpr uint32_t LD_PARAM_NUM = 16; + +template +class LIVector { +public: + // =================================类型定义区================================= + // 中间计算数据类型为float,高精度模式 + using K_T = typename LIT::keyType; + static constexpr LI_LAYOUT LAYOUT_T = LIT::layout; + + // MM输出数据类型, 当前只支持float + using MM1_OUT_T = float; + + __aicore__ inline LIVector(){}; + __aicore__ inline void ProcessVec(const LICommon::RunInfo &info); + __aicore__ inline void ProcessLD(); + __aicore__ inline void InitBuffers(TPipe *pipe); + __aicore__ inline void InitParams(const struct LICommon::ConstInfo &constInfo, + const LITilingData *__restrict tilingData); + __aicore__ inline void InitVec1GlobalTensor(GlobalTensor mm1ResGm, GlobalTensor vec1ResGm, + GlobalTensor vec1ParamGm, GlobalTensor weightsGm, + GlobalTensor indiceOutGm); + __aicore__ inline void CleanInvalidOutput(int64_t invalidS1offset); + __aicore__ inline void AllocEventID(); + __aicore__ inline void FreeEventID(); + __aicore__ inline void InitLDBuffers(TPipe *pipe); + +protected: + GlobalTensor mm1ResGm; + GlobalTensor vec1ResGm; + GlobalTensor vec1ParamGm; + GlobalTensor weightsGm; + GlobalTensor indiceOutGm; + // =================================常量区================================= + +private: + // ================================Local Buffer区==================================== + // queue + TQue inQueue_; + TQue outQueue_; + + // tmp buff for vector + TBuf sortOutBuf_; + TBuf indexBuf_; + TBuf reduceOutBuf_; + TBuf brcBuf_; + TBuf paramBuf_; + + // tmp buff for LD + TBuf<> ldToBeMrgBuf_; + TBuf<> ldTmpBuf_; + TBuf<> ldOutValueBuf_; + TBuf<> ldOutIdxBuf_; + + LocalTensor globalTopkIndice_; + LocalTensor globalTopkUb_; + LocalTensor SortedBasicBlock_; + + int32_t blockId_ = -1; + // para for vector + int32_t groupInner_ = 0; + int32_t globalTopkNum_ = 0; + int64_t blockS2StartIdx_ = 0; + int32_t gSize_ = 0; + int32_t kHeadNum_ = 0; + int32_t s1BaseSize_ = 0; + int32_t s2BaseSize_ = 0; + + // para for LD + uint32_t mrgListNum_ = 4; + uint32_t paramNum_ = 16; + + constexpr static uint32_t REDUCE_BANK_CONFLICT_OFFSETS = 256; + constexpr static uint32_t REDUCE_BANK_CONFLICT_NUM = REDUCE_BANK_CONFLICT_OFFSETS / sizeof(float); + + struct LICommon::ConstInfo constInfo_; +}; + +template +__aicore__ inline void LIVector::InitBuffers(TPipe *pipe) +{ + uint32_t outNeedBufSize = (BASE_TOPK * 2) * 2 * sizeof(float); + uint32_t reduceCacheSize = REDUCE_BANK_CONFLICT_OFFSETS + groupInner_ * s2BaseSize_ * sizeof(float); + outNeedBufSize = reduceCacheSize > outNeedBufSize ? reduceCacheSize : outNeedBufSize; + + pipe->InitBuffer(inQueue_, 2, + groupInner_ * s2BaseSize_ * sizeof(float) + s2BaseSize_ * sizeof(float)); // 69KB mm_out_ub + pipe->InitBuffer(outQueue_, 1, outNeedBufSize); // 32KB extract + pipe->InitBuffer(sortOutBuf_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2 * sizeof(float)); // 64KB + pipe->InitBuffer(indexBuf_, s2BaseSize_ * sizeof(int32_t)); // 2KB + pipe->InitBuffer(reduceOutBuf_, s2BaseSize_ * 2 * sizeof(float)); // 4KB + pipe->InitBuffer(brcBuf_, groupInner_ * 8 * sizeof(float)); + pipe->InitBuffer(paramBuf_, LD_PARAM_NUM * sizeof(int64_t)); + + // + globalTopkIndice_ = indexBuf_.Get(); + globalTopkUb_ = sortOutBuf_.Get(); + SortedBasicBlock_ = globalTopkUb_[BASE_TOPK * 2 * 2]; + globalTopkNum_ = 0; + + // 基本块执行前初始化UB和GM + // step1. 初始化一个有序索引 0 - s2BaseSize_ + ArithProgression(globalTopkIndice_, 0, 1, s2BaseSize_); + // step2. globalTopkUb_ [CeilDiv(s1BaseSize_, 2), BASE_TOPK, 2] -inf,-1 + InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2); + + // step3. 初始化vec1ParamGm,是否进行LD的标志位设为-1(needFd=-1) + // vec1ResIn32Gm = [aic, 2, s1BaseSize_, 16] int32 + // ws清零 [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, ......] + LocalTensor tmpfBuff = outQueue_.AllocTensor(); + Duplicate(tmpfBuff.template ReinterpretCast(), -1, 2 * (s1BaseSize_ / 2) * paramNum_ * 2); + SetWaitFlag(HardEvent::V_MTE3); + int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移 + (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_; // 每个AIV的地址偏移,S1方向 + DataCopyPad(vec1ParamGm[wsInfoOffset], tmpfBuff.template ReinterpretCast(), + {1, static_cast((s1BaseSize_ / 2) * 2 * paramNum_ * sizeof(int64_t)), 0, 0}); + SetWaitFlag(HardEvent::MTE3_V); + outQueue_.FreeTensor(tmpfBuff); +} + +template +__aicore__ inline void LIVector::InitLDBuffers(TPipe *pipe) +{ + pipe->Reset(); + pipe->InitBuffer(ldToBeMrgBuf_, 2 * BASE_TOPK * mrgListNum_ * sizeof(float)); // 2:value + index + pipe->InitBuffer(ldTmpBuf_, 2 * BASE_TOPK * mrgListNum_ * sizeof(float)); // 2:value + index + pipe->InitBuffer(ldOutValueBuf_, BASE_TOPK * sizeof(float)); + pipe->InitBuffer(ldOutIdxBuf_, BASE_TOPK * sizeof(int32_t)); +} + +template +__aicore__ inline void LIVector::InitParams(const struct LICommon::ConstInfo &constInfo, + const LITilingData *__restrict tilingData) +{ + this->constInfo_ = constInfo; + blockS2StartIdx_ = 0; + gSize_ = constInfo.gSize; + // define N2 para + kHeadNum_ = constInfo.kHeadNum; + // define MMBase para + s1BaseSize_ = constInfo.s1BaseSize; + s2BaseSize_ = constInfo.s2BaseSize; + + // group ub 切分因子当前按照UB空间强制为16 + groupInner_ = 16; + + blockId_ = GetBlockIdx(); +} + +template +__aicore__ inline void +LIVector::InitVec1GlobalTensor(GlobalTensor mm1ResGm, GlobalTensor vec1ResGm, + GlobalTensor vec1ParamGm, GlobalTensor weightsGm, + GlobalTensor indiceOutGm) +{ + this->mm1ResGm = mm1ResGm; + this->vec1ResGm = vec1ResGm; + this->vec1ParamGm = vec1ParamGm; + this->weightsGm = weightsGm; + this->indiceOutGm = indiceOutGm; +} + +template +__aicore__ inline void LIVector::AllocEventID() +{ +} + +template +__aicore__ inline void LIVector::FreeEventID() +{ +} + +template +__aicore__ inline void LIVector::CleanInvalidOutput(int64_t invalidS1offset) +{ + // init -1 and copy to output + LocalTensor valueULocal = outQueue_.AllocTensor(); + LocalTensor idxULocal1 = valueULocal.template ReinterpretCast(); + Duplicate(idxULocal1, constInfo_.INVALID_IDX, constInfo_.sparseCount); + outQueue_.EnQue(valueULocal); + valueULocal = outQueue_.DeQue(); + LIServiceVec::CopyOut(indiceOutGm[invalidS1offset], idxULocal1, constInfo_.sparseCount); + outQueue_.FreeTensor(valueULocal); +} + +template +__aicore__ inline void LIVector::ProcessVec(const LICommon::RunInfo &info) +{ + int32_t cuBaseS1Idx = info.gS1Idx * s1BaseSize_; + int32_t cuBaseS2Idx = info.s2Idx * s2BaseSize_; + + // 计算基本块基地址偏移 偶数循环 -> 0 + aic_offset 奇数循环 -> 512*512 + aic_offset + int64_t mmGmOffset = (info.loop % 2) * ((s1BaseSize_ * gSize_) * s2BaseSize_); + // (B,S1,N1,1);(T,N1,1) -> (B,S1,N2,G,1) 当前只切分到S1轴 + int64_t weightGmOffset = info.tensorWeightsOffset + cuBaseS1Idx * kHeadNum_ * gSize_; + + PipeBarrier(); + // cuS1BeginIdxPerAiv: 每个AIV的S1起始偏移 + int32_t cuS1BeginIdxPerAiv = cuBaseS1Idx; + int32_t cuS1ProcNum = + cuS1BeginIdxPerAiv + s1BaseSize_ > info.actS1Size ? info.actS1Size % s1BaseSize_ : s1BaseSize_; + // cuS1ProcNumPerAiv: 每个AIv的S1计算量 + int32_t cuS1ProcNumPerAiv = blockId_ % 2 == 0 ? CeilDiv(cuS1ProcNum, 2) : (cuS1ProcNum / 2); + cuS1BeginIdxPerAiv += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2); + + // 基本块基地址偏移奇数核加一个S1地址偏移 + weightGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * kHeadNum_ * gSize_; + mmGmOffset += (blockId_ % 2) * CeilDiv(cuS1ProcNum, 2) * gSize_ * info.actualSingleProcessSInnerSizeAlign; + + // cut G + int32_t outerG = CeilDiv(gSize_, groupInner_); + + // 非首个基本块, M(S1)轴发生切换需要初始化 + if (info.loop != 0 && info.s2Idx == 0) { + // globalTopkUb_ value,index=-inf,-1 + InitSortOutBuf(globalTopkUb_, CeilDiv(s1BaseSize_, 2) * BASE_TOPK * 2); + blockS2StartIdx_ = 0; + } else if (info.loop == 0) { + blockS2StartIdx_ = info.s2Idx; + } + // cuRealAcSeq: 当前基本块S1对应的AcSeq + int32_t cuRealAcSeq = info.actS2Size; + if (constInfo_.attenMaskFlag) { + // attenMask true场景 + cuRealAcSeq = info.actS2Size - (info.actS1Size - cuS1BeginIdxPerAiv); + } + LocalTensor reduceOutBuff = reduceOutBuf_.Get(); + LocalTensor brcBuf = brcBuf_.Get(); + // LD输出S1方向偏移,保证2个Vector输出的内容连续 + uint32_t ldS1Offset = (blockId_ % 2 == 0) ? s1BaseSize_ / 2 - cuS1ProcNumPerAiv : 0; + for (int innerS1Idx = 0; innerS1Idx < cuS1ProcNumPerAiv; innerS1Idx++) { + if (constInfo_.attenMaskFlag) { + cuRealAcSeq += 1; + } + int32_t cuS2Len = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq ? cuRealAcSeq - cuBaseS2Idx : s2BaseSize_; + int32_t cuS1Idx = cuS1BeginIdxPerAiv + innerS1Idx; + if (cuRealAcSeq > 0 && cuS2Len > 0) { + int32_t cuS2LenVecAlign = CeilDiv(cuS2Len, s2BaseSize_) * s2BaseSize_; + int32_t mmUbStride = (cuS2LenVecAlign - info.actualSingleProcessSInnerSizeAlign) / B32_BLOCK_ALIGN_NUM; + LocalTensor reduceOutInner = reduceOutBuff[s2BaseSize_]; + PipeBarrier(); + LocalTensor reduceCacheBuf = outQueue_.AllocTensor(); + for (int outerGidx = 0; outerGidx < outerG; outerGidx++) { + int32_t procGnum = outerGidx != outerG - 1 ? groupInner_ : gSize_ - outerGidx * groupInner_; + LocalTensor mmInUb = inQueue_.AllocTensor(); + LocalTensor weightsInUb = mmInUb[procGnum * s2BaseSize_]; + LocalTensor weightsInTUb = weightsInUb.template ReinterpretCast(); + if constexpr (!IsSameType::value) { + weightsInTUb = weightsInTUb[groupInner_]; + } + LIServiceVec::CopyIn(mmInUb, weightsInTUb, mm1ResGm, weightsGm, + mmGmOffset + innerS1Idx * gSize_ * info.actualSingleProcessSInnerSizeAlign + + outerGidx * groupInner_ * info.actualSingleProcessSInnerSizeAlign, + weightGmOffset + innerS1Idx * gSize_ + outerGidx * groupInner_, procGnum, + info.actualSingleProcessSInnerSizeAlign, mmUbStride); + + inQueue_.EnQue(mmInUb); + mmInUb = inQueue_.DeQue(); + weightsInUb = mmInUb[procGnum * s2BaseSize_]; + LIServiceVec::DoScale(reduceCacheBuf[REDUCE_BANK_CONFLICT_NUM], mmInUb, weightsInUb, weightsInTUb, + brcBuf, procGnum, s2BaseSize_, outerGidx); + // confused reduceOp in DoScale + // neednot use LIServiceVec::doReduce(mmInUb, reduceOutInner, procGnum, (s2BaseSize_+8)); + inQueue_.FreeTensor(mmInUb); + } + + int32_t gRedCnt = groupInner_ > gSize_ ? gSize_ : groupInner_; + bool isS2End = cuBaseS2Idx + s2BaseSize_ >= cuRealAcSeq; + LIServiceVec::DoReduce(reduceCacheBuf[REDUCE_BANK_CONFLICT_NUM], reduceOutInner, gRedCnt, s2BaseSize_); + outQueue_.FreeTensor(reduceCacheBuf); + + LocalTensor sortScoreUb = reduceOutBuff; + LocalTensor sortIndiceUb = reduceOutBuff[cuS2LenVecAlign]; + PipeBarrier(); + Duplicate(sortScoreUb.template ReinterpretCast(), LIServiceVec::NEG_INF, cuS2LenVecAlign); + PipeBarrier(); + Adds(sortScoreUb, reduceOutInner, 0.0f, cuS2Len); + PipeBarrier(); + LocalTensor sortIndiceUbInt = sortIndiceUb.template ReinterpretCast(); + // 无效数据索引填充为-1 + if (cuS2LenVecAlign != cuS2Len) { + Duplicate(sortIndiceUbInt, -1, cuS2LenVecAlign); + } + PipeBarrier(); + Adds(sortIndiceUbInt, globalTopkIndice_, static_cast(cuBaseS2Idx), cuS2Len); + PipeBarrier(); + + LocalTensor tmpSortBuf = outQueue_.AllocTensor(); + if (info.actS1Size > 4) { + // info.actS1Size > 4 则单个vector核内处理的 s1>2,缓存方案无法处理 + LIServiceVec::SortAll(reduceOutBuff, tmpSortBuf, + cuS2LenVecAlign); // cuS2LenVecAlign <= s2BaseSize_, fill -inf + PipeBarrier(); + LIServiceVec::MergeSort(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], BASE_TOPK, reduceOutBuff, + cuS2LenVecAlign, tmpSortBuf); + } else { + int64_t globalTopkUbCacheIdx = (info.s2Idx - blockS2StartIdx_) % 4; + Sort( + SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2 + globalTopkUbCacheIdx * s2BaseSize_ * 2], + reduceOutBuff, sortIndiceUbInt.template ReinterpretCast(), tmpSortBuf, + cuS2LenVecAlign / 32); + // 缓存4块512或者S2结束, 需要进行精排 + if (globalTopkUbCacheIdx == 3 || isS2End || info.isAllLoopEnd) { + LocalTensor tt = SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2]; + // 前4块直接精排覆盖到globalTopkUb_ + if (info.s2Idx - blockS2StartIdx_ < 4) { + MrgBasicBlock(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], tt, + static_cast(globalTopkUbCacheIdx + 1), s2BaseSize_); + } else { // 后面缓存在 SortedBasicBlock_, 先精排, 再merge到globalTopkUb_ + if (globalTopkUbCacheIdx > 0) { + MrgBasicBlock(tmpSortBuf, tt, static_cast(globalTopkUbCacheIdx + 1), s2BaseSize_); + PipeBarrier(); + DataCopy(SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2], tmpSortBuf, + (globalTopkUbCacheIdx + 1) * s2BaseSize_ * 2); + } + PipeBarrier(); + SparseTopK(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], + SortedBasicBlock_[innerS1Idx * BASE_TOPK * 2], tmpSortBuf, BASE_TOPK, + s2BaseSize_ * (globalTopkUbCacheIdx + 1)); + } + } + } + + PipeBarrier(); + outQueue_.FreeTensor(tmpSortBuf); + + bool needCopyOutGm = blockS2StartIdx_ == 0 && isS2End; + + // 中间结果保存 + bool needCopyWsGm = info.isAllLoopEnd || isS2End; + + if (needCopyOutGm) { + LocalTensor valueULocal = outQueue_.AllocTensor(); + LocalTensor idxULocal = valueULocal.template ReinterpretCast()[BASE_TOPK]; + ExtractIndex(idxULocal, globalTopkUb_[innerS1Idx * BASE_TOPK * 2].template ReinterpretCast(), + BASE_TOPK); + PipeBarrier(); + InitSortOutBuf(globalTopkUb_[innerS1Idx * BASE_TOPK * 2], BASE_TOPK * 2); + outQueue_.EnQue(valueULocal); + valueULocal = outQueue_.DeQue(); + LocalTensor idxULocal1 = valueULocal.template ReinterpretCast()[BASE_TOPK]; + LIServiceVec::CopyOut(indiceOutGm[info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount], + idxULocal1, constInfo_.sparseCount); + outQueue_.FreeTensor(valueULocal); + } else if (needCopyWsGm) { + // vec1Res Gm = [aic, s1BaseSize_, 2, 2, topkOut_] float32 + // vec1Param Gm = [aic, s1BaseSize_, 2, 16] int64 + // 16 = [needFd, s2AcSeq, s2Start, s2End, isS2End, bn2idx, s1Idx, S1ProcNum, ......] + + int64_t wsOffset = (blockId_ / 2) * s1BaseSize_ * 2 * 2 * BASE_TOPK + // 2个AIV共同地址偏移 + (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * 2 * BASE_TOPK + // 每个AIV的地址偏移,S1方向 + (ldS1Offset + innerS1Idx) * 2 * 2 * BASE_TOPK; + int64_t wsInfoOffset = (blockId_ / 2) * s1BaseSize_ * 2 * paramNum_ + // 2个AIV共同地址偏移 + (blockId_ % 2) * (s1BaseSize_ / 2) * 2 * paramNum_ + // 每个AIV的地址偏移,S1方向 + (ldS1Offset + innerS1Idx) * 2 * paramNum_; + + LocalTensor tmpiBuff = paramBuf_.Get(); + SetWaitFlag(HardEvent::MTE3_S); + tmpiBuff.SetValue(0, static_cast(1)); + tmpiBuff.SetValue(1, static_cast(cuRealAcSeq)); + tmpiBuff.SetValue(2, static_cast(blockS2StartIdx_)); + tmpiBuff.SetValue(3, static_cast(cuBaseS2Idx + cuS2Len)); + tmpiBuff.SetValue(4, static_cast(isS2End)); + tmpiBuff.SetValue(5, static_cast(info.bN2Idx)); + tmpiBuff.SetValue(6, static_cast(cuS1Idx)); + tmpiBuff.SetValue(7, static_cast(cuS1ProcNum)); + tmpiBuff.SetValue(8, static_cast(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount)); + // 写入头尾判断 + // [head, tail] + // head: 与前面规约,与前后规约 + // tail: 与后面规约 + bool isTailReduce = blockS2StartIdx_ == 0; // 一定是isLastTile + // WS偏移规则 blockS2StartIdx_ != 0 + // 跟前面块做规约 写到0偏移 不用做计算 blockS2StartIdx_ == 0 and !isS2End + // 跟后面块做规约 写到1偏移 需要 + s1BaseSize_, BASE_TOPK*2 + if (isTailReduce) { // S2不是最后结束的数据就需要往后做规约,放入第二块ws + wsInfoOffset += paramNum_; + wsOffset += 2 * BASE_TOPK; + } + SetWaitFlag(HardEvent::S_MTE3); + LIServiceVec::CopyOut(vec1ParamGm[wsInfoOffset], tmpiBuff, 16); + SetWaitFlag(HardEvent::V_MTE3); + LIServiceVec::CopyOut(vec1ResGm[wsOffset], globalTopkUb_[innerS1Idx * BASE_TOPK * 2], 2 * BASE_TOPK); + SetWaitFlag(HardEvent::MTE3_V); + } + } else if (cuRealAcSeq <= 0) { + CleanInvalidOutput(info.indiceOutOffset + cuS1Idx * constInfo_.sparseCount); + } + } + + // BNSD场景无效S1 输出-1 + if (LAYOUT_T == LI_LAYOUT::BSND) { + // 最后一个S1的基本块, 需要 >= info.actS1Size + bool isS1LoopEnd = (cuBaseS1Idx + s1BaseSize_) >= info.actS1Size; + int32_t invalidS1Num = constInfo_.qSeqSize - info.actS1Size; + // blockS2StartIdx_ == 0 控制S2从开始的核去做冗余清理 + if (invalidS1Num > 0 && isS1LoopEnd && blockS2StartIdx_ == 0) { + int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num, 2) : (invalidS1Num / 2); + int32_t s1OffsetPerAiv = info.actS1Size + (blockId_ % 2) * CeilDiv(invalidS1Num, 2); + for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) { + CleanInvalidOutput(info.indiceOutOffset + (s1OffsetPerAiv + innerS1Idx) * constInfo_.sparseCount); + } + } + + int32_t invalidS1Num2 = info.actS1Size - info.actS2Size; + if (invalidS1Num2 > 0 && isS1LoopEnd && blockS2StartIdx_ == 0 && constInfo_.attenMaskFlag) { + int32_t s1NumPerAiv = blockId_ % 2 == 0 ? CeilDiv(invalidS1Num2, 2) : (invalidS1Num2 / 2); + int32_t s1OffsetPerAiv = (blockId_ % 2) * CeilDiv(invalidS1Num2, 2); + for (int innerS1Idx = 0; innerS1Idx < s1NumPerAiv; innerS1Idx++) { + CleanInvalidOutput((info.bN2Idx * constInfo_.qSeqSize + s1OffsetPerAiv + innerS1Idx) * + constInfo_.sparseCount); + } + } + } + + if (info.isLastS2InnerLoop) { + // S2最后一个Loop后, 下一个基本块初始从0开始 + blockS2StartIdx_ = 0; + } +} + +template +__aicore__ inline void LIVector::ProcessLD() +{ + int32_t curCubeId = blockId_ / 2; + int32_t tmpCubeId = curCubeId; + + int64_t s2ActSeq; + int64_t s2Start; + int64_t s2End; + int64_t isS2End; + int64_t bn2Idx; + int64_t s1Idx; + uint32_t acc_list_num = 0; + int64_t bIdx = 0; + int64_t needFd; + int64_t wsOffset; + int64_t wsInfoOffset = 0; + int64_t nextneedFd; + int64_t valueOffset = 0; + int64_t outOffset = 0; + + LocalTensor curValueIdxUb = ldToBeMrgBuf_.Get(); + LocalTensor tmpUb = ldTmpBuf_.Get(); + + // S2开头信息 + // 开始必然没有头规约,因此从尾规约开始处理,while循环读取下一个核的头规约 + // 存满4个list或者遇到S2结尾,则做merge,直到做完S2 + // 每个核都忽略自己的头规约,因为必然由前面的核做完 + uint32_t s1LdStartIdx = 0; + uint32_t s1ProcNum = 0; + uint64_t paramGmCoreOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_; + for (uint32_t innerS1Idx = 0; innerS1Idx < s1BaseSize_; innerS1Idx++) { + needFd = vec1ParamGm.GetValue(paramGmCoreOffset + innerS1Idx * 2 * paramNum_ + paramNum_); + if (needFd == 1) { + s1LdStartIdx = (s1ProcNum == 0) ? innerS1Idx : s1LdStartIdx; + s1ProcNum++; + } + } + + if (s1ProcNum == 0) { + return; + } + + // S1逐行计算 + uint32_t s1VecNum = CeilDiv(s1ProcNum, 2); + if (blockId_ % 2 == 1) { + s1LdStartIdx = s1LdStartIdx + s1VecNum; + s1VecNum = s1ProcNum - s1VecNum; + } + for (uint32_t innerS1Idx = s1LdStartIdx; innerS1Idx < s1LdStartIdx + s1VecNum; innerS1Idx++) { + // 重置偏移 + tmpCubeId = curCubeId; + acc_list_num = 0; + valueOffset = 0; + + // 搬入数据 + wsOffset = tmpCubeId * s1BaseSize_ * 2 * 2 * BASE_TOPK + // 2个AIV共同地址偏移 + innerS1Idx * 2 * 2 * BASE_TOPK + 2 * BASE_TOPK; + SetWaitFlag(HardEvent::V_MTE2); + SetWaitFlag(HardEvent::S_MTE2); + DataCopyPad(curValueIdxUb, vec1ResGm[wsOffset], + {1, static_cast(2 * BASE_TOPK * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0}); + acc_list_num++; + valueOffset += 2 * BASE_TOPK; + + // 获取下一个核规约信息 + tmpCubeId++; + wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_; + needFd = vec1ParamGm.GetValue(wsInfoOffset); + isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4); + s1Idx = vec1ParamGm.GetValue(wsInfoOffset + 6); + outOffset = vec1ParamGm.GetValue(wsInfoOffset + 8); + + while (needFd == 1) { + // 搬入头规约数据 + wsOffset = tmpCubeId * s1BaseSize_ * 2 * 2 * BASE_TOPK + // 2个AIV共同地址偏移 + innerS1Idx * 2 * 2 * BASE_TOPK; + SetWaitFlag(HardEvent::V_MTE2); + SetWaitFlag(HardEvent::S_MTE2); + DataCopyPad(curValueIdxUb[valueOffset], vec1ResGm[wsOffset], + {1, static_cast(2 * BASE_TOPK * sizeof(int32_t)), 0, 0}, {true, 0, 0, 0}); + valueOffset += 2 * BASE_TOPK; + acc_list_num++; + + // 每满4个list,聚合 前2K为mrg结果 + if (acc_list_num == mrgListNum_) { + // MrgSort 四条2048的队列,Mrg成一条 + AscendC::MrgSort4Info params; + params.elementLengths[0] = BASE_TOPK; + params.elementLengths[1] = BASE_TOPK; + params.elementLengths[2] = BASE_TOPK; + params.elementLengths[3] = BASE_TOPK; + params.ifExhaustedSuspension = true; + params.validBit = 0b1111; + params.repeatTimes = 1; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = curValueIdxUb[0]; + srcList.src2 = curValueIdxUb[2 * BASE_TOPK]; + srcList.src3 = curValueIdxUb[4 * BASE_TOPK]; + srcList.src4 = curValueIdxUb[6 * BASE_TOPK]; + SetWaitFlag(HardEvent::MTE2_V); + MrgSort(tmpUb, srcList, params); + PipeBarrier(); + DataCopy(curValueIdxUb, tmpUb, 2 * BASE_TOPK); + PipeBarrier(); + acc_list_num = 1; + valueOffset = 2 * BASE_TOPK; + } + + // reduce到S2末尾,则跳出 + if (isS2End == 1) { + break; + } + + tmpCubeId++; + wsInfoOffset = tmpCubeId * s1BaseSize_ * 2 * paramNum_ + innerS1Idx * 2 * paramNum_; + needFd = vec1ParamGm.GetValue(wsInfoOffset); + isS2End = vec1ParamGm.GetValue(wsInfoOffset + 4); + } + + // mrg不足4个list的数据 + if (acc_list_num != 1) { + AscendC::MrgSort4Info params; + params.elementLengths[0] = BASE_TOPK; + params.elementLengths[1] = BASE_TOPK; + params.elementLengths[2] = BASE_TOPK; + params.elementLengths[3] = BASE_TOPK; + params.ifExhaustedSuspension = true; + if (acc_list_num == 2) { + params.validBit = 0b0011; + } else if (acc_list_num == 3) { + params.validBit = 0b0111; + } + params.repeatTimes = 1; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = curValueIdxUb[0]; + srcList.src2 = curValueIdxUb[2 * BASE_TOPK]; + srcList.src3 = curValueIdxUb[4 * BASE_TOPK]; + srcList.src4 = curValueIdxUb[6 * BASE_TOPK]; + SetWaitFlag(HardEvent::MTE2_V); + MrgSort(tmpUb, srcList, params); + PipeBarrier(); + DataCopy(curValueIdxUb, tmpUb, 2 * BASE_TOPK); + PipeBarrier(); + } + + // 搬出 + LocalTensor outValueUb = ldOutValueBuf_.Get(); + LocalTensor outIdxUb = ldOutIdxBuf_.Get(); + + Extract(outValueUb, outIdxUb, curValueIdxUb, (BASE_TOPK / 32)); + LocalTensor idxULocal1 = outIdxUb.template ReinterpretCast(); + SetWaitFlag(HardEvent::V_MTE3); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(indiceOutGm[outOffset], idxULocal1, + {1, static_cast(constInfo_.sparseCount * sizeof(int32_t)), 0, 0}); + SetWaitFlag(HardEvent::MTE3_V); + } +} +} // namespace LIKernel +#endif \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_template_tiling_key.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_template_tiling_key.h new file mode 100644 index 00000000000..2902be721a1 --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_template_tiling_key.h @@ -0,0 +1,69 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_template_tiling_key.h + * \brief + */ + +#ifndef TEMPLATE_TILING_KEY_LI_H_ +#define TEMPLATE_TILING_KEY_LI_H_ + +#include "ascendc/host_api/tiling/template_argument.h" + +#define LI_TPL_FP16 1 +#define LI_TPL_INT32 3 +#define LI_TPL_BF16 27 + +#define LI_LAYOUT_BSND 0 +#define LI_LAYOUT_TND 1 +#define LI_LAYOUT_PA_BSND 2 + +#define ASCENDC_TPL_4_BW 4 + +// 模板参数支持的范围定义 +ASCENDC_TPL_ARGS_DECL(LightningIndexer, // 算子OpType + ASCENDC_TPL_DTYPE_DECL(DT_Q, LI_TPL_FP16, LI_TPL_BF16), + ASCENDC_TPL_DTYPE_DECL(DT_K, LI_TPL_FP16, LI_TPL_BF16), + ASCENDC_TPL_DTYPE_DECL(DT_OUT, LI_TPL_INT32), ASCENDC_TPL_BOOL_DECL(PAGE_ATTENTION, 0, 1), + ASCENDC_TPL_UINT_DECL(LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, + LI_LAYOUT_TND), + ASCENDC_TPL_UINT_DECL(K_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, + LI_LAYOUT_PA_BSND, LI_LAYOUT_BSND, LI_LAYOUT_TND), ); + +// 支持的模板参数组合 +// 用于调用GET_TPL_TILING_KEY获取TilingKey时,接口内部校验TilingKey是否合法 +ASCENDC_TPL_SEL( + ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_FP16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_FP16), + ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), + ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_PA_BSND), ), + + ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_BF16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_BF16), + ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), + ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 1), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_PA_BSND), ), + + ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_FP16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_FP16), + ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), + ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, + LI_LAYOUT_BSND, LI_LAYOUT_TND), ), + + ASCENDC_TPL_ARGS_SEL(ASCENDC_TPL_DTYPE_SEL(DT_Q, LI_TPL_BF16), ASCENDC_TPL_DTYPE_SEL(DT_K, LI_TPL_BF16), + ASCENDC_TPL_DTYPE_SEL(DT_OUT, LI_TPL_INT32), + ASCENDC_TPL_BOOL_SEL(PAGE_ATTENTION, 0), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(K_LAYOUT_T, ASCENDC_TPL_UI_LIST, LI_LAYOUT_BSND, LI_LAYOUT_TND), ), ); + +#endif \ No newline at end of file diff --git a/csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h b/csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h new file mode 100644 index 00000000000..28436f2b453 --- /dev/null +++ b/csrc/lightning_indexer/op_kernel/lightning_indexer_vector.h @@ -0,0 +1,391 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file lightning_indexer_vector.h + * \brief + */ +#ifndef LIGHTNING_INDEXER_VECTOR_H +#define LIGHTNING_INDEXER_VECTOR_H + +#include "lightning_indexer_vector.h" +#include "kernel_operator.h" + +namespace LIServiceVec { +using namespace AscendC; + +constexpr int32_t NEG_INF = 0xFF800000; +constexpr int32_t INVALID_INDEX = -1; +constexpr uint8_t VEC_REPEAT_MAX = 255; +constexpr uint8_t B32_VEC_ELM_NUM = 64; +constexpr uint8_t B32_BLOCK_ALIGN_NUM = 8; +constexpr uint8_t B32_VEC_REPEAT_STRIDE = 8; +constexpr uint64_t VEC_REPEAT_BYTES = 256; +constexpr int32_t CONST_TWO = 2; +constexpr int64_t VALUE_AND_INDEX_NUM = 2; +constexpr int64_t BLOCK_BYTES = 32; +constexpr int64_t MRG_QUE_0 = 0; +constexpr int64_t MRG_QUE_1 = 1; +constexpr int64_t MRG_QUE_2 = 2; +constexpr int64_t MRG_QUE_3 = 3; +constexpr int64_t MRG_BLOCK_2 = 2; +constexpr int64_t MRG_BLOCK_3 = 3; +constexpr int64_t MRG_BLOCK_4 = 4; + +template +__aicore__ inline void CopyIn(LocalTensor &mmOutUb, LocalTensor &weightsUb, GlobalTensor &mMoutGm, + GlobalTensor &weightScaleGm, int64_t MMout_gmoffset, int64_t weights_gmoffset, + int64_t groupInner, int64_t s2Inner, int64_t mmUbStride) +{ + // 将MMout_gmoffset copy到UB上 + AscendC::DataCopyPadExtParams padParams{false, 0, 0, 0}; + AscendC::DataCopyExtParams dataCopymMoutParams; + dataCopymMoutParams.blockCount = groupInner; + dataCopymMoutParams.blockLen = s2Inner * sizeof(float); + dataCopymMoutParams.srcStride = 0; + dataCopymMoutParams.dstStride = mmUbStride; + dataCopymMoutParams.rsv = 0; + AscendC::DataCopyPad(mmOutUb, mMoutGm[MMout_gmoffset], dataCopymMoutParams, padParams); + + // 将weights_gmoffset copy到UB + AscendC::DataCopyPadExtParams padTParams{false, 0, 0, 0}; + AscendC::DataCopyExtParams dataCopyweightParams; + dataCopyweightParams.blockCount = 1; + dataCopyweightParams.blockLen = groupInner * sizeof(T); + dataCopyweightParams.srcStride = 0; + dataCopyweightParams.dstStride = 0; + dataCopyweightParams.rsv = 0; + AscendC::DataCopyPad(weightsUb, weightScaleGm[weights_gmoffset], dataCopyweightParams, padTParams); +} + + +template +__aicore__ inline void CopyOut(const GlobalTensor &dstGm, const LocalTensor &srcUb, int64_t copyCount) +{ + AscendC::DataCopyParams dataCopyOutyParams; + dataCopyOutyParams.blockCount = 1; + dataCopyOutyParams.blockLen = copyCount * sizeof(T); + dataCopyOutyParams.srcStride = 0; + dataCopyOutyParams.dstStride = 0; + AscendC::DataCopyPad(dstGm, srcUb, dataCopyOutyParams); +} + + +template +__aicore__ inline void DoScale(const LocalTensor &reduceCacheBuf, LocalTensor &mmOutUb, + LocalTensor &weightsUb, LocalTensor &weightsTUb, LocalTensor &tmpBuff, + int64_t groupInner, int64_t s2Inner, int32_t outerGidx) +{ + // cast bfloat16_t to float + if constexpr (!IsSameType::value) { + AscendC::Cast(weightsUb, weightsTUb, RoundMode::CAST_NONE, groupInner); + AscendC::PipeBarrier(); + } + + // weight broadcast: [groupInner, 1] -> [groupInner, 8] + AscendC::Brcb(tmpBuff, weightsUb, LICommon::CeilDiv(groupInner, static_cast(B32_BLOCK_ALIGN_NUM)), + {1, B32_VEC_REPEAT_STRIDE}); + AscendC::PipeBarrier(); + + // do scale: [groupInner, 8] * [groupInner, s2Inner] + uint64_t countPerRepeat = VEC_REPEAT_BYTES / sizeof(float); + uint64_t repeatTimes = s2Inner / countPerRepeat; + for (int32_t i = 0; i < groupInner; i++) { + if (outerGidx == 0) { + AscendC::Mul(reduceCacheBuf[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM], + countPerRepeat, repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0}); + } else { + AscendC::Mul(mmOutUb[i * s2Inner], mmOutUb[i * s2Inner], tmpBuff[i * B32_BLOCK_ALIGN_NUM], countPerRepeat, + repeatTimes, {1, 1, 0, B32_VEC_REPEAT_STRIDE, B32_VEC_REPEAT_STRIDE, 0}); + } + } + + if (outerGidx != 0) { + AscendC::PipeBarrier(); + AscendC::Add(reduceCacheBuf, mmOutUb, reduceCacheBuf, groupInner * s2Inner); + } + AscendC::PipeBarrier(); +} + + +__aicore__ inline uint64_t FindNearestPower2(uint64_t value) +{ + if (value <= CONST_TWO) { + return value; + } else { + const uint64_t pow = 63 - clz(value); // clz返回前导0的个数,对于64位整数,最大有效位位置 = 63 - 前导0个数 + return (1 << pow); + } +} + + +// dstTensor 需要初始化0 +__aicore__ inline void DoReduce(const LocalTensor &srcTensor, LocalTensor &dstTensor, int32_t rNum, + int32_t aNum) +{ + if (rNum == 1) { + AscendC::Adds(dstTensor, srcTensor, 0, aNum); + AscendC::PipeBarrier(); + return; + } + + uint32_t dichotomizeAddPow = FindNearestPower2(rNum); + uint32_t dichotomizeAddDiffSize = rNum - dichotomizeAddPow; + if (dichotomizeAddDiffSize != 0) { + AscendC::Add(srcTensor, srcTensor, srcTensor[dichotomizeAddPow * aNum], dichotomizeAddDiffSize * aNum); + AscendC::PipeBarrier(); + } + int32_t nowRows = dichotomizeAddPow; + while (nowRows > CONST_TWO) { + nowRows = nowRows / CONST_TWO; + AscendC::Add(srcTensor, srcTensor, srcTensor[nowRows * aNum], nowRows * aNum); + AscendC::PipeBarrier(); + } + AscendC::Add(dstTensor, srcTensor, srcTensor[aNum], aNum); + AscendC::PipeBarrier(); +} + + +/** + src: 传入的初始化空间 + eleNum: 需要初始化的元素个数需为64整数倍,元素将被初始化为交错排布的-inf,-1 + */ +__aicore__ inline void InitSortOutBuf(const LocalTensor &src, int64_t eleNum) +{ + uint64_t mask1[2] = {0x5555555555555555, 0}; + uint64_t mask0[2] = {0xaaaaaaaaaaaaaaaa, 0}; + int64_t repeatNum = eleNum / B32_VEC_ELM_NUM; + int64_t forLoop = repeatNum / VEC_REPEAT_MAX; + int64_t forRemain = repeatNum % VEC_REPEAT_MAX; + for (int i = 0; i < forLoop; i++) { + AscendC::Duplicate(src.template ReinterpretCast(), NEG_INF, mask1, VEC_REPEAT_MAX, 1, + B32_VEC_REPEAT_STRIDE); + AscendC::Duplicate(src.template ReinterpretCast(), INVALID_INDEX, mask0, VEC_REPEAT_MAX, 1, + B32_VEC_REPEAT_STRIDE); + } + if (forRemain > 0) { + AscendC::Duplicate(src.template ReinterpretCast()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], NEG_INF, + mask1, forRemain, 1, B32_VEC_REPEAT_STRIDE); + AscendC::Duplicate(src.template ReinterpretCast()[forLoop * VEC_REPEAT_MAX * B32_VEC_ELM_NUM], + INVALID_INDEX, mask0, forRemain, 1, B32_VEC_REPEAT_STRIDE); + } + AscendC::PipeBarrier(); +} + + +/** + src: logits和索引,前logitsNum为logits,后logitsNum为索引 + tmp: 计算使用到的临时空间,大小与src一致 + logitsNum: 排序的元素个数, 暂只支持[128,256,384,512,1024,2048] + */ +__aicore__ inline void SortAll(LocalTensor &src, LocalTensor &tmp, int64_t logitsNum) +{ + int64_t sort32Repeats = logitsNum / BLOCK_BYTES; + AscendC::Sort32(tmp, src, src[logitsNum].ReinterpretCast(), sort32Repeats); + AscendC::PipeBarrier(); + + int64_t mrgGroups = sort32Repeats; + int64_t mrgElements = BLOCK_BYTES; + int64_t i = 0; + AscendC::LocalTensor srcTensor; + AscendC::LocalTensor dstTensor; + while (true) { + if (i % CONST_TWO == 0) { + srcTensor = tmp; + dstTensor = src; + } else { + srcTensor = src; + dstTensor = tmp; + } + AscendC::MrgSort4Info params; + params.elementLengths[0] = mrgElements; + params.elementLengths[MRG_QUE_1] = mrgElements; + params.elementLengths[MRG_QUE_2] = mrgElements; + params.elementLengths[MRG_QUE_3] = mrgElements; + params.ifExhaustedSuspension = false; + params.validBit = 0b1111; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = srcTensor[0]; + srcList.src2 = srcTensor[MRG_QUE_1 * VALUE_AND_INDEX_NUM * mrgElements]; + srcList.src3 = srcTensor[MRG_QUE_2 * VALUE_AND_INDEX_NUM * mrgElements]; + srcList.src4 = srcTensor[MRG_QUE_3 * VALUE_AND_INDEX_NUM * mrgElements]; + if (mrgGroups <= MRG_BLOCK_4) { + params.repeatTimes = 1; + if (mrgGroups == 1) { + break; + } else if (mrgGroups == MRG_BLOCK_2) { + params.validBit = 0b0011; + } else if (mrgGroups == MRG_BLOCK_3) { + params.validBit = 0b0111; + } else if (mrgGroups == MRG_BLOCK_4) { + params.validBit = 0b1111; + } + AscendC::MrgSort(dstTensor, srcList, params); + i += 1; + break; + } else { + params.repeatTimes = mrgGroups / MRG_BLOCK_4; + AscendC::MrgSort(dstTensor, srcList, params); + i += 1; + mrgElements = mrgElements * MRG_BLOCK_4; + mrgGroups = mrgGroups / MRG_BLOCK_4; + } + AscendC::PipeBarrier(); + } + if (i % CONST_TWO == 0) { + AscendC::DataCopy(src, tmp, logitsNum * VALUE_AND_INDEX_NUM); + AscendC::PipeBarrier(); + } +} + + +/** + dst: 输出全排序的结果,排布方式为value,index + srcValue:输入的待排序浮点数 + srcIndex:浮点数的索引 + tmp: 计算使用到的临时空间,大小为srcValue+srcIndex + logitsNum: 排序的元素个数 + */ +__aicore__ inline void SortAll(LocalTensor &dst, LocalTensor &srcValue, LocalTensor &srcIndex, + LocalTensor &tmpTensor, int64_t logitsNum) +{ + int64_t sort32Repeats = logitsNum / BLOCK_BYTES; + AscendC::Sort(dst, srcValue, srcIndex, tmpTensor, sort32Repeats); + AscendC::PipeBarrier(); +} + + +/** + mrgDst: 合并进的Tensor + mrgSrc: 待合并的Tensor + tmpTensor:空间为mrgDst+mrgSrc + */ +__aicore__ inline void MergeSort(const LocalTensor &mrgDst, int32_t mrgDstNum, LocalTensor &mrgSrc, + int32_t mrgSrcNum, LocalTensor &tmpTensor) +{ + AscendC::MrgSort4Info params; + params.elementLengths[0] = mrgDstNum; + params.elementLengths[1] = mrgSrcNum; + params.ifExhaustedSuspension = false; + params.validBit = 0b0011; + params.repeatTimes = 1; + + AscendC::MrgSortSrcList srcList; + srcList.src1 = mrgDst; + srcList.src2 = mrgSrc; + + AscendC::MrgSort(tmpTensor, srcList, params); + AscendC::PipeBarrier(); + AscendC::DataCopy(mrgDst, tmpTensor, mrgDstNum * VALUE_AND_INDEX_NUM); + AscendC::PipeBarrier(); +} + + +/** + * @brief 合并基础块函数 + * @param dst 归并后的输出, 大小为blockNum * basicBlockSize * 2 * sizeof(float) + * @param src 基本块输入 + * @param blockNum 基本块的数量 + * @param basicBlockSize 基础块的大小 + * @return 无 + */ +__aicore__ inline void MrgBasicBlock(const LocalTensor &dst, const LocalTensor &src, int64_t blockNum, + int64_t basicBlockSize) +{ + // 初始化合并排序参数 + AscendC::MrgSort4Info params; + params.elementLengths[MRG_QUE_0] = basicBlockSize; + params.elementLengths[MRG_QUE_1] = basicBlockSize; + params.elementLengths[MRG_QUE_2] = basicBlockSize; + params.elementLengths[MRG_QUE_3] = basicBlockSize; + params.ifExhaustedSuspension = false; + // 根据块的数量设置有效位 + if (blockNum == MRG_BLOCK_2) { + params.validBit = 0b0011; + } else if (blockNum == MRG_BLOCK_3) { + params.validBit = 0b0111; + } else if (blockNum == MRG_BLOCK_4) { + params.validBit = 0b1111; + } else { + AscendC::DataCopy(dst, src, basicBlockSize * VALUE_AND_INDEX_NUM); + return; + } + // 初始化源列表 + AscendC::MrgSortSrcList srcList; + srcList.src1 = src[0]; + srcList.src2 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_1]; + srcList.src3 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_2]; + srcList.src4 = src[basicBlockSize * VALUE_AND_INDEX_NUM * MRG_QUE_3]; + // 执行合并排序 + AscendC::MrgSort(dst, srcList, params); +} + + +/** + * @brief 从两个队列中选择topk + * @param dst 已经归并好的topk数据 + * @param needsMerging 需要合并的有序数据 + * @param tmp 临时空间 + * @param topk topk的元素个数 + * @param mergSize 待合并的元素个数 + * @return 无 + */ +template +__aicore__ inline void SparseTopK(const LocalTensor &dst, const LocalTensor &needsMerging, + const LocalTensor &tmp, int64_t topk, int64_t mergSize) +{ + // 如果不需要合并,则直接复制数据 + if (!needMrg) { + AscendC::DataCopy(dst, needsMerging, mergSize * VALUE_AND_INDEX_NUM); + return; + } + // 初始化合并排序参数 + AscendC::MrgSort4Info params; + params.elementLengths[0] = topk; + params.elementLengths[1] = mergSize; + params.ifExhaustedSuspension = (topk == mergSize); + params.validBit = 0b0011; + // 初始化源列表 + AscendC::MrgSortSrcList srcList; + srcList.src1 = dst; + srcList.src2 = needsMerging; + // 执行合并排序 + AscendC::MrgSort(tmp, srcList, params); + // 将结果复制到目标张量 + AscendC::DataCopy(dst, tmp, topk * VALUE_AND_INDEX_NUM); +} + + +__aicore__ inline void ExtractIndex(const LocalTensor &idxULocal, const LocalTensor &sortLocal, + int64_t extractNum) +{ + AscendC::GatherMaskParams gatherMaskParams; + gatherMaskParams.repeatTimes = Ceil(extractNum * sizeof(float) * VALUE_AND_INDEX_NUM, VEC_REPEAT_BYTES); + gatherMaskParams.src0BlockStride = 1; + gatherMaskParams.src0RepeatStride = B32_VEC_REPEAT_STRIDE; + gatherMaskParams.src1RepeatStride = 0; + uint64_t rsvdCnt = 0; // 用于保存筛选后保留下来的元素个数 + uint8_t src1Pattern = 2; // 固定模式2,表示筛选出奇数索引的数 + AscendC::GatherMask(idxULocal, sortLocal, src1Pattern, false, static_cast(0), gatherMaskParams, rsvdCnt); + AscendC::PipeBarrier(); +} + + +template +__aicore__ inline void SetWaitFlag(HardEvent evt) +{ + event_t eventId = static_cast(GetTPipePtr()->FetchEventID(evt)); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); +} + +} // namespace LIServiceVec +#endif // LIGHTNING_INDEXER_VECTOR_H \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index c249bb58750..2401792db0d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -158,4 +158,13 @@ namespace vllm_ascend { void* tiling, const uint32_t block_dim ); + + extern void batch_matmul_transpose_impl( + void* stream, + void* gm_a, + void* gm_b, + void* gm_c, + void* gm_tiling_data, + const uint32_t block_dim + ); } diff --git a/csrc/sparse_flash_attention/op_host/CMakeLists.txt b/csrc/sparse_flash_attention/op_host/CMakeLists.txt new file mode 100644 index 00000000000..ad24f34e6af --- /dev/null +++ b/csrc/sparse_flash_attention/op_host/CMakeLists.txt @@ -0,0 +1,39 @@ +# This program is free software, you can redistribute it and/or modify it. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME SparseFlashAttention + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror + -fpermissive +) + +set(sparse_flash_attention_depends transformer/attention/sparse_flash_attention PARENT_SCOPE) +target_sources(op_host_aclnn PRIVATE + sparse_flash_attention_def.cpp +) + +target_sources(optiling PRIVATE + sparse_flash_attention_tiling.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(opmaster_ct PRIVATE + sparse_flash_attention_tiling.cpp + ) +endif () + +target_sources(opsproto PRIVATE + sparse_flash_attention_proto.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) diff --git a/csrc/sparse_flash_attention/op_host/sparse_flash_attention_def.cpp b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_def.cpp new file mode 100644 index 00000000000..2378412155b --- /dev/null +++ b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_def.cpp @@ -0,0 +1,90 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_def.cpp + * \brief + */ + +#include "register/op_def_registry.h" + +namespace ops { +class SparseFlashAttention : public OpDef { +public: + explicit SparseFlashAttention(const char *name) : OpDef(name) + { + this->Input("query") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("key") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("value") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("sparse_indices") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("block_table") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("actual_seq_lengths_query") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("actual_seq_lengths_kv") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("query_rope") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("key_rope") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Output("attention_out") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("scale_value").AttrType(REQUIRED).Float(1.0); + this->Attr("sparse_block_size").AttrType(REQUIRED).Int(1); + this->Attr("layout_query").AttrType(OPTIONAL).String("BSND"); + this->Attr("layout_kv").AttrType(OPTIONAL).String("BSND"); + this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3:默认值,只计算下三角 + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn"); + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + } +}; +OP_ADD(SparseFlashAttention); +} // namespace ops diff --git a/csrc/sparse_flash_attention/op_host/sparse_flash_attention_proto.cpp b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_proto.cpp new file mode 100644 index 00000000000..07d2091bb79 --- /dev/null +++ b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_proto.cpp @@ -0,0 +1,48 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_proto.cpp + * \brief + */ + +#include +#include +#include "error/ops_error.h" + +using namespace ge; + +namespace ops { +constexpr size_t QUERY_INPUT_INDEX = 0; + +ge::graphStatus InferShapeSparseFlashAttention(gert::InferShapeContext *context) +{ + OPS_ERR_IF(context == nullptr, OPS_LOG_E("SparseFlashAttention", "InferShapeContext is nullptr"), + return ge::GRAPH_FAILED); + const gert::Shape *queryShape = context->GetInputShape(QUERY_INPUT_INDEX); + OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED) + gert::Shape *attentionOutShape = context->GetOutputShape(0); + OPS_LOG_E_IF_NULL(context, attentionOutShape, return ge::GRAPH_FAILED) + *attentionOutShape = *queryShape; + return GRAPH_SUCCESS; +} + +ge::graphStatus InferDataTypeSparseFlashAttention(gert::InferDataTypeContext *context) +{ + OPS_ERR_IF(context == nullptr, OPS_LOG_E("SparseFlashAttention", "InferShapeContext is nullptr"), + return ge::GRAPH_FAILED); + const auto inputDataType = context->GetInputDataType(QUERY_INPUT_INDEX); + context->SetOutputDataType(0, inputDataType); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP(SparseFlashAttention).InferShape(InferShapeSparseFlashAttention).InferDataType(InferDataTypeSparseFlashAttention); +} // namespace ops + diff --git a/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.cpp b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.cpp new file mode 100644 index 00000000000..1728e73c6a7 --- /dev/null +++ b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.cpp @@ -0,0 +1,1876 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_tiling.cpp + * \brief + */ + +#include +#include +#include +#include +#include +#include "error/ops_error.h" +#include "register/op_def_registry.h" +#include "../op_kernel/sparse_flash_attention_template_tiling_key.h" +#include "sparse_flash_attention_tiling.h" + +using std::map; +using std::string; +using std::pair; + +using namespace ge; +using namespace AscendC; +namespace optiling { + +constexpr uint32_t PRE_LOAD_NUM = 2; +constexpr uint32_t BLOCK_TABLE_ELEM_BYTE = 4; +constexpr int32_t SPARSE_MODE_BAND = 4; + +static const std::string QUERY_NAME = "query"; +static const std::string KEY_NAME = "key"; +static const std::string VALUE_NAME = "value"; +static const std::string BLOCK_TABLE_NAME = "block_table"; +static const std::string SPARSE_INDICES_NAME = "sparse_indices"; +static const std::string QUERY_ROPE_NAME = "query_rope"; +static const std::string KEY_ROPE_NAME = "key_rope"; +static const std::string ATTEN_OUT_NAME = "attention_out"; + +const std::map> DTYPE_SUPPORT_MAP = { + {QUERY_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {KEY_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {VALUE_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {QUERY_ROPE_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {KEY_ROPE_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {ATTEN_OUT_NAME, {ge::DT_FLOAT16, ge::DT_BF16}}, + {SPARSE_INDICES_NAME, {ge::DT_INT32}} +}; + +const std::map> LAYOUT_SUPPORT_MAP = { + {QUERY_NAME, {SFALayout::BSND, SFALayout::TND}}, + {KEY_NAME, {SFALayout::BSND, SFALayout::TND, SFALayout::PA_BSND}}, + {VALUE_NAME, {SFALayout::BSND, SFALayout::TND, SFALayout::PA_BSND}}, + {ATTEN_OUT_NAME, {SFALayout::BSND, SFALayout::TND}}, +}; + +const std::map DATATYPE_TO_STRING_MAP = { + {ge::DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. + {ge::DT_FLOAT, "DT_FLOAT"}, // float type + {ge::DT_FLOAT16, "DT_FLOAT16"}, // fp16 type + {ge::DT_INT8, "DT_INT8"}, // int8 type + {ge::DT_INT16, "DT_INT16"}, // int16 type + {ge::DT_UINT16, "DT_UINT16"}, // uint16 type + {ge::DT_UINT8, "DT_UINT8"}, // uint8 type + {ge::DT_INT32, "DT_INT32"}, // uint32 type + {ge::DT_INT64, "DT_INT64"}, // int64 type + {ge::DT_UINT32, "DT_UINT32"}, // unsigned int32 + {ge::DT_UINT64, "DT_UINT64"}, // unsigned int64 + {ge::DT_BOOL, "DT_BOOL"}, // bool type + {ge::DT_DOUBLE, "DT_DOUBLE"}, // double type + {ge::DT_DUAL, "DT_DUAL"}, // dual output type + {ge::DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type + {ge::DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type + {ge::DT_COMPLEX32, "DT_COMPLEX32"}, // complex32 type + {ge::DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type + {ge::DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type + {ge::DT_QINT8, "DT_QINT8"}, // qint8 type + {ge::DT_QINT16, "DT_QINT16"}, // qint16 type + {ge::DT_QINT32, "DT_QINT32"}, // qint32 type + {ge::DT_QUINT8, "DT_QUINT8"}, // quint8 type + {ge::DT_QUINT16, "DT_QUINT16"}, // quint16 type + {ge::DT_RESOURCE, "DT_RESOURCE"}, // resource type + {ge::DT_STRING_REF, "DT_STRING_REF"}, // string ref type + {ge::DT_STRING, "DT_STRING"}, // string type + {ge::DT_VARIANT, "DT_VARIANT"}, // dt_variant type + {ge::DT_BF16, "DT_BFLOAT16"}, // dt_bfloat16 type + {ge::DT_INT4, "DT_INT4"}, // dt_variant type + {ge::DT_UINT1, "DT_UINT1"}, // dt_variant type + {ge::DT_INT2, "DT_INT2"}, // dt_variant type + {ge::DT_UINT2, "DT_UINT2"} // dt_variant type +}; + +struct SparseFlashAttentionCompileInfo { + int64_t core_num; +}; + +static const std::map> SFA_LAYOUT_AXIS_MAP = { + {SFALayout::BSND, {SFAAxis::B, SFAAxis::S, SFAAxis::N, SFAAxis::D}}, + {SFALayout::TND, {SFAAxis::T, SFAAxis::N, SFAAxis::D}}, + {SFALayout::PA_BSND, {SFAAxis::Bn, SFAAxis::Bs, SFAAxis::N, SFAAxis::D}}, +}; + +static const std::map SFA_LAYOUT_DIM_MAP = { + {SFALayout::BSND, DIM_NUM_FOUR}, + {SFALayout::TND, DIM_NUM_THREE}, + {SFALayout::PA_BSND, DIM_NUM_FOUR}, +}; + +static std::string GetShapeStr(gert::Shape shape) +{ + std::ostringstream oss; + oss << "["; + if (shape.GetDimNum() > 0) { + for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) { + oss << shape.GetDim(i) << ", "; + } + oss << shape.GetDim(shape.GetDimNum() - 1); + } + oss << "]"; + return oss.str(); +} + +static std::string SFADataTypeToSerialString(ge::DataType type) +{ + const auto it = DATATYPE_TO_STRING_MAP.find(type); + if (it != DATATYPE_TO_STRING_MAP.end()) { + return it->second; + } else { + OPS_LOG_E("SparseFlashAttention", "datatype %d not support", type); + return "UNDEFINED"; + } +} + +string SFATensorDesc2String(const gert::StorageShape *shape, const gert::CompileTimeTensorDesc *tensor) +{ + if (shape == nullptr || tensor == nullptr) { + return "nil "; + } + + std::ostringstream oss; + oss << "(dtype: " << ge::TypeUtils::DataTypeToAscendString(tensor->GetDataType()).GetString() << "),"; + oss << "(shape:" << SFAShape2String(shape->GetStorageShape()) << "),"; + oss << "(ori_shape:" << SFAShape2String(shape->GetOriginShape()) << "),"; + oss << "(format: " + << ge::TypeUtils::FormatToAscendString( + static_cast(ge::GetPrimaryFormat(tensor->GetStorageFormat()))) + .GetString() + << "),"; + oss << "(ori_format: " << ge::TypeUtils::FormatToAscendString(tensor->GetOriginFormat()).GetString() << ") "; + + return oss.str(); +} + +string SFADebugTilingContext(const gert::TilingContext *context) +{ + std::ostringstream oss; + for (size_t i = 0; i < context->GetComputeNodeInfo()->GetInputsNum(); ++i) { + oss << "input" << i << ": "; + oss << SFATensorDesc2String(context->GetInputShape(i), context->GetInputDesc(i)); + } + + for (size_t i = 0; i < context->GetComputeNodeInfo()->GetOutputsNum(); ++i) { + oss << "output" << i << ": "; + oss << SFATensorDesc2String(context->GetOutputShape(i), context->GetOutputDesc(i)); + } + return oss.str(); +} + +std::string SFALayoutToSerialString(SFALayout layout) +{ + switch (layout) { + case SFALayout::BSND: return "BSND"; + case SFALayout::TND: return "TND"; + case SFALayout::PA_BSND: return "PA_BSND"; + default: return "UNKNOWN"; + } +} + +ge::graphStatus SFAMlaTiling::SetBlockDim(uint32_t blockDim) +{ + context_->SetBlockDim(blockDim); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAMlaTiling::SetTilingKey(uint64_t tilingKey) +{ + context_->SetTilingKey(tilingKey); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAMlaTiling::SetWorkspaceSize(uint64_t workspaceSize) +{ + OPS_ERR_IF(context_->GetWorkspaceSizes(1) == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "workSpaceSize got from ge is nullptr"), + return ge::GRAPH_FAILED); + size_t *workSpaces = context_->GetWorkspaceSizes(1); + workSpaces[0] = workspaceSize; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAMlaTiling::SetTilingData(TilingDef &tilingData) +{ + OPS_ERR_IF(context_->GetRawTilingData() == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "RawTilingData got from GE context is nullptr."), + return ge::GRAPH_FAILED); + + tilingData.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity()); + context_->GetRawTilingData()->SetDataSize(tilingData.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAMlaTiling::GetPlatformInfo() +{ + OPS_ERR_IF(sfaInfo_->platformInfo == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(sfaInfo_->opName, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(sfaInfo_->platformInfo); + libapiSize_ = ascendcPlatform.GetLibApiWorkSpaceSize(); + aivNum_ = ascendcPlatform.GetCoreNumAiv(); + aicNum_ = ascendcPlatform.GetCoreNumAic(); + + OPS_ERR_IF(aicNum_ == 0 || aivNum_ == 0, + OPS_REPORT_VECTOR_INNER_ERR(sfaInfo_->opName, "num of core obtained is 0."), return GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +void SFAMlaTiling::GenTilingKey() +{ + uint32_t inputQType = static_cast(sfaInfo_->inputQType); + uint32_t inputKvType = static_cast(sfaInfo_->inputKvType); + uint32_t outputType = static_cast(sfaInfo_->outputType); + uint32_t layoutQuery = static_cast(sfaInfo_->qLayout); + uint32_t layoutKV = static_cast(sfaInfo_->kvLayout); + + tilingKey_ = GET_TPL_TILING_KEY(0U, layoutQuery, layoutKV, perfMode_ == SFAPerfMode::V_TEMPLATE_MODE); + + OPS_LOG_I(sfaInfo_->opName, "SFA tilingKey_: %lu.", tilingKey_); +} + +void SFAMlaTiling::ZeroTensorProcess() +{ + if (sfaInfo_->s2Size == 0) { + /* + * 1024,空tensor场景下,作为默认值完成后续计算 + * 避免matmal tiling softmax tiling异常 + * kernel计算使用真实的seqSize=0, 与actuseq_len流程归一 + */ + sfaInfo_->s2Size = 1024; + } +} + +void SFAMlaTiling::InitParams() +{ + if (sfaInfo_->s2Size != 0 && sfaInfo_->sparseBlockSize <= 4) { // 4:当前支持范围 + perfMode_ = SFAPerfMode::V_TEMPLATE_MODE; + } else { + perfMode_ = SFAPerfMode::C_TEMPLATE_MODE; + } + + coreNum_ = aicNum_; + + headDimAlign_ = Align(sfaInfo_->qkHeadDim, BYTE_BLOCK); // 元素个数按照基本块大小对齐 + ZeroTensorProcess(); +} + +void SFAMlaTiling::CalcUbBmm() +{ + uint32_t cubeMSize = sfaInfo_->gSize * sfaInfo_->s1Size; + uint32_t maxMSize = mBaseSize_; + if (cubeMSize > maxMSize) { + cubeMSize = maxMSize; + } + mmResUbSize_ = sInnerSizeAlign_ * Align(cubeMSize, 16U);// kernel按照16对齐写出,tiling按照这个原则分配内存 + bmm2ResUbSize_ = headDimAlign_ * Align(cubeMSize, 16U);// kernel按照16对齐写出,tiling按照这个原则分配内存 + + qPreSizeMla_ = sfaInfo_->gSize * (headDimAlign_ + 64U) * sfaInfo_->s1Size; +} + +void SFAMlaTiling::CheckUbSpace() +{ + CalcUbBmm(); +} + +void SFAMlaTiling::CalcInnerSize(uint32_t s2Size) +{ + sInnerSize_ = 512; // 512:s2默认切分大小 + // FlashDecode时,如果S2的计算量>=256(确保切分后不小于128)但又不足以分2次计算时,则修改sInnerSize_,均分为2份进行计算,确保Nbuffer=2 + if (splitKVFlag_ && sfaInfo_->qLayout != SFALayout::TND) { + if (s2Size == 256) { // 256:s2Size的阈值,判断sInnerSize_是否切分 + sInnerSize_ = 128; // 128:sInnerSize_值为s2Size的一半,均分为2份进行计算, + } else if (s2Size > 256 && s2Size <= sInnerSize_) { // 256:s2Size的阈值,判断sInnerSize_是否切分 + sInnerSize_ = (sInnerSize_ + 1) / 2; // 2:减半 + } + } + + sInnerLoopTimes_ = (s2Size + sInnerSize_ - 1) / sInnerSize_; + sInnerSizeTail_ = s2Size - (sInnerLoopTimes_ - 1) * sInnerSize_; + if (sInnerSize_ > s2Size) { + sInnerSize_ = s2Size; + } + sInnerSizeAlign_ = Align(sInnerSize_, BYTE_BLOCK); // 元素个数按照基本块大小对齐 + + CheckUbSpace(); +} + +void SFAMlaTiling::SplitBalanced() +{ + CalcInnerSize(sfaInfo_->s2Size); + + InnerSplitParams innerSplitParams; + innerSplitParams.s1GBaseSize = sfaInfo_->gSize; + innerSplitParams.s2BaseSize = sInnerSize_; + tilingData_.innerSplitParams.set_mBaseSize(innerSplitParams.s1GBaseSize); + tilingData_.innerSplitParams.set_s2BaseSize(innerSplitParams.s2BaseSize); + + usedCoreNum_ = aicNum_; +} + +void SFAMlaTiling::Split() +{ + SplitBalanced(); +} + +void SFAMlaTiling::FillTilingBaseParamsMla() +{ + tilingData_.baseParams.set_batchSize(sfaInfo_->bSize); + tilingData_.baseParams.set_seqSize(sfaInfo_->s2Size); + tilingData_.baseParams.set_qSeqSize(sfaInfo_->s1Size); + tilingData_.baseParams.set_blockSize(sfaInfo_->blockSize); + tilingData_.baseParams.set_maxBlockNumPerBatch(sfaInfo_->maxBlockNumPerBatch); + tilingData_.baseParams.set_scaleValue(sfaInfo_->scaleValue); + tilingData_.baseParams.set_nNumOfQInOneGroup(sfaInfo_->n1Size / sfaInfo_->n2Size); + tilingData_.baseParams.set_actualLenDimsQ(sfaInfo_->actualLenDimsQ); + tilingData_.baseParams.set_actualLenDimsKV(sfaInfo_->actualLenDimsKV); + tilingData_.baseParams.set_outputLayout(static_cast(sfaInfo_->outLayout)); + tilingData_.baseParams.set_sparseMode(sfaInfo_->sparseMode); + tilingData_.baseParams.set_sparseBlockSize(sfaInfo_->sparseBlockSize); + tilingData_.baseParams.set_sparseBlockCount(sfaInfo_->sparseBlockCount); +} + +// for flash decode +void SFAMlaTiling::FillTilingSplitKVMla() +{ + tilingData_.splitKVParams.set_s2(kvSplitPart_); + + tilingData_.splitKVParams.set_accumOutSize(aicNum_ * 2 * sfaInfo_->n2Size * mBaseSize_ * headDimAlign_); // 2:每个核可能有头规约和尾规约,一共两份规约信息 + tilingData_.splitKVParams.set_logSumExpSize(2 * aicNum_ * 2 * sfaInfo_->n2Size * mBaseSize_ * // 2:每个核可能有头规约和尾规约,一共两份规约信息;sum + max + (BYTE_BLOCK / BLOCK_TABLE_ELEM_BYTE)); + + if (!splitKVFlag_) { + tilingData_.splitKVParams.set_s2(0); + } +} + +void SFAMlaTiling::FillTilingSingleCoreParamsMla() +{ + tilingData_.singleCoreParams.set_usedCoreNum(usedCoreNum_); +} + +void SFAMlaTiling::FillTilingSingleCoreTensorSizeMla() +{ + tilingData_.singleCoreTensorSize.set_mmResUbSize(mmResUbSize_); + tilingData_.singleCoreTensorSize.set_bmm2ResUbSize(bmm2ResUbSize_); +} + +void SFAMlaTiling::FillTiling() +{ + FillTilingBaseParamsMla(); + FillTilingSplitKVMla(); + FillTilingSingleCoreParamsMla(); + FillTilingSingleCoreTensorSizeMla(); +} + +uint32_t SFAMlaTiling::CalcBalanceFDParamNums(const uint32_t actCoreNum) +{ + return actCoreNum * 2 * sfaInfo_->n2Size * mBaseSize_; // 2:每个核可能有头规约和尾规约,一共两份规约信息 +} + +void SFAMlaTiling::NormalCalcFDWorkSpace(const uint32_t actCoreNum) +{ + if (splitKVFlag_) { + uint32_t accumOutSize = 0; + uint32_t logSumExpSize = 0; + uint32_t FDParamNums = CalcBalanceFDParamNums(actCoreNum); //balanceModeFlag_ ? CalcBalanceFDParamNums(actCoreNum) : CalcUnbalanceFDParamNums(); + accumOutSize = FDParamNums * headDimAlign_; + logSumExpSize = 2 * FDParamNums * (BYTE_BLOCK / sfaInfo_->blockTypeSize); // log和sum的存储空间一致,共需要2份内存 + workspaceSize_ += (accumOutSize + logSumExpSize) * sfaInfo_->blockTypeSize; + if (sfaInfo_->socVersion == platform_ascendc::SocVersion::ASCEND310P) { + workspaceSize_ += static_cast(actCoreNum) * 32; // 每个核SyncAll软同步需要32Byte记录状态 + } + } +} + +void SFAMlaTiling::CalcFDWorkSpace(const uint32_t actCoreNum) +{ + NormalCalcFDWorkSpace(actCoreNum); +} + +void SFAMlaTiling::GetWorkspaceSize() +{ + uint32_t mmResElemSize = 4; // 4:fp32 + uint32_t vec1ResElemSize = 2; // 2:fp16/bf16 + uint32_t bmm2ResElemSize = 4; // 4:fp32 + uint32_t qPreProcResElemSize = 0; // 普通场景不涉及Q预处理 + uint32_t nUpdateElemSize = 4; // 4:int32 + uint32_t softmaxSumElemSize = 4; // 4:int32 + float kvDtypeRatio = 1.0; + + workspaceSize_ = libapiSize_; + uint32_t preLoadNum = 1; + uint32_t actCoreNum = coreNum_; + preLoadNum = PRE_LOAD_NUM; + + workspaceSize_ += preLoadNum * (mmResUbSize_ * actCoreNum * mmResElemSize); + workspaceSize_ += preLoadNum * static_cast(static_cast(mmResUbSize_ * actCoreNum * vec1ResElemSize) * kvDtypeRatio); + workspaceSize_ += preLoadNum * bmm2ResUbSize_ * actCoreNum * bmm2ResElemSize; + workspaceSize_ += preLoadNum * static_cast(static_cast(qPreSizeMla_ * actCoreNum * qPreProcResElemSize) * kvDtypeRatio); + workspaceSize_ += preLoadNum * mBaseSize_ * actCoreNum * nUpdateElemSize; + workspaceSize_ += preLoadNum * mBaseSize_ * actCoreNum * softmaxSumElemSize; + // topk BlkSize == 1场景, 需要额外空间缓存离散聚合的值 + // bufNum s2Base D dRope sizeOf(half) + workspaceSize_ += 4 * 512 * (512 + 64) * 2 * actCoreNum; // 4:bufNum 512:s2Base 512:D 64:dRope 2:sizeOf(half) + // 缓存有效mte2 size的长度 份数 512B对齐的长度 sizeof(int32_t) aiv核数 + workspaceSize_ += 4 * 128 * 4 * (2 * actCoreNum); // 4:缓存有效mte2 size的长度 128:份数 4:512B对齐的长度 2:aiv核数 + + CalcFDWorkSpace(actCoreNum); +} + +void SFAMlaTiling::CalcBlockDim() +{ + auto ascendcPlatform = platform_ascendc::PlatformAscendC(sfaInfo_->platformInfo); + auto aicNum = usedCoreNum_; + auto aivNum = 2 * usedCoreNum_; + + blockDim_ = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum); + OPS_LOG_I(sfaInfo_->opName, "SFA block dim: %u aiv Num: %u aic Num: %u.", blockDim_, aivNum, aicNum); +} + +ge::graphStatus SFAMlaTiling::DoOpTiling(SFATilingInfo *sfaInfo) +{ + sfaInfo_ = sfaInfo; + if (GetPlatformInfo() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + InitParams(); + Split(); + FillTiling(); + CalcBlockDim(); + GetWorkspaceSize(); + GenTilingKey(); + + if ((SetBlockDim(blockDim_) != ge::GRAPH_SUCCESS) || + (SetTilingKey(tilingKey_) != ge::GRAPH_SUCCESS) || + (SetWorkspaceSize(workspaceSize_) != ge::GRAPH_SUCCESS) || + (SetTilingData(tilingData_) != ge::GRAPH_SUCCESS)) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus TilingSparseFlashAttention(gert::TilingContext *context) +{ + SFATilingInfo sfaInfo; + SFAInfoParser sfaInfoParser(context); + if (sfaInfoParser.Parse(sfaInfo) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + SFATilingCheck tilingChecker(sfaInfo); + if (tilingChecker.Process() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + SFAMlaTiling tiling(context); + return tiling.DoOpTiling(&sfaInfo); +} + +ge::graphStatus TilingPrepareForSparseFlashAttention(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::GetExpectedShape(gert::Shape &shapeExpected, + const SFATilingShapeCompareParam ¶m, const SFALayout &layout) const +{ + if (layout == SFALayout::BSND) { + shapeExpected = gert::Shape({param.B, param.S, param.N, param.D}); + } else if (layout == SFALayout::TND) { + shapeExpected = gert::Shape({param.T, param.N, param.D}); + } else if (layout == SFALayout::PA_BSND) { + shapeExpected = gert::Shape({param.Bn, param.Bs, param.N, param.D}); + } else { + OPS_LOG_E(opName_, "layout %s is unsupported", SFALayoutToSerialString(layout).c_str()); + return ge::GRAPH_FAILED; + } + if (shapeExpected.GetDim(0) == 0) { + OPS_LOG_E(opName_, "expected shape is %s, the first dim should not be 0.", GetShapeStr(shapeExpected).c_str()); + return ge::GRAPH_PARAM_INVALID; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CompareShape(SFATilingShapeCompareParam ¶m, + const gert::Shape &shape, const SFALayout &layout, const std::string &name) const +{ + gert::Shape shapeExpected; + if (GetExpectedShape(shapeExpected, param, layout) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + if (shape.GetDimNum() != shapeExpected.GetDimNum()) { + OPS_LOG_E(opName_, + "%s dimension is %zu, expected dimension is %zu.", + name.c_str(), shape.GetDimNum(), shapeExpected.GetDimNum()); + return ge::GRAPH_FAILED; + } + + for (size_t i = 0; i < shape.GetDimNum(); i++) { + if (shape.GetDim(i) != shapeExpected.GetDim(i)) { + OPS_LOG_E(opName_, "%s layout is %s, shape is %s, expected shape is %s.", + name.c_str(), SFALayoutToSerialString(layout).c_str(), + GetShapeStr(shape).c_str(), GetShapeStr(shapeExpected).c_str()); + return ge::GRAPH_FAILED; + } + } + + return ge::GRAPH_SUCCESS; +} + +void SFATilingCheck::LogErrorDtypeSupport(const std::vector &expectDtypeList, + const ge::DataType &actualDtype, const std::string &name) const +{ + std::ostringstream oss; + for (size_t i = 0; i < expectDtypeList.size(); ++i) { + oss << SFADataTypeToSerialString(expectDtypeList[i]); + if (i < expectDtypeList.size() - 1) { + oss << ", "; + } + } + OPS_LOG_E(opName_, "Tensor %s only supports dtype %s, but got %s", + name.c_str(), oss.str().c_str(), SFADataTypeToSerialString(actualDtype).c_str()); +} + +ge::graphStatus SFATilingCheck::CheckDtypeSupport(const gert::CompileTimeTensorDesc *desc, + const std::string &name) const +{ + if (desc != nullptr) { + const auto& it = DTYPE_SUPPORT_MAP.find(name); + OPS_ERR_IF(it == DTYPE_SUPPORT_MAP.end(), + OPS_LOG_E(opName_, "%s datatype support list should be specify in DTYPE_SUPPORT_MAP", name.c_str()), + return ge::GRAPH_FAILED); + auto &expectDtypeList = it->second; + OPS_ERR_IF(std::find( + expectDtypeList.begin(), expectDtypeList.end(), desc->GetDataType()) == expectDtypeList.end(), + LogErrorDtypeSupport(expectDtypeList, desc->GetDataType(), name), + return ge::GRAPH_FAILED); + } + return ge::GRAPH_SUCCESS; +} + +template +void SFATilingCheck::LogErrorNumberSupport(const std::vector &expectNumberList, + const T &actualValue, const std::string &name, const std::string subName) const +{ + std::ostringstream oss; + for (size_t i = 0; i < expectNumberList.size(); ++i) { + oss << std::to_string(expectNumberList[i]); + if (i < expectNumberList.size() - 1) { + oss << ", "; + } + } + + OPS_LOG_E(opName_, "%s %s only supports %s, but got %s", + name.c_str(), subName.c_str(), oss.str().c_str(), std::to_string(actualValue).c_str()); +} + +template +void SFATilingCheck::LogErrorDimNumSupport(const std::vector &expectNumberList, + const T &actualValue, const std::string &name) const +{ + LogErrorNumberSupport(expectNumberList, actualValue, name, "dimension"); +} + +ge::graphStatus SFATilingCheck::CheckDimNumInLayoutSupport(const SFALayout &layout, + const gert::StorageShape *shape, const std::string &name) const +{ + const auto& dimIt = SFA_LAYOUT_DIM_MAP.find(layout); + OPS_ERR_IF(shape->GetStorageShape().GetDimNum() != dimIt->second, + OPS_LOG_E(opName_, "When layout is %s, %s dimension should be %zu, but it's %zu", + SFALayoutToSerialString(layout).c_str(), name.c_str(), dimIt->second, + shape->GetStorageShape().GetDimNum()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckDimNumSupport(const gert::StorageShape *shape, + const std::vector &expectDimNumList, const std::string &name) const +{ + if (shape == nullptr) { + return ge::GRAPH_SUCCESS; + } + + if (std::find(expectDimNumList.begin(), expectDimNumList.end(), + shape->GetStorageShape().GetDimNum()) == expectDimNumList.end()) { + LogErrorDimNumSupport(expectDimNumList, shape->GetStorageShape().GetDimNum(), name); + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + + +void SFATilingCheck::LogErrorLayoutSupport(const std::vector &expectLayoutList, + const SFALayout &actualLayout, const std::string &name) const +{ + std::ostringstream oss; + for (size_t i = 0; i < expectLayoutList.size(); ++i) { + oss << SFALayoutToSerialString(expectLayoutList[i]); + if (i < expectLayoutList.size() - 1) { + oss << ", "; + } + } + OPS_LOG_E(opName_, "Tensor %s only supports layout %s, but got %s", + name.c_str(), oss.str().c_str(), SFALayoutToSerialString(actualLayout).c_str()); +} + +ge::graphStatus SFATilingCheck::CheckLayoutSupport(const SFALayout &actualLayout, const std::string &name) const +{ + const auto& it = LAYOUT_SUPPORT_MAP.find(name); + OPS_ERR_IF(it == LAYOUT_SUPPORT_MAP.end(), + OPS_LOG_E(opName_, "%s layout support list should be specify in LAYOUT_SUPPORT_MAP", name.c_str()), + return ge::GRAPH_FAILED); + auto &expectLayoutList = it->second; + OPS_ERR_IF(std::find( + expectLayoutList.begin(), expectLayoutList.end(), actualLayout) == expectLayoutList.end(), + LogErrorLayoutSupport(expectLayoutList, actualLayout, name), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaQuery() const +{ + const std::vector queryDimNumList = {DIM_NUM_THREE, DIM_NUM_FOUR}; + if (ge::GRAPH_SUCCESS != CheckDtypeSupport(opParamInfo_.query.desc, QUERY_NAME) || + ge::GRAPH_SUCCESS != CheckLayoutSupport(qLayout_, QUERY_NAME) || + ge::GRAPH_SUCCESS != CheckDimNumSupport(opParamInfo_.query.shape, queryDimNumList, QUERY_NAME) || + ge::GRAPH_SUCCESS != CheckDimNumInLayoutSupport(qLayout_, opParamInfo_.query.shape, QUERY_NAME)) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaKey() const +{ + const std::vector keyDimNumList = {DIM_NUM_FOUR, DIM_NUM_THREE}; + if (ge::GRAPH_SUCCESS != CheckDtypeSupport(opParamInfo_.key.desc, KEY_NAME) || + ge::GRAPH_SUCCESS != CheckLayoutSupport(kvLayout_, KEY_NAME) || + ge::GRAPH_SUCCESS != CheckDimNumSupport(opParamInfo_.key.shape, keyDimNumList, KEY_NAME) || + ge::GRAPH_SUCCESS != CheckDimNumInLayoutSupport(kvLayout_, opParamInfo_.key.shape, KEY_NAME)) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaNumHeads() const +{ + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaKvHeadNums() const +{ + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaSparseMode() const +{ + OPS_ERR_IF((*opParamInfo_.sparseMode != 3 && *opParamInfo_.sparseMode != 0), + OPS_LOG_E(opName_, "sparseMode must == 0/3, but got: %ld.", *opParamInfo_.sparseMode), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaSparseBlockSize() const +{ + OPS_ERR_IF((*opParamInfo_.sparseBlockSize <= 0), + OPS_LOG_E(opName_, "sparseBlockSize should be greater than 0, but got: %ld.", *opParamInfo_.sparseBlockSize), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSingleParaSparseIndices() const +{ + if (ge::GRAPH_SUCCESS != CheckDtypeSupport(opParamInfo_.sparseIndices.desc, SPARSE_INDICES_NAME)) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckSinglePara() const +{ + if (ge::GRAPH_SUCCESS != CheckSingleParaQuery() || + ge::GRAPH_SUCCESS != CheckSingleParaKey() || + ge::GRAPH_SUCCESS != CheckSingleParaSparseIndices() || + ge::GRAPH_SUCCESS != CheckSingleParaNumHeads() || + ge::GRAPH_SUCCESS != CheckSingleParaKvHeadNums() || + ge::GRAPH_SUCCESS != CheckSingleParaSparseMode() || + ge::GRAPH_SUCCESS != CheckSingleParaSparseBlockSize()) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckRopeExistence() +{ + OPS_ERR_IF((opParamInfo_.queryRope.tensor != nullptr && opParamInfo_.keyRope.tensor == nullptr), + OPS_LOG_E(opName_, "KeyRope is null, but queryRope exists, they should be both null or exist."), + return ge::GRAPH_FAILED); + OPS_ERR_IF((opParamInfo_.queryRope.tensor == nullptr && opParamInfo_.keyRope.tensor != nullptr), + OPS_LOG_E(opName_, "QueryRope is null, but keyRope exists, they should be both null or exist."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.keyRope.desc == nullptr || opParamInfo_.queryRope.desc == nullptr, + OPS_LOG_E(opName_, "In Mla situation, desc of keyRope and queryRope should not be null"), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckExists(const void *pointer, const std::string &name) const +{ + OPS_ERR_IF(pointer == nullptr, + OPS_LOG_E(opName_, "%s should not be null", name.c_str()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckNotExists(const void *pointer, const std::string &name) const +{ + OPS_ERR_IF(pointer != nullptr, + OPS_LOG_E(opName_, "%s should be null", name.c_str()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckExistsByMap(const std::map ¶mMap) const +{ + for (const auto& kv : paramMap) { + if (CheckExists(kv.second, kv.first) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckNotExistsByMap(const std::map ¶mMap) const +{ + for (const auto& kv : paramMap) { + if (CheckNotExists(kv.second, kv.first) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckExistenceByMap(std::map &existMap, + std::map ¬ExistMap) const +{ + if (CheckExistsByMap(existMap) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + if (CheckNotExistsByMap(notExistMap) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +template +ge::graphStatus SFATilingCheck::CheckAttrValueByMap(std::map> &attrMap) const +{ + for (auto const &kv : attrMap) { + const std::string &name = kv.first; + const std::pair &pointerValuePair = kv.second; + if (pointerValuePair.first == nullptr) { + OPS_LOG_E(opName_, "Attr %s should not be nullptr", name.c_str()); + return ge::GRAPH_FAILED; + } + + if (*(pointerValuePair.first) != pointerValuePair.second) { + std::ostringstream ossExpect; + ossExpect << std::to_string(pointerValuePair.second); + std::ostringstream ossActual; + ossActual << std::to_string(*(pointerValuePair.first)); + OPS_LOG_E(opName_, + "%s value should be %s, but got %s", + name.c_str(), + ossExpect.str().c_str(), + ossActual.str().c_str()); + return ge::GRAPH_FAILED; + } + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckParaExistenceMlaNoquant() const +{ + if (kvStorageMode_ != KvStorageMode::PAGE_ATTENTION) { + return ge::GRAPH_SUCCESS; + } + std::map mlaNoquantParamExistMap = { + {"actualSeqLengths", opParamInfo_.actualSeqLengths.tensor}, + {"blockTable", opParamInfo_.blockTable.tensor}, + }; + std::map mlaNoquantParamNotExistMap = {}; + if (CheckExistenceByMap(mlaNoquantParamExistMap, mlaNoquantParamNotExistMap) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckParaExistenceMla() const +{ + return CheckParaExistenceMlaNoquant(); +} + +ge::graphStatus SFATilingCheck::CheckParaExistence() +{ + if (ge::GRAPH_SUCCESS != CheckRopeExistence()) { + return ge::GRAPH_FAILED; + } + + return CheckParaExistenceMla(); +} + +ge::graphStatus SFATilingCheck::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + const SFALayout &layoutQuery, const std::string &name) +{ + if (tensor == nullptr) { + OPS_LOG_E(opName_, "when layout of query is %s, %s must be provided.", + SFALayoutToSerialString(layoutQuery).c_str(), name.c_str()); + return ge::GRAPH_FAILED; + } + int64_t shapeSize = tensor->GetShapeSize(); + if (shapeSize <= 0) { + OPS_LOG_E(opName_, "the shape size of %s is %ld, it should be greater than 0.", + name.c_str(), shapeSize); + return ge::GRAPH_FAILED; + } + size = static_cast(shapeSize); + return ge::GRAPH_SUCCESS; +} + +void SFATilingCheck::SetSFAShapeCompare() +{ + queryShapeCmp_ = opParamInfo_.query.shape->GetStorageShape(); + topkShapeCmp_ = opParamInfo_.sparseIndices.shape->GetStorageShape(); + keyShapeCmp_ = opParamInfo_.key.shape->GetStorageShape(); + valueShapeCmp_ = opParamInfo_.value.shape->GetStorageShape(); + attenOutShapeCmp_ = opParamInfo_.attenOut.shape->GetStorageShape(); + queryRopeShapeCmp_ = opParamInfo_.queryRope.tensor->GetStorageShape(); + keyRopeShapeCmp_ = opParamInfo_.keyRope.tensor->GetStorageShape(); +} + +ge::graphStatus SFATilingCheck::CheckBlockTable() const +{ + if (kvStorageMode_ != KvStorageMode::PAGE_ATTENTION) { + OPS_ERR_IF(opParamInfo_.blockTable.tensor != nullptr, + OPS_LOG_E(opName_, "when the layout_kv is %s, %s should be null", + SFALayoutToSerialString(kvLayout_).c_str(), BLOCK_TABLE_NAME.c_str()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; + } + + uint32_t blockTableBatch = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(0); + OPS_ERR_IF(blockTableBatch != bSize_, + OPS_LOG_E(opName_, "%s's first dimension(%u) should be equal to batch size(%u)", + BLOCK_TABLE_NAME.c_str(), blockTableBatch, bSize_), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckDTypeConsistency(const ge::DataType &actualDtype, + const ge::DataType &expectDtype, const std::string &name) const +{ + if (actualDtype != expectDtype) { + OPS_LOG_E(opName_, "%s dtype should be %s, but it's %s.", name.c_str(), + SFADataTypeToSerialString(expectDtype).c_str(), + SFADataTypeToSerialString(actualDtype).c_str()); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckQRopeShape() +{ + SFATilingShapeCompareParam shapeParams; + shapeParams.B = bSize_; + shapeParams.N = n1Size_; + shapeParams.S = s1Size_; + shapeParams.D = ropeHeadDim_; + shapeParams.T = qTSize_; + return CompareShape(shapeParams, queryRopeShapeCmp_, qLayout_, QUERY_ROPE_NAME); +} + +ge::graphStatus SFATilingCheck::CheckTopkShape() +{ + SFATilingShapeCompareParam shapeParams; + shapeParams.B = bSize_; + shapeParams.N = n2Size_; + shapeParams.S = s1Size_; + shapeParams.D = sparseBlockCount_; + shapeParams.T = qTSize_; + return CompareShape(shapeParams, topkShapeCmp_, topkLayout_, SPARSE_INDICES_NAME); +} + +ge::graphStatus SFATilingCheck::CheckAttenOutShape() +{ + SFATilingShapeCompareParam shapeParams; + shapeParams.B = bSize_; + shapeParams.N = n1Size_; + shapeParams.S = s1Size_; + shapeParams.D = vHeadDim_; + shapeParams.T = qTSize_; + if (CompareShape(shapeParams, attenOutShapeCmp_, outLayout_, ATTEN_OUT_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckAttenOut() +{ + if (ge::GRAPH_SUCCESS != CheckDTypeConsistency(opParamInfo_.attenOut.desc->GetDataType(), + inputQType_, ATTEN_OUT_NAME) || + ge::GRAPH_SUCCESS != CheckAttenOutShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckQRope() +{ + if (ge::GRAPH_SUCCESS != CheckDTypeConsistency(opParamInfo_.queryRope.desc->GetDataType(), + inputQType_, QUERY_ROPE_NAME) || + ge::GRAPH_SUCCESS != CheckQRopeShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckTopK() +{ + if (ge::GRAPH_SUCCESS != CheckTopkShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckVAndKRopeShapeForBatchContinuous() +{ + SFATilingShapeCompareParam shapeParams; + shapeParams.B = bSize_; + shapeParams.N = n2Size_; + shapeParams.S = s2Size_; + shapeParams.T = kvTSize_; + shapeParams.D = qkHeadDim_; + if (CompareShape(shapeParams, keyShapeCmp_, kvLayout_, KEY_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + shapeParams.D = vHeadDim_; + if (CompareShape(shapeParams, valueShapeCmp_, kvLayout_, VALUE_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + shapeParams.D = ropeHeadDim_; + if (CompareShape(shapeParams, keyRopeShapeCmp_, kvLayout_, KEY_ROPE_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +uint32_t SFATilingCheck::GetTypeSize(ge::DataType dtype) const +{ + uint32_t typeSize = NUM_BYTES_FLOAT16; + switch (dtype) { + case ge::DT_FLOAT16: + typeSize = NUM_BYTES_FLOAT16; + break; + case ge::DT_BF16: + typeSize = NUM_BYTES_BF16; + break; + default: + typeSize = NUM_BYTES_FLOAT16; + } + return typeSize; +} + +ge::graphStatus SFATilingCheck::CheckVAndKRopeShapeForPageAttention() +{ + int64_t blockNum = keyShapeCmp_.GetDim(0); + OPS_ERR_IF(blockNum <= 0, + OPS_LOG_E(opName_, "The first dim(%ld) of key should be greater than 0", blockNum), + return ge::GRAPH_FAILED); + SFATilingShapeCompareParam shapeParams; + shapeParams.Bn = blockNum; + shapeParams.N = n2Size_; + shapeParams.Bs = blockSize_; + shapeParams.D = vHeadDim_; + shapeParams.T = kvTSize_; + if (CompareShape(shapeParams, valueShapeCmp_, kvLayout_, VALUE_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + shapeParams.D = ropeHeadDim_; + if (CompareShape(shapeParams, keyRopeShapeCmp_, kvLayout_, KEY_ROPE_NAME) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckVAndKRopeShape() +{ + if (kvStorageMode_ == KvStorageMode::BATCH_CONTINUOUS) { + return CheckVAndKRopeShapeForBatchContinuous(); + } + + if (kvStorageMode_ == KvStorageMode::PAGE_ATTENTION) { + return CheckVAndKRopeShapeForPageAttention(); + } + + OPS_LOG_E(opName_, "storage mode of key and value is %u, it is incorrect.", static_cast(kvStorageMode_)); + return ge::GRAPH_FAILED; +} + +ge::graphStatus SFATilingCheck::CheckVAndKRope() +{ + if (ge::GRAPH_SUCCESS != CheckDTypeConsistency(opParamInfo_.value.desc->GetDataType(), + inputKvType_, VALUE_NAME) || + ge::GRAPH_SUCCESS != CheckDTypeConsistency(opParamInfo_.keyRope.desc->GetDataType(), + inputKvType_, KEY_ROPE_NAME) || ge::GRAPH_SUCCESS != CheckVAndKRopeShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLensQ() +{ + if (ge::GRAPH_SUCCESS != CheckActualSeqLensQDType() || + ge::GRAPH_SUCCESS != CheckActualSeqLensQShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLensQDType() +{ + if (opParamInfo_.actualSeqLengthsQ.tensor == nullptr) { + return ge::GRAPH_SUCCESS; + } + if (opParamInfo_.actualSeqLengthsQ.desc == nullptr) { + OPS_LOG_E(opName_, "actualSeqLengthsQ is not empty," + "but actualSeqLengthsQ's dtype is nullptr."); + return ge::GRAPH_FAILED; + } + if (opParamInfo_.actualSeqLengthsQ.desc->GetDataType() != ge::DT_INT32) { + OPS_LOG_E(opName_, "actualSeqLengthsQ's dtype is %s, it should be DT_INT32.", + SFADataTypeToSerialString(opParamInfo_.actualSeqLengthsQ.desc->GetDataType()).c_str()); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLensQShape() +{ + if (opParamInfo_.actualSeqLengthsQ.tensor == nullptr) { + return ge::GRAPH_SUCCESS; + } + uint32_t shapeSize = 0; + if (GetActualSeqLenSize(shapeSize, opParamInfo_.actualSeqLengthsQ.tensor, qLayout_, "actualSeqLengthsQ") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + if (shapeSize != bSize_) { + OPS_LOG_E(opName_, "actualSeqLengthsQ shape size is %u, it should be equal to batch size[%u]", + shapeSize, bSize_); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLens() +{ + if (std::string(opParamInfo_.layoutKV) == "TND" && opParamInfo_.actualSeqLengths.tensor == nullptr) { + OPS_LOG_E(opName_, + "when the layout of key and value is TND, " + "the actualSeqLengths of key and value shoule not be empty."); + return ge::GRAPH_PARAM_INVALID; + } + if (ge::GRAPH_SUCCESS != CheckActualSeqLensDType() || + ge::GRAPH_SUCCESS != CheckActualSeqLensShape()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLensDType() +{ + if (opParamInfo_.actualSeqLengths.tensor == nullptr) { + return ge::GRAPH_SUCCESS; + } + if (opParamInfo_.actualSeqLengths.desc == nullptr) { + OPS_LOG_E(opName_, "actualSeqLengths is not empty," + "but actualSeqLengths's dtype is nullptr."); + return ge::GRAPH_FAILED; + } + if (opParamInfo_.actualSeqLengths.desc->GetDataType() != ge::DT_INT32) { + OPS_LOG_E(opName_, "actualSeqLengths's dtype is %s, it should be DT_INT32.", + SFADataTypeToSerialString(opParamInfo_.actualSeqLengths.desc->GetDataType()).c_str()); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckActualSeqLensShape() +{ + if (opParamInfo_.actualSeqLengths.tensor == nullptr) { + return ge::GRAPH_SUCCESS; + } + uint32_t shapeSize = 0; + if(GetActualSeqLenSize(shapeSize, opParamInfo_.actualSeqLengths.tensor, kvLayout_, "actualSeqLengths") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + if (shapeSize != bSize_) { + OPS_LOG_E(opName_, "actualSeqLengths shape size is %u, it should be equal to batch size[%u].", + shapeSize, bSize_); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckMultiParaConsistency() +{ + SetSFAShapeCompare(); + if (ge::GRAPH_SUCCESS != CheckVAndKRope() || + ge::GRAPH_SUCCESS != CheckQRope() || + ge::GRAPH_SUCCESS != CheckTopK() || + ge::GRAPH_SUCCESS != CheckAttenOut() || + ge::GRAPH_SUCCESS != CheckActualSeqLensQ() || + ge::GRAPH_SUCCESS != CheckActualSeqLens() || + ge::GRAPH_SUCCESS != CheckBlockTable()) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMlaNoQuantShape() const +{ + OPS_ERR_IF(bSize_ <= 0, + OPS_LOG_E(opName_, "batch_size should be greater than 0, but got %u", bSize_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(qTSize_ <= 0 && (qLayout_ == SFALayout::TND), + OPS_LOG_E(opName_, "T_size of query should be greater than 0, but got %u", qTSize_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(n1Size_ <= 0, + OPS_LOG_E(opName_, "q_head_num should be greater than 0, but got %u", n1Size_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(n2Size_ != 1, + OPS_LOG_E(opName_, "kv_head_num should be 1, but got %u", n2Size_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(n1Size_ % n2Size_ != 0, + OPS_LOG_E(opName_, "q_head_num(%u) must be divisible by kv_head_num(%u)", n1Size_, n2Size_), + return ge::GRAPH_FAILED); + + std::vector gSizeSupportList = {1, 2, 4, 8, 16, 32, 64, 128}; + OPS_ERR_IF(std::find(gSizeSupportList.begin(), gSizeSupportList.end(), gSize_) == gSizeSupportList.end(), + OPS_LOG_E(opName_, "group num should be in 1, 2, 4, 8, 16, 32, 64, 128, but got %u", gSize_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(qkHeadDim_ != 512, + OPS_LOG_E(opName_, "qk_head_dim only support 512, but got %u", qkHeadDim_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(qkHeadDim_ != vHeadDim_, + OPS_LOG_E(opName_, "qk_head_dim[%u] should be equal to v_head_dim[%u]", qkHeadDim_, vHeadDim_), + return ge::GRAPH_FAILED); + + OPS_ERR_IF(ropeHeadDim_ != 64, + OPS_LOG_E(opName_, "rope_head_dim should be 64, but got %u", ropeHeadDim_), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMlaNoQuantLayout() const +{ + const std::vector layoutSupportList = { + "BSND", + "TND" + }; + std::string layoutQuery = opParamInfo_.layoutQuery; + OPS_ERR_IF(std::find(layoutSupportList.begin(), layoutSupportList.end(), layoutQuery) == layoutSupportList.end(), + OPS_LOG_E(opName_, "layoutQuery only supports BSND/TND, but got %s", layoutQuery.c_str()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMlaNoQuantDtype() const +{ + OPS_ERR_IF(inputQType_ != ge::DT_BF16 && inputQType_ != ge::DT_FLOAT16, + OPS_LOG_E(opName_, "query dtype only support %s and %s, but got %s", + SFADataTypeToSerialString(ge::DT_BF16).c_str(), SFADataTypeToSerialString(ge::DT_FLOAT16).c_str(), + SFADataTypeToSerialString(inputQType_).c_str()), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMlaNoquantPa() const +{ + if (kvStorageMode_ != KvStorageMode::PAGE_ATTENTION) { + return ge::GRAPH_SUCCESS; + } + + OPS_ERR_IF(blockSize_ <= 0 || blockSize_ > static_cast(MAX_BLOCK_SIZE), + OPS_LOG_E(opName_, "when page attention is enabled, block_size(%d) should be in range (0, %u].", + blockSize_, MAX_BLOCK_SIZE), return ge::GRAPH_FAILED); + + OPS_ERR_IF(blockSize_ % 16 > 0, + OPS_LOG_E(opName_, "when page attention is enabled, block_size(%d) should be 16-aligned.", + blockSize_), return ge::GRAPH_FAILED); + + OPS_ERR_IF(blockSize_ % sparseBlockSize_ > 0, + OPS_LOG_E(opName_, "when page attention is enabled, block_size(%d) must be divided by sparse_block_size(%d), but now the remainder is %d.", + blockSize_, sparseBlockSize_, blockSize_ % sparseBlockSize_), return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMlaNoquant() const +{ + if (ge::GRAPH_SUCCESS != CheckFeatureMlaNoQuantShape() || + ge::GRAPH_SUCCESS != CheckFeatureMlaNoQuantLayout() || + ge::GRAPH_SUCCESS != CheckFeatureMlaNoQuantDtype() || + ge::GRAPH_SUCCESS != CheckFeatureMlaNoquantPa()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFATilingCheck::CheckFeatureMla() const +{ + return CheckFeatureMlaNoquant(); +} + +ge::graphStatus SFATilingCheck::CheckFeature() const +{ + return CheckFeatureMla(); +} + +void SFATilingCheck::Init() +{ + opName_ = sfaInfo_.opName; + platformInfo_ = sfaInfo_.platformInfo; + opParamInfo_ = sfaInfo_.opParamInfo; + socVersion_ = sfaInfo_.socVersion; + + bSize_ = sfaInfo_.bSize; + n1Size_ = sfaInfo_.n1Size; + n2Size_ = sfaInfo_.n2Size; + s1Size_ = sfaInfo_.s1Size; + s2Size_ = sfaInfo_.s2Size; + gSize_ = sfaInfo_.gSize; + qkHeadDim_ = sfaInfo_.qkHeadDim; + vHeadDim_ = sfaInfo_.vHeadDim; + ropeHeadDim_ = sfaInfo_.ropeHeadDim; + maxBlockNumPerBatch_ = sfaInfo_.maxBlockNumPerBatch; + qTSize_ = sfaInfo_.qTSize; + kvTSize_ = sfaInfo_.kvTSize; + blockSize_ = sfaInfo_.blockSize; + sparseBlockCount_ = sfaInfo_.sparseBlockCount; + sparseBlockSize_ = sfaInfo_.sparseBlockSize; + + inputQType_ = sfaInfo_.inputQType; + inputKvType_ = sfaInfo_.inputKvType; + inputQRopeType_ = sfaInfo_.inputQRopeType; + inputKRopeType_ = sfaInfo_.inputKRopeType; + outputType_ = sfaInfo_.outputType; + + qLayout_ = sfaInfo_.qLayout; + topkLayout_ = sfaInfo_.topkLayout; + kvLayout_ = sfaInfo_.kvLayout; + outLayout_ = sfaInfo_.outLayout; + + kvStorageMode_ = sfaInfo_.kvStorageMode; + l2CacheSize_ = sfaInfo_.l2CacheSize; +} + +ge::graphStatus SFATilingCheck::Process() +{ + Init(); + if (CheckSinglePara() != ge::GRAPH_SUCCESS || + CheckParaExistence() != ge::GRAPH_SUCCESS || + CheckFeature() != ge::GRAPH_SUCCESS || + CheckMultiParaConsistency() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +bool SFAInfoParser::HasAxis(const SFAAxis &axis, const SFALayout &layout, const gert::Shape &shape) const +{ + const auto& layoutIt = SFA_LAYOUT_AXIS_MAP.find(layout); + if (layoutIt == SFA_LAYOUT_AXIS_MAP.end()) { + return false; + } + + const std::vector& axes = layoutIt->second; + const auto& axisIt = std::find(axes.begin(), axes.end(), axis); + if (axisIt == axes.end()) { + return false; + } + const auto& dimIt = SFA_LAYOUT_DIM_MAP.find(layout); + if (dimIt == SFA_LAYOUT_DIM_MAP.end() || dimIt->second != shape.GetDimNum()) { + return false; + } + return true; +} + +size_t SFAInfoParser::GetAxisIdx(const SFAAxis &axis, const SFALayout &layout) const +{ + const std::vector& axes = SFA_LAYOUT_AXIS_MAP.find(layout)->second; + const auto& axisIt = std::find(axes.begin(), axes.end(), axis); + return std::distance(axes.begin(), axisIt); +} + +uint32_t SFAInfoParser::GetAxisNum(const gert::Shape &shape, const SFAAxis &axis,const SFALayout &layout) const +{ + return HasAxis(axis, layout, shape) ? shape.GetDim(GetAxisIdx(axis, layout)) : invalidDimValue_; +} + +ge::graphStatus SFAInfoParser::CheckRequiredInOutExistence() const +{ + OPS_ERR_IF(opParamInfo_.query.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor query is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.query.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor query is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor k is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.key.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor k is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.value.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor value is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.value.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor value is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.sparseIndices.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor sparseIndices is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.sparseIndices.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor sparseIndices is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.attenOut.shape == nullptr, OPS_LOG_E(opName_, "Shape of tensor output is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.attenOut.desc == nullptr, OPS_LOG_E(opName_, "Desc of tensor output is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.queryRope.tensor == nullptr, OPS_LOG_E(opName_, "Shape of queryRope is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.queryRope.desc == nullptr, OPS_LOG_E(opName_, "Desc of queryRope is nullptr"), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::CheckRequiredAttrExistence() const +{ + OPS_ERR_IF(opParamInfo_.layoutQuery == nullptr, OPS_LOG_E(opName_, "attr layoutQuery is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.layoutKV == nullptr, OPS_LOG_E(opName_, "attr layoutKV is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.sparseBlockSize == nullptr, OPS_LOG_E(opName_, "attr sparseBlockSize is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.scaleValue == nullptr, OPS_LOG_E(opName_, "attr scaleValue is nullptr"), + return ge::GRAPH_FAILED); + OPS_ERR_IF(opParamInfo_.sparseMode == nullptr, OPS_LOG_E(opName_, "attr sparseMode is nullptr"), + return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::CheckRequiredParaExistence() const +{ + if (CheckRequiredInOutExistence() != ge::GRAPH_SUCCESS || + CheckRequiredAttrExistence() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + SFALayout &layout, const std::string &name) +{ + if ((tensor == nullptr)) { + OPS_LOG_E(opName_, "when layout of query is %s, %s must be provided.", + SFALayoutToSerialString(layout).c_str(), name.c_str()); + return ge::GRAPH_FAILED; + } + int64_t shapeSize = tensor->GetShapeSize(); + if (shapeSize <= 0) { + OPS_LOG_E(opName_, "the shape size of %s is %ld, it should be greater than 0.", + name.c_str(), shapeSize); + return ge::GRAPH_FAILED; + } + size = static_cast(shapeSize); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetActualSeqLenQSize(uint32_t &size) +{ + return GetActualSeqLenSize(size, opParamInfo_.actualSeqLengthsQ.tensor, qLayout_, "actualSeqLengthsQ"); +} + +ge::graphStatus SFAInfoParser::GetOpName() +{ + if (context_->GetNodeName() == nullptr) { + OPS_LOG_E("SparseFlashAttention", "opName got from TilingContext is nullptr"); + return ge::GRAPH_FAILED; + } + opName_ = context_->GetNodeName(); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetNpuInfo() +{ + platformInfo_ = context_->GetPlatformInfo(); + OPS_ERR_IF(platformInfo_ == nullptr, + OPS_REPORT_VECTOR_INNER_ERR(opName_, "GetPlatformInfo is nullptr."), return ge::GRAPH_FAILED); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo_); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint32_t aicNum = ascendcPlatform.GetCoreNumAic(); + OPS_ERR_IF(aicNum == 0 || aivNum == 0, + OPS_REPORT_VECTOR_INNER_ERR(opName_, "num of core obtained is 0."), return GRAPH_FAILED); + + socVersion_ = ascendcPlatform.GetSocVersion(); + if (socVersion_ != platform_ascendc::SocVersion::ASCEND910B) { + OPS_REPORT_VECTOR_INNER_ERR(opName_, "SOC Version[%d] is not support.", (int32_t)socVersion_); + return GRAPH_FAILED; + } + + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L2, l2CacheSize_); + + return ge::GRAPH_SUCCESS; +} + +void SFAInfoParser::GetOptionalInputParaInfo() +{ + opParamInfo_.blockTable.tensor = context_->GetOptionalInputTensor(BLOCK_TABLE_INPUT_INDEX); + opParamInfo_.actualSeqLengthsQ.tensor = context_->GetOptionalInputTensor(ACT_SEQ_LEN_Q_INPUT_INDEX); + opParamInfo_.actualSeqLengthsQ.desc = context_->GetOptionalInputDesc(ACT_SEQ_LEN_Q_INPUT_INDEX); + opParamInfo_.actualSeqLengths.tensor = context_->GetOptionalInputTensor(ACT_SEQ_LEN_KV_INPUT_INDEX); + opParamInfo_.actualSeqLengths.desc = context_->GetOptionalInputDesc(ACT_SEQ_LEN_KV_INPUT_INDEX); + opParamInfo_.queryRope.tensor = context_->GetOptionalInputTensor(QUERY_ROPE_INPUT_INDEX); + opParamInfo_.queryRope.desc = context_->GetOptionalInputDesc(QUERY_ROPE_INPUT_INDEX); + opParamInfo_.keyRope.tensor = context_->GetOptionalInputTensor(KEY_ROPE_INPUT_INDEX); + opParamInfo_.keyRope.desc = context_->GetOptionalInputDesc(KEY_ROPE_INPUT_INDEX); +} + +void SFAInfoParser::GetInputParaInfo() +{ + opParamInfo_.query.desc = context_->GetInputDesc(QUERY_INPUT_INDEX); + opParamInfo_.query.shape = context_->GetInputShape(QUERY_INPUT_INDEX); + opParamInfo_.key.desc = context_->GetInputDesc(KEY_INPUT_INDEX); + opParamInfo_.key.shape = context_->GetInputShape(KEY_INPUT_INDEX); + opParamInfo_.value.desc = context_->GetInputDesc(VALUE_INPUT_INDEX); + opParamInfo_.value.shape = context_->GetInputShape(VALUE_INPUT_INDEX); + opParamInfo_.sparseIndices.desc = context_->GetInputDesc(SPARSE_INDICES_INPUT_INDEX); + opParamInfo_.sparseIndices.shape = context_->GetInputShape(SPARSE_INDICES_INPUT_INDEX); + GetOptionalInputParaInfo(); +} + +void SFAInfoParser::GetOutputParaInfo() +{ + opParamInfo_.attenOut.desc = context_->GetOutputDesc(OUTPUT_INDEX); + opParamInfo_.attenOut.shape = context_->GetOutputShape(OUTPUT_INDEX); +} + +ge::graphStatus SFAInfoParser::GetAttrParaInfo() +{ + auto attrs = context_->GetAttrs(); + OPS_ERR_IF(attrs == nullptr, OPS_REPORT_VECTOR_INNER_ERR(context_->GetNodeName(), "attrs got from ge is nullptr"), + return ge::GRAPH_FAILED); + + opParamInfo_.layoutQuery = attrs->GetStr(LAYOUT_QUERY_ATTR_INDEX); + opParamInfo_.layoutKV = attrs->GetStr(LAYOUT_KV_ATTR_INDEX); + opParamInfo_.sparseBlockSize = attrs->GetAttrPointer(SPARSE_BLOCK_SIZE_ATTR_INDEX); + opParamInfo_.scaleValue = attrs->GetAttrPointer(SCALE_VALUE_ATTR_INDEX); + opParamInfo_.sparseMode = attrs->GetAttrPointer(SPARSE_MODE_ATTR_INDEX); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetOpParaInfo() +{ + GetInputParaInfo(); + GetOutputParaInfo(); + if (ge::GRAPH_SUCCESS != GetAttrParaInfo()) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetInOutDataType() +{ + inputQType_ = opParamInfo_.query.desc->GetDataType(); + inputKvType_ = opParamInfo_.key.desc->GetDataType(); + outputType_ = opParamInfo_.attenOut.desc->GetDataType(); + if (opParamInfo_.queryRope.desc != nullptr) { + inputQRopeType_ = opParamInfo_.queryRope.desc->GetDataType(); + } + if (opParamInfo_.keyRope.desc != nullptr) { + inputKRopeType_ = opParamInfo_.keyRope.desc->GetDataType(); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetBatchSize() +{ + // 获取B基准值 + // 1、非TND时, 以query的batch_size维度为基准; + // 2、TND时, actual_seq_lens_q必须传入, 以actual_seq_lens_q数组的长度为B轴大小 + if (qLayout_ == SFALayout::TND) { + return GetActualSeqLenQSize(bSize_); + } else { // BSND + bSize_ = GetAxisNum(queryShape_, SFAAxis::B, qLayout_); + return ge::GRAPH_SUCCESS; + } +} + +ge::graphStatus SFAInfoParser::GetQTSize() +{ + // 获取query的T基准值 + // 1、非TND时, 以query的batch_size维度为基准; + // 2、TND时, actual_seq_lens_q必须传入, 以actual_seq_lens_q数组的长度为B轴大小 + qTSize_ = (qLayout_ == SFALayout::TND) ? GetAxisNum(queryShape_, SFAAxis::T, qLayout_) : 0; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetKVTSize() +{ + // 获取query的T基准值 + // 1、非TND时, 以key的batch_size维度为基准; + // 2、TND时, actual_seq_lens_q必须传入, 以actual_seq_lens_q数组的长度为B轴大小 + kvTSize_ = (kvLayout_ == SFALayout::TND) ? GetAxisNum(keyShape_, SFAAxis::T, kvLayout_) : 0; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetQkHeadDim() +{ + // 获取qkHeadDim基准值 + // 以query的D维度为基准 + qkHeadDim_ = GetAxisNum(queryShape_, SFAAxis::D, qLayout_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetS1Size() +{ + // 获取S1基准值 + // 1、非TND时, 以query的S维度为基准; + // 2、TND时, actual_seq_lens_q必须传入, 以actual_seq_lens_q数组中的最大值为基准 + if (qLayout_ == SFALayout::TND) { + s1Size_ = GetAxisNum(queryShape_, SFAAxis::T, qLayout_); + return ge::GRAPH_SUCCESS; + } else { // BSND + s1Size_ = GetAxisNum(queryShape_, SFAAxis::S, qLayout_); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetKvStorageMode() +{ + if (kvLayout_ == SFALayout::PA_BSND) { + kvStorageMode_ = KvStorageMode::PAGE_ATTENTION; + } else { + kvStorageMode_ = KvStorageMode::BATCH_CONTINUOUS; + } + // kv存储模式基准值 + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetKvLayout() +{ + const map layoutKVMap = { + {"BSND", SFALayout::BSND}, + {"PA_BSND", SFALayout::PA_BSND}, + {"TND", SFALayout::TND} + }; + + std::string layout(opParamInfo_.layoutKV); + auto it = layoutKVMap.find(layout); + if (it != layoutKVMap.end()) { + kvLayout_ = it->second; + } else { + OPS_LOG_E(opName_, "layoutKV is %s, it is unsupported.", layout.c_str()); + return ge::GRAPH_FAILED; + } + if (kvLayout_ != SFALayout::PA_BSND && qLayout_ != kvLayout_) { + OPS_LOG_E(opName_, "When layoutKV is not PA_BSND, layoutKV must be the same as layoutQ."); + return ge::GRAPH_FAILED; + } + uint32_t keyDimNum = opParamInfo_.key.shape->GetStorageShape().GetDimNum(); + if (kvLayout_ == SFALayout::PA_BSND && keyDimNum != 4U) { + OPS_LOG_E(opName_, "When layoutKV is PA_BSND, kvDimNum must be 4, but now is %d.", keyDimNum); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetS2SizeForBatchContinuous() +{ + if (kvLayout_ == SFALayout::BSND) { // BSND + s2Size_ = GetAxisNum(keyShape_, SFAAxis::S, kvLayout_); + } else if (kvLayout_ == SFALayout::TND) { + s2Size_ = GetAxisNum(keyShape_, SFAAxis::T, kvLayout_); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetMaxBlockNumPerBatch() +{ + if (opParamInfo_.blockTable.tensor == nullptr) { + OPS_LOG_E(opName_, "the layout_kv is %s, blockTable must be provided.", SFALayoutToSerialString(kvLayout_).c_str()); + return ge::GRAPH_FAILED; + } + uint32_t dimNum = opParamInfo_.blockTable.tensor->GetStorageShape().GetDimNum(); + if (dimNum != DIM_NUM_TWO) { + OPS_LOG_E(opName_, "the dim num of block_table is %u, it should be %u.", dimNum, DIM_NUM_TWO); + return ge::GRAPH_FAILED; + } + if (opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1) <= 0) { + OPS_LOG_E(opName_, "%s's second dimension(%ld) should be greater than 0", + BLOCK_TABLE_NAME.c_str(), opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1)); + return ge::GRAPH_FAILED; + } + maxBlockNumPerBatch_ = opParamInfo_.blockTable.tensor->GetStorageShape().GetDim(1); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetBlockSize() +{ + blockSize_ = GetAxisNum(keyShape_, SFAAxis::Bs, kvLayout_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetSparseBlockCount() +{ + sparseBlockCount_ = GetAxisNum(sparseIndicesShape_, SFAAxis::K, qLayout_); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetS2SizeForPageAttention() +{ + if (GetMaxBlockNumPerBatch() != ge::GRAPH_SUCCESS || GetBlockSize() != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + s2Size_ = maxBlockNumPerBatch_ * blockSize_; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetS2Size() +{ + // 获取S2基准值 + // 1、BATCH_CONTINUOUS时, 从key的S轴获取 + // 2、PAGE_ATTENTION时, S2 = block_table.dim1 * block_size + if (kvStorageMode_ == KvStorageMode::BATCH_CONTINUOUS) { + return GetS2SizeForBatchContinuous(); + } + return GetS2SizeForPageAttention(); +} + +ge::graphStatus SFAInfoParser::GetValueHeadDim() +{ + // 获取vHeadDim基准值 + // 以value的D维度为基准 + vHeadDim_ = GetAxisNum(valueShape_, SFAAxis::D, kvLayout_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetRopeHeadDim() +{ + ropeHeadDim_ = GetAxisNum(queryRopeShape_, SFAAxis::D, qLayout_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetQueryAndOutLayout() +{ + // 获取query和attentionOut的Layout基准值 + // layoutQuery: {qLayout, outLayout} + const map> layoutMap = { + {"BSND", {SFALayout::BSND, SFALayout::BSND}}, + {"TND", {SFALayout::TND, SFALayout::TND }}, + }; + + std::string layout(opParamInfo_.layoutQuery); + auto it = layoutMap.find(layout); + if (it != layoutMap.end()) { + qLayout_ = it->second.first; + outLayout_ = it->second.second; + } else { + OPS_LOG_E(opName_, "layoutQuery is %s, it is unsupported.", layout.c_str()); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetTopkLayout() +{ + topkLayout_ = qLayout_; + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetN1Size() +{ + n1Size_ = GetAxisNum(queryShape_, SFAAxis::N, qLayout_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetN2Size() +{ + n2Size_ = GetAxisNum(keyShape_, SFAAxis::N, kvLayout_); + return ge::GRAPH_SUCCESS; +} + +void SFAInfoParser::SetSFAShape() +{ + queryShape_ = opParamInfo_.query.shape->GetStorageShape(); + keyShape_ = opParamInfo_.key.shape->GetStorageShape(); + valueShape_ = opParamInfo_.value.shape->GetStorageShape(); + sparseIndicesShape_ = opParamInfo_.sparseIndices.shape->GetStorageShape(); + queryRopeShape_ = opParamInfo_.queryRope.tensor->GetStorageShape(); +} + +ge::graphStatus SFAInfoParser::GetGSize() +{ + if (n2Size_ != 0) { + gSize_ = n1Size_ / n2Size_; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus SFAInfoParser::GetActualseqInfo() +{ + maxActualseq_ = static_cast(s2Size_); + if (opParamInfo_.actualSeqLengths.tensor != nullptr) { + actualLenDimsKV_ = opParamInfo_.actualSeqLengths.tensor->GetShapeSize(); + } + if (opParamInfo_.actualSeqLengthsQ.tensor != nullptr) { + actualLenDimsQ_ = opParamInfo_.actualSeqLengthsQ.tensor->GetShapeSize(); + } + return ge::GRAPH_SUCCESS; +} + +void SFAInfoParser::GenerateInfo(SFATilingInfo &sfaInfo) +{ + sfaInfo.opName = opName_; + sfaInfo.platformInfo = platformInfo_; + sfaInfo.opParamInfo = opParamInfo_; + sfaInfo.socVersion = socVersion_; + + sfaInfo.bSize = bSize_; + sfaInfo.n1Size = n1Size_; + sfaInfo.n2Size = n2Size_; + sfaInfo.s1Size = s1Size_; + sfaInfo.s2Size = s2Size_; + sfaInfo.gSize = gSize_; + sfaInfo.qkHeadDim = qkHeadDim_; + sfaInfo.vHeadDim = vHeadDim_; + sfaInfo.ropeHeadDim = ropeHeadDim_; + sfaInfo.qTSize = qTSize_; + sfaInfo.kvTSize = kvTSize_; + sfaInfo.sparseBlockSize = *opParamInfo_.sparseBlockSize; + sfaInfo.sparseBlockCount = sparseBlockCount_; + + sfaInfo.inputQType = inputQType_; + sfaInfo.inputKvType = inputKvType_; + sfaInfo.inputQRopeType = inputQRopeType_; + sfaInfo.inputKRopeType = inputKRopeType_; + sfaInfo.outputType = outputType_; + + sfaInfo.kvStorageMode = kvStorageMode_; + sfaInfo.l2CacheSize = l2CacheSize_; + + sfaInfo.totalBlockNum = opParamInfo_.key.shape->GetStorageShape().GetDim(0); + sfaInfo.scaleValue = *opParamInfo_.scaleValue; + sfaInfo.pageAttentionFlag = (kvStorageMode_ == KvStorageMode::PAGE_ATTENTION); + sfaInfo.blockSize = blockSize_; + sfaInfo.blockTypeSize = sizeof(float); + sfaInfo.maxBlockNumPerBatch = maxBlockNumPerBatch_; + + sfaInfo.actualLenDimsQ = actualLenDimsQ_; + sfaInfo.actualLenDimsKV = actualLenDimsKV_; + sfaInfo.maxActualseq = maxActualseq_; + sfaInfo.actualSeqLenFlag = (opParamInfo_.actualSeqLengths.tensor != nullptr); + sfaInfo.isSameSeqAllKVTensor = isSameSeqAllKVTensor_; + sfaInfo.isSameActualseq = isSameActualseq_; + + sfaInfo.sparseMode = *opParamInfo_.sparseMode; + + sfaInfo.qLayout = qLayout_; + sfaInfo.topkLayout = topkLayout_; + sfaInfo.kvLayout = kvLayout_; + sfaInfo.outLayout = outLayout_; +} + +ge::graphStatus SFAInfoParser::Parse(SFATilingInfo &sfaInfo) +{ + if (context_ == nullptr) { + OPS_LOG_E("SparseFlashAttention", "tiling context is nullptr!"); + return ge::GRAPH_FAILED; + } + OPS_LOG_FULL(DLOG_INFO, "SparseFlashAttention", "TilingContext: %s", SFADebugTilingContext(context_).c_str()); + if (ge::GRAPH_SUCCESS != GetOpName() || + ge::GRAPH_SUCCESS != GetNpuInfo() || + ge::GRAPH_SUCCESS != GetOpParaInfo() || + ge::GRAPH_SUCCESS != CheckRequiredParaExistence()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != GetInOutDataType() || + ge::GRAPH_SUCCESS != GetQueryAndOutLayout() || + ge::GRAPH_SUCCESS != GetTopkLayout() || + ge::GRAPH_SUCCESS != GetKvLayout() || + ge::GRAPH_SUCCESS != GetKvStorageMode()) { + return ge::GRAPH_FAILED; + } + + SetSFAShape(); + if ( + ge::GRAPH_SUCCESS != GetN1Size() || + ge::GRAPH_SUCCESS != GetN2Size() || + ge::GRAPH_SUCCESS != GetGSize() || + ge::GRAPH_SUCCESS != GetBatchSize() || + ge::GRAPH_SUCCESS != GetQTSize() || + ge::GRAPH_SUCCESS != GetKVTSize() || + ge::GRAPH_SUCCESS != GetS1Size() || + ge::GRAPH_SUCCESS != GetQkHeadDim() || + ge::GRAPH_SUCCESS != GetS2Size() || + ge::GRAPH_SUCCESS != GetValueHeadDim() || + ge::GRAPH_SUCCESS != GetRopeHeadDim() || + ge::GRAPH_SUCCESS != GetSparseBlockCount()) { + return ge::GRAPH_FAILED; + } + + if (ge::GRAPH_SUCCESS != GetActualseqInfo()) { + return ge::GRAPH_FAILED; + } + + GenerateInfo(sfaInfo); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(SparseFlashAttention) + .Tiling(TilingSparseFlashAttention) + .TilingParse(TilingPrepareForSparseFlashAttention); +} // namespace optiling diff --git a/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.h b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.h new file mode 100644 index 00000000000..2afeffc5da1 --- /dev/null +++ b/csrc/sparse_flash_attention/op_host/sparse_flash_attention_tiling.h @@ -0,0 +1,591 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_tiling.h + * \brief + */ +#ifndef SPARSE_FLASH_ATTENTION_TILING_H +#define SPARSE_FLASH_ATTENTION_TILING_H + +#include +#include +#include +#include +#include "register/tilingdata_base.h" +#include "exe_graph/runtime/tiling_context.h" + +namespace optiling { +// ------------------算子原型索引常量定义---------------- +// Inputs Index +constexpr uint32_t QUERY_INPUT_INDEX = 0; +constexpr uint32_t KEY_INPUT_INDEX = 1; +constexpr uint32_t VALUE_INPUT_INDEX = 2; +constexpr uint32_t SPARSE_INDICES_INPUT_INDEX = 3; +constexpr uint32_t BLOCK_TABLE_INPUT_INDEX = 4; +constexpr uint32_t ACT_SEQ_LEN_Q_INPUT_INDEX = 5; +constexpr uint32_t ACT_SEQ_LEN_KV_INPUT_INDEX = 6; +constexpr uint32_t QUERY_ROPE_INPUT_INDEX = 7; +constexpr uint32_t KEY_ROPE_INPUT_INDEX = 8; +// Outputs Index +constexpr uint32_t OUTPUT_INDEX = 0; +// Attributes Index +constexpr uint32_t SCALE_VALUE_ATTR_INDEX = 0; +constexpr uint32_t SPARSE_BLOCK_SIZE_ATTR_INDEX = 1; +constexpr uint32_t LAYOUT_QUERY_ATTR_INDEX = 2; +constexpr uint32_t LAYOUT_KV_ATTR_INDEX = 3; +constexpr uint32_t SPARSE_MODE_ATTR_INDEX = 4; +// Dim Num +constexpr size_t DIM_NUM_TWO = 2; +constexpr size_t DIM_NUM_THREE = 3; +constexpr size_t DIM_NUM_FOUR = 4; +// 常量 +constexpr uint32_t MAX_BLOCK_SIZE = 1024; +constexpr uint32_t COPYND2NZ_SRC_STRIDE_LIMITATION = 65535; +constexpr uint32_t NUM_BYTES_FLOAT = 4; +constexpr uint32_t NUM_BYTES_FLOAT16 = 2; +constexpr uint32_t NUM_BYTES_BF16 = 2; +constexpr uint32_t BYTE_BLOCK = 32; +const uint32_t SFA_MAX_AIC_CORE_NUM = 26; // 25 + 1 保证数组8字节对齐 + +// ------------------公共定义-------------------------- +enum class SFALayout : uint32_t { + BSND = 0, + TND = 1, + PA_BSND = 2 +}; + +struct SFATilingShapeCompareParam { + int64_t B = 1; + int64_t S = 1; + int64_t N = 1; + int64_t D = 1; + int64_t T = 1; + // PA + int64_t Bs = 1; + int64_t Bn = 1; +}; + +enum class KvStorageMode : uint32_t { + BATCH_CONTINUOUS = 0, + PAGE_ATTENTION = 1 +}; + +enum class SFAPerfMode : uint32_t { + C_TEMPLATE_MODE = 0, + V_TEMPLATE_MODE +}; + +enum class SFAAxis : uint32_t { + B = 0, + S = 1, + N = 2, + D = 3, + K = 3, // sparse_indices的K和key的D枚举值相同,表达相同位置, 最后一维 + T = 5, + Bn = 6, // block number + Bs = 7, // block size +}; + +struct SFARequiredParaInfo { + const gert::CompileTimeTensorDesc *desc; + const gert::StorageShape *shape; +}; + +struct SFAOptionalParaInfo { + const gert::CompileTimeTensorDesc *desc; + const gert::Tensor *tensor; +}; + +// -----------算子Tiling入参结构体定义--------------- +struct SFAParaInfo { + SFARequiredParaInfo query = {nullptr, nullptr}; + SFARequiredParaInfo key = {nullptr, nullptr}; + SFARequiredParaInfo value = {nullptr, nullptr}; + SFARequiredParaInfo sparseIndices = {nullptr, nullptr}; + SFAOptionalParaInfo blockTable = {nullptr, nullptr}; + SFAOptionalParaInfo actualSeqLengthsQ = {nullptr, nullptr}; + SFAOptionalParaInfo actualSeqLengths = {nullptr, nullptr}; + SFAOptionalParaInfo queryRope = {nullptr, nullptr}; + SFAOptionalParaInfo keyRope = {nullptr, nullptr}; + SFARequiredParaInfo attenOut = {nullptr, nullptr}; + + const char *layoutQuery = nullptr; + const char *layoutKV = nullptr; + const int64_t *sparseBlockSize = nullptr; + const float *scaleValue = nullptr; + const int64_t *sparseMode = nullptr; +}; + +struct InnerSplitParams { + uint32_t s1GBaseSize = 1; + uint32_t s2BaseSize = 1; +}; + +// -----------算子TilingData定义--------------- +BEGIN_TILING_DATA_DEF(SparseFlashAttentionBaseParamsMla) +TILING_DATA_FIELD_DEF(uint32_t, batchSize) +TILING_DATA_FIELD_DEF(uint32_t, seqSize) +TILING_DATA_FIELD_DEF(uint32_t, qSeqSize) +TILING_DATA_FIELD_DEF(int64_t, blockSize) +TILING_DATA_FIELD_DEF(uint32_t, maxBlockNumPerBatch) +TILING_DATA_FIELD_DEF(float, scaleValue) +TILING_DATA_FIELD_DEF(uint32_t, nNumOfQInOneGroup) +TILING_DATA_FIELD_DEF(uint32_t, actualLenDimsQ) +TILING_DATA_FIELD_DEF(uint32_t, actualLenDimsKV) +TILING_DATA_FIELD_DEF(uint32_t, outputLayout) +TILING_DATA_FIELD_DEF(uint32_t, sparseMode) +TILING_DATA_FIELD_DEF(int64_t, sparseBlockSize) +TILING_DATA_FIELD_DEF(uint32_t, sparseBlockCount) +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttentionBaseParamsMlaOp, SparseFlashAttentionBaseParamsMla) + +BEGIN_TILING_DATA_DEF(SparseFlashAttentionSingleCoreParamsMla) +TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum); +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSingleCoreParamsMlaOp, SparseFlashAttentionSingleCoreParamsMla) + +BEGIN_TILING_DATA_DEF(SparseFlashAttentionSingleCoreTensorSizeMla) +TILING_DATA_FIELD_DEF(uint32_t, mmResUbSize); +TILING_DATA_FIELD_DEF(uint32_t, bmm2ResUbSize); +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSingleCoreTensorSizeMlaOp, SparseFlashAttentionSingleCoreTensorSizeMla) + +BEGIN_TILING_DATA_DEF(SparseFlashAttentionSplitKVParamsMla) +TILING_DATA_FIELD_DEF(uint32_t, s2) // S2切分份数 +TILING_DATA_FIELD_DEF(uint32_t, accumOutSize) // FD workspace +TILING_DATA_FIELD_DEF(uint32_t, logSumExpSize) // FD workspace +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttentionSplitKVParamsMlaOp, SparseFlashAttentionSplitKVParamsMla) + +// 内切基本块参数 +BEGIN_TILING_DATA_DEF(SparseFlashAttentionInnerSplitParams) +TILING_DATA_FIELD_DEF(uint32_t, mBaseSize) +TILING_DATA_FIELD_DEF(uint32_t, s2BaseSize) +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttentionInnerSplitParamsOp, SparseFlashAttentionInnerSplitParams) + +BEGIN_TILING_DATA_DEF(SparseFlashAttentionTilingDataMla) +TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionBaseParamsMla, baseParams); +TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSplitKVParamsMla, splitKVParams); +TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSingleCoreParamsMla, singleCoreParams); +TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionSingleCoreTensorSizeMla, singleCoreTensorSize); +TILING_DATA_FIELD_DEF_STRUCT(SparseFlashAttentionInnerSplitParams, innerSplitParams); +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(SparseFlashAttention, SparseFlashAttentionTilingDataMla) + +template inline T Align(T num, T rnd) +{ + return (((rnd) == 0) ? 0 : (((num) + (rnd) - 1) / (rnd) * (rnd))); +} + +template +std::string SFAShape2String(const T &shape) +{ + std::ostringstream oss; + oss << "["; + if (shape.GetDimNum() > 0) { + for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) { + oss << shape.GetDim(i) << ", "; + } + oss << shape.GetDim(shape.GetDimNum() - 1); + } + oss << "]"; + return oss.str(); +} + +static std::string GetShapeStr(gert::Shape shape); +static std::string SFADataTypeToSerialString(ge::DataType type); +std::string SFATensorDesc2String(const gert::StorageShape *shape, const gert::CompileTimeTensorDesc *tensor); +std::string SFADebugTilingContext(const gert::TilingContext *context); +std::string SFALayoutToSerialString(SFALayout layout); + +// -----------算子Tiling入参信息类--------------- +struct SFATilingInfo { + const char *opName = nullptr; + fe::PlatFormInfos *platformInfo = nullptr; + SFAParaInfo opParamInfo; + + // Base Param + platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B; + uint32_t bSize = 0; + uint32_t n1Size = 0; + uint32_t n2Size = 0; + uint32_t s1Size = 0; + int64_t s2Size = 0; + uint32_t qkHeadDim = 0; + uint32_t vHeadDim = 0; + uint32_t gSize = 0; + uint32_t ropeHeadDim = 0; + uint32_t qTSize = 0; // 仅TND时生效 + uint32_t kvTSize = 0; // 仅TND时生效 + float scaleValue = 0; + uint32_t innerPrecise = 0; + uint32_t l2CacheOffFlag = 0; + int64_t sparseBlockSize = 0; + int64_t sparseBlockCount = 0; + + bool pageAttentionFlag = false; + int64_t blockSize = 0; + uint32_t blockTypeSize = 0; + uint32_t maxBlockNumPerBatch = 0; + uint32_t totalBlockNum = 0; + + uint32_t actualLenDimsQ = 0; + uint32_t maxActualseq = 0; + + bool actualSeqLenFlag = false; + bool isSameSeqAllKVTensor = true; + bool isSameActualseq = true; + uint32_t actualLenDimsKV = 0; + std::vector kvListSeqLens {}; + + uint32_t sparseMode = 0; + + ge::DataType inputQType = ge::DT_FLOAT16; + ge::DataType inputKvType = ge::DT_FLOAT16; + ge::DataType outputType = ge::DT_FLOAT16; + + KvStorageMode kvStorageMode = KvStorageMode::BATCH_CONTINUOUS; + + SFALayout qLayout = SFALayout::BSND; + SFALayout topkLayout = SFALayout::BSND; + SFALayout outLayout = SFALayout::BSND; + SFALayout kvLayout = SFALayout::BSND; + + ge::DataType inputQRopeType = ge::DT_FLOAT16; + ge::DataType inputKRopeType = ge::DT_FLOAT16; + + uint64_t l2CacheSize = 0; +}; + +// ---------------算子Tiling类--------------- +class SFAMlaTiling { +public: + explicit SFAMlaTiling(gert::TilingContext *context) : context_(context) {} + ge::graphStatus DoOpTiling(SFATilingInfo *sfaInfo); + +private: + ge::graphStatus SetBlockDim(uint32_t blockDim); + ge::graphStatus SetTilingKey(uint64_t tilingKey); + ge::graphStatus SetWorkspaceSize(uint64_t workspaceSize); + ge::graphStatus SetTilingData(TilingDef &tilingData); + gert::TilingContext *context_ = nullptr; + ge::graphStatus GetPlatformInfo(); + void GenTilingKey(); + bool DealSameSeqEachBatch(); + + void ZeroTensorProcess(); + void InitParams(); + + void Split(); + bool IsBalanceSplitCore(); + + void SplitBalanced(); + void CalcInnerSize(uint32_t s2Size); + + bool IsFlashDecode(uint32_t coreNum); + + void FillTilingBaseParamsMla(); + void FillTilingSplitKVMla(); + + void FillTilingSingleCoreParamsMla(); + void FillTilingSingleCoreTensorSizeMla(); + void FillTiling(); + + void CalcUbBmm(); + void CheckUbSpace(); + void NormalCalcFDWorkSpace(const uint32_t actCoreNum); + void CalcFDWorkSpace(const uint32_t actCoreNum); + void GetWorkspaceSize(); + + uint32_t CalcBalanceFDParamNums(const uint32_t actCoreNum); + + void CalcBlockDim(); + + bool balanceModeFlag_ = false; + bool splitKVFlag_ = false; + + uint32_t coreNum_ = 0; + SFAPerfMode perfMode_ = SFAPerfMode::V_TEMPLATE_MODE; + uint32_t kvSplitPart_ = 1; + size_t mmResUbSize_ = 0; + size_t bmm2ResUbSize_ = 0; + size_t qPreSizeMla_= 0; + uint32_t sInnerLoopTimes_ = 0; + uint32_t sInnerSize_ = 0; + uint32_t sInnerSizeTail_ = 0; + uint32_t sInnerSizeAlign_ = 0; + uint32_t kvSplit_ = 0; + uint32_t usedCoreNum_ = 0; + uint32_t formerCoreNum_ = 0; + uint32_t blockSplitBn2Range_ = 0; + uint32_t tailSplitedBatchRange_ = 0; + + uint32_t aicNum_ = 0; + uint32_t aivNum_ = 0; + size_t libapiSize_ = 0; + + SparseFlashAttentionTilingDataMla tilingData_; + uint32_t blockDim_{0}; + uint64_t workspaceSize_{0}; + uint64_t tilingKey_{0}; + + uint32_t headDimAlign_ = 0; + uint32_t mBaseSize_ = 128; + uint32_t mFdBaseSize_ = 8; + + SFATilingInfo *sfaInfo_ = nullptr; +}; + +// -----------算子Tiling入参信息解析及Check类--------------- +class SFATilingCheck { +public: + explicit SFATilingCheck(const SFATilingInfo &sfaInfo) : sfaInfo_(sfaInfo) {}; + ~SFATilingCheck() = default; + virtual ge::graphStatus Process(); +private: + void Init(); + void LogErrorDtypeSupport(const std::vector &expectDtypeList, + const ge::DataType &actualDtype, const std::string &name) const; + ge::graphStatus CheckDtypeSupport(const gert::CompileTimeTensorDesc *desc, + const std::string &name) const; + template void LogErrorNumberSupport(const std::vector &expectNumberList, + const T &actualValue, const std::string &name, const std::string subName) const; + template void LogErrorDimNumSupport(const std::vector &expectNumberList, + const T &actualValue, const std::string &name) const; + ge::graphStatus CheckDimNumSupport(const gert::StorageShape *shape, + const std::vector &expectDimNumList, const std::string &name) const; + ge::graphStatus CheckDimNumInLayoutSupport(const SFALayout &layout, + const gert::StorageShape *shape, const std::string &name) const; + void LogErrorLayoutSupport(const std::vector &expectLayoutList, + const SFALayout &actualLayout, const std::string &name) const; + ge::graphStatus GetExpectedShape(gert::Shape &shapeExpected, + const SFATilingShapeCompareParam ¶m, const SFALayout &layout) const; + ge::graphStatus CompareShape(SFATilingShapeCompareParam ¶m, + const gert::Shape &shape, const SFALayout &layout, const std::string &name) const; + ge::graphStatus CheckLayoutSupport(const SFALayout &actualLayout, const std::string &name) const; + ge::graphStatus CheckSingleParaQuery() const; + ge::graphStatus CheckSingleParaKey() const; + ge::graphStatus CheckSingleParaValue() const; + ge::graphStatus CheckSingleParaQueryRope() const; + ge::graphStatus CheckSingleParaKeyRope() const; + ge::graphStatus CheckSingleParaAttenOut() const; + ge::graphStatus CheckSingleParaNumHeads() const; + ge::graphStatus CheckSingleParaKvHeadNums() const; + ge::graphStatus CheckSingleParaLayout() const; + ge::graphStatus CheckSingleParaSparseMode() const; + ge::graphStatus CheckSingleParaSparseBlockSize() const; + ge::graphStatus CheckSingleParaSparseIndices() const; + ge::graphStatus CheckSinglePara() const; + ge::graphStatus CheckMultiParaConsistency() const; + ge::graphStatus CheckRopeExistence(); + ge::graphStatus CheckExists(const void *pointer, const std::string &name) const; + ge::graphStatus CheckNotExists(const void *pointer, const std::string &name) const; + ge::graphStatus CheckExistsByMap(const std::map ¶mMap) const; + ge::graphStatus CheckNotExistsByMap(const std::map ¶mMap) const; + ge::graphStatus CheckExistenceByMap(std::map &existMap, + std::map ¬ExistMap) const; + template ge::graphStatus CheckAttrValueByMap( + std::map> &attrMap) const; + ge::graphStatus CheckParaExistenceMlaNoquant() const; + ge::graphStatus CheckParaExistenceGqaNoquant() const; + ge::graphStatus CheckParaExistenceMla() const; + ge::graphStatus CheckParaExistence(); + ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + const SFALayout &layout, const std::string &name); + void SetSFAShapeCompare(); + ge::graphStatus CheckQRope(); + ge::graphStatus CheckQRopeShape(); + ge::graphStatus CheckVAndKRopeShapeForBatchContinuous(); + uint32_t GetTypeSize(ge::DataType dtype) const; + ge::graphStatus CheckVAndKRopeShapeForPageAttention(); + ge::graphStatus CheckVAndKRopeShape(); + ge::graphStatus CheckVAndKRope(); + ge::graphStatus CheckTopK(); + ge::graphStatus CheckTopkShape(); + ge::graphStatus CheckBlockTable() const; + ge::graphStatus CheckDTypeConsistency(const ge::DataType &actualDtype, + const ge::DataType &expectDtype, const std::string &name) const; + + ge::graphStatus CheckAttenOut(); + ge::graphStatus CheckAttenOutShape(); + ge::graphStatus CheckActualSeqLensQ(); + ge::graphStatus CheckActualSeqLensQShape(); + ge::graphStatus CheckActualSeqLensQDType(); + ge::graphStatus CheckActualSeqLens(); + ge::graphStatus CheckActualSeqLensDType(); + ge::graphStatus CheckActualSeqLensShape(); + ge::graphStatus CheckMultiParaConsistency(); + + ge::graphStatus CheckFeatureMlaNoQuantShape() const; + ge::graphStatus CheckFeatureMlaNoQuantLayout() const; + ge::graphStatus CheckFeatureMlaNoQuantDtype() const; + ge::graphStatus CheckFeatureMlaNoquantPa() const; + ge::graphStatus CheckFeatureMlaNoquant() const; + ge::graphStatus CheckFeatureMla() const; + ge::graphStatus CheckFeature() const; + +private: + const char *opName_; + fe::PlatFormInfos *platformInfo_; + SFAParaInfo opParamInfo_; + const SFATilingInfo &sfaInfo_; + + uint32_t bSize_ = 0; + uint32_t n1Size_ = 0; + uint32_t n2Size_ = 0; + uint32_t gSize_ = 0; + uint32_t s1Size_ = 0; + int64_t s2Size_ = 0; + uint32_t qkHeadDim_ = 0; + uint32_t vHeadDim_ = 0; + uint32_t ropeHeadDim_ = 0; + uint32_t qTSize_ = 0; // 仅TND时生效 + uint32_t kvTSize_ = 0; // 仅TND时生效 + KvStorageMode kvStorageMode_ = KvStorageMode::BATCH_CONTINUOUS; + uint32_t sparseBlockCount_ = 0; + int64_t sparseBlockSize_ = 0; + + SFALayout qLayout_ = SFALayout::BSND; + SFALayout topkLayout_ = SFALayout::BSND; + SFALayout outLayout_ = SFALayout::BSND; + SFALayout kvLayout_ = SFALayout::BSND; + + uint32_t maxBlockNumPerBatch_ = 0; + int64_t blockSize_ = 0; + + uint32_t aicNum_ = 0; + uint32_t aivNum_ = 0; + platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B; + uint64_t l2CacheSize_ = 0; + + ge::DataType inputQType_ = ge::DT_FLOAT16; + ge::DataType inputKvType_ = ge::DT_FLOAT16; + ge::DataType outputType_ = ge::DT_FLOAT16; + ge::DataType inputQRopeType_ = ge::DT_FLOAT16; + ge::DataType inputKRopeType_ = ge::DT_FLOAT16; + + gert::Shape queryShapeCmp_{}; + gert::Shape keyShapeCmp_{}; + gert::Shape valueShapeCmp_{}; + gert::Shape topkShapeCmp_{}; + gert::Shape queryRopeShapeCmp_{}; + gert::Shape keyRopeShapeCmp_{}; + gert::Shape attenOutShapeCmp_{}; +}; + +class SFAInfoParser { +public: + explicit SFAInfoParser(const gert::TilingContext *context) : context_(context) {} + ~SFAInfoParser() = default; + + ge::graphStatus CheckRequiredInOutExistence() const; + ge::graphStatus CheckRequiredAttrExistence() const; + ge::graphStatus CheckRequiredParaExistence() const; + + ge::graphStatus GetActualSeqLenSize(uint32_t &size, const gert::Tensor *tensor, + SFALayout &layout, const std::string &name); + ge::graphStatus GetActualSeqLenQSize(uint32_t &size); + ge::graphStatus GetOpName(); + ge::graphStatus GetNpuInfo(); + void GetOptionalInputParaInfo(); + void GetInputParaInfo(); + void GetOutputParaInfo(); + ge::graphStatus GetAttrParaInfo(); + ge::graphStatus GetKvCache(); + ge::graphStatus GetOpParaInfo(); + + ge::graphStatus GetInOutDataType(); + ge::graphStatus GetBatchSize(); + ge::graphStatus GetQTSize(); + ge::graphStatus GetKVTSize(); + ge::graphStatus GetQkHeadDim(); + ge::graphStatus GetS1Size(); + ge::graphStatus GetKvStorageMode(); + ge::graphStatus GetKvLayout(); + void SetSFAShape(); + ge::graphStatus GetS2SizeForBatchContinuous(); + ge::graphStatus GetMaxBlockNumPerBatch(); + ge::graphStatus GetBlockSize(); + ge::graphStatus GetS2SizeForPageAttention(); + ge::graphStatus GetS2Size(); + ge::graphStatus GetValueHeadDim(); + ge::graphStatus GetRopeHeadDim(); + ge::graphStatus GetQueryAndOutLayout(); + ge::graphStatus GetTopkLayout(); + ge::graphStatus GetN1Size(); + ge::graphStatus GetN2Size(); + ge::graphStatus GetGSize(); + ge::graphStatus GetSparseBlockCount(); + ge::graphStatus GetActualseqInfo(); + void GenerateInfo(SFATilingInfo &sfaInfo); + ge::graphStatus Parse(SFATilingInfo &sfaInfo); + +public: + bool HasAxis(const SFAAxis &axis, const SFALayout &layout, const gert::Shape &shape) const; + size_t GetAxisIdx(const SFAAxis &axis, const SFALayout &layout) const; + uint32_t GetAxisNum(const gert::Shape &shape, const SFAAxis &axis,const SFALayout &layout) const; + + const gert::TilingContext *context_ = nullptr; + + const char *opName_; + fe::PlatFormInfos *platformInfo_; + SFAParaInfo opParamInfo_; + static constexpr int64_t invalidDimValue_ = std::numeric_limits::min(); + + uint32_t bSize_ = 0; + uint32_t n1Size_ = 0; + uint32_t n2Size_ = 0; + uint32_t gSize_ = 0; + uint32_t s1Size_ = 0; + int64_t s2Size_ = 0; + uint32_t qkHeadDim_ = 0; + uint32_t vHeadDim_ = 0; + uint32_t ropeHeadDim_ = 0; + uint32_t qTSize_ = 0; // 仅TND时生效 + uint32_t kvTSize_ = 0; // 仅TND时生效 + KvStorageMode kvStorageMode_ = KvStorageMode::BATCH_CONTINUOUS; + uint32_t sparseBlockCount_ = 0; + + SFALayout qLayout_ = SFALayout::BSND; + SFALayout topkLayout_ = SFALayout::BSND; + SFALayout outLayout_ = SFALayout::BSND; + SFALayout kvLayout_ = SFALayout::BSND; + + uint32_t maxBlockNumPerBatch_ = 0; + uint32_t blockSize_ = 0; + + platform_ascendc::SocVersion socVersion_ = platform_ascendc::SocVersion::ASCEND910B; + + ge::DataType inputQType_ = ge::DT_FLOAT16; + ge::DataType inputKvType_ = ge::DT_FLOAT16; + ge::DataType outputType_ = ge::DT_FLOAT16; + ge::DataType inputQRopeType_ = ge::DT_FLOAT16; + ge::DataType inputKRopeType_ = ge::DT_FLOAT16; + + uint64_t l2CacheSize_ = 0; + + bool isSameSeqAllKVTensor_ = true; + bool isSameActualseq_ = true; + uint32_t maxActualseq_ = 0; + + uint32_t actualLenDimsQ_ = 0; + uint32_t actualLenDimsKV_ = 0; + + gert::Shape queryShape_{}; + gert::Shape keyShape_{}; + gert::Shape valueShape_{}; + gert::Shape sparseIndicesShape_{}; + gert::Shape queryRopeShape_{}; + gert::Shape keyRopeShape_{}; +}; +} // namespace optiling +#endif // SPARSE_FLASH_ATTENTION_TILING_H diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention.cpp b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention.cpp new file mode 100644 index 00000000000..a71306b57f4 --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention.cpp @@ -0,0 +1,53 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + /*! + * \file sparse_flash_attention.cpp + * \brief + */ + +#include "kernel_operator.h" +#include "sparse_flash_attention_template_tiling_key.h" +#include "sparse_flash_attention_kernel_mla.h" + +using namespace AscendC; + +#define SFA_OP_IMPL(templateClass, tilingdataClass, ...) \ + do { \ + templateClass> op; \ + GET_TILING_DATA_WITH_STRUCT(tilingdataClass, tiling_data_in, tiling); \ + const tilingdataClass *__restrict tiling_data = &tiling_data_in; \ + op.Init(query, key, value, sparseIndices, actualSeqLengthsQuery, actualSeqLengthsKV, \ + blocktable, queryRope, keyRope, attentionOut, user, tiling_data, tiling, &tPipe); \ + op.Process(); \ + } while (0) + +template + __global__ __aicore__ void +sparse_flash_attention(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value, + __gm__ uint8_t *sparseIndices, __gm__ uint8_t *blocktable, + __gm__ uint8_t *actualSeqLengthsQuery, __gm__ uint8_t *actualSeqLengthsKV, + __gm__ uint8_t* queryRope, __gm__ uint8_t* keyRope, + __gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace, __gm__ uint8_t *tiling) +{ + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); + + TPipe tPipe; + __gm__ uint8_t *user = GetUserWorkspace(workspace); + + if constexpr (ORIG_DTYPE_QUERY == DT_FLOAT16 && ORIG_DTYPE_KEY == DT_FLOAT16 && + ORIG_DTYPE_ATTENTION_OUT == DT_FLOAT16) { + SFA_OP_IMPL(SparseFlashAttentionMla, SparseFlashAttentionTilingDataMla, half, half, half, + FLASH_DECODE, static_cast(LAYOUT_T), static_cast(KV_LAYOUT_T), TEMPLATE_MODE); + } else { // bf16 + SFA_OP_IMPL(SparseFlashAttentionMla, SparseFlashAttentionTilingDataMla, bfloat16_t, bfloat16_t, bfloat16_t, + FLASH_DECODE, static_cast(LAYOUT_T), static_cast(KV_LAYOUT_T), TEMPLATE_MODE); + } +} \ No newline at end of file diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_common.h b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_common.h new file mode 100644 index 00000000000..91916ba8cab --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_common.h @@ -0,0 +1,198 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_common.h + * \brief + */ + +#ifndef SPARSE_FLASH_ATTENTION_COMMON_H +#define SPARSE_FLASH_ATTENTION_COMMON_H + +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" + +using namespace AscendC; +// 将isCheckTiling设置为false, 输入输出的max&sum&exp的shape为(m, 1) +constexpr SoftmaxConfig SFA_SOFTMAX_FLASHV2_CFG_WITHOUT_BRC = {false, 0, 0, SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC}; + +enum class SFA_LAYOUT +{ + BSND = 0, + TND = 1, + PA_BSND = 2, +}; + +template +struct SFAType { + using queryType = Q_T; + using kvType = KV_T; + using outputType = OUT_T; + static constexpr bool flashDecode = FLASH_DECODE; + static constexpr SFA_LAYOUT layout = LAYOUT_T; + static constexpr SFA_LAYOUT kvLayout = KV_LAYOUT_T; + static constexpr int templateMode = TEMPLATE_MODE; + static constexpr bool pageAttention = (KV_LAYOUT_T == SFA_LAYOUT::PA_BSND); +}; + +// ================================Util functions================================== +template __aicore__ inline T SFAAlign(T num, T rnd) +{ + return (((rnd) == 0) ? 0 : (((num) + (rnd) - 1) / (rnd) * (rnd))); +} + +template __aicore__ inline T1 Min(T1 a, T2 b) +{ + return (a > b) ? (b) : (a); +} + +template __aicore__ inline size_t BlockAlign(size_t s) +{ + if constexpr (IsSameType::value) { + return (s + 63) / 64 * 64; + } + size_t n = (32 / sizeof(T)); + return (s + n - 1) / n * n; +} + +struct RunInfo { + uint32_t loop; + uint32_t bIdx; + uint32_t gIdx; + uint32_t s1Idx; + uint32_t s2Idx; + uint32_t bn2IdxInCurCore; + uint32_t curSInnerLoopTimes; + uint64_t tndBIdxOffsetForQ; + uint64_t tndBIdxOffsetForKV; + uint64_t tensorAOffset; + uint64_t tensorBOffset; + uint64_t tensorARopeOffset; + uint64_t tensorBRopeOffset; + uint64_t attenOutOffset; + uint64_t attenMaskOffset; + uint64_t topKBaseOffset; + uint32_t actualSingleProcessSInnerSize; + uint32_t actualSingleProcessSInnerSizeAlign; + bool isFirstSInnerLoop; + bool isChangeBatch; + uint32_t s2BatchOffset; + uint32_t gSize; + uint32_t s1Size; + uint32_t s2Size; + uint32_t mSize; + uint32_t mSizeV; + uint32_t mSizeVStart; + uint32_t tndIsS2SplitCore; + uint32_t tndCoreStartKVSplitPos; + bool isBmm2Output; + bool isValid = false; + + static constexpr uint32_t n2Idx = 0; + uint64_t actS1Size = 1; + uint64_t curActualSeqLenOri = 0ULL; + + uint32_t gS1Idx; + uint64_t actS2Size = 1; + uint32_t actMBaseSize; + bool isLastS2Loop; + int32_t nextTokensPerBatch = 0; + int64_t threshold; + uint32_t curTopKIdx = 0; + uint64_t curOffsetInSparseBlock = 0; +}; + +struct ConstInfo { + // CUBE与VEC核间同步的模式 + static constexpr uint32_t SFA_SYNC_MODE2 = 2; + // BUFFER的字节数 + static constexpr uint32_t BUFFER_SIZE_BYTE_32B = 32; + static constexpr uint32_t BUFFER_SIZE_BYTE_64B = 64; + static constexpr uint32_t BUFFER_SIZE_BYTE_256B = 256; + static constexpr uint32_t BUFFER_SIZE_BYTE_512B = 512; + static constexpr uint32_t BUFFER_SIZE_BYTE_1K = 1024; + static constexpr uint32_t BUFFER_SIZE_BYTE_2K = 2048; + static constexpr uint32_t BUFFER_SIZE_BYTE_4K = 4096; + static constexpr uint32_t BUFFER_SIZE_BYTE_8K = 8192; + static constexpr uint32_t BUFFER_SIZE_BYTE_16K = 16384; + static constexpr uint32_t BUFFER_SIZE_BYTE_32K = 32768; + // FP32的0值和极大值 + static constexpr float FLOAT_ZERO = 0; + static constexpr float FLOAT_MAX = 3.402823466e+38F; + + // preLoad的总次数 + uint32_t preLoadNum = 0U; + uint32_t nBufferMBaseSize = 0U; + // CUBE和VEC的核间同步EventID + uint32_t syncV1NupdateC2 = 0U; + uint32_t syncV0C1 = 0U; + uint32_t syncC1V1 = 0U; + uint32_t syncV1C2 = 0U; + uint32_t syncC2V2 = 0U; + uint32_t syncC2V1 = 0U; + + uint32_t mmResUbSize = 0U; // Matmul1输出结果GM上的大小 + uint32_t vec1ResUbSize = 0U; // Vector1输出结果GM上的大小 + uint32_t bmm2ResUbSize = 0U; // Matmul2输出结果GM上的大小 + uint64_t batchSize = 0ULL; + uint64_t gSize = 0ULL; + uint64_t qHeadNum = 0ULL; + uint64_t kvHeadNum; + uint64_t headDim; + uint64_t headDimRope; + uint64_t kvSeqSize = 0ULL; // kv最大S长度 + uint64_t qSeqSize = 1ULL; // q最大S长度 + int64_t kvCacheBlockSize = 0; // PA场景的block size + uint32_t maxBlockNumPerBatch = 0; // PA场景的最大单batch block number + uint32_t splitKVNum = 0U; // S2核间切分的切分份数 + SFA_LAYOUT outputLayout; // 输出的Transpose格式 + uint32_t sparseMode = 0; + bool needInit = false; + + // FlashDecoding + uint32_t actualCombineLoopSize = 0U; // FlashDecoding场景, S2在核间切分的最大份数 + uint64_t combineLseOffset = 0ULL; + uint64_t combineAccumOutOffset = 0ULL; + + uint32_t actualLenDimsQ = 0U; // query的actualSeqLength 的维度 + uint32_t actualLenDimsKV = 0U; // KV 的actualSeqLength 的维度 + + // TND + uint32_t s2Start = 0U; // TND场景下,S2的起始位置 + uint32_t s2End = 0U; // 单核TND场景下S2循环index上限 + + uint32_t bN2Start = 0U; + uint32_t bN2End = 0U; + uint32_t gS1Start = 0U; + uint32_t gS1End = 0U; + + uint32_t tndFDCoreArrLen = 0U; // TNDFlashDecoding相关分核信息array的长度 + uint32_t coreStartKVSplitPos = 0U; // TNDFlashDecoding kv起始位置 + + uint32_t mBaseSize = 1ULL; + uint32_t s2BaseSize = 1ULL; + + // sparse attr + int64_t sparseBlockSize = 0; + uint32_t sparseBlockCount = 0; +}; + +struct MSplitInfo { + uint32_t nBufferIdx = 0U; + uint32_t nBufferStartM = 0U; + uint32_t nBufferDealM = 0U; + uint32_t vecStartM = 0U; + uint32_t vecDealM = 0U; +}; + +#endif // SPARSE_FLASH_ATTENTION_COMMON_H \ No newline at end of file diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_kernel_mla.h b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_kernel_mla.h new file mode 100644 index 00000000000..62de0295594 --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_kernel_mla.h @@ -0,0 +1,987 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_kernel_mla.h + * \brief + */ + +#ifndef SPARSE_FLASH_ATTENTION_KERNEL_MLA_H +#define SPARSE_FLASH_ATTENTION_KERNEL_MLA_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "sparse_flash_attention_common.h" +#include "sparse_flash_attention_service_cube_mla.h" +#include "sparse_flash_attention_service_vector_mla.h" + +using namespace matmul; +using AscendC::CacheMode; +using AscendC::CrossCoreSetFlag; +using AscendC::CrossCoreWaitFlag; + +// 由于S2循环前,RunInfo还没有赋值,使用Bngs1Param临时存放B、N、S1轴相关的信息;同时减少重复计算 +struct TempLoopInfo { + uint32_t bn2IdxInCurCore = 0; + uint32_t bIdx = 0U; + uint32_t n2Idx = 0U; + uint64_t s2BasicSizeTail = 0U; // S2方向循环的尾基本块大小 + uint32_t s2LoopTimes = 0U; // S2方向循环的总次数,无论TND还是BXXD都是等于实际次数,不用减1 + uint64_t curActualSeqLen = 0ULL; + uint64_t curActualSeqLenOri = 0ULL; + bool curActSeqLenIsZero = false; + int32_t nextTokensPerBatch = 0; + + uint64_t actS1Size = 1ULL; // TND场景下当前Batch循环处理的S1轴的大小 + uint32_t tndCoreStartKVSplitPos; + bool tndIsS2SplitCore; + + uint32_t gS1Idx = 0U; + uint64_t mBasicSizeTail = 0U; // gS1方向循环的尾基本块大小 +}; + +template class SparseFlashAttentionMla { +public: + // 中间计算数据类型为float,高精度模式 + using T = float; + using Q_T = typename SFAT::queryType; + using KV_T = typename SFAT::kvType; + using OUT_T = typename SFAT::outputType; + using Q_ROPE_T = Q_T; + using K_ROPE_T = KV_T; + using UPDATE_T = T; + using MM1_OUT_T = T; + using MM2_OUT_T = T; + + __aicore__ inline SparseFlashAttentionMla(){}; + __aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value, + __gm__ uint8_t *sparseIndices, __gm__ uint8_t *actualSeqLengthsQ, + __gm__ uint8_t *actualSeqLengths, __gm__ uint8_t *blockTable, + __gm__ uint8_t *queryRope, __gm__ uint8_t *keyRope, + __gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace, + const SparseFlashAttentionTilingDataMla *__restrict tiling, + __gm__ uint8_t *gmTiling, TPipe *tPipe); + + __aicore__ inline void Process(); + +private: + static constexpr bool PAGE_ATTENTION = SFAT::pageAttention; + static constexpr int TEMPLATE_MODE = SFAT::templateMode; + static constexpr bool FLASH_DECODE = SFAT::flashDecode; + static constexpr SFA_LAYOUT LAYOUT_T = SFAT::layout; + static constexpr SFA_LAYOUT KV_LAYOUT_T = SFAT::kvLayout; + + static constexpr uint32_t PRELOAD_NUM = 2; + static constexpr uint32_t N_BUFFER_M_BASIC_SIZE = 256; + static constexpr uint32_t SFA_PRELOAD_TASK_CACHE_SIZE = 3; + + static constexpr uint32_t SYNC_V0_C1_FLAG = 6; + static constexpr uint32_t SYNC_C1_V1_FLAG = 7; + static constexpr uint32_t SYNC_V1_C2_FLAG = 8; + static constexpr uint32_t SYNC_C2_V2_FLAG = 9; + static constexpr uint32_t SYNC_C2_V1_FLAG = 4; + static constexpr uint32_t SYNC_V1_NUPDATE_C2_FLAG = 5; + + static constexpr uint64_t SYNC_MM2RES_BUF1_FLAG = 10; + static constexpr uint64_t SYNC_MM2RES_BUF2_FLAG = 11; + static constexpr uint64_t SYNC_FDOUTPUT_BUF_FLAG = 12; + + static constexpr uint32_t BLOCK_ELEMENT_NUM = SFAVectorService::BYTE_BLOCK / sizeof(T); + + static constexpr uint64_t kvHeadNum = 1ULL; + static constexpr uint64_t headDim = 512ULL; + static constexpr uint64_t headDimAlign = 512ULL; + static constexpr uint64_t headDimRope = 64ULL; + static constexpr uint32_t msdIterNum = 2U; + + static constexpr uint32_t dbWorkspaceRatio = PRELOAD_NUM; + + const SparseFlashAttentionTilingDataMla *__restrict tilingData = nullptr; + + TPipe *pipe = nullptr; + + uint64_t mSizeVStart = 0ULL; + int64_t threshold = 0; + uint64_t topKBaseOffset = 0ULL; + uint64_t s2BatchBaseOffset = 0; + uint64_t tensorACoreOffset = 0ULL; + uint64_t tensorBCoreOffset = 0ULL; + uint64_t tensorARopeCoreOffset = 0ULL; + uint64_t tensorBRopeCoreOffset = 0ULL; + uint64_t tensorBOffset = 0ULL; + uint64_t attenOutOffset = 0ULL; + + uint32_t tmpBlockIdx = 0U; + uint32_t aiCoreIdx = 0U; + uint32_t usedCoreNum = 0U; + + __gm__ uint8_t *keyPtr = nullptr; + __gm__ uint8_t *valuePtr = nullptr; + + ConstInfo constInfo{}; + TempLoopInfo tempLoopInfo{}; + + SFAMatmulService matmulService; + SFAVectorService vectorService; + + GlobalTensor queryGm; + GlobalTensor keyGm; + GlobalTensor valueGm; + GlobalTensor qRopeGm; + GlobalTensor kRopeGm; + + GlobalTensor attentionOutGm; + GlobalTensor blockTableGm; + GlobalTensor topKGm; + + GlobalTensor actualSeqLengthsQGm; + GlobalTensor actualSeqLengthsKVGm; + + // workspace + GlobalTensor mm1ResGm; + GlobalTensor vec1ResGm; + GlobalTensor mm2ResGm; + GlobalTensor kvMergeGm_; + GlobalTensor kvValidSizeGm_; + + GlobalTensor mm2ResInt32Gm; + GlobalTensor vec2ResGm; + + GlobalTensor accumOutGm; + GlobalTensor lseSumFdGm; + GlobalTensor lseMaxFdGm; + + // ================================Init functions=================================== + __aicore__ inline void InitTilingData(); + __aicore__ inline void InitCalcParamsEach(); + __aicore__ inline void InitBuffers(); + __aicore__ inline void InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, __gm__ uint8_t *actualSeqLengths); + __aicore__ inline void InitOutputSingleCore(); + // ================================Process functions================================ + __aicore__ inline void ProcessBalance(); + __aicore__ inline void PreloadPipeline(uint32_t loop, uint64_t s2Start, uint64_t s2LoopIdx, + RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE], uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock); + // ================================Offset Calc===================================== + __aicore__ inline void GetActualSeqLen(uint32_t bIdx, uint32_t s1Idx = 0); + __aicore__ inline void GetSparseActualSeqLen(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx); + __aicore__ inline void CalcSinnerTopKBegin(RunInfo &info, uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock); + __aicore__ inline void UpdateInnerLoopCond(); + __aicore__ inline void DealActSeqLenIsZero(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx); + __aicore__ inline void CalcParams(uint32_t loop, uint64_t s2Start, uint32_t s2LoopIdx, RunInfo &info); + __aicore__ inline void GetAxisStartIdx(uint32_t bN2EndPrev, uint32_t gS1EndPrev, uint32_t s2EndPrev); + __aicore__ inline uint64_t GetBalanceActualSeqLengths(GlobalTensor &actualSeqLengths, uint32_t bIdx); + __aicore__ inline uint32_t GetActualSeqLenKV(uint32_t bIdx); + __aicore__ inline void GetBN2Idx(uint32_t bN2Idx, uint32_t &bIdx, uint32_t &n2Idx); + __aicore__ inline void UpdateInner(uint32_t &s2End, uint32_t &curS2End, uint32_t s1Idx, bool isEnd); + __aicore__ inline void GetPreNextTokensLeftUp(); + // ================================Mm1============================================== + __aicore__ inline void ComputeMm1(const RunInfo &info); + // ================================Mm2============================================== + __aicore__ inline void ComputeMm2(const RunInfo &info); + __aicore__ inline void Bmm2DataCopyOut(uint64_t attenOutOffset, LocalTensor &attenOutUb, uint32_t startRow, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline void InitAllZeroOutput(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx); +}; + +template __aicore__ inline void SparseFlashAttentionMla::InitTilingData() +{ + usedCoreNum = tilingData->singleCoreParams.usedCoreNum; + constInfo.splitKVNum = tilingData->splitKVParams.s2; + constInfo.mmResUbSize = tilingData->singleCoreTensorSize.mmResUbSize; + constInfo.bmm2ResUbSize = tilingData->singleCoreTensorSize.bmm2ResUbSize; + constInfo.vec1ResUbSize = constInfo.mmResUbSize * msdIterNum; + + constInfo.batchSize = tilingData->baseParams.batchSize; + constInfo.qHeadNum = constInfo.gSize = tilingData->baseParams.nNumOfQInOneGroup; + constInfo.kvSeqSize = tilingData->baseParams.seqSize; + constInfo.qSeqSize = tilingData->baseParams.qSeqSize; + constInfo.maxBlockNumPerBatch = tilingData->baseParams.maxBlockNumPerBatch; + constInfo.kvCacheBlockSize = tilingData->baseParams.blockSize; + constInfo.outputLayout = static_cast(tilingData->baseParams.outputLayout); + constInfo.mBaseSize = tilingData->innerSplitParams.mBaseSize; + constInfo.s2BaseSize = tilingData->innerSplitParams.s2BaseSize; + constInfo.kvHeadNum = kvHeadNum; + constInfo.headDim = headDim; + constInfo.headDimRope = headDimRope; + constInfo.sparseBlockSize = tilingData->baseParams.sparseBlockSize; + constInfo.sparseBlockCount = tilingData->baseParams.sparseBlockCount; + constInfo.sparseMode = tilingData->baseParams.sparseMode; + + constInfo.preLoadNum = PRELOAD_NUM; + constInfo.nBufferMBaseSize = N_BUFFER_M_BASIC_SIZE; + constInfo.syncV0C1 = SYNC_V0_C1_FLAG; + constInfo.syncC1V1 = SYNC_C1_V1_FLAG; + constInfo.syncV1C2 = SYNC_V1_C2_FLAG; + constInfo.syncC2V2 = SYNC_C2_V2_FLAG; + constInfo.syncC2V1 = SYNC_C2_V1_FLAG; + constInfo.syncV1NupdateC2 = SYNC_V1_NUPDATE_C2_FLAG; +} + +template __aicore__ inline void SparseFlashAttentionMla::InitBuffers() +{ + if ASCEND_IS_AIV { + vectorService.InitBuffers(pipe); + } else { + matmulService.InitBuffers(pipe); + } +} + +template +__aicore__ inline void +SparseFlashAttentionMla::InitActualSeqLen(__gm__ uint8_t *actualSeqLengthsQ, + __gm__ uint8_t *actualSeqLengths) +{ + constInfo.actualLenDimsQ = tilingData->baseParams.actualLenDimsQ; + constInfo.actualLenDimsKV = tilingData->baseParams.actualLenDimsKV; + if (constInfo.actualLenDimsKV != 0) { + actualSeqLengthsKVGm.SetGlobalBuffer((__gm__ int32_t *)actualSeqLengths, constInfo.actualLenDimsKV); + } + if (constInfo.actualLenDimsQ != 0) { + actualSeqLengthsQGm.SetGlobalBuffer((__gm__ int32_t *)actualSeqLengthsQ, constInfo.actualLenDimsQ); + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::InitAllZeroOutput(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx) +{ + if (constInfo.outputLayout == SFA_LAYOUT::TND) { + uint32_t tBase = bIdx == 0 ? 0 : actualSeqLengthsQGm.GetValue(bIdx - 1); + uint32_t s1Count = tempLoopInfo.actS1Size; + + uint64_t attenOutOffset = (tBase + s1Idx) * kvHeadNum * constInfo.gSize * headDim + // T轴、s1轴偏移 + n2Idx * constInfo.gSize * headDim; // N2轴偏移 + matmul::InitOutput(attentionOutGm[attenOutOffset], constInfo.gSize * headDim, 0); + } else if (constInfo.outputLayout == SFA_LAYOUT::BSND) { + uint64_t attenOutOffset = bIdx * constInfo.qSeqSize * kvHeadNum * constInfo.gSize * headDim + + s1Idx * kvHeadNum * constInfo.gSize * headDim + // B轴、S1轴偏移 + n2Idx * constInfo.gSize * headDim; // N2轴偏移 + matmul::InitOutput(attentionOutGm[attenOutOffset], constInfo.gSize * headDim, 0); + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::InitOutputSingleCore() +{ + uint32_t coreNum = GetBlockNum(); + if (coreNum != 0) { + uint64_t totalOutputSize = constInfo.batchSize * constInfo.qHeadNum * constInfo.qSeqSize * constInfo.headDim; + uint64_t singleCoreSize = (totalOutputSize + (2 * coreNum) - 1) / (2 * coreNum); // 2 means c:v = 1:2 + uint64_t tailSize = totalOutputSize - tmpBlockIdx * singleCoreSize; + uint64_t singleInitOutputSize = tailSize < singleCoreSize ? tailSize : singleCoreSize; + if (singleInitOutputSize > 0) { + matmul::InitOutput(attentionOutGm[tmpBlockIdx * singleCoreSize], singleInitOutputSize, 0); + } + SyncAll(); + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::GetActualSeqLen(uint32_t bIdx, uint32_t s1Idx) +{ + tempLoopInfo.curActualSeqLenOri = GetActualSeqLenKV(bIdx); + tempLoopInfo.actS1Size = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx); +} + +template +__aicore__ inline void SparseFlashAttentionMla::GetSparseActualSeqLen(uint32_t bIdx, uint32_t s1Idx, + uint32_t n2Idx) +{ + if (tempLoopInfo.nextTokensPerBatch < 0 && s1Idx < (-tempLoopInfo.nextTokensPerBatch)) { //存在行无效 + tempLoopInfo.curActualSeqLen = 0; + return; + } + int64_t threshold = tempLoopInfo.curActualSeqLenOri; + if (constInfo.sparseMode == 3) { + threshold = static_cast(tempLoopInfo.nextTokensPerBatch) + s1Idx + 1; + } + + tempLoopInfo.curActualSeqLen = (constInfo.sparseBlockCount * constInfo.sparseBlockSize > threshold) ? + threshold : + constInfo.sparseBlockCount * constInfo.sparseBlockSize; +} + +template +__aicore__ inline uint32_t SparseFlashAttentionMla::GetActualSeqLenKV(uint32_t bIdx) +{ + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::TND) { + if (bIdx > 0) { + return actualSeqLengthsKVGm.GetValue(bIdx) - actualSeqLengthsKVGm.GetValue(bIdx - 1); + } else if (bIdx == 0) { + return actualSeqLengthsKVGm.GetValue(0); + } else { + return 0; + } + } else { + if (constInfo.actualLenDimsKV == 0) { + return constInfo.kvSeqSize; + } else if (constInfo.actualLenDimsKV == 1) { + return actualSeqLengthsKVGm.GetValue(0); + } else { + return actualSeqLengthsKVGm.GetValue(bIdx); + } + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::DealActSeqLenIsZero(uint32_t bIdx, uint32_t s1Idx, uint32_t n2Idx) +{ + if ASCEND_IS_AIV { + InitAllZeroOutput(bIdx, s1Idx, n2Idx); + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::GetPreNextTokensLeftUp() +{ + if (constInfo.sparseMode == 3) { + tempLoopInfo.nextTokensPerBatch = + static_cast(tempLoopInfo.curActualSeqLenOri) - static_cast(tempLoopInfo.actS1Size); + } +} + +template __aicore__ inline void SparseFlashAttentionMla::UpdateInnerLoopCond() +{ + if ((tempLoopInfo.curActualSeqLen == 0) || (tempLoopInfo.actS1Size == 0)) { + tempLoopInfo.curActSeqLenIsZero = true; + return; + } + tempLoopInfo.curActSeqLenIsZero = false; + tempLoopInfo.mBasicSizeTail = (tempLoopInfo.actS1Size * constInfo.gSize) % constInfo.mBaseSize; + tempLoopInfo.mBasicSizeTail = + (tempLoopInfo.mBasicSizeTail == 0) ? constInfo.mBaseSize : tempLoopInfo.mBasicSizeTail; + tempLoopInfo.s2LoopTimes = 0; +} + +template +__aicore__ inline void SparseFlashAttentionMla::UpdateInner(uint32_t &s2End, uint32_t &curS2End, + uint32_t s1Idx, bool isEnd) +{ + uint32_t s1BaseSize = 1; + int64_t s1Offset = s1BaseSize * s1Idx; + int64_t s2LastToken = Min(s1Offset + tempLoopInfo.nextTokensPerBatch + s1BaseSize,tempLoopInfo.curActualSeqLenOri); + s2LastToken = Min(constInfo.sparseBlockSize * constInfo.sparseBlockCount, s2LastToken); + curS2End = (s2LastToken + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; + tempLoopInfo.s2LoopTimes = isEnd ? constInfo.s2End + 1 : curS2End; +} + +template +__aicore__ inline void SparseFlashAttentionMla::Init(__gm__ uint8_t *query, + __gm__ uint8_t *key, __gm__ uint8_t *value, + __gm__ uint8_t *sparseIndices, __gm__ uint8_t *actualSeqLengthsQ, + __gm__ uint8_t *actualSeqLengths, __gm__ uint8_t *blockTable, + __gm__ uint8_t *queryRope, __gm__ uint8_t *keyRope, + __gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace, + const SparseFlashAttentionTilingDataMla *__restrict tiling, + __gm__ uint8_t *gmTiling, TPipe *tPipe) +{ + if ASCEND_IS_AIV { + tmpBlockIdx = GetBlockIdx(); // vec:0-47 + aiCoreIdx = tmpBlockIdx / 2; + } else { + tmpBlockIdx = GetBlockIdx(); // cube:0-23 + aiCoreIdx = tmpBlockIdx; + } + + // init tiling data + tilingData = tiling; + + InitTilingData(); + InitActualSeqLen(actualSeqLengthsQ, actualSeqLengths); + + // 初始化计算参数 + InitCalcParamsEach(); + pipe = tPipe; + keyPtr = key; + valuePtr = value; + + // init global buffer + queryGm.SetGlobalBuffer((__gm__ Q_T *)query); + keyGm.SetGlobalBuffer((__gm__ KV_T *)keyPtr); + valueGm.SetGlobalBuffer((__gm__ KV_T *)valuePtr); + qRopeGm.SetGlobalBuffer((__gm__ Q_ROPE_T *)queryRope); + kRopeGm.SetGlobalBuffer((__gm__ K_ROPE_T *)keyRope); + + attentionOutGm.SetGlobalBuffer((__gm__ OUT_T *)attentionOut); + + if ASCEND_IS_AIV { + if (constInfo.needInit && LAYOUT_T != SFA_LAYOUT::TND) { + InitOutputSingleCore(); + } + } + + if constexpr (PAGE_ATTENTION) { + blockTableGm.SetGlobalBuffer((__gm__ int32_t *)blockTable); + } + topKGm.SetGlobalBuffer((__gm__ int32_t *)sparseIndices); + + // workspace 内存排布 + // |Q--|mm1ResGm(存S)|vec1ResGm(存A1,A2)|mm2ResGm(存O)|vec2ResGm + // |Core0_Q1-Core0_Q2-Core1_Q1-Core1_Q2....Core32_Q1-Core32_Q2|Core0_mmRes + uint64_t offset = 0; + mm1ResGm.SetGlobalBuffer( + (__gm__ MM1_OUT_T *)(workspace + offset + + aiCoreIdx * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(MM1_OUT_T))); + offset += GetBlockNum() * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(MM1_OUT_T); + + vec1ResGm.SetGlobalBuffer( + (__gm__ KV_T *)(workspace + offset + aiCoreIdx * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(KV_T))); + offset += GetBlockNum() * dbWorkspaceRatio * constInfo.mmResUbSize * sizeof(KV_T); + + mm2ResGm.SetGlobalBuffer( + (__gm__ MM2_OUT_T *)(workspace + offset + + aiCoreIdx * dbWorkspaceRatio * constInfo.bmm2ResUbSize * sizeof(MM2_OUT_T))); + offset += GetBlockNum() * dbWorkspaceRatio * constInfo.bmm2ResUbSize * sizeof(MM2_OUT_T); + mm2ResInt32Gm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(mm2ResGm.GetPhyAddr(0))); + + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + // s2 d+rope bufNum + kvMergeGm_.SetGlobalBuffer((__gm__ KV_T *)(workspace + offset + aiCoreIdx * 512 * 576 * 4 * sizeof(KV_T))); + offset += GetBlockNum() * 512 * 576 * 4 * sizeof(KV_T); + + kvValidSizeGm_.SetGlobalBuffer( + (__gm__ int32_t *)(workspace + offset + (aiCoreIdx * 2) * 128 * 4 * sizeof(int32_t))); + } + + if constexpr (FLASH_DECODE) { + accumOutGm.SetGlobalBuffer((__gm__ float *)(workspace + offset)); + offset = offset + tilingData->splitKVParams.accumOutSize * sizeof(float); + lseSumFdGm.SetGlobalBuffer((__gm__ float *)(workspace + offset)); + lseMaxFdGm.SetGlobalBuffer((__gm__ float *)(workspace + offset) + tilingData->splitKVParams.logSumExpSize / 2); + offset = offset + tilingData->splitKVParams.logSumExpSize * sizeof(float); + } + + if ASCEND_IS_AIV { + vectorService.InitParams(constInfo, tilingData); + vectorService.InitMm2ResInt32GmGlobalTensor(mm2ResInt32Gm); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + vectorService.InitVec0GlobalTensor(kvValidSizeGm_, kvMergeGm_, kRopeGm, keyGm, blockTableGm); + } + vectorService.InitVec1GlobalTensor(mm1ResGm, vec1ResGm, actualSeqLengthsQGm, + actualSeqLengthsKVGm, lseMaxFdGm, lseSumFdGm, topKGm); + vectorService.InitVec2GlobalTensor(accumOutGm, vec2ResGm, mm2ResGm, attentionOutGm); + } + + if ASCEND_IS_AIC { + matmulService.InitParams(constInfo); + matmulService.InitMm1GlobalTensor(queryGm, qRopeGm, keyGm, kRopeGm, mm1ResGm); + matmulService.InitMm2GlobalTensor(vec1ResGm, valueGm, mm2ResGm, attentionOutGm); + matmulService.InitPageAttentionInfo(kvMergeGm_, blockTableGm, topKGm, + constInfo.kvCacheBlockSize, constInfo.maxBlockNumPerBatch); + } + // 要在InitParams之后执行 + if (pipe != nullptr) { + InitBuffers(); + } +} + +template __aicore__ inline void SparseFlashAttentionMla::InitCalcParamsEach() +{ + //计算总的基本块 + uint32_t totalBaseNum = 0; + uint32_t s1GBaseSize = constInfo.gSize; + uint32_t actBatchS2 = 1; + uint32_t coreNum = GetBlockNum(); + uint32_t currCoreIdx = aiCoreIdx; + uint32_t actBatchS1 = 1; + for (uint32_t bIdx = 0; bIdx < constInfo.batchSize; bIdx++) { + uint32_t actBatchS1 = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx); + if (actBatchS1 < constInfo.qSeqSize) { + constInfo.needInit = true; + } + totalBaseNum += actBatchS1*actBatchS2 ; + } + uint32_t avgBaseNum = 1; + if (totalBaseNum > coreNum) { + avgBaseNum = (totalBaseNum + coreNum - 1) / coreNum; + }else { + usedCoreNum = totalBaseNum; + } + if(aiCoreIdx>=usedCoreNum){ + return; + } + //计算当前核的基本块 + uint32_t accumBaseNum = 0; // 当前累积的基本块数 + uint32_t targetBaseNum = 0; + uint32_t lastValidBIdx = 0; + uint32_t lastValidactBatchS1=0; + bool setStart=false; + targetBaseNum = (currCoreIdx + 1) * avgBaseNum; // 计算当前的目标权重 + uint32_t targetStartBaseNum = targetBaseNum-avgBaseNum; + for (uint32_t bN2Idx = 0; bN2Idx < constInfo.batchSize * constInfo.kvHeadNum; bN2Idx++) { + uint32_t bIdx = bN2Idx / constInfo.kvHeadNum; + actBatchS1 = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bIdx); + for (uint32_t s1GIdx = 0; s1GIdx < actBatchS1; s1GIdx++) { + accumBaseNum += 1; + if(!setStart && accumBaseNum >= targetStartBaseNum){ + constInfo.bN2Start = bN2Idx; + constInfo.gS1Start = s1GIdx; + setStart=true; + } + if (accumBaseNum >= targetBaseNum) { + // 更新当前核的End分核信息 + constInfo.bN2End = bN2Idx; + constInfo.gS1End = s1GIdx; + constInfo.s2End = 0; + constInfo.coreStartKVSplitPos = 0; + if (aiCoreIdx != 0) { + GetAxisStartIdx(constInfo.bN2Start, constInfo.gS1Start, 0); + } + return; + } + } + if ((actBatchS1 > 0) && (actBatchS2 > 0)) { + lastValidBIdx = bIdx; + lastValidactBatchS1 = actBatchS1; + } + } + if (!setStart){ + constInfo.bN2Start = lastValidBIdx; + constInfo.gS1Start = lastValidactBatchS1-1; + } + if (accumBaseNum < targetBaseNum) { + // 更新最后一个核的End分核信息 + constInfo.bN2End = lastValidBIdx; + constInfo.gS1End = lastValidactBatchS1-1; + constInfo.s2End = 0; + constInfo.coreStartKVSplitPos = 0; + if (aiCoreIdx != 0) { + GetAxisStartIdx(constInfo.bN2Start, constInfo.gS1Start, 0); + } + return; + } +} + +template +__aicore__ inline void +SparseFlashAttentionMla::Bmm2DataCopyOut(uint64_t attenOutOffset, LocalTensor &attenOutUb, + uint32_t startRow, uint32_t dealRowCount, + uint32_t columnCount, uint32_t actualColumnCount) +{ + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = dealRowCount; + dataCopyParams.blockLen = actualColumnCount * sizeof(OUT_T); + dataCopyParams.srcStride = (columnCount - actualColumnCount) / (SFAVectorService::BYTE_BLOCK / sizeof(OUT_T)); + dataCopyParams.dstStride = 0; + DataCopyPad(attentionOutGm[attenOutOffset + (mSizeVStart + startRow) * actualColumnCount], attenOutUb, + dataCopyParams); +} + + +template +__aicore__ inline void SparseFlashAttentionMla::CalcParams(uint32_t loop, uint64_t s2Start, + uint32_t s2LoopIdx, RunInfo &info) +{ + info.loop = loop; + info.bIdx = tempLoopInfo.bIdx; + info.gS1Idx = tempLoopInfo.gS1Idx; + info.s2Idx = s2LoopIdx; + info.curSInnerLoopTimes = tempLoopInfo.s2LoopTimes; + + info.tndIsS2SplitCore = tempLoopInfo.tndIsS2SplitCore; + info.tndCoreStartKVSplitPos = tempLoopInfo.tndCoreStartKVSplitPos; + info.isBmm2Output = false; + + info.actS1Size = tempLoopInfo.actS1Size; + + + info.actMBaseSize = constInfo.mBaseSize; + uint32_t remainedGS1Size = tempLoopInfo.actS1Size * constInfo.gSize - tempLoopInfo.gS1Idx; + if (remainedGS1Size <= constInfo.mBaseSize && remainedGS1Size > 0) { + info.actMBaseSize = tempLoopInfo.mBasicSizeTail; + } + + info.isValid = s2LoopIdx < tempLoopInfo.s2LoopTimes; + + if ASCEND_IS_AIV { + info.mSize = info.actMBaseSize; + info.mSizeV = (info.mSize <= 16) ? info.mSize : (((info.mSize + 15) / 16 + 1) / 2 * 16); + info.mSizeVStart = 0; + if (tmpBlockIdx % 2 == 1) { + info.mSizeVStart = info.mSizeV; + info.mSizeV = info.mSize - info.mSizeV; + } + } + + info.isChangeBatch = false; + + info.isFirstSInnerLoop = s2LoopIdx == s2Start; + if (info.isFirstSInnerLoop) { + tempLoopInfo.bn2IdxInCurCore++; + } + info.isLastS2Loop = s2LoopIdx == tempLoopInfo.s2LoopTimes - 1; + info.bn2IdxInCurCore = tempLoopInfo.bn2IdxInCurCore - 1; + uint64_t actualSeqQPrefixSum; + if constexpr (LAYOUT_T == SFA_LAYOUT::TND) { + actualSeqQPrefixSum = (info.bIdx <= 0) ? 0 : actualSeqLengthsQGm.GetValue(info.bIdx - 1); + } else { + actualSeqQPrefixSum = (info.bIdx <= 0) ? 0 : info.bIdx * constInfo.qSeqSize; + } + info.tndBIdxOffsetForQ = actualSeqQPrefixSum * constInfo.qHeadNum * headDim; + + uint64_t actualSeqKVPrefixSum; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::TND) { + actualSeqKVPrefixSum = (info.bIdx <= 0) ? 0 : actualSeqLengthsKVGm.GetValue(info.bIdx - 1); + } else { + actualSeqKVPrefixSum = (info.bIdx <= 0) ? 0 : info.bIdx * constInfo.kvSeqSize; + } + info.tndBIdxOffsetForKV = actualSeqKVPrefixSum * constInfo.kvHeadNum * headDim; + + if (info.isFirstSInnerLoop) { + uint64_t tndBIdxRopeOffsetForQ = actualSeqQPrefixSum * constInfo.qHeadNum * headDimRope; + tensorACoreOffset = info.tndBIdxOffsetForQ + info.gS1Idx * headDim; + tensorARopeCoreOffset = tndBIdxRopeOffsetForQ + info.gS1Idx * headDimRope; + + uint64_t tndBIdxRopeOffsetForK = actualSeqKVPrefixSum * constInfo.kvHeadNum * headDimRope; + tensorBCoreOffset = info.tndBIdxOffsetForKV + info.n2Idx * headDim; + tensorBRopeCoreOffset = tndBIdxRopeOffsetForK + info.n2Idx * headDimRope; + if (constInfo.sparseMode == 3) { + threshold = static_cast(tempLoopInfo.nextTokensPerBatch) + info.gS1Idx / constInfo.gSize + 1; + } else { + threshold = tempLoopInfo.curActualSeqLenOri; + } + if constexpr(LAYOUT_T == SFA_LAYOUT::BSND) { // B,S1,N2 K + topKBaseOffset = info.bIdx * constInfo.qSeqSize * constInfo.kvHeadNum * constInfo.sparseBlockCount + + info.gS1Idx / constInfo.gSize * constInfo.kvHeadNum * constInfo.sparseBlockCount + + info.n2Idx * constInfo.sparseBlockCount; + } else if (LAYOUT_T == SFA_LAYOUT::TND) { // T N2 K + topKBaseOffset = info.tndBIdxOffsetForQ / constInfo.gSize / constInfo.headDim * constInfo.kvHeadNum * + constInfo.sparseBlockCount + info.n2Idx * constInfo.sparseBlockCount + + info.gS1Idx / constInfo.gSize * constInfo.kvHeadNum * constInfo.sparseBlockCount; + } else { // B N2 S1 K + topKBaseOffset = info.bIdx * constInfo.kvHeadNum * constInfo.qSeqSize * constInfo.sparseBlockCount + + info.n2Idx * constInfo.qSeqSize * constInfo.sparseBlockCount + + info.gS1Idx / constInfo.gSize * constInfo.sparseBlockCount; + } + } + info.topKBaseOffset = topKBaseOffset; + info.threshold = threshold; + info.tensorAOffset = tensorACoreOffset; + info.tensorARopeOffset = tensorARopeCoreOffset; + info.tensorBOffset = tensorBCoreOffset; + info.tensorBRopeOffset = tensorBRopeCoreOffset; + info.attenOutOffset = tensorACoreOffset; + + uint64_t sInnerOffsetDataSize = info.s2Idx * constInfo.s2BaseSize; + info.s2BatchOffset = s2BatchBaseOffset + sInnerOffsetDataSize; + + info.curActualSeqLenOri = tempLoopInfo.curActualSeqLenOri; + //计算实际基本块size + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + if (tempLoopInfo.curActualSeqLen > sInnerOffsetDataSize) { + info.actualSingleProcessSInnerSize = tempLoopInfo.curActualSeqLen - sInnerOffsetDataSize; + info.actualSingleProcessSInnerSize = info.actualSingleProcessSInnerSize > constInfo.s2BaseSize ? + constInfo.s2BaseSize : info.actualSingleProcessSInnerSize; + info.actualSingleProcessSInnerSize = + SFAAlign((int64_t)info.actualSingleProcessSInnerSize, (int64_t)constInfo.sparseBlockSize); + } else { + info.actualSingleProcessSInnerSize = 0; + } + info.actualSingleProcessSInnerSizeAlign = + SFAAlign((uint32_t)info.actualSingleProcessSInnerSize, (uint32_t)SFAVectorService::BYTE_BLOCK); + } + +} + +template +__aicore__ inline void SparseFlashAttentionMla::ComputeMm1(const RunInfo &info) +{ + uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize; + uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize; + for (uint32_t i = 0; i < nBufferLoopTimes; i++) { + MSplitInfo mSplitInfo; + mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize; + mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail; + matmulService.ComputeMm1(info, mSplitInfo); + CrossCoreSetFlag(constInfo.syncC1V1); + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::ComputeMm2(const RunInfo &info) +{ + uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize; + uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize; + for (uint32_t i = 0; i < nBufferLoopTimes; i++) { + MSplitInfo mSplitInfo; + mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize; + mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail; + CrossCoreWaitFlag(constInfo.syncV1C2); + matmulService.ComputeMm2(info, mSplitInfo); + CrossCoreSetFlag(constInfo.syncC2V2); + CrossCoreSetFlag(constInfo.syncC2V1); + } +} + +template __aicore__ inline void SparseFlashAttentionMla::Process() +{ + if (aiCoreIdx < usedCoreNum) { + if ASCEND_IS_AIV { + vectorService.AllocEventID(); + vectorService.InitSoftmaxDefaultBuffer(); + } else { + matmulService.AllocEventID(); + } + ProcessBalance(); + + if ASCEND_IS_AIV { + vectorService.FreeEventID(); + } else { + matmulService.FreeEventID(); + } + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::GetBN2Idx(uint32_t bN2Idx, uint32_t &bIdx, + uint32_t &n2Idx) +{ + bIdx = bN2Idx / kvHeadNum; + n2Idx = bN2Idx % kvHeadNum; +} + +template __aicore__ inline void SparseFlashAttentionMla::ProcessBalance() +{ + RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE]; + uint32_t gloop = 0; + int gS1LoopEnd; + bool globalLoopStart = true; + if ASCEND_IS_AIC { + CrossCoreSetFlag(constInfo.syncC2V1); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + CrossCoreSetFlag(3); + CrossCoreSetFlag(3); + CrossCoreSetFlag(3); + CrossCoreSetFlag(3); + } + } + for (uint32_t bN2LoopIdx = constInfo.bN2Start; bN2LoopIdx <= constInfo.bN2End; bN2LoopIdx++) { + GetBN2Idx(bN2LoopIdx, tempLoopInfo.bIdx, tempLoopInfo.n2Idx); + GetActualSeqLen(tempLoopInfo.bIdx); // 获取actualSeqLength及ActualSeqLengthKV + GetPreNextTokensLeftUp(); + if (tempLoopInfo.actS1Size == 0) { + continue; + } + int gS1SplitNum = (tempLoopInfo.actS1Size * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize; + gS1LoopEnd = (bN2LoopIdx == constInfo.bN2End) ? constInfo.gS1End : gS1SplitNum - 1; + for (uint32_t gS1LoopIdx = constInfo.gS1Start; gS1LoopIdx <= gS1LoopEnd; gS1LoopIdx++) { + tempLoopInfo.gS1Idx = gS1LoopIdx * constInfo.mBaseSize; + GetSparseActualSeqLen(tempLoopInfo.bIdx, gS1LoopIdx, tempLoopInfo.n2Idx); // TopK值sparse完后的ActualSeqLengthKV + UpdateInnerLoopCond(); + + if (tempLoopInfo.curActSeqLenIsZero) { + DealActSeqLenIsZero(tempLoopInfo.bIdx, gS1LoopIdx, tempLoopInfo.n2Idx); + } + int s2SplitNum = + (tempLoopInfo.curActualSeqLen + constInfo.s2BaseSize - 1) / constInfo.s2BaseSize; // S2切分份数 + bool isEnd = (bN2LoopIdx == constInfo.bN2End) && (gS1LoopIdx == constInfo.gS1End); + tempLoopInfo.s2LoopTimes = s2SplitNum; + // 分核修改后需要打开 + // 当前s2是否被切,决定了输出是否要写到attenOut上 + tempLoopInfo.tndIsS2SplitCore = + ((constInfo.s2Start == 0) && (tempLoopInfo.s2LoopTimes == s2SplitNum)) ? false : true; + tempLoopInfo.tndCoreStartKVSplitPos = globalLoopStart ? constInfo.coreStartKVSplitPos : 0; + uint32_t extraLoop = isEnd ? 2 : 0; + + uint32_t curTopKIdx = 0; + uint64_t curOffsetInSparseBlock = 0; + for (int s2LoopIdx = constInfo.s2Start; s2LoopIdx < (tempLoopInfo.s2LoopTimes + extraLoop); s2LoopIdx++) { + // PreloadPipeline loop初始值要求为 PRELOAD_NUM + PreloadPipeline(gloop, constInfo.s2Start, s2LoopIdx, extraInfo, curTopKIdx, curOffsetInSparseBlock); + ++gloop; + } + globalLoopStart = false; + constInfo.s2Start = 0; + } + constInfo.gS1Start = 0; + } + if ASCEND_IS_AIV { + CrossCoreWaitFlag(constInfo.syncC2V1); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + CrossCoreWaitFlag(3); + CrossCoreWaitFlag(3); + CrossCoreWaitFlag(3); + CrossCoreWaitFlag(3); + } + } +} + +template +__aicore__ inline void +SparseFlashAttentionMla::PreloadPipeline(uint32_t loop, uint64_t s2Start, uint64_t s2LoopIdx, + RunInfo extraInfo[SFA_PRELOAD_TASK_CACHE_SIZE], uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock) +{ + RunInfo &extraInfo0 = extraInfo[loop % SFA_PRELOAD_TASK_CACHE_SIZE]; // 本轮任务 + RunInfo &extraInfo2 = extraInfo[(loop + 2) % SFA_PRELOAD_TASK_CACHE_SIZE]; // 上一轮任务 + RunInfo &extraInfo1 = extraInfo[(loop + 1) % SFA_PRELOAD_TASK_CACHE_SIZE]; // 上两轮任务 + + CalcParams(loop, s2Start, s2LoopIdx, extraInfo0); + CalcSinnerTopKBegin(extraInfo0, curTopKIdx, curOffsetInSparseBlock); + + if (extraInfo0.isValid) { + if ASCEND_IS_AIC { + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + CrossCoreWaitFlag(constInfo.syncV0C1); + } + ComputeMm1(extraInfo0); + } else { + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + CrossCoreWaitFlag(3); + vectorService.MergeKv(extraInfo0); + CrossCoreSetFlag(constInfo.syncV0C1); + } + } + } + if (extraInfo2.isValid) { + if ASCEND_IS_AIV { + vectorService.ProcessVec1L(extraInfo2); + } + if ASCEND_IS_AIC { + ComputeMm2(extraInfo2); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + CrossCoreSetFlag(3); + } + } + } + if (extraInfo1.isValid) { + if ASCEND_IS_AIV { + vectorService.ProcessVec2L(extraInfo1); + } + extraInfo1.isValid = false; + } +} + +template +__aicore__ inline uint64_t +SparseFlashAttentionMla::GetBalanceActualSeqLengths(GlobalTensor &actualSeqLengths, + uint32_t bIdx) +{ + if constexpr (LAYOUT_T == SFA_LAYOUT::TND) { + if (bIdx > 0) { + return actualSeqLengths.GetValue(bIdx) - actualSeqLengths.GetValue(bIdx - 1); + } else if (bIdx == 0) { + return actualSeqLengths.GetValue(0); + } else { + return 0; + } + } else { + if (constInfo.actualLenDimsQ == 0) { + return constInfo.qSeqSize; + } else if (constInfo.actualLenDimsQ == 1) { + return actualSeqLengths.GetValue(0); + } else { + return actualSeqLengths.GetValue(bIdx); + } + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::GetAxisStartIdx(uint32_t bN2EndPrev, + uint32_t s1GEndPrev, + uint32_t s2EndPrev) +{ + uint32_t bEndPrev = bN2EndPrev / kvHeadNum; + uint32_t actualSeqQPrev = GetBalanceActualSeqLengths(actualSeqLengthsQGm, bEndPrev); + uint32_t s1GPrevBaseNum = (actualSeqQPrev * constInfo.gSize + constInfo.mBaseSize - 1) / constInfo.mBaseSize; + constInfo.bN2Start = bN2EndPrev; + constInfo.gS1Start = s1GEndPrev; + + constInfo.s2Start = 0; + if (s1GEndPrev >= s1GPrevBaseNum - 1) { // 上个核把S1G处理完了 + constInfo.gS1Start = 0; + constInfo.bN2Start++; + } else { + constInfo.gS1Start++; + } +} + +template +__aicore__ inline void SparseFlashAttentionMla::CalcSinnerTopKBegin(RunInfo &info, uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock) + +{ + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + return; + } + + uint64_t thresholdSparseCount = (info.threshold + constInfo.sparseBlockSize - 1) / constInfo.sparseBlockSize; + uint64_t validCount = (constInfo.sparseBlockCount > thresholdSparseCount) ? thresholdSparseCount : constInfo.sparseBlockCount; + + int32_t sparseIndices = topKGm.GetValue(info.topKBaseOffset + curTopKIdx); + if (sparseIndices == -1 || curTopKIdx == validCount) { + info.actualSingleProcessSInnerSize = 0; + info.actualSingleProcessSInnerSizeAlign = 0; + tempLoopInfo.s2BasicSizeTail = 0; + if (curTopKIdx == 0) { + DealActSeqLenIsZero(info.bIdx, info.gS1Idx / constInfo.gSize, tempLoopInfo.n2Idx); + } + return; + } + + uint32_t sparseLen = 0; + uint64_t blockBegin = sparseIndices * constInfo.sparseBlockSize; + uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? info.threshold : blockBegin + constInfo.sparseBlockSize; + int32_t blockLen = blockEnd - blockBegin; + sparseLen += (blockLen > static_cast(curOffsetInSparseBlock)) ? blockLen - curOffsetInSparseBlock : 0; + + bool firstVaildFlag = false; + if (curTopKIdx > 0) { + info.curTopKIdx = curTopKIdx; + info.curOffsetInSparseBlock = curOffsetInSparseBlock; + } else if (curTopKIdx == 0 && sparseLen > 0) { + info.curTopKIdx = curTopKIdx; + info.curOffsetInSparseBlock = 0; + firstVaildFlag = true; + } + + for (uint64_t topkIdx = curTopKIdx + 1; topkIdx < validCount; topkIdx++) { + int32_t sparseIndices = topKGm.GetValue(info.topKBaseOffset + topkIdx); + if (sparseIndices == -1) { + curTopKIdx = topkIdx; + curOffsetInSparseBlock = 0; + break; + } + uint64_t blockBegin = sparseIndices * constInfo.sparseBlockSize; + if (blockBegin >= info.threshold) { + continue; + } + if (firstVaildFlag == false && curTopKIdx == 0) { + info.curTopKIdx = topkIdx; + info.curOffsetInSparseBlock = 0; + firstVaildFlag = true; + } + uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? info.threshold : blockBegin + constInfo.sparseBlockSize; + uint64_t blockLen = blockEnd - blockBegin; + sparseLen += blockLen; + if (sparseLen >= constInfo.s2BaseSize) { + curTopKIdx = topkIdx; + curOffsetInSparseBlock = blockLen - (sparseLen - constInfo.s2BaseSize); + sparseLen = constInfo.s2BaseSize; + break; + } + + if (topkIdx == validCount - 1) { + curTopKIdx = validCount; + curOffsetInSparseBlock = 0; + } + } + + info.actualSingleProcessSInnerSize = sparseLen; + info.actualSingleProcessSInnerSizeAlign = SFAAlign((uint32_t)info.actualSingleProcessSInnerSize, (uint32_t)SFAVectorService::BYTE_BLOCK); + tempLoopInfo.s2BasicSizeTail = (sparseLen == constInfo.s2BaseSize) ? 0 : sparseLen; + if (curTopKIdx == 0 && sparseLen == 0) { + DealActSeqLenIsZero(info.bIdx, info.gS1Idx / constInfo.gSize, tempLoopInfo.n2Idx); + } +} + + + +#endif // SPARSE_FLASH_ATTENTION_KERNEL_MLA_H \ No newline at end of file diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_cube_mla.h b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_cube_mla.h new file mode 100644 index 00000000000..53b65c0bb8a --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_cube_mla.h @@ -0,0 +1,1125 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_service_cube_mla.h + * \brief use 7 buffer for matmul l1, better pipeline + */ +#ifndef SPARSE_FLASH_ATTENTION_SERVICE_CUBE_MLA_H +#define SPARSE_FLASH_ATTENTION_SERVICE_CUBE_MLA_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "sparse_flash_attention_common.h" + +struct PAShape { + uint32_t blockSize; + uint32_t headNum; //一般为kv的head num,对应n2 + uint32_t headDim; //mla下rope为64,nope为512, 对应d + uint32_t maxblockNumPerBatch; //block table 每一行的最大个数 + uint32_t actHeadDim; //实际拷贝col大小,考虑到N切块 s*d, 对应d + uint32_t copyRowNum; //总共要拷贝的行数 + uint32_t copyRowNumAlign; +}; + +struct Position { + uint32_t bIdx; + uint32_t n2Idx; + uint32_t s2Idx; + uint32_t dIdx; +}; + +// 场景:query、queryRope、key、value GM to L1 +// GM按ND格式存储 +// L1按NZ格式存储 +// GM的行、列、列的stride +template +__aicore__ inline void DataCopyGmNDToL1(LocalTensor &l1Tensor, GlobalTensor &gmTensor, + uint32_t rowAct, + uint32_t rowAlign, + uint32_t col, // D + uint32_t colStride) // D or N*D +{ + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = rowAct; //nd矩阵的行数 + // T为int4场景下,dValue = col / 2,srcDValue = colStride / 2 + nd2nzPara.dValue = col; //nd矩阵的列数 + nd2nzPara.srcDValue = colStride; //同一nd矩阵相邻行起始地址间的偏移 + nd2nzPara.dstNzC0Stride = rowAlign; + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(l1Tensor, gmTensor, nd2nzPara); +} + +/* + 适用PA数据从GM拷贝到L1,支持ND、NZ数据; + PA的layout分 BNBD(blockNum,N,blockSize,D) BBH(blockNum,blockSize,N*D + BSH\BSND\TND 为BBH + shape.copyRowNumAlign 需要16字节对齐,如拷贝k矩阵,一次拷贝128*512,遇到尾块 10*512 需对齐到16*512 +*/ +template +__aicore__ inline void DataCopyPA(LocalTensor &dstTensor, //l1 + GlobalTensor &srcTensor, //gm + GlobalTensor &blockTableGm, + const PAShape &shape, // blockSize, headNum, headDim + const Position &startPos) // bacthIdx nIdx curSeqIdx +{ + uint32_t copyFinishRowCnt = 0; + uint64_t blockTableBaseOffset = startPos.bIdx * shape.maxblockNumPerBatch; + uint32_t curS2Idx = startPos.s2Idx; + uint32_t blockElementCnt = 32 / sizeof(T); + while (copyFinishRowCnt < shape.copyRowNum) { + uint64_t blockIdOffset = curS2Idx / shape.blockSize; // 获取block table上的索引 + uint64_t reaminRowCnt = curS2Idx % shape.blockSize; // 获取在单个块上超出的行数 + uint64_t idInBlockTable = blockTableGm.GetValue(blockTableBaseOffset + blockIdOffset); // 从block table上的获取编号 + // 计算可以拷贝行数 + uint32_t copyRowCnt = shape.blockSize - reaminRowCnt; //一次只能处理一个Block + if (copyFinishRowCnt + copyRowCnt > shape.copyRowNum) { + copyRowCnt = shape.copyRowNum - copyFinishRowCnt; //一个block未拷满 + } + uint64_t offset = idInBlockTable * shape.blockSize * shape.headNum * shape.headDim ; //PA的偏移 + + uint64_t dStride = shape.headDim; + if constexpr (SRC_LAYOUT == SFA_LAYOUT::BSND || SRC_LAYOUT == SFA_LAYOUT::TND) { + offset += (uint64_t)(startPos.n2Idx * shape.headDim) + + reaminRowCnt * shape.headDim * shape.headNum + startPos.dIdx; + dStride = shape.headDim * shape.headNum; + } else { + offset += (uint64_t)(startPos.n2Idx * shape.headDim * shape.blockSize) + + reaminRowCnt * shape.headDim + startPos.dIdx; + } + + uint32_t dValue = shape.actHeadDim; + uint32_t srcDValue = dStride; + LocalTensor tmpDstTensor = dstTensor[copyFinishRowCnt * blockElementCnt]; + GlobalTensor tmpSrcTensor = srcTensor[offset]; + + DataCopyGmNDToL1(tmpDstTensor, tmpSrcTensor, copyRowCnt, shape.copyRowNumAlign, dValue, srcDValue); + copyFinishRowCnt += copyRowCnt; + curS2Idx += copyRowCnt; + } +} + +template class SFAMatmulService { +public: + // 中间计算数据类型为float, 高精度模式 + using T = float; + using Q_T = typename SFAT::queryType; + using KV_T = typename SFAT::kvType; + using OUT_T = typename SFAT::outputType; + using MM_OUT_T = T; + + __aicore__ inline SFAMatmulService(){}; + __aicore__ inline void InitParams(const ConstInfo &constInfo); + __aicore__ inline void InitMm1GlobalTensor(GlobalTensor queryGm, GlobalTensor qRopeGm, + GlobalTensor keyGm, GlobalTensor kRopeGm, + GlobalTensor mm1ResGm); + __aicore__ inline void InitMm2GlobalTensor(GlobalTensor vec1ResGm, GlobalTensor valueGm, + GlobalTensor mm2ResGm, GlobalTensor attentionOutGm); + __aicore__ inline void InitPageAttentionInfo(const GlobalTensor& kvMergeGm, + GlobalTensor blockTableGm, GlobalTensor topKGm, + uint32_t blockSize, uint32_t maxBlockNumPerBatch); + __aicore__ inline void InitBuffers(TPipe *pipe); + __aicore__ inline void UpdateKey(GlobalTensor keyGm); + __aicore__ inline void UpdateValue(GlobalTensor valueGm); + + __aicore__ inline void AllocEventID(); + __aicore__ inline void FreeEventID(); + __aicore__ inline void CalcTopKBlockInfo(const RunInfo &info, uint32_t &curTopKIdx, + uint64_t &curOffsetInSparseBlock, uint32_t curSeqIdx, + uint32_t ©RowCnt, int64_t &idInTopK); + __aicore__ inline void ComputeMm1(const RunInfo &info, const MSplitInfo mSplitInfo); + __aicore__ inline void ComputeMm2(const RunInfo &info, const MSplitInfo mSplitInfo); + +private: + static constexpr bool PAGE_ATTENTION = SFAT::pageAttention; + static constexpr int TEMPLATE_MODE = SFAT::templateMode; + static constexpr bool FLASH_DECODE = SFAT::flashDecode; + static constexpr SFA_LAYOUT LAYOUT_T = SFAT::layout; + static constexpr SFA_LAYOUT KV_LAYOUT_T = SFAT::kvLayout; + + static constexpr uint32_t M_SPLIT_SIZE = 128; // m方向切分 + static constexpr uint32_t N_SPLIT_SIZE = 128; // n方向切分 + static constexpr uint32_t N_WORKSPACE_SIZE = 512; // n方向切分 + + static constexpr uint32_t L1_BLOCK_SIZE = (64 * (512 + 64) * sizeof(Q_T)); + static constexpr uint32_t L1_BLOCK_OFFSET = 64 * (512 + 64); // 72K的元素个数 + + static constexpr uint32_t L0A_PP_SIZE = (32 * 1024); + static constexpr uint32_t L0B_PP_SIZE = (32 * 1024); + static constexpr uint32_t L0C_PP_SIZE = (64 * 1024); + + // mte2 <> mte1 EventID + // L1 3buf, 使用3个eventId + static constexpr uint32_t L1_EVENT0 = EVENT_ID2; + static constexpr uint32_t L1_EVENT1 = EVENT_ID3; + static constexpr uint32_t L1_EVENT2 = EVENT_ID4; + static constexpr uint32_t L1_EVENT3 = EVENT_ID5; + static constexpr uint32_t L1_EVENT4 = EVENT_ID6; + static constexpr uint32_t L1_EVENT5 = EVENT_ID7; + static constexpr uint32_t L1_EVENT6 = EVENT_ID1; + + // m <> mte1 EventID + static constexpr uint32_t L0AB_EVENT0 = EVENT_ID3; + static constexpr uint32_t L0AB_EVENT1 = EVENT_ID4; + + static constexpr IsResetLoad3dConfig LOAD3DV2_CONFIG = {true, true}; // isSetFMatrix isSetPadding; + static constexpr uint32_t mte21QPIds[4] = {L1_EVENT0, L1_EVENT1, L1_EVENT2, L1_EVENT3}; // mte12复用 + static constexpr uint32_t mte21KVIds[3] = {L1_EVENT4, L1_EVENT5, L1_EVENT6}; + + uint32_t kvCacheBlockSize = 0; + uint32_t maxBlockNumPerBatch = 0; + ConstInfo constInfo{}; + + // L1分成3块buf, 用于记录 + uint32_t qpL1BufIter = 0; + uint32_t kvL1BufIter = -1; + uint32_t abL0BufIter = 0; + uint32_t cL0BufIter = 0; + + // mm1 + GlobalTensor queryGm; + GlobalTensor qRopeGm; + GlobalTensor keyGm; + GlobalTensor kRopeGm; + GlobalTensor mm1ResGm; + GlobalTensor kvMergeGm_; + + // mm2 + GlobalTensor vec1ResGm; + GlobalTensor valueGm; + GlobalTensor mm2ResGm; + GlobalTensor attentionOutGm; + + // block_table + GlobalTensor blockTableGm; + GlobalTensor topKGm; + + TBuf bufQPL1; + TBuf bufKVL1; + TBuf tmpBufL0A; + TBuf tmpBufL0B; + TBuf tmpBufL0C; + + LocalTensor l1QPTensor; + LocalTensor l1KVTensor; + LocalTensor aL0TensorPingPong; + LocalTensor bL0TensorPingPong; + LocalTensor cL0TensorPingPong; + + // L0AB m <> mte1 EventID + __aicore__ inline uint32_t Mte1MmABEventId(uint32_t idx) + { + return (L0AB_EVENT0 + idx); + } + + __aicore__ inline uint32_t GetQPL1RealIdx(uint32_t mIdx, uint32_t k1Idx) + { + uint32_t idxMap[] = {0, 2}; // 确保0块和1块连在一起, 2和3块连在一起, 来保证同一m块的地址相连 + return idxMap[mIdx % 2] + k1Idx; + } + + __aicore__ inline void CopyGmToL1(LocalTensor &l1Tensor, GlobalTensor &gmSrcTensor, uint32_t srcN, + uint32_t srcD, uint32_t srcDstride); + __aicore__ inline void CopyInMm1AToL1(LocalTensor &aL1Tensor, const RunInfo &info, uint32_t mSeqIdx, + uint32_t mSizeAct, uint32_t headSize, uint32_t headOffset); + __aicore__ inline void CopyInMm1ARopeToL1(LocalTensor &aL1Tensor, const RunInfo &info, uint32_t mSeqIdx, + uint32_t mSizeAct); + __aicore__ inline void CopyInMm1BToL1(LocalTensor &bL1Tensor, const uint64_t keyGmBaseOffset, + uint32_t copyTotalRowCntAlign, uint32_t copyStartRowCnt, + uint32_t nActCopyRowCount, uint32_t headSize); + __aicore__ inline void CopyInMm1BRopeToL1(LocalTensor &bL1Tensor, const uint64_t keyGmBaseOffset, + uint32_t copyTotalRowCntAlign, uint32_t copyStartRowCnt, + uint32_t nActCopyRowCount, uint32_t headSize); + __aicore__ inline void CopyInMm2AToL1(LocalTensor &aL1Tensor, const RunInfo &info, uint32_t mSeqIdx, + uint32_t subMSizeAct, uint32_t nSize, uint32_t nOffset); + __aicore__ inline void CopyInMm2BToL1(LocalTensor &bL1Tensor, const uint64_t valueGmBaseOffset, + uint32_t copyTotalRowCntAlign, uint32_t copyStartRowCnt, + uint32_t nActCopyRowCount, uint32_t copyStartColumnCount, + uint32_t copyColumnCount); + __aicore__ inline void LoadDataMm1A(LocalTensor &aL0Tensor, LocalTensor &aL1Tensor, uint32_t idx, + uint32_t kSplitSize, uint32_t mSize, uint32_t kSize); + __aicore__ inline void LoadDataMm1B(LocalTensor &bL0Tensor, LocalTensor &bL1Tensor, uint32_t idx, + uint32_t kSplitSize, uint32_t kSize, uint32_t nSize); +}; + +template __aicore__ inline void SFAMatmulService::InitParams(const ConstInfo &constInfo) +{ + this->constInfo = constInfo; +} + +template +__aicore__ inline void +SFAMatmulService::InitMm1GlobalTensor(GlobalTensor queryGm, GlobalTensor qRopeGm, + GlobalTensor keyGm, GlobalTensor kRopeGm, + GlobalTensor mm1ResGm) +{ + // mm1 + this->queryGm = queryGm; + this->qRopeGm = qRopeGm; + this->keyGm = keyGm; + this->kRopeGm = kRopeGm; + this->mm1ResGm = mm1ResGm; +} + +template +__aicore__ inline void +SFAMatmulService::InitMm2GlobalTensor(GlobalTensor vec1ResGm, GlobalTensor valueGm, + GlobalTensor mm2ResGm, GlobalTensor attentionOutGm) +{ + // mm2 + this->vec1ResGm = vec1ResGm; + this->valueGm = valueGm; + this->mm2ResGm = mm2ResGm; + this->attentionOutGm = attentionOutGm; +} + +template +__aicore__ inline void +SFAMatmulService::InitPageAttentionInfo(const GlobalTensor& kvMergeGm, GlobalTensor blockTableGm, + GlobalTensor topKGm, uint32_t blockSize, uint32_t maxBlockNumPerBatch) +{ + this->blockTableGm = blockTableGm; + this->topKGm = topKGm; + this->kvCacheBlockSize = blockSize; + this->maxBlockNumPerBatch = maxBlockNumPerBatch; + this->kvMergeGm_ = kvMergeGm; +} + +template __aicore__ inline void SFAMatmulService::InitBuffers(TPipe *pipe) +{ + pipe->InitBuffer(bufQPL1, L1_BLOCK_SIZE * 4); // (64K + 8K) * 4 + l1QPTensor = bufQPL1.Get(); + pipe->InitBuffer(bufKVL1, L1_BLOCK_SIZE * 3); // (64K + 8K) * 3 + l1KVTensor = bufKVL1.Get(); + + // L0A + pipe->InitBuffer(tmpBufL0A, L0A_PP_SIZE * 2); // 64K + aL0TensorPingPong = tmpBufL0A.Get(); + // L0B + pipe->InitBuffer(tmpBufL0B, L0B_PP_SIZE * 2); // 64K + bL0TensorPingPong = tmpBufL0B.Get(); + // L0C + pipe->InitBuffer(tmpBufL0C, L0C_PP_SIZE * 2); // 128K + cL0TensorPingPong = tmpBufL0C.Get(); +} + +template __aicore__ inline void SFAMatmulService::UpdateKey(GlobalTensor keyGm) +{ + this->keyGm = keyGm; +} + +template __aicore__ inline void SFAMatmulService::UpdateValue(GlobalTensor valueGm) +{ + this->valueGm = valueGm; +} + +template __aicore__ inline void SFAMatmulService::AllocEventID() +{ + SetFlag(L1_EVENT0); + SetFlag(L1_EVENT1); + SetFlag(L1_EVENT2); + SetFlag(L1_EVENT3); + SetFlag(L1_EVENT4); + SetFlag(L1_EVENT5); + SetFlag(L1_EVENT6); + SetFlag(L0AB_EVENT0); + SetFlag(L0AB_EVENT1); +} + +template __aicore__ inline void SFAMatmulService::FreeEventID() +{ + WaitFlag(L1_EVENT0); + WaitFlag(L1_EVENT1); + WaitFlag(L1_EVENT2); + WaitFlag(L1_EVENT3); + WaitFlag(L1_EVENT4); + WaitFlag(L1_EVENT5); + WaitFlag(L1_EVENT6); + WaitFlag(L0AB_EVENT0); + WaitFlag(L0AB_EVENT1); +} + +template +__aicore__ inline void SFAMatmulService::CopyGmToL1(LocalTensor &l1Tensor, + GlobalTensor &gmSrcTensor, uint32_t srcN, + uint32_t srcD, uint32_t srcDstride) +{ + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = srcN; // 行数 + nd2nzPara.dValue = srcD; + nd2nzPara.srcDValue = srcDstride; + nd2nzPara.dstNzC0Stride = (srcN + 15) / 16 * 16; // 对齐到16 单位block + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(l1Tensor, gmSrcTensor, nd2nzPara); +} + +template +__aicore__ inline void SFAMatmulService::CopyInMm1AToL1(LocalTensor &l1Tensor, const RunInfo &info, + uint32_t mSeqIdx, uint32_t mSizeAct, + uint32_t headSize, uint32_t headOffset) +{ + auto srcGm = queryGm[info.tensorAOffset + mSeqIdx * constInfo.headDim + headOffset]; + CopyGmToL1(l1Tensor, srcGm, mSizeAct, headSize, constInfo.headDim); +} + +template +__aicore__ inline void SFAMatmulService::CopyInMm1ARopeToL1(LocalTensor &l1Tensor, + const RunInfo &info, uint32_t mSeqIdx, + uint32_t mSizeAct) +{ + auto srcGm = qRopeGm[info.tensorARopeOffset + mSeqIdx * constInfo.headDimRope]; + CopyGmToL1(l1Tensor, srcGm, mSizeAct, constInfo.headDimRope, constInfo.headDimRope); +} + +template +__aicore__ inline void +SFAMatmulService::CopyInMm1BToL1(LocalTensor &bL1Tensor, const uint64_t keyGmBaseOffset, + uint32_t copyTotalRowCntAlign, uint32_t copyStartRowCnt, + uint32_t nActCopyRowCount, uint32_t headSize) +{ + uint64_t dStride = constInfo.headDim; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::BSND || KV_LAYOUT_T == SFA_LAYOUT::TND) { + dStride = constInfo.headDim * constInfo.kvHeadNum; + } + + uint32_t blockElementCnt = 32 / sizeof(KV_T); + + Nd2NzParams mm1Nd2NzParamsForB; + mm1Nd2NzParamsForB.ndNum = 1; + mm1Nd2NzParamsForB.nValue = nActCopyRowCount; + mm1Nd2NzParamsForB.dValue = headSize; + mm1Nd2NzParamsForB.srcDValue = dStride; + mm1Nd2NzParamsForB.dstNzC0Stride = copyTotalRowCntAlign; + mm1Nd2NzParamsForB.dstNzNStride = 1; + mm1Nd2NzParamsForB.srcNdMatrixStride = 0; + mm1Nd2NzParamsForB.dstNzMatrixStride = 0; + DataCopy(bL1Tensor[copyStartRowCnt * blockElementCnt], keyGm[keyGmBaseOffset], mm1Nd2NzParamsForB); +} + +template +__aicore__ inline void +SFAMatmulService::CopyInMm1BRopeToL1(LocalTensor &bL1Tensor, const uint64_t kRopeGmBaseOffset, + uint32_t copyTotalRowCntAlign, uint32_t copyStartRowCnt, + uint32_t nActCopyRowCount, uint32_t headSize) +{ + uint64_t dStride = constInfo.headDimRope; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::BSND || KV_LAYOUT_T == SFA_LAYOUT::TND) { + dStride = constInfo.headDimRope * constInfo.kvHeadNum; + } + + uint32_t blockElementCnt = 32 / sizeof(KV_T); + + Nd2NzParams mm1Nd2NzParamsForB; + mm1Nd2NzParamsForB.ndNum = 1; + mm1Nd2NzParamsForB.nValue = nActCopyRowCount; + mm1Nd2NzParamsForB.dValue = headSize; + mm1Nd2NzParamsForB.srcDValue = dStride; + mm1Nd2NzParamsForB.dstNzC0Stride = copyTotalRowCntAlign; + mm1Nd2NzParamsForB.dstNzNStride = 1; + mm1Nd2NzParamsForB.srcNdMatrixStride = 0; + mm1Nd2NzParamsForB.dstNzMatrixStride = 0; + DataCopy(bL1Tensor[copyStartRowCnt * blockElementCnt], kRopeGm[kRopeGmBaseOffset], mm1Nd2NzParamsForB); +} + +template +__aicore__ inline void SFAMatmulService::LoadDataMm1A(LocalTensor &aL0Tensor, + LocalTensor &aL1Tensor, uint32_t idx, + uint32_t kSplitSize, uint32_t mSize, uint32_t kSize) +{ + LocalTensor srcTensor = aL1Tensor[mSize * kSplitSize * idx]; + LoadData3DParamsV2 loadData3DParams; + // SetFmatrixParams + loadData3DParams.l1H = mSize / 16; // Hin=M1=8 + loadData3DParams.l1W = 16; // Win=M0 + loadData3DParams.padList[0] = 0; + loadData3DParams.padList[1] = 0; + loadData3DParams.padList[2] = 0; + loadData3DParams.padList[3] = 255; // 尾部数据不影响滑窗的结果 + + // SetLoadToA0Params + loadData3DParams.mExtension = mSize; // M + loadData3DParams.kExtension = kSize; // K + loadData3DParams.mStartPt = 0; + loadData3DParams.kStartPt = 0; + loadData3DParams.strideW = 1; + loadData3DParams.strideH = 1; + loadData3DParams.filterW = 1; + loadData3DParams.filterSizeW = (1 >> 8) & 255; + loadData3DParams.filterH = 1; + loadData3DParams.filterSizeH = (1 >> 8) & 255; + loadData3DParams.dilationFilterW = 1; + loadData3DParams.dilationFilterH = 1; + loadData3DParams.enTranspose = 0; + loadData3DParams.fMatrixCtrl = 0; + loadData3DParams.channelSize = kSize; // Cin=K + LoadData(aL0Tensor, srcTensor, loadData3DParams); +} + +template +__aicore__ inline void SFAMatmulService::LoadDataMm1B(LocalTensor &l0Tensor, + LocalTensor &l1Tensor, uint32_t idx, + uint32_t kSplitSize, uint32_t kSize, uint32_t nSize) +{ + // N 方向全载 + LocalTensor srcTensor = l1Tensor[nSize * kSplitSize * idx]; + + LoadData2DParams loadData2DParams; + loadData2DParams.startIndex = 0; + loadData2DParams.repeatTimes = (nSize + 15) / 16 * kSize / (32 / sizeof(KV_T)); + loadData2DParams.srcStride = 1; + loadData2DParams.dstGap = 0; + loadData2DParams.ifTranspose = false; + LoadData(l0Tensor, srcTensor, loadData2DParams); +} + +template +__aicore__ inline void SFAMatmulService::CopyInMm2AToL1(LocalTensor &aL1Tensor, const RunInfo &info, + uint32_t mSeqIdx, uint32_t subMSizeAct, + uint32_t nSize, uint32_t nOffset) +{ + auto srcGm = vec1ResGm[(info.loop % constInfo.preLoadNum) * constInfo.mmResUbSize + + mSeqIdx * info.actualSingleProcessSInnerSizeAlign + nOffset]; + CopyGmToL1(aL1Tensor, srcGm, subMSizeAct, nSize, info.actualSingleProcessSInnerSizeAlign); +} + +template +__aicore__ inline void SFAMatmulService::CopyInMm2BToL1( + LocalTensor &bL1Tensor, const uint64_t valueGmBaseOffset, uint32_t copyTotalRowCntAlign, + uint32_t copyStartRowCnt, uint32_t nActCopyRowCount, uint32_t copyStartColumnCount, uint32_t copyColumnCount) +{ + uint64_t step = constInfo.headDim; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::BSND || KV_LAYOUT_T == SFA_LAYOUT::TND) { + step = constInfo.headDim * constInfo.kvHeadNum; + } + + uint32_t blockElementCnt = 32 / sizeof(KV_T); + + Nd2NzParams mm1Nd2NzParamsForB; + mm1Nd2NzParamsForB.ndNum = 1; + mm1Nd2NzParamsForB.nValue = nActCopyRowCount; + mm1Nd2NzParamsForB.dValue = copyColumnCount; + mm1Nd2NzParamsForB.srcDValue = step; + mm1Nd2NzParamsForB.dstNzC0Stride = copyTotalRowCntAlign; + mm1Nd2NzParamsForB.dstNzNStride = 1; + mm1Nd2NzParamsForB.srcNdMatrixStride = 0; + mm1Nd2NzParamsForB.dstNzMatrixStride = 0; + DataCopy(bL1Tensor[copyStartRowCnt * blockElementCnt], valueGm[valueGmBaseOffset + copyStartColumnCount], + mm1Nd2NzParamsForB); +} + +template +__aicore__ inline void SFAMatmulService::CalcTopKBlockInfo( + const RunInfo &info, uint32_t &curTopKIdx, uint64_t &curOffsetInSparseBlock, uint32_t curSeqIdx, uint32_t ©RowCnt, int64_t &idInTopK) +{ + uint64_t blockBegin = idInTopK * constInfo.sparseBlockSize; + uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? + info.threshold : blockBegin + constInfo.sparseBlockSize; + uint64_t blockLen = blockEnd - blockBegin; + if (curOffsetInSparseBlock + copyRowCnt < blockLen) { + curOffsetInSparseBlock += copyRowCnt; + copyRowCnt = blockLen - curOffsetInSparseBlock; + } else { + for (uint64_t topkidx = curTopKIdx + 1; topkidx < constInfo.sparseBlockCount; topkidx++) { + int64_t sparseIndices = topKGm.GetValue(info.topKBaseOffset + topkidx); + if (sparseIndices == -1) { + break; + } + + uint64_t blockBegin = sparseIndices * constInfo.sparseBlockSize; + if (blockBegin >= info.threshold) { + continue; + } + uint64_t blockEnd = (blockBegin + constInfo.sparseBlockSize > info.threshold) ? + info.threshold : blockBegin + constInfo.sparseBlockSize; + uint64_t blockLen = blockEnd - blockBegin; + curTopKIdx = topkidx; + idInTopK = sparseIndices; + curOffsetInSparseBlock = 0; + copyRowCnt = blockLen; + break; + } + } +} + +template +__aicore__ inline void SFAMatmulService::ComputeMm1(const RunInfo &info, const MSplitInfo mSplitInfo) +{ + // 最外层还需要一层m的循环 + uint32_t mSize = mSplitInfo.nBufferDealM; + uint32_t mL1Size = M_SPLIT_SIZE; + uint32_t mL1SizeAlign = SFAAlign(M_SPLIT_SIZE, 16U); + uint32_t mL1Loops = (mSize + M_SPLIT_SIZE - 1) / M_SPLIT_SIZE; + + uint32_t nSize = info.actualSingleProcessSInnerSize; + uint32_t nL1Size = N_SPLIT_SIZE; + uint32_t nL1SizeAlign = SFAAlign(N_SPLIT_SIZE, 16U); + uint32_t nL1Loops = (nSize + N_SPLIT_SIZE - 1) / N_SPLIT_SIZE; + + uint32_t kSize = 576; + uint32_t kL1Size = 288; + uint32_t kL1Loops = 2; // 2 : 576/288, mla专用 这里不考虑d泛化 + + uint32_t kL0Size = 96; + uint32_t kL0Loops = (kL1Size + kL0Size - 1) / kL0Size; // 288 / 96 = 3 kloops + + LocalTensor bL1Tensor; + LocalTensor kRopeTensor; + LocalTensor kTensor; + // ka表示左矩阵4buf选择哪一块buf, kb表示右矩阵3buf选择哪一块buf + uint32_t ka = 0, kb = 0; + + uint32_t curTopKIdx = info.curTopKIdx; + uint64_t curOffsetInSparseBlock = info.curOffsetInSparseBlock; //sparse Block块内偏移 + uint32_t copyRowCnt = 0; + int64_t idInTopK = topKGm.GetValue(info.topKBaseOffset + curTopKIdx); + + uint32_t curTopKIdxTmp = 0; + uint64_t curOffsetInSparseBlockTmp = 0; + uint32_t copyRowCntTmp = 0; + int64_t idInTopKTmp = 0; + + // L1 切n切k切m + for (uint32_t nL1 = 0; nL1 < nL1Loops; nL1++) { // L1切n, 512/128=4 + if (nL1 == (nL1Loops - 1)) { + // 尾块重新计算size + nL1Size = nSize - (nL1Loops - 1) * N_SPLIT_SIZE; + nL1SizeAlign = SFAAlign(nL1Size, 16U); + } + curTopKIdxTmp = curTopKIdx; + curOffsetInSparseBlockTmp = curOffsetInSparseBlock; + copyRowCntTmp = copyRowCnt; + idInTopKTmp = idInTopK; + + for (uint32_t kL1 = 0; kL1 < kL1Loops; kL1++) { // L1切k, 576/288, 这里不考虑d泛化 + kvL1BufIter++; + uint32_t kb = kvL1BufIter % 3; + WaitFlag(mte21KVIds[kb]); + // 从k当中取当前的块 + bL1Tensor = l1KVTensor[kb * L1_BLOCK_OFFSET]; + // mm1拷贝主流程 + + uint32_t curSeqIdx = info.s2BatchOffset + nL1 * N_SPLIT_SIZE; + uint32_t copyFinishRowCnt = 0; + curTopKIdx = curTopKIdxTmp; + curOffsetInSparseBlock = curOffsetInSparseBlockTmp; + copyRowCnt = copyRowCntTmp; + idInTopK = idInTopKTmp; + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + if (kL1 == 0) { + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = nL1Size; // 行数 + nd2nzPara.dValue = constInfo.headDim >> 1; // constInfo.headDim; + nd2nzPara.srcDValue = constInfo.headDim; + nd2nzPara.dstNzC0Stride = nL1SizeAlign; + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(bL1Tensor, + kvMergeGm_[info.loop % 4 * N_WORKSPACE_SIZE * kSize + + nL1 * N_SPLIT_SIZE * constInfo.headDim], + nd2nzPara); + nd2nzPara.dValue = constInfo.headDimRope >> 1; + nd2nzPara.srcDValue = constInfo.headDimRope; + DataCopy( + bL1Tensor[nL1SizeAlign * (constInfo.headDim >> 1)], + kvMergeGm_[info.loop % 4 * N_WORKSPACE_SIZE * kSize + N_WORKSPACE_SIZE * constInfo.headDim + + nL1 * N_SPLIT_SIZE * constInfo.headDimRope], + nd2nzPara); + } else { + LocalTensor kTmpTensor = bL1Tensor[(constInfo.headDimRope >> 1) * nL1SizeAlign]; + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = nL1Size; // 行数 + nd2nzPara.dValue = constInfo.headDim >> 1; // constInfo.headDim; + nd2nzPara.srcDValue = constInfo.headDim; + nd2nzPara.dstNzC0Stride = nL1SizeAlign; + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(kTmpTensor, + kvMergeGm_[info.loop % 4 * N_WORKSPACE_SIZE * kSize + (constInfo.headDim >> 1) + + nL1 * N_SPLIT_SIZE * constInfo.headDim], + nd2nzPara); + nd2nzPara.dValue = constInfo.headDimRope >> 1; + nd2nzPara.srcDValue = constInfo.headDimRope; + DataCopy( + bL1Tensor, + kvMergeGm_[info.loop % 4 * N_WORKSPACE_SIZE * kSize + N_WORKSPACE_SIZE * constInfo.headDim + + (constInfo.headDimRope >> 1) + nL1 * N_SPLIT_SIZE * constInfo.headDimRope], + nd2nzPara); + } + } else { + while (copyFinishRowCnt < nL1Size) { + CalcTopKBlockInfo(info, curTopKIdx, curOffsetInSparseBlock, curSeqIdx, copyRowCnt, idInTopK); + if (copyFinishRowCnt + copyRowCnt > nL1Size) { + copyRowCnt = nL1Size - copyFinishRowCnt; + } + + // BN2轴偏移 + if constexpr (PAGE_ATTENTION) { + Position startPos; + startPos.bIdx = info.bIdx; + startPos.n2Idx = info.n2Idx; + startPos.s2Idx = idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock; + // 256、32等待7buf命名更改 + startPos.dIdx = kL1 * 256; // mm1 右矩阵 bn2s2d, d为k轴不切; mm2 右矩阵, s2为k轴, d轴切分 + Position ropeStartPos = startPos; + ropeStartPos.dIdx = kL1 * 32; + PAShape shape; + shape.blockSize = kvCacheBlockSize; + shape.headNum = constInfo.kvHeadNum; + shape.headDim = constInfo.headDim; + shape.actHeadDim = 256; + shape.maxblockNumPerBatch = maxBlockNumPerBatch; + shape.copyRowNum = copyRowCnt; + shape.copyRowNumAlign = nL1SizeAlign; + PAShape ropeShape = shape; + ropeShape.headDim = constInfo.headDimRope; + ropeShape.actHeadDim = 32; + if (kL1 == 0) { + kTensor = bL1Tensor[copyFinishRowCnt * 16]; + DataCopyPA(kTensor, keyGm, blockTableGm, shape, startPos); + kRopeTensor = bL1Tensor[(nL1SizeAlign * (BlockAlign(constInfo.headDim) >> 1)) + + copyFinishRowCnt * 16]; + DataCopyPA(kRopeTensor, kRopeGm, blockTableGm, ropeShape, + ropeStartPos); + } else { + kRopeTensor = bL1Tensor[copyFinishRowCnt * 16]; + DataCopyPA(kRopeTensor, kRopeGm, blockTableGm, ropeShape, + ropeStartPos); + LocalTensor kTmpTensor = bL1Tensor[32 * nL1SizeAlign + copyFinishRowCnt * 16]; + DataCopyPA(kTmpTensor, keyGm, blockTableGm, shape, startPos); + } + } else { + uint64_t keyOffset = info.tensorBOffset; + uint64_t kRopeOffset = info.tensorBRopeOffset; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::BSND || KV_LAYOUT_T == SFA_LAYOUT::TND) { + keyOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.kvHeadNum * constInfo.headDim; + kRopeOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.kvHeadNum * constInfo.headDimRope; + } else { + keyOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.headDim; + kRopeOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.headDimRope; + } + + if (kL1 == 0) { + CopyInMm1BToL1(bL1Tensor, keyOffset, nL1SizeAlign, copyFinishRowCnt, copyRowCnt, 256); + kRopeTensor = bL1Tensor[nL1SizeAlign * (BlockAlign(constInfo.headDim) >> 1)]; + CopyInMm1BRopeToL1(kRopeTensor, kRopeOffset, nL1SizeAlign, copyFinishRowCnt, copyRowCnt, + 32); + } else { + kRopeTensor = bL1Tensor; + CopyInMm1BRopeToL1(kRopeTensor, kRopeOffset + 32, nL1SizeAlign, copyFinishRowCnt, + copyRowCnt, 32); + LocalTensor kTmpTensor = bL1Tensor[nL1SizeAlign * 32]; + CopyInMm1BToL1(kTmpTensor, keyOffset + 256, nL1SizeAlign, copyFinishRowCnt, copyRowCnt, + 256); + } + } + + // 更新循环变量 + copyFinishRowCnt += copyRowCnt; + curSeqIdx += copyRowCnt; + } + } + + SetFlag(mte21KVIds[kb]); + WaitFlag(mte21KVIds[kb]); + mL1Size = M_SPLIT_SIZE; + mL1SizeAlign = SFAAlign(M_SPLIT_SIZE, 16U); + for (uint32_t mL1 = 0; mL1 < mL1Loops; mL1++) { + uint32_t aL1PaddingSize = 0; // 用于使左矩阵对齐到尾部, 以保证两块32K内存连续 + if (mL1 == (mL1Loops - 1)) { + // 尾块重新计算size + mL1Size = mSize - (mL1Loops - 1) * M_SPLIT_SIZE; + mL1SizeAlign = SFAAlign(mL1Size, 16U); + // mL1SizeAlign<128 kL1=0时需要偏移, 确保qRope能一半拷贝到当前tensor, 一半拷贝到下一个tensor + aL1PaddingSize = (M_SPLIT_SIZE - mL1SizeAlign) * 288; + } + + // 左矩阵L1选择12块还是34块的index, 由m l1 index决定 + // 左矩阵L1选择12块或34块的前一块还是后一块, 由k l1 index决定 + uint32_t mIdx = qpL1BufIter + mL1; + ka = GetQPL1RealIdx(mIdx, kL1); + LocalTensor aL1Tensor = + l1QPTensor[ka * L1_BLOCK_OFFSET + (1 - kL1) * aL1PaddingSize]; // kL1=0时需要偏移 + if (nL1 == 0) { // mL1=0, mL1=1两次 + if (kL1 == 0) { + WaitFlag(mte21QPIds[ka]); + WaitFlag(mte21QPIds[ka + 1]); + CopyInMm1AToL1(aL1Tensor, info, mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE, mL1Size, 256, 0); + // 由于L1里面是NZ, 这里q rope的偏移为整块q nope切k的后大小, 256为headDim的一半 + LocalTensor qRopeTensor = + aL1Tensor[mL1SizeAlign * + 256]; + CopyInMm1ARopeToL1(qRopeTensor, info, mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE, mL1Size); + } else { + // 32为rope headDim的一半 + LocalTensor qTmpTensor = aL1Tensor[mL1SizeAlign * 32]; + CopyInMm1AToL1(qTmpTensor, info, mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE, mL1Size, 256, + 256); + } + SetFlag(mte21QPIds[ka]); + WaitFlag(mte21QPIds[ka]); + } + + // 使用unitflag同步 + LocalTensor cL0Tensor = + cL0TensorPingPong[(cL0BufIter % 2) * + (L0C_PP_SIZE / sizeof(MM_OUT_T))]; // 需要保证cL0BufIter和m步调一致 + for (uint32_t kL0 = 0; kL0 < kL0Loops; kL0++) { + WaitFlag(Mte1MmABEventId(abL0BufIter % 2)); + LocalTensor aL0Tensor = aL0TensorPingPong[(abL0BufIter % 2) * (L0A_PP_SIZE / sizeof(KV_T))]; + LoadDataMm1A(aL0Tensor, aL1Tensor, kL0, kL0Size, mL1SizeAlign, kL0Size); + LocalTensor bL0Tensor = bL0TensorPingPong[(abL0BufIter % 2) * (L0B_PP_SIZE / sizeof(KV_T))]; + LoadDataMm1B(bL0Tensor, bL1Tensor, kL0, kL0Size, kL0Size, nL1SizeAlign); + SetFlag(Mte1MmABEventId(abL0BufIter % 2)); + WaitFlag(Mte1MmABEventId(abL0BufIter % 2)); + + // m == 1的时候需要特殊处理 + MmadParams mmadParams; + mmadParams.m = mL1SizeAlign; + mmadParams.n = nL1SizeAlign; + mmadParams.k = kL0Size; + mmadParams.cmatrixInitVal = (kL1 == 0 && kL0 == 0); + mmadParams.cmatrixSource = false; + mmadParams.unitFlag = + (kL1 == 1 && kL0 == (kL0Loops - 1)) ? 0b11 : 0b10; // 累加最后一次翻转flag, 表示可以搬出 + Mmad(cL0Tensor, aL0Tensor, bL0Tensor, mmadParams); + + if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) { + PipeBarrier(); + } + SetFlag(Mte1MmABEventId(abL0BufIter % 2)); + abL0BufIter++; + } + + if (nL1 == (nL1Loops - 1)) { + SetFlag(mte21QPIds[ka]); // 反向同步, 表示L1中的A已经被mte1消费完 + } + + if (kL1 == 1) { // 最后一轮kL1循环 + FixpipeParamsV220 fixParams; + fixParams.nSize = nL1SizeAlign; + fixParams.mSize = mL1SizeAlign; + fixParams.srcStride = mL1SizeAlign; + // 改成nSizeAlign + fixParams.dstStride = info.actualSingleProcessSInnerSizeAlign; // mm1ResGm两行之间的间隔 + fixParams.unitFlag = 0b11; + fixParams.ndNum = 1; // 输出ND + + // 输出偏移info.loop % (constInfo.preLoadNum)) * mmResUbSize是否在matmul里计算 + Fixpipe(mm1ResGm[(info.loop % (constInfo.preLoadNum)) * constInfo.mmResUbSize + nL1 * N_SPLIT_SIZE + + (mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE) * + info.actualSingleProcessSInnerSizeAlign], + cL0Tensor, fixParams); + } + if (mL1Loops == 2) { + cL0BufIter++; + } + } + SetFlag(mte21KVIds[kb]); // 反向同步, 表示L1已经被mte1消费完 + } + if (mL1Loops == 1) { + cL0BufIter++; + } + } + qpL1BufIter += mL1Loops; +} + +template +__aicore__ inline void SFAMatmulService::ComputeMm2(const RunInfo &info, const MSplitInfo mSplitInfo) +{ + uint32_t mSize = mSplitInfo.nBufferDealM; + uint32_t mSizeAlign = (mSize + 16 - 1) / 16; + uint32_t mL1Loops = (mSize + M_SPLIT_SIZE - 1) / M_SPLIT_SIZE; + uint32_t mL1SizeAlign = M_SPLIT_SIZE; // 16对齐 + uint32_t mL1Size = M_SPLIT_SIZE; // m的实际大小 + + uint32_t nSize = BlockAlign(constInfo.headDim); + uint32_t nL1Loops = (nSize + N_SPLIT_SIZE - 1) / N_SPLIT_SIZE; + uint32_t nL1SizeAlign = N_SPLIT_SIZE; // 16对齐 + uint32_t nL1Size = N_SPLIT_SIZE; // n的实际大小 + + uint32_t kSize = info.actualSingleProcessSInnerSize; + uint32_t kL1Size = 256; + uint32_t kL1SizeAlign = SFAAlign(kL1Size, 16U); + uint32_t kL1Loops = (kSize + kL1Size - 1) / kL1Size; + uint32_t kL0Size = 128; + uint32_t kL0Loops = (kL1Size + kL0Size - 1) / kL0Size; + uint32_t kL0SizeAlign = kL0Size; + LocalTensor bL1Tensor; + LocalTensor subvTensor; + + // ka表示左矩阵4buf选择哪一块buf, kb表示右矩阵3buf选择哪一块buf + uint32_t ka = 0, kb = 0; + uint32_t mBaseIdx = qpL1BufIter; + for (uint32_t nL1 = 0; nL1 < nL1Loops; nL1++) { // n切L1 + if (nL1 == (nL1Loops - 1)) { + // 尾块 + nL1Size = nSize - (nL1Loops - 1) * N_SPLIT_SIZE; + nL1SizeAlign = SFAAlign(nL1Size, 16U); + } + + // k l1写成一个循环, 和mm1保持一致 + kL1Size = 256; + kL1SizeAlign = SFAAlign(kL1Size, 16U); + + uint32_t curTopKIdx = info.curTopKIdx; + uint64_t curOffsetInSparseBlock = info.curOffsetInSparseBlock; + uint32_t copyRowCnt = 0; + int64_t idInTopK = topKGm.GetValue(info.topKBaseOffset + curTopKIdx); + + for (uint32_t k1 = 0; k1 < kL1Loops; k1++) { // k切L1, 这里套了一层l0来操作 + if (k1 == (kL1Loops - 1)) { + // 尾块 + kL1Size = kSize - (kL1Loops - 1) * 256; + kL1SizeAlign = SFAAlign(kL1Size, 16U); + } + kvL1BufIter++; + uint32_t kb = kvL1BufIter % 3; + WaitFlag(mte21KVIds[kb]); + bL1Tensor = l1KVTensor[kb * L1_BLOCK_OFFSET]; + uint32_t kOffset = k1 * kL0Loops; + kL0Size = 128; + // 此处必须先初始化kL0Size, 再求kL0Loops, 否则由于循环会改变kL0Size大小, 导致kL0Loops错误 + kL0Loops = (kL1Size + kL0Size - 1) / kL0Size; + kL0SizeAlign = kL0Size; + for (uint32_t kL1 = kOffset; kL1 < kL0Loops + kOffset; kL1++) { // 128 循环搬pa + if (kL1 == kOffset + kL0Loops - 1) { + // 尾块 + kL0Size = kL1Size - (kL0Loops - 1) * kL0Size; + kL0SizeAlign = SFAAlign(kL0Size, 16U); + } + + uint32_t curSeqIdx = info.s2BatchOffset + (kL1 - kOffset) * 128 + k1 * 256; + uint32_t copyFinishRowCnt = 0; + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + Nd2NzParams nd2nzPara; + nd2nzPara.ndNum = 1; + nd2nzPara.nValue = kL0Size; // 行数 + nd2nzPara.dValue = N_SPLIT_SIZE; // constInfo.headDim; + nd2nzPara.srcDValue = constInfo.headDim; + nd2nzPara.dstNzC0Stride = kL0SizeAlign; + nd2nzPara.dstNzNStride = 1; + nd2nzPara.srcNdMatrixStride = 0; + nd2nzPara.dstNzMatrixStride = 0; + DataCopy(bL1Tensor[(kL1 - kOffset) * 128 * N_SPLIT_SIZE], + kvMergeGm_[info.loop % 4 * N_WORKSPACE_SIZE * 576 + kL1 * 128 * constInfo.headDim + + nL1 * N_SPLIT_SIZE], + nd2nzPara); + } else { + while (copyFinishRowCnt < kL0Size) { + CalcTopKBlockInfo(info, curTopKIdx, curOffsetInSparseBlock, curSeqIdx, copyRowCnt, idInTopK); + + if (copyFinishRowCnt + copyRowCnt > kL0Size) { + copyRowCnt = kL0Size - copyFinishRowCnt; + } + + if constexpr (PAGE_ATTENTION) { + Position startPos; + startPos.bIdx = info.bIdx; + startPos.n2Idx = info.n2Idx; + startPos.s2Idx = idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock; + startPos.dIdx = + nL1 * N_SPLIT_SIZE; // mm1 右矩阵 bn2s2d, d为k轴不切; mm2 右矩阵, s2为k轴, d轴切分 + PAShape shape; + shape.blockSize = kvCacheBlockSize; + shape.headNum = constInfo.kvHeadNum; + shape.headDim = constInfo.headDim; + shape.actHeadDim = nL1Size; + shape.maxblockNumPerBatch = maxBlockNumPerBatch; + shape.copyRowNum = copyRowCnt; + shape.copyRowNumAlign = kL0SizeAlign; + subvTensor = bL1Tensor[(kL1 - kOffset) * 128 * N_SPLIT_SIZE + copyFinishRowCnt * 16]; + DataCopyPA(subvTensor, valueGm, blockTableGm, shape, startPos); + } else { + uint64_t valueOffset = info.tensorBOffset; + if constexpr (KV_LAYOUT_T == SFA_LAYOUT::BSND || KV_LAYOUT_T == SFA_LAYOUT::TND) { + valueOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.kvHeadNum * constInfo.headDim; + } else { + valueOffset += (idInTopK * constInfo.sparseBlockSize + curOffsetInSparseBlock) * + constInfo.headDim; + } + + subvTensor = bL1Tensor[(kL1 - kOffset) * 128 * N_SPLIT_SIZE]; + CopyInMm2BToL1(subvTensor, valueOffset, kL0SizeAlign, copyFinishRowCnt, copyRowCnt, + nL1 * N_SPLIT_SIZE, nL1Size); + } + // 更新循环变量 + copyFinishRowCnt += copyRowCnt; + curSeqIdx += copyRowCnt; + } + } + } + SetFlag(mte21KVIds[kb]); + WaitFlag(mte21KVIds[kb]); + mL1SizeAlign = M_SPLIT_SIZE; + mL1Size = M_SPLIT_SIZE; // m的实际大小 + for (uint32_t mL1 = 0; mL1 < mL1Loops; mL1++) { + if (mL1 == (mL1Loops - 1)) { + // 尾块 + mL1Size = mSize - (mL1Loops - 1) * M_SPLIT_SIZE; + mL1SizeAlign = SFAAlign(mL1Size, 16U); + } + + uint32_t mIdx = mBaseIdx + mL1; + ka = GetQPL1RealIdx(mIdx, k1); + LocalTensor aL1Tensor = l1QPTensor[ka * L1_BLOCK_OFFSET]; + if (nL1 == 0) { + WaitFlag(mte21QPIds[ka]); + CopyInMm2AToL1(aL1Tensor, info, mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE, mL1Size, kL1Size, + 256 * k1); + SetFlag(mte21QPIds[ka]); + WaitFlag(mte21QPIds[ka]); + } + + LocalTensor cL0Tensor = + cL0TensorPingPong[(cL0BufIter % 2) * + (L0C_PP_SIZE / sizeof(MM_OUT_T))]; // 需要保证cL0BufIter和m步调一致 + uint32_t baseK = 128; + uint32_t baseN = 128; + kL0Size = 128; + kL0SizeAlign = kL0Size; + for (uint32_t kL0 = 0; kL0 < kL0Loops; kL0++) { + if (kL0 + 1 == kL0Loops) { + kL0Size = kL1Size - (kL0Loops - 1) * kL0Size; + kL0SizeAlign = SFAAlign(kL0Size, 16U); + } + WaitFlag(Mte1MmABEventId(abL0BufIter % 2)); + LocalTensor bL0Tensor = bL0TensorPingPong[(abL0BufIter % 2) * (L0B_PP_SIZE / sizeof(KV_T))]; + LoadData3DParamsV2 loadData3DParamsForB; + loadData3DParamsForB.l1H = kL0SizeAlign / 16; // 源操作数height + loadData3DParamsForB.l1W = 16; // 源操作数weight=16,目的height=l1H*L1W + loadData3DParamsForB.padList[0] = 0; + loadData3DParamsForB.padList[1] = 0; + loadData3DParamsForB.padList[2] = 0; + loadData3DParamsForB.padList[3] = 255; // 尾部数据不影响滑窗的结果 + + loadData3DParamsForB.mExtension = kL0SizeAlign; // 在目的操作数height维度的传输长度 + loadData3DParamsForB.kExtension = nL1SizeAlign; // 在目的操作数width维度的传输长度 + loadData3DParamsForB.mStartPt = 0; // 卷积核在目的操作数width维度的起点 + loadData3DParamsForB.kStartPt = 0; // 卷积核在目的操作数height维度的起点 + loadData3DParamsForB.strideW = 1; + loadData3DParamsForB.strideH = 1; + loadData3DParamsForB.filterW = 1; + loadData3DParamsForB.filterSizeW = false; // 是否在filterW的基础上将卷积核width增加256个元素 + loadData3DParamsForB.filterH = 1; + loadData3DParamsForB.filterSizeH = false; // 是否在filterH的基础上将卷积核height增加256个元素 + loadData3DParamsForB.dilationFilterW = 1; // 卷积核width膨胀系数 + loadData3DParamsForB.dilationFilterH = 1; // 卷积核height膨胀系数 + loadData3DParamsForB.enTranspose = 1; // 是否启用转置功能 + loadData3DParamsForB.fMatrixCtrl = 0; // 使用FMATRIX_LEFT还是使用FMATRIX_RIGHT,=0使用FMATRIX_LEFT,=1使用FMATRIX_RIGHT 1 + loadData3DParamsForB.channelSize = nL1SizeAlign; // 源操作数的通道数。膨胀系数为1时,目的weight为filterW*filterH*channelSize + LoadData(bL0Tensor, bL1Tensor[kL0 * baseK * baseN], loadData3DParamsForB); + + LocalTensor aL0Tensor = aL0TensorPingPong[(abL0BufIter % 2) * (L0A_PP_SIZE / sizeof(KV_T))]; + LoadData3DParamsV2 loadData3DParamsForA; + loadData3DParamsForA.l1H = mL1SizeAlign / 16; // 源操作数height + loadData3DParamsForA.l1W = 16; // 源操作数weight + loadData3DParamsForA.padList[0] = 0; + loadData3DParamsForA.padList[1] = 0; + loadData3DParamsForA.padList[2] = 0; + loadData3DParamsForA.padList[3] = 255; // 尾部数据不影响滑窗的结果 + + loadData3DParamsForA.mExtension = mL1SizeAlign; // 在目的操作数height维度的传输长度 + loadData3DParamsForA.kExtension = kL0SizeAlign; // 在目的操作数width维度的传输长度 + loadData3DParamsForA.mStartPt = 0; // 卷积核在目的操作数width维度的起点 + loadData3DParamsForA.kStartPt = 0; // 卷积核在目的操作数height维度的起点 + loadData3DParamsForA.strideW = 1; // 卷积核在源操作数width维度滑动的步长 + loadData3DParamsForA.strideH = 1; // 卷积核在源操作数height维度滑动的步长 + loadData3DParamsForA.filterW = 1; // 卷积核width + loadData3DParamsForA.filterSizeW = false; // 是否在filterW的基础上将卷积核width增加256个元素 + loadData3DParamsForA.filterH = 1; // 卷积核height + loadData3DParamsForA.filterSizeH = false; // 是否在filterH的基础上将卷积核height增加256个元素 + loadData3DParamsForA.dilationFilterW = 1; // 卷积核width膨胀系数 + loadData3DParamsForA.dilationFilterH = 1; // 卷积核height膨胀系数 + loadData3DParamsForA.enTranspose = 0; // 是否启用转置功能,对整个目标矩阵进行转置 + loadData3DParamsForA.fMatrixCtrl = 0; + loadData3DParamsForA.channelSize = kL0SizeAlign; // 源操作数的通道数。膨胀系数为1时,目的weight为filterW*filterH*channelSize + LoadData(aL0Tensor, aL1Tensor[kL0 * baseK * mL1SizeAlign], + loadData3DParamsForA); + SetFlag(Mte1MmABEventId(abL0BufIter % 2)); + WaitFlag(Mte1MmABEventId(abL0BufIter % 2)); + + MmadParams mmadParams; + mmadParams.m = mL1SizeAlign; + mmadParams.n = nL1SizeAlign; + mmadParams.k = kL0Size; + mmadParams.cmatrixInitVal = (kL0 == 0 && k1 == 0); + mmadParams.cmatrixSource = false; + mmadParams.unitFlag = ((k1 == (kL1Loops - 1)) && (kL0 == (kL0Loops - 1))) ? 0b11 : 0b10; + + Mmad(cL0Tensor, aL0Tensor, bL0Tensor, mmadParams); + if ((mmadParams.m / 16) * (mmadParams.n / 16) < 10) { + PipeBarrier(); + } + SetFlag(Mte1MmABEventId(abL0BufIter % 2)); + abL0BufIter++; + } + + if (nL1 == (nL1Loops - 1)) { // nL1最后一轮, 需要将B驻留在L1中, 用于下一轮的计算? + SetFlag(mte21QPIds[ka]); // 反向同步, 表示L1中的A已经被mte1消费完 + } + + if (k1 == (kL1Loops - 1)) { + if (nL1 == 0 && mL1 == 0) { // 第一次Fixpipe前等待 + CrossCoreWaitFlag(constInfo.syncV1NupdateC2); + } + + if (!info.isFirstSInnerLoop) { + SetAtomicAdd(); + } + // ND + FixpipeParamsV220 fixParams; + fixParams.nSize = nL1SizeAlign; + fixParams.mSize = mL1SizeAlign; + fixParams.srcStride = mL1SizeAlign; + fixParams.dstStride = nSize; // mm2ResGm两行之间的间隔 + fixParams.ndNum = 1; // 输出ND + fixParams.unitFlag = 0b11; + + uint64_t mm2Offset = (mSplitInfo.nBufferStartM + mL1 * M_SPLIT_SIZE) * nSize + nL1 * N_SPLIT_SIZE; + Fixpipe(mm2ResGm[(info.bn2IdxInCurCore % (constInfo.preLoadNum)) * + constInfo.bmm2ResUbSize + mm2Offset], cL0Tensor, fixParams); + if (!info.isFirstSInnerLoop) { + SetAtomicNone(); + } + } + + if (mL1Loops == 2) { + cL0BufIter++; + } + } + SetFlag(mte21KVIds[kb]); // 反向同步, 表示L1已经被mte1消费完 + } + // cL0BufIter已经不在使用 + if (mL1Loops == 1) { + cL0BufIter++; + } + } + qpL1BufIter += mL1Loops; +} + +#endif // SPARSE_FLASH_ATTENTION_SERVICE_CUBE_MLA_H \ No newline at end of file diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_vector_mla.h b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_vector_mla.h new file mode 100644 index 00000000000..baea90905a8 --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_service_vector_mla.h @@ -0,0 +1,1377 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_service_vector_mla.h + * \brief + */ +#ifndef SPARSE_FLASH_ATTENTION_SERVICE_VECTOR_MLA_H +#define SPARSE_FLASH_ATTENTION_SERVICE_VECTOR_MLA_H + +#include "kernel_operator.h" +#include "kernel_operator_list_tensor_intf.h" +#include "kernel_tiling/kernel_tiling.h" +#include "lib/matmul_intf.h" +#include "lib/matrix/matmul/tiling.h" +#include "sparse_flash_attention_common.h" + +using AscendC::CrossCoreSetFlag; +using AscendC::CrossCoreWaitFlag; + +template class SFAVectorService { +public: + // 中间计算数据类型为float,高精度模式 + using T = float; + using KV_T = typename SFAT::kvType; + using OUT_T = typename SFAT::outputType; + using UPDATE_T = T; + using MM1_OUT_T = float; + using MM2_OUT_T = float; + + __aicore__ inline SFAVectorService(){}; + __aicore__ inline void ProcessVec1L(const RunInfo &info); + __aicore__ inline void ProcessVec2L(const RunInfo &info); + __aicore__ inline void InitBuffers(TPipe *pipe); + __aicore__ inline void InitParams(const struct ConstInfo &constInfo, + const SparseFlashAttentionTilingDataMla *__restrict tilingData); + __aicore__ inline void InitMm2ResInt32GmGlobalTensor(GlobalTensor mm2ResInt32Gm); + __aicore__ inline void InitVec0GlobalTensor(const GlobalTensor &kvValidSizeGm, + const GlobalTensor &kvMergeGm, + const GlobalTensor &keyRopeGm, const GlobalTensor &keyGm, + const GlobalTensor &blkTableGm); + __aicore__ inline void InitVec1GlobalTensor(GlobalTensor mm1ResGm, GlobalTensor vec1ResGm, + GlobalTensor actualSeqLengthsQGm, + GlobalTensor actualSeqLengthsKVGm, GlobalTensor lseMaxFdGm, + GlobalTensor lseSumFdGm, GlobalTensor topKGm); + __aicore__ inline void InitVec2GlobalTensor(GlobalTensor accumOutGm, GlobalTensor vec2ResGm, + GlobalTensor mm2ResGm, GlobalTensor attentionOutGm); + __aicore__ inline void AllocEventID(); + __aicore__ inline void FreeEventID(); + __aicore__ inline void InitSoftmaxDefaultBuffer(); + // ================================Base Vector========================================== + __aicore__ inline void RowDivs(LocalTensor dstUb, LocalTensor src0Ub, LocalTensor src1Ub, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline void RowMuls(LocalTensor dstUb, LocalTensor src0Ub, LocalTensor src1Ub, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + // ================================Vector0========================================== + __aicore__ inline void MergeKv(const RunInfo &runInfo); + __aicore__ inline int64_t GetKeyGmOffset(int64_t realS2Idx, const RunInfo &runInfo, int64_t s2IdLimit); + __aicore__ inline int64_t GetKeyRopeGmOffset(int64_t realS2Idx, const RunInfo &runInfo, int64_t s2IdLimit); + __aicore__ inline void GetRealS2Idx(int64_t s2GmOffset, int64_t &realS2Idx, int64_t topkGmBaseOffset, + const RunInfo &runInfo); + __aicore__ inline void CopyInKv(int64_t &mte2Size, int64_t mte3Size, int64_t mergeMte3Idx, int64_t realS2Idx1, + int64_t realS2Idx2, const RunInfo &runInfo); + __aicore__ inline void CopyOutMrgeResult(int64_t mte2Size, int64_t mte3Size, int64_t s2StartGmOffset, + int64_t mergeMte3Idx, const RunInfo &runInfo); + __aicore__ inline void SetInfInBlk(const LocalTensor &mmResUb, uint32_t dealRowCount, uint32_t columnCount, + uint64_t startId, uint64_t endId); + __aicore__ inline void SetMidInf(const LocalTensor &mmResUb, uint32_t dealRowCount, uint32_t columnCount, + uint64_t startId, uint64_t endId); + __aicore__ inline void CopyInSingleKv(int64_t &mte2Size, int64_t mte3Size, int64_t mergeMte3Idx, int64_t realS2Idx, + int64_t keyBNBOffset,int64_t s2IdLimit, const RunInfo &runInfo); + // ================================Vector1========================================== + __aicore__ inline void ProcessVec1SingleBuf(const RunInfo &info, const MSplitInfo &mSplitInfo); + __aicore__ inline void DealBmm1ResBaseBlock(const RunInfo &info, const MSplitInfo &mSplitInfo, uint32_t startRow, + uint32_t dealRowCount, uint32_t columnCount, uint32_t loopId); + __aicore__ inline void SoftmaxFlashV2Compute(const RunInfo &info, const MSplitInfo &mSplitInfo, + LocalTensor &mmResUb, LocalTensor &softmaxTmpUb, + uint32_t startRow, uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount); + __aicore__ inline void AmlaVecCompute(const RunInfo &info, const MSplitInfo &mSplitInfo, LocalTensor &mmResUb, + LocalTensor &softmaxTmpUb, uint32_t startRow, uint32_t dealRowCount, + uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline void ElewiseCompute(const RunInfo &info, const LocalTensor &mmResUb, uint32_t dealRowCount, + uint32_t columnCount); + __aicore__ inline void ProcessAmlaNupdate(const RunInfo &info, const MSplitInfo &mSplitInfo); + __aicore__ inline void ComputeLogSumExpAndCopyToGm(const RunInfo &info, const MSplitInfo &mSplitInfo, + LocalTensor &softmaxSumUb, LocalTensor &softmaxMaxUb); + // ================================Vecotr2========================================== + __aicore__ inline void ProcessVec2SingleBuf(const RunInfo &info, const MSplitInfo &mSplitInfo); + __aicore__ inline void DealBmm2ResBaseBlock(const RunInfo &info, const MSplitInfo &mSplitInfo, uint32_t startRow, + uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount); + __aicore__ inline void ProcessVec2Inner(const RunInfo &info, const MSplitInfo &mSplitInfo, uint32_t mStartRow, + uint32_t mDealSize); + __aicore__ inline void Bmm2DataCopyOutTrans(const RunInfo &info, LocalTensor &attenOutUb, uint32_t wsMStart, + uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount); + __aicore__ inline void Bmm2ResCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, uint32_t wsMStart, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline void Bmm2CastAndCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, uint32_t wsMStart, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline void Bmm2FDDataCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, uint32_t wsMStart, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount); + __aicore__ inline uint64_t CalcAccumOffset(uint32_t bN2Idx, uint32_t gS1Idx); + __aicore__ inline void GetConfusionTransposeTiling(int64_t numR, int64_t numC, const uint32_t stackBufferSize, + const uint32_t typeSize, ConfusionTransposeTiling &tiling); + + // BLOCK和REPEAT的字节数 + static constexpr uint64_t BYTE_BLOCK = 32UL; + static constexpr uint32_t REPEAT_BLOCK_BYTE = 256U; + // BLOCK和REPEAT的FP32元素数 + static constexpr uint32_t FP32_BLOCK_ELEMENT_NUM = BYTE_BLOCK / sizeof(float); + static constexpr uint32_t FP32_REPEAT_ELEMENT_NUM = REPEAT_BLOCK_BYTE / sizeof(float); + // repeat stride不能超过256 + static constexpr uint32_t REPEATE_STRIDE_UP_BOUND = 256; + +private: + static constexpr bool PAGE_ATTENTION = SFAT::pageAttention; + static constexpr int TEMPLATE_MODE = SFAT::templateMode; + static constexpr bool FLASH_DECODE = SFAT::flashDecode; + static constexpr SFA_LAYOUT LAYOUT_T = SFAT::layout; + static constexpr SFA_LAYOUT KV_LAYOUT_T = SFAT::kvLayout; + + static constexpr uint64_t MERGE_CACHE_GM_BUF_NUM = 4; + static constexpr uint64_t SYNC_INPUT_BUF1_FLAG = 2; + static constexpr uint64_t SYNC_INPUT_BUF1_PONG_FLAG = 3; + static constexpr uint64_t SYNC_INPUT_BUF2_FLAG = 4; + static constexpr uint64_t SYNC_INPUT_BUF2_PONG_FLAG = 5; + static constexpr uint64_t SYNC_OUTPUT_BUF1_FLAG = 4; + static constexpr uint64_t SYNC_OUTPUT_BUF2_FLAG = 5; + static constexpr uint32_t INPUT1_BUFFER_OFFSET = ConstInfo::BUFFER_SIZE_BYTE_32K; + static constexpr uint32_t SOFTMAX_TMP_BUFFER_OFFSET = ConstInfo::BUFFER_SIZE_BYTE_1K; + static constexpr uint32_t BASE_BLOCK_MAX_ELEMENT_NUM = ConstInfo::BUFFER_SIZE_BYTE_32K / sizeof(T); // 32768/4=8096 + static constexpr uint32_t BLOCK_ELEMENT_NUM = BYTE_BLOCK / sizeof(T); // 32/4=8 + static constexpr T FLOAT_E_SCALAR = 8388608; + static constexpr T LN2 = 0.6931471805599453094172; + static constexpr T RECIP_OF_LN2 = 1 / LN2; + static constexpr T SOFTMAX_MIN_NUM = -2e38; + + const SparseFlashAttentionTilingDataMla *__restrict tilingData; + + uint32_t pingpongFlag = 0U; + ConstInfo constInfo = {}; + + GlobalTensor mm2ResInt32Gm; + GlobalTensor mm1ResGm; + GlobalTensor vec1ResGm; + GlobalTensor lseSumFdGm; + GlobalTensor lseMaxFdGm; + + GlobalTensor actualSeqLengthsQGm; + GlobalTensor actualSeqLengthsKVGm; + GlobalTensor vec2ResGm; + GlobalTensor mm2ResGm; + GlobalTensor accumOutGm; + GlobalTensor attentionOutGm; + GlobalTensor blkTableGm_; + + GlobalTensor kvMergeGm_; + GlobalTensor keyRopeGm_; + GlobalTensor keyGm_; + GlobalTensor topkGm_; + GlobalTensor kvValidSizeGm_; + + // ================================Local Buffer区==================================== + TBuf<> inputBuff1; // 32K + TBuf<> inputBuff2; // 16K + TBuf<> outputBuff1; // 32K + TBuf<> outputBuff2; // 4K + + TBuf<> tmpBuff1; // 32K + TBuf<> v0ValidSizeBuff; // 8K + + TBuf<> nValueBuff; + TBuf<> cofValueBuff; + TBuf<> aMlaSumBuff; + TBuf<> softmaxMaxBuff; // PRE_LOAD_NUM * 2K + TBuf<> softmaxExpBuff; // PRE_LOAD_NUM * 2K + TBuf<> softmaxSumBuff; // PRE_LOAD_NUM * 2K + TBuf<> softmaxMaxDefaultBuff; // 2K + TBuf<> softmaxSumDefaultBuff; // 2K + + LocalTensor softmaxMaxDefaultUb; + LocalTensor softmaxSumDefaultUb; + + LocalTensor nValueUb; + LocalTensor cofValueUb; + LocalTensor aMlaSumUb; + LocalTensor softmaxMaxUb; + LocalTensor softmaxSumUb; + LocalTensor softmaxExpUb; + LocalTensor kvMergUb_; + LocalTensor ropeMergUb_; + LocalTensor v0ValidSizeUb_; +}; + +template __aicore__ inline void SFAVectorService::InitBuffers(TPipe *pipe) +{ + pipe->InitBuffer(inputBuff1, ConstInfo::BUFFER_SIZE_BYTE_32K * 2); // 2:pingpong + pipe->InitBuffer(inputBuff2, ConstInfo::BUFFER_SIZE_BYTE_8K * 2); // 2:pingpong + pipe->InitBuffer(outputBuff1, ConstInfo::BUFFER_SIZE_BYTE_32K); + pipe->InitBuffer(outputBuff2, ConstInfo::BUFFER_SIZE_BYTE_4K); + + pipe->InitBuffer(tmpBuff1, ConstInfo::BUFFER_SIZE_BYTE_32K); + pipe->InitBuffer(v0ValidSizeBuff, ConstInfo::BUFFER_SIZE_BYTE_8K); + + // M_MAX = 512/2vector = 256, 256 * sizeof(T) * N_Buffer + pipe->InitBuffer(nValueBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + pipe->InitBuffer(cofValueBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + pipe->InitBuffer(aMlaSumBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + + pipe->InitBuffer(softmaxMaxBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + pipe->InitBuffer(softmaxExpBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + pipe->InitBuffer(softmaxSumBuff, ConstInfo::BUFFER_SIZE_BYTE_1K * constInfo.preLoadNum); + + pipe->InitBuffer(softmaxMaxDefaultBuff, ConstInfo::BUFFER_SIZE_BYTE_1K); + pipe->InitBuffer(softmaxSumDefaultBuff, ConstInfo::BUFFER_SIZE_BYTE_1K); + + nValueUb = nValueBuff.Get(); + cofValueUb = cofValueBuff.Get(); + aMlaSumUb = aMlaSumBuff.Get(); + + softmaxMaxUb = softmaxMaxBuff.Get(); + softmaxSumUb = softmaxSumBuff.Get(); + softmaxExpUb = softmaxExpBuff.Get(); + + softmaxMaxDefaultUb = softmaxMaxDefaultBuff.Get(); + softmaxSumDefaultUb = softmaxSumDefaultBuff.Get(); + + kvMergUb_ = inputBuff1.Get(); + ropeMergUb_ = inputBuff2.Get(); + + v0ValidSizeUb_ = v0ValidSizeBuff.Get(); +} + +template +__aicore__ inline void +SFAVectorService::InitParams(const struct ConstInfo &constInfo, + const SparseFlashAttentionTilingDataMla *__restrict tilingData) +{ + this->constInfo = constInfo; + this->tilingData = tilingData; +} + +template +__aicore__ inline void +SFAVectorService::InitMm2ResInt32GmGlobalTensor(GlobalTensor mm2ResInt32Gm) +{ + this->mm2ResInt32Gm = mm2ResInt32Gm; +} + +template +__aicore__ inline void SFAVectorService::InitVec0GlobalTensor( + const GlobalTensor &kvValidSizeGm, const GlobalTensor &kvMergeGm, + const GlobalTensor &keyRopeGm, const GlobalTensor &keyGm, const GlobalTensor &blkTableGm) +{ + this->kvMergeGm_ = kvMergeGm; + this->keyRopeGm_ = keyRopeGm; + this->keyGm_ = keyGm; + this->blkTableGm_ = blkTableGm; + this->kvValidSizeGm_ = kvValidSizeGm; +} + +template +__aicore__ inline void SFAVectorService::InitVec1GlobalTensor( + GlobalTensor mm1ResGm, GlobalTensor vec1ResGm, + GlobalTensor actualSeqLengthsQGm, GlobalTensor actualSeqLengthsKVGm, GlobalTensor lseMaxFdGm, + GlobalTensor lseSumFdGm, GlobalTensor topKGm) +{ + this->mm1ResGm = mm1ResGm; + this->vec1ResGm = vec1ResGm; + this->actualSeqLengthsQGm = actualSeqLengthsQGm; + this->actualSeqLengthsKVGm = actualSeqLengthsKVGm; + this->lseMaxFdGm = lseMaxFdGm; + this->lseSumFdGm = lseSumFdGm; + this->topkGm_ = topKGm; +} + +template +__aicore__ inline void SFAVectorService::InitVec2GlobalTensor(GlobalTensor accumOutGm, + GlobalTensor vec2ResGm, + GlobalTensor mm2ResGm, + GlobalTensor attentionOutGm) +{ + this->accumOutGm = accumOutGm; + this->vec2ResGm = vec2ResGm; + this->mm2ResGm = mm2ResGm; + this->attentionOutGm = attentionOutGm; +} + +template __aicore__ inline void SFAVectorService::AllocEventID() +{ + SetFlag(SYNC_INPUT_BUF1_FLAG); + SetFlag(SYNC_INPUT_BUF1_PONG_FLAG); + SetFlag(SYNC_INPUT_BUF2_FLAG); + SetFlag(SYNC_INPUT_BUF2_PONG_FLAG); + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); +} + +template __aicore__ inline void SFAVectorService::FreeEventID() +{ + WaitFlag(SYNC_INPUT_BUF1_FLAG); + WaitFlag(SYNC_INPUT_BUF1_PONG_FLAG); + WaitFlag(SYNC_INPUT_BUF2_FLAG); + WaitFlag(SYNC_INPUT_BUF2_PONG_FLAG); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); +} + +template __aicore__ inline void SFAVectorService::InitSoftmaxDefaultBuffer() +{ + Duplicate(softmaxMaxDefaultUb, SOFTMAX_MIN_NUM, SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)); + Duplicate(softmaxSumDefaultUb, ConstInfo::FLOAT_ZERO, SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)); +} + +template +__aicore__ inline void SFAVectorService::ComputeLogSumExpAndCopyToGm(const RunInfo &info, + const MSplitInfo &mSplitInfo, + LocalTensor &softmaxSumUb, + LocalTensor &softmaxMaxUb) +{ + if (mSplitInfo.vecDealM == 0) { + return; + } + uint64_t baseOffset = mSplitInfo.nBufferStartM / 2; + size_t size = mSplitInfo.vecDealM * FP32_BLOCK_ELEMENT_NUM; + uint64_t accumTmpOutNum = CalcAccumOffset(info.bIdx, info.gS1Idx); + uint64_t offset = (accumTmpOutNum * constInfo.kvHeadNum * constInfo.mBaseSize + // taskoffset + info.tndCoreStartKVSplitPos * constInfo.kvHeadNum * constInfo.mBaseSize + // 份数offset + mSplitInfo.nBufferStartM + mSplitInfo.vecStartM) * + FP32_BLOCK_ELEMENT_NUM; // m轴offset + if (info.actualSingleProcessSInnerSize != 0) { + LocalTensor tmp = outputBuff2.Get(); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + Brcb(tmp, softmaxSumUb[baseOffset], (mSplitInfo.vecDealM + 7) / 8, {1, 8}); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + DataCopy(lseSumFdGm[offset], tmp, size); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + + tmp = outputBuff2.Get(); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + Brcb(tmp, softmaxMaxUb[baseOffset], (mSplitInfo.vecDealM + 7) / 8, {1, 8}); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + DataCopy(lseMaxFdGm[offset], tmp, size); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + } else { + matmul::InitOutput(lseSumFdGm[offset], size, ConstInfo::FLOAT_ZERO); + matmul::InitOutput(lseMaxFdGm[offset], size, SOFTMAX_MIN_NUM); + } +} + +template +__aicore__ inline void SFAVectorService::ElewiseCompute(const RunInfo &info, + const LocalTensor &mmResUb, + uint32_t dealRowCount, uint32_t columnCount) +{ + Muls(mmResUb, mmResUb, static_cast(tilingData->baseParams.scaleValue), dealRowCount * columnCount); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + // v0的无效值判断 + uint64_t s2ValidSizeFirstPart = v0ValidSizeUb_.GetValue(128 + info.loop % MERGE_CACHE_GM_BUF_NUM); + uint64_t s2ValidSizeSecondPart = v0ValidSizeUb_.GetValue(256 + info.loop % MERGE_CACHE_GM_BUF_NUM); + + int64_t s2ProcessSize = info.actualSingleProcessSInnerSize; + int64_t s2Pair = CeilDiv(s2ProcessSize, 2L * constInfo.sparseBlockSize); + int64_t s2Mid = CeilDiv(s2Pair, 2L) * 2 * constInfo.sparseBlockSize; + if (s2Mid > s2ProcessSize) { + s2Mid = s2ProcessSize; + } + if (unlikely(s2ValidSizeFirstPart < s2Mid)) { + int64_t s2StartCeilAlign = CeilAlign(s2ValidSizeFirstPart, 8); + int64_t s2MidFloorAlign = s2Mid / 8 * 8; + // 场景一 s2Mid > s2ValidSizeFirstPart + oneBlk + // 可以推导出s2StartCeilAlign < s2Mid 第一阶段取到s2StartCeilAlign + // s2StartCeilAlign <= s2MidFloorAlign 第二阶段取到s2MidFloorAlign + // 场景二 s2Mid <= s2ValidSizeFirstPart + oneBlk + // 可以推导出 s2StartCeilAlign >= s2Mid 第一阶段取到mid + // s2StartCeilAlign > s2MidFloorAlign 第二阶段取到s2StartCeilAlign + SetInfInBlk(mmResUb, dealRowCount, columnCount, s2ValidSizeFirstPart, + s2StartCeilAlign >= s2Mid ? s2Mid : s2StartCeilAlign); + SetMidInf(mmResUb, dealRowCount, columnCount, s2StartCeilAlign, s2MidFloorAlign); + SetInfInBlk(mmResUb, dealRowCount, columnCount, + s2StartCeilAlign <= s2MidFloorAlign ? s2MidFloorAlign : s2StartCeilAlign, s2Mid); + } + if (unlikely(s2ValidSizeSecondPart < s2ProcessSize - s2Mid)) { + // 场景一 s2Mid + s2ValidSizeSecondPart > s2ProcessSize + oneBlk + // 可以推导出 s2StartCeilAlign < s2ProcessSize 第一阶段取到s2StartCeilAlign + // s2StartCeilAlign <= s2EndFloorAlign 第二阶段取到s2EndFloorAlign + // 场景二 s2Mid + s2ValidSizeSecondPart <= s2ProcessSize + oneBlk + // 可以推导出 s2StartCeilAlign >= s2ProcessSize 第一阶段取到s2ProcessSize + // s2StartCeilAlign > s2EndFloorAlign 第二阶段取到s2StartCeilAlign + int64_t s2StartCeilAlign = CeilAlign(s2Mid + s2ValidSizeSecondPart, 8); + int64_t s2EndFloorAlign = s2ProcessSize / 8 * 8; + SetInfInBlk(mmResUb, dealRowCount, columnCount, s2Mid + s2ValidSizeSecondPart, + s2StartCeilAlign >= s2ProcessSize ? s2ProcessSize : s2StartCeilAlign); + SetMidInf(mmResUb, dealRowCount, columnCount, s2StartCeilAlign, s2EndFloorAlign); + SetInfInBlk(mmResUb, dealRowCount, columnCount, + s2StartCeilAlign <= s2EndFloorAlign ? s2EndFloorAlign : s2StartCeilAlign, s2ProcessSize); + } + } +} + +template +__aicore__ inline void SFAVectorService::SetInfInBlk(const LocalTensor &mmResUb, + uint32_t dealRowCount, uint32_t columnCount, + uint64_t startId, uint64_t endId) +{ + // startId endId + // x x x 0 0 0 x x x + // 从startId到endId部分置-inf, endId、startId为endId一个blk内部的下标 + if (startId >= endId) { + return; + } + + uint64_t startFloorAlignSize = startId / BLOCK_ELEMENT_NUM * BLOCK_ELEMENT_NUM; + uint64_t notComputePreMaskOneBlk = (1 << (startId - startFloorAlignSize)) - 1; + uint64_t notComputePostMaskOneBlk = ~((1 << (endId - startFloorAlignSize)) - 1); + uint64_t notComputeMaskOneBlk = notComputePreMaskOneBlk ^ notComputePostMaskOneBlk; + + uint64_t maskOneBlk = ~notComputeMaskOneBlk; + uint64_t mask[1] = {maskOneBlk}; + for (int i = 1; i < 8; i++) { + mask[0] = mask[0] | (maskOneBlk << (i * 8)); + } + for (uint64_t rowId = 0; rowId < dealRowCount; rowId += 8) { + Duplicate(mmResUb[rowId * columnCount + startFloorAlignSize], SOFTMAX_MIN_NUM, mask, + 1, CeilDiv(columnCount, 8), 0); + } +} + +template +__aicore__ inline void SFAVectorService::SetMidInf(const LocalTensor &mmResUb, + uint32_t dealRowCount, uint32_t columnCount, + uint64_t startId, uint64_t endId) +{ + if (startId >= endId) { + return; + } + // startId endId + // 0 ... 0 + // 从startId到endId部分置-inf, startId、endId为32B对齐的下标 + for (uint64_t rowId = 0; rowId < dealRowCount; rowId++) { + Duplicate(mmResUb[rowId * columnCount + startId], SOFTMAX_MIN_NUM, endId - startId); + } +} + +template +__aicore__ inline void SFAVectorService::SoftmaxFlashV2Compute( + const RunInfo &info, const MSplitInfo &mSplitInfo, LocalTensor &mmResUb, LocalTensor &softmaxTmpUb, + uint32_t startRow, uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount) +{ + LocalTensor inSumTensor; + LocalTensor inMaxTensor; + uint32_t baseOffset = mSplitInfo.nBufferStartM / 2 + startRow; + uint32_t outIdx = info.loop % (constInfo.preLoadNum); + uint32_t softmaxOutOffset = outIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset; + if (info.isFirstSInnerLoop) { + inMaxTensor = softmaxMaxDefaultUb; + inSumTensor = softmaxSumDefaultUb; + } else { + uint32_t inIdx = (info.loop - 1) % (constInfo.preLoadNum); + inMaxTensor = softmaxMaxUb[inIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset]; + inSumTensor = softmaxSumUb[inIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset]; + } + if (actualColumnCount !=0) { + SoftMaxShapeInfo srcShape{dealRowCount, columnCount, dealRowCount, actualColumnCount}; + SoftMaxTiling newTiling = + SoftMaxFlashV2TilingFunc(srcShape, sizeof(T), sizeof(T), softmaxTmpUb.GetSize(), true, false); + SoftmaxFlashV2( + mmResUb, softmaxSumUb[softmaxOutOffset], softmaxMaxUb[softmaxOutOffset], mmResUb, + softmaxExpUb[softmaxOutOffset], inSumTensor, inMaxTensor, softmaxTmpUb, newTiling, srcShape); + } else { + uint32_t dealRowCountAlign = SFAAlign(dealRowCount, FP32_BLOCK_ELEMENT_NUM); + DataCopy(softmaxSumUb[softmaxOutOffset], inSumTensor, dealRowCountAlign); + pipe_barrier(PIPE_V); + DataCopy(softmaxMaxUb[softmaxOutOffset], inMaxTensor, dealRowCountAlign); + } +} + +template +__aicore__ inline void SFAVectorService::AmlaVecCompute( + const RunInfo &info, const MSplitInfo &mSplitInfo, LocalTensor &mmResUb, LocalTensor &softmaxTmpUb, + uint32_t startRow, uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount) +{ + uint32_t baseOffset = mSplitInfo.nBufferStartM / 2 + startRow; + uint32_t calCount = dealRowCount; + uint32_t outIdx = info.loop % (constInfo.preLoadNum); + uint32_t softmaxOutOffset = outIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset; + // compute n(i) + LocalTensor nTmp = softmaxTmpUb.template ReinterpretCast(); + LocalTensor nUpdateTmp = nTmp[SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + Muls(nTmp, softmaxMaxUb[softmaxOutOffset], ((T)(-1.0)) * RECIP_OF_LN2, calCount); + + pipe_barrier(PIPE_V); + Cast(nTmp, nTmp, RoundMode::CAST_ROUND, calCount); + pipe_barrier(PIPE_V); + + uint32_t prOutIdx = (info.loop - 1) % (constInfo.preLoadNum); + uint32_t PreSoftmaxOutOffset = prOutIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset; + // n(i) - n(i-1) + if (info.isFirstSInnerLoop) { + Duplicate(nUpdateTmp, ConstInfo::FLOAT_ZERO, calCount); // n1=n0 + } else { + Sub(nUpdateTmp, nTmp, nValueUb[PreSoftmaxOutOffset], calCount); + } + pipe_barrier(PIPE_V); + // update n(i), DataCopy not support when calCount is not align 32B, so use Adds + Adds(nValueUb[softmaxOutOffset], nTmp, ConstInfo::FLOAT_ZERO, calCount); + pipe_barrier(PIPE_V); + + // update softmax res + LocalTensor nUpdateTmp2 = nTmp[2 * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + LocalTensor nTmp_KvT = nTmp[3 * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)].template ReinterpretCast(); + LocalTensor tmpCofUb = nTmp[4 * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + LocalTensor epsUb = nTmp[5 * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + Muls(nUpdateTmp2, softmaxMaxUb[softmaxOutOffset], RECIP_OF_LN2, calCount); + pipe_barrier(PIPE_V); + Add(nTmp, nUpdateTmp2, nTmp, calCount); + pipe_barrier(PIPE_V); + Muls(nTmp, nTmp, LN2, calCount); + pipe_barrier(PIPE_V); + Exp(nTmp, nTmp, calCount); + pipe_barrier(PIPE_V); + Cast(nTmp_KvT, nTmp, RoundMode::CAST_ROUND, calCount); // fp32->fp16/bf16 + pipe_barrier(PIPE_V); + Cast(nUpdateTmp2, nTmp_KvT, RoundMode::CAST_NONE, calCount); // fp16/bf16->fp32 + pipe_barrier(PIPE_V); + if (info.s2Idx + 1 == info.curSInnerLoopTimes) { + Mul(aMlaSumUb[softmaxOutOffset], softmaxSumUb[softmaxOutOffset], nUpdateTmp2, calCount); + } + if (actualColumnCount == 0) { + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + return; + } + LocalTensor nTmp3 = nTmp[6 * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + Brcb(nTmp3, nUpdateTmp2, (dealRowCount + 7) / 8, {1, 8}); + pipe_barrier(PIPE_V); + RowMuls(mmResUb, mmResUb, nTmp3, dealRowCount, columnCount, actualColumnCount); + + Div(tmpCofUb, nTmp, nUpdateTmp2, calCount); // cof(i)=tmpS32/tmpS16 + if (info.isFirstSInnerLoop) { + Duplicate(cofValueUb[softmaxOutOffset], (T)1.0, calCount); // cof_0=1 + pipe_barrier(PIPE_V); + Div(epsUb, cofValueUb[softmaxOutOffset], tmpCofUb, calCount); // 1 / cof(i) + } else { + pipe_barrier(PIPE_V); + Div(epsUb, cofValueUb[PreSoftmaxOutOffset], tmpCofUb, calCount); // cof(i - 1) / cof(i) + } + pipe_barrier(PIPE_V); + + Adds(cofValueUb[softmaxOutOffset], tmpCofUb, ConstInfo::FLOAT_ZERO, calCount); // store cof(i) + Adds(epsUb, epsUb, (T)(-1.0), calCount); // cof(i - 1) / cof(i) - 1 + pipe_barrier(PIPE_V); + Muls(epsUb, epsUb, (T)1.5, calCount); // (cof(i - 1) - cof(i)) / cof(i) * 1.5 + + Maxs(nUpdateTmp, nUpdateTmp, (T)(-30.0), calCount); // N = max(n(i) - n(i-1), -30) + pipe_barrier(PIPE_V); + Adds(epsUb, epsUb, (T)(0.000001), calCount); + pipe_barrier(PIPE_V); + Add(nUpdateTmp, nUpdateTmp, epsUb, calCount); + pipe_barrier(PIPE_V); + Muls(nUpdateTmp, nUpdateTmp, FLOAT_E_SCALAR, calCount); // N = N * pow(2, 23) + pipe_barrier(PIPE_V); + + // nUpdate int32 out + LocalTensor tmQue = outputBuff2.Get(); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + LocalTensor nInt32Out = tmQue[startRow]; // 缓存nUpdate + + Cast(nInt32Out, nUpdateTmp, RoundMode::CAST_ROUND, dealRowCount); + pipe_barrier(PIPE_V); + + SetFlag(SYNC_OUTPUT_BUF2_FLAG); +} + +template +__aicore__ inline void SFAVectorService::DealBmm1ResBaseBlock( + const RunInfo &info, const MSplitInfo &mSplitInfo, uint32_t startRow, uint32_t dealRowCount, + uint32_t columnCount, uint32_t loopId) +{ + uint32_t computeSize = dealRowCount * columnCount; + uint64_t inOutGmOffset = (info.loop % constInfo.preLoadNum) * constInfo.mmResUbSize + + (mSplitInfo.nBufferStartM + mSplitInfo.vecStartM + startRow) * columnCount; + LocalTensor mmResUb = inputBuff1.Get(); + mmResUb = mmResUb[pingpongFlag * INPUT1_BUFFER_OFFSET / sizeof(MM1_OUT_T)]; + WaitFlag(SYNC_INPUT_BUF1_FLAG + pingpongFlag); + + DataCopy(mmResUb, mm1ResGm[inOutGmOffset], computeSize); + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + if (loopId == 0) { + WaitFlag(0); + } + } + SetFlag(SYNC_INPUT_BUF1_FLAG); + WaitFlag(SYNC_INPUT_BUF1_FLAG); + + ElewiseCompute(info, mmResUb, dealRowCount, columnCount); + + pipe_barrier(PIPE_V); + LocalTensor tmpAFloorUb = tmpBuff1.Get(); + LocalTensor softmaxTmpUb = tmpAFloorUb.template ReinterpretCast(); + + SoftmaxFlashV2Compute(info, mSplitInfo, mmResUb, softmaxTmpUb, startRow, dealRowCount, columnCount, + info.actualSingleProcessSInnerSize); + + pipe_barrier(PIPE_V); + AmlaVecCompute(info, mSplitInfo, mmResUb, softmaxTmpUb, startRow, dealRowCount, columnCount, + info.actualSingleProcessSInnerSize); + + pipe_barrier(PIPE_V); + LocalTensor tmpMMResCastTensor = outputBuff1.Get(); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + + Cast(tmpMMResCastTensor, mmResUb, AscendC::RoundMode::CAST_ROUND, computeSize); + SetFlag(SYNC_INPUT_BUF1_FLAG + pingpongFlag); + + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + DataCopy(vec1ResGm[inOutGmOffset], tmpMMResCastTensor, computeSize); + SetFlag(SYNC_OUTPUT_BUF1_FLAG); +} + +template +__aicore__ inline void SFAVectorService::ProcessAmlaNupdate(const RunInfo &info, const MSplitInfo &mSplitInfo) +{ + if (mSplitInfo.vecDealM == 0) { + return; + } + if (info.isFirstSInnerLoop) { + return; + } + + LocalTensor nUpdateTensor = outputBuff2.Get(); // shape:1/2*s1*g + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + SetFlag(SYNC_OUTPUT_BUF2_FLAG); + WaitFlag(SYNC_OUTPUT_BUF2_FLAG); + + constexpr uint32_t dGroupSize = 128U; + constexpr uint32_t mSplitSize = 64U; // tmpQue size 32KB,一次只能处理64个N,最大保存的数据大小:64*128*sizeof(int32) + constexpr uint32_t ONE_BLOCK_SIZE = 32U; // 32B + + uint32_t subMSize = SFAAlign(mSplitInfo.vecDealM, 16U); + uint16_t elementPerBlock = ONE_BLOCK_SIZE / sizeof(int32_t); // 单个datablock的元素数,int32_t类型的为32/4=8 + uint32_t loopCount = (subMSize + mSplitSize - 1) / mSplitSize; + uint32_t tailSplitSize = subMSize - (loopCount - 1) * mSplitSize; // 尾块 + + for (uint32_t loop = 0, processMSize = mSplitSize; loop < loopCount; loop++) { + if (loop == (loopCount - 1)) { + processMSize = tailSplitSize; + } + LocalTensor tmpQue = outputBuff1.Get(); + + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + // (m,1)单次brcb扩充成(m,8), 重复16次, 扩充为(m,128) + for (uint32_t i = 0; i < dGroupSize / elementPerBlock; i++) { + Brcb(tmpQue[i * elementPerBlock], + nUpdateTensor[loop * mSplitSize], + static_cast((processMSize + elementPerBlock - 1) / elementPerBlock), + {static_cast(dGroupSize / elementPerBlock), // 单次迭代内,目的操作数不同datablock间地址步长,单位为datablock + static_cast(dGroupSize)}); // 相邻迭代间,目的操作数相同datablock地址步长 + } + + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + + uint64_t baseoffset = (info.bn2IdxInCurCore % constInfo.preLoadNum) * constInfo.bmm2ResUbSize + + (mSplitInfo.nBufferStartM + mSplitInfo.vecStartM + loop * mSplitSize) * constInfo.headDim; + + SetAtomicAdd(); + DataCopyParams dataCopyParams; + dataCopyParams.blockCount = static_cast(processMSize); + dataCopyParams.blockLen = dGroupSize * sizeof(int32_t) / ONE_BLOCK_SIZE; // 每个block是128个元素,单位为32B + dataCopyParams.srcStride = 0; // 前面一个数据块的尾与后面数据块的头的间隔 + dataCopyParams.dstStride = static_cast((constInfo.headDim - dGroupSize) * + sizeof(int32_t) / ONE_BLOCK_SIZE); // 单位为32B + for (uint32_t i = 0; i < constInfo.headDim / dGroupSize; i++) { // 4=512/128 + DataCopy(mm2ResInt32Gm[baseoffset + i * dGroupSize] ,tmpQue, dataCopyParams); + } + SetAtomicNone(); + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + } + SetFlag(SYNC_OUTPUT_BUF2_FLAG); +} + +template +__aicore__ inline void SFAVectorService::ProcessVec1SingleBuf(const RunInfo &info, + const MSplitInfo &mSplitInfo) +{ + if (mSplitInfo.vecDealM == 0) { + return; + } + uint32_t mSplitSize = info.actualSingleProcessSInnerSize == 0 ? + 16 : BASE_BLOCK_MAX_ELEMENT_NUM / info.actualSingleProcessSInnerSizeAlign; + // 1. 向下8对齐是因为UB操作至少32B + // 2. info.actualSingleProcessSInnerSizeAlign最大512, mSplitSize可以确保最小为16 + mSplitSize = mSplitSize / 8 * 8; + + if (mSplitSize > mSplitInfo.vecDealM) { + mSplitSize = mSplitInfo.vecDealM; + } + uint32_t loopCount = (mSplitInfo.vecDealM + mSplitSize - 1) / mSplitSize; + uint32_t tailSplitSize = mSplitInfo.vecDealM - (loopCount - 1) * mSplitSize; + + if constexpr (TEMPLATE_MODE == V_TEMPLATE) { + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = 1; + dataCopyParams.blockLen = 256 * sizeof(int32_t); + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = 0; + DataCopyPadExtParams padParams; + // 额外偏移128个元素,避免不同loop下v0和v1互相影响 + DataCopyPad(v0ValidSizeUb_[128], kvValidSizeGm_[info.loop % MERGE_CACHE_GM_BUF_NUM * (128 * 2)], + dataCopyParams, padParams); + SetFlag(0); + if (unlikely(loopCount == 0)) { + // scalar同步影响较大,挪到循环内部进行 + WaitFlag(0); + } + } + for (uint32_t i = 0, dealSize = mSplitSize; i < loopCount; i++) { + if (i == (loopCount - 1)) { + dealSize = tailSplitSize; + } + DealBmm1ResBaseBlock(info, mSplitInfo, i * mSplitSize, dealSize, info.actualSingleProcessSInnerSizeAlign, i); + pingpongFlag ^= 1; // pingpong 0 1切换 + } +} + +template +__aicore__ inline void SFAVectorService::GetRealS2Idx(int64_t s2GmOffset, int64_t &realS2Idx, + int64_t topkGmBaseOffset, const RunInfo &runInfo) +{ + int64_t topkGmIdx = (s2GmOffset + runInfo.s2Idx * constInfo.s2BaseSize) / constInfo.sparseBlockSize; + if (unlikely(topkGmIdx >= constInfo.sparseBlockCount)) { + realS2Idx = -1; + return; + } + realS2Idx = topkGm_.GetValue(topkGmBaseOffset + topkGmIdx) * static_cast(constInfo.sparseBlockSize) + + static_cast((s2GmOffset + runInfo.s2Idx * constInfo.s2BaseSize) % constInfo.sparseBlockSize); +} + +template +__aicore__ inline int64_t SFAVectorService::GetKeyGmOffset(int64_t realS2Idx, + const RunInfo &runInfo, int64_t s2IdLimit) +{ + if (realS2Idx < 0 || realS2Idx >= s2IdLimit) { + return -1; + } + int64_t realKeyGmOffset = 0; + if constexpr (PAGE_ATTENTION) { + int64_t blkTableIdx = realS2Idx / constInfo.kvCacheBlockSize; + int64_t blkTableOffset = realS2Idx % constInfo.kvCacheBlockSize; + realKeyGmOffset = blkTableGm_.GetValue(runInfo.bIdx * constInfo.maxBlockNumPerBatch + blkTableIdx) * + static_cast(constInfo.kvCacheBlockSize) * + static_cast(constInfo.kvHeadNum) + + blkTableOffset; + } else { + realKeyGmOffset = (runInfo.tensorBOffset + + realS2Idx * constInfo.kvHeadNum * constInfo.headDim) / + constInfo.headDim; + } + return realKeyGmOffset; +} + +template +__aicore__ inline int64_t SFAVectorService::GetKeyRopeGmOffset(int64_t realS2Idx, + const RunInfo &runInfo, int64_t s2IdLimit) +{ + if (realS2Idx < 0 || realS2Idx >= s2IdLimit) { + return -1; + } + int64_t realKeyRopeGmOffset = 0; + realKeyRopeGmOffset = (runInfo.tensorBRopeOffset + + realS2Idx * constInfo.kvHeadNum * constInfo.headDimRope) / + constInfo.headDimRope; + return realKeyRopeGmOffset; +} + +template +__aicore__ inline void +SFAVectorService::CopyInSingleKv(int64_t &mte2Size, int64_t mte3Size, int64_t mergeMte3Idx, int64_t realS2Idx, + int64_t keyBNBOffset,int64_t s2IdLimit, const RunInfo &runInfo) +{ + if (keyBNBOffset < 0) { + return; + } + int64_t validS2Count = + (realS2Idx + constInfo.sparseBlockSize > s2IdLimit ? s2IdLimit - realS2Idx : constInfo.sparseBlockSize); + DataCopyExtParams intriParams; + intriParams.blockLen = validS2Count * constInfo.headDim * sizeof(KV_T); + intriParams.blockCount = 1; + intriParams.dstStride = 0; + intriParams.srcStride = 0; + DataCopyPadExtParams padParams; + DataCopyPad(kvMergUb_[mergeMte3Idx % 2 * 32 * 512 + (mte2Size - mte3Size) * constInfo.headDim], + keyGm_[keyBNBOffset * constInfo.headDim], intriParams, padParams); + intriParams.blockLen = validS2Count * constInfo.headDimRope * sizeof(KV_T); + + DataCopyPad(ropeMergUb_[mergeMte3Idx % 2 * 32 * 64 + (mte2Size - mte3Size) * constInfo.headDimRope], + keyRopeGm_[keyBNBOffset * constInfo.headDimRope], intriParams, padParams); + mte2Size += validS2Count; +} + +template +__aicore__ inline void SFAVectorService::CopyInKv(int64_t &mte2Size, int64_t mte3Size, int64_t mergeMte3Idx, + int64_t realS2Idx1, int64_t realS2Idx2, const RunInfo &runInfo) +{ + int64_t s2IdLimit = runInfo.curActualSeqLenOri; + if (constInfo.sparseMode == 3) { + s2IdLimit = runInfo.curActualSeqLenOri - runInfo.actS1Size + runInfo.gS1Idx / constInfo.gSize + 1; + } + + int64_t keyOffset1 = GetKeyGmOffset(realS2Idx1, runInfo, s2IdLimit); + int64_t keyOffset2 = GetKeyGmOffset(realS2Idx2, runInfo, s2IdLimit); + if (unlikely(keyOffset1 < 0 && keyOffset2 < 0)) { + return; + } + + int64_t keySrcStride = 0; + int64_t keyRopeSrcStride = 0; + if constexpr (PAGE_ATTENTION) { + int64_t blkTableSrcStride = + ((keyOffset1 > keyOffset2 ? (keyOffset1 - keyOffset2) : + (keyOffset2 - keyOffset1)) - constInfo.sparseBlockSize); + keySrcStride = blkTableSrcStride * constInfo.headDim * sizeof(KV_T); + keyRopeSrcStride = blkTableSrcStride * constInfo.headDimRope * sizeof(KV_T); + } else { + int64_t keyRopeOffset1 = GetKeyRopeGmOffset(realS2Idx1, runInfo, s2IdLimit); + int64_t keyRopeOffset2 = GetKeyRopeGmOffset(realS2Idx2, runInfo, s2IdLimit); + keySrcStride = ((keyOffset1 > keyOffset2 ? (keyOffset1 - keyOffset2) : + (keyOffset2 - keyOffset1)) - constInfo.sparseBlockSize) * constInfo.headDim * sizeof(KV_T); + keyRopeSrcStride = ((keyRopeOffset1 > keyRopeOffset2 ? (keyRopeOffset1 - keyRopeOffset2) : + (keyRopeOffset2 - keyRopeOffset1)) - constInfo.sparseBlockSize) * + constInfo.headDimRope * sizeof(KV_T); + } + + if (unlikely(keySrcStride >= INT32_MAX || keySrcStride < 0 || + (!PAGE_ATTENTION && (keyRopeSrcStride >= INT32_MAX || keyRopeSrcStride < 0)) || + realS2Idx1 + constInfo.sparseBlockSize >= s2IdLimit || + realS2Idx2 + constInfo.sparseBlockSize >= s2IdLimit)) { + // stride溢出、stride为负数、s2超长等异常场景,还原成2条搬运指令 + CopyInSingleKv(mte2Size, mte3Size, mergeMte3Idx, realS2Idx1, keyOffset1, s2IdLimit, runInfo); + CopyInSingleKv(mte2Size, mte3Size, mergeMte3Idx, realS2Idx2, keyOffset2, s2IdLimit, runInfo); + } else { + DataCopyExtParams intriParams; + intriParams.blockLen = constInfo.sparseBlockSize * constInfo.headDim * sizeof(KV_T); + intriParams.blockCount = (keyOffset1 >= 0) + (keyOffset2 >= 0); + intriParams.dstStride = 0; + intriParams.srcStride = keySrcStride; + DataCopyPadExtParams padParams; + + int64_t startGmOffset = keyOffset1 > -1 ? keyOffset1 : keyOffset2; + if (keyOffset2 > -1 && keyOffset2 < keyOffset1) { + startGmOffset = keyOffset2; + } + DataCopyPad(kvMergUb_[mergeMte3Idx % 2 * 32 * 512 + (mte2Size - mte3Size) * constInfo.headDim], + keyGm_[startGmOffset * constInfo.headDim], intriParams, padParams); + + intriParams.blockLen = constInfo.sparseBlockSize * constInfo.headDimRope * sizeof(KV_T); + intriParams.dstStride = 0; + intriParams.srcStride = keyRopeSrcStride; + DataCopyPad(ropeMergUb_[mergeMte3Idx % 2 * 32 * 64 + (mte2Size - mte3Size) * constInfo.headDimRope], + keyRopeGm_[startGmOffset * constInfo.headDimRope], intriParams, padParams); + mte2Size += ((keyOffset1 > -1) + (keyOffset2 > -1)) * constInfo.sparseBlockSize; + } +} + +template +__aicore__ inline void SFAVectorService::CopyOutMrgeResult(int64_t mte2Size, int64_t mte3Size, + int64_t s2GmStartOffset, int64_t mergeMte3Idx, + const RunInfo &runInfo) +{ + if (mte2Size <= mte3Size) { + return; + } + SetFlag(0); + WaitFlag(0); + + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = mte2Size - mte3Size; + dataCopyParams.blockLen = constInfo.headDim * sizeof(KV_T); + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = 0; + + DataCopyPad(kvMergeGm_[runInfo.loop % 4 * 512 * 576 + (s2GmStartOffset + mte3Size)*constInfo.headDim], + kvMergUb_[mergeMte3Idx % 2 * 32 * 512], dataCopyParams); + + dataCopyParams.blockLen = constInfo.headDimRope * sizeof(KV_T); + DataCopyPad(kvMergeGm_[runInfo.loop % 4 * 512 * 576 + 512 * 512 + (s2GmStartOffset + mte3Size) * + constInfo.headDimRope], ropeMergUb_[mergeMte3Idx % 2 * 32 * 64], dataCopyParams); +} + +// b s1 k +template +__aicore__ inline void SFAVectorService::MergeKv(const RunInfo &runInfo) +{ + int64_t s2ProcessSize = runInfo.actualSingleProcessSInnerSize; + int64_t s2Pair = CeilDiv(s2ProcessSize, 2L * constInfo.sparseBlockSize); + int64_t topkGmBaseOffset = 0; + + if constexpr (LAYOUT_T == SFA_LAYOUT::TND) { + uint64_t actualSeqQPrefixSum = (runInfo.bIdx <= 0) ? 0 : actualSeqLengthsQGm.GetValue(runInfo.bIdx - 1); + topkGmBaseOffset += (actualSeqQPrefixSum + runInfo.gS1Idx / constInfo.gSize) * constInfo.kvHeadNum * + constInfo.sparseBlockCount + runInfo.n2Idx * constInfo.sparseBlockCount; + } else { + topkGmBaseOffset += runInfo.bIdx * constInfo.qSeqSize * constInfo.sparseBlockCount + + runInfo.gS1Idx / constInfo.gSize * constInfo.sparseBlockCount; + } + int64_t mergeMte3Idx = 0; + int64_t mte2Size = 0; + int64_t mte3Size = 0; + int64_t s2IdxArray0 = -1; + int64_t s2IdxArray1 = -1; + bool needWaitMte3ToMte2 = true; + SetFlag(0); + SetFlag(1); + int64_t s2GmStartOffset = GetSubBlockIdx() == 0 ? 0 : CeilDiv(s2Pair, 2L) * 2 * constInfo.sparseBlockSize; + int64_t s2GmLimit = GetSubBlockIdx() == 0 ? CeilDiv(s2Pair, 2L) * 2 * constInfo.sparseBlockSize: s2ProcessSize; + if (s2GmLimit > s2ProcessSize) { + s2GmLimit = s2ProcessSize; + } + for (int64_t s2GmOffsetArray = s2GmStartOffset; s2GmOffsetArray < s2GmLimit; s2GmOffsetArray += 2 * constInfo.sparseBlockSize) { + if (needWaitMte3ToMte2) { + WaitFlag(mergeMte3Idx % 2); + needWaitMte3ToMte2 = false; + } + GetRealS2Idx(s2GmOffsetArray, s2IdxArray0, topkGmBaseOffset, runInfo); + if (unlikely(s2IdxArray0 < 0)) { + CopyOutMrgeResult(mte2Size, mte3Size, s2GmStartOffset, mergeMte3Idx, runInfo); + SetFlag(mergeMte3Idx % 2); + mergeMte3Idx++; + break; + } + GetRealS2Idx(s2GmOffsetArray + constInfo.sparseBlockSize, s2IdxArray1, topkGmBaseOffset, runInfo); + CopyInKv(mte2Size, mte3Size, mergeMte3Idx, s2IdxArray0, s2IdxArray1, runInfo); + if ((mte2Size - mte3Size + 2 * constInfo.sparseBlockSize > 32) || + s2GmOffsetArray + 2 * constInfo.sparseBlockSize >= s2GmLimit) { + CopyOutMrgeResult(mte2Size, mte3Size, s2GmStartOffset, mergeMte3Idx, runInfo); + mte3Size = mte2Size; + SetFlag(mergeMte3Idx % 2); + mergeMte3Idx++; + needWaitMte3ToMte2 = true; + } + } + + if (unlikely(s2GmStartOffset + mte2Size < s2GmLimit)) { + SetFlag(0); + WaitFlag(0); + WaitFlag(mergeMte3Idx & 1); + Duplicate(kvMergUb_, static_cast(0.0), constInfo.headDim); + SetFlag(0); + WaitFlag(0); + + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = 1; + dataCopyParams.blockLen = constInfo.headDim * sizeof(KV_T); + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = 0; + for (int64_t s2GmOffset = s2GmStartOffset + mte2Size; s2GmOffset < s2GmLimit; s2GmOffset++) { + DataCopyPad(kvMergeGm_[runInfo.loop % MERGE_CACHE_GM_BUF_NUM * 512 * 576 + s2GmOffset * constInfo.headDim], + kvMergUb_, dataCopyParams); + } + dataCopyParams.blockLen = constInfo.headDimRope * sizeof(KV_T); + for (int64_t s2GmOffset = s2GmStartOffset + mte2Size; s2GmOffset < s2GmLimit; s2GmOffset++) { + DataCopyPad(kvMergeGm_[runInfo.loop % MERGE_CACHE_GM_BUF_NUM * 512 * 576 + 512 * constInfo.headDim + + s2GmOffset * constInfo.headDimRope], + kvMergUb_, dataCopyParams); + } + SetFlag(mergeMte3Idx & 1); + mergeMte3Idx++; + } + WaitFlag(0); + WaitFlag(1); + v0ValidSizeUb_.SetValue(runInfo.loop % MERGE_CACHE_GM_BUF_NUM, mte2Size); + SetFlag(1); + WaitFlag(1); + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = 1; + dataCopyParams.blockLen = 128 * sizeof(int32_t); + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = 0; + DataCopyPad(kvValidSizeGm_[runInfo.loop % MERGE_CACHE_GM_BUF_NUM * (128 * 2) + GetSubBlockIdx() * 128], + v0ValidSizeUb_, dataCopyParams); + return; +} + +template +__aicore__ inline void SFAVectorService::ProcessVec1L(const RunInfo &info) +{ + uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize; + uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize; + for (uint32_t i = 0; i < nBufferLoopTimes; i++) { + MSplitInfo mSplitInfo; + mSplitInfo.nBufferIdx = i; + mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize; + mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail; + + mSplitInfo.vecDealM = (mSplitInfo.nBufferDealM <= 16) ? mSplitInfo.nBufferDealM : + (((mSplitInfo.nBufferDealM + 15) / 16 + 1) / 2 * 16); + mSplitInfo.vecStartM = 0; + if (GetBlockIdx() % 2 == 1) { + mSplitInfo.vecStartM = mSplitInfo.vecDealM; + mSplitInfo.vecDealM = mSplitInfo.nBufferDealM - mSplitInfo.vecDealM; + } + + CrossCoreWaitFlag(constInfo.syncC1V1); + // vec1 compute + ProcessVec1SingleBuf(info, mSplitInfo); + CrossCoreSetFlag(constInfo.syncV1C2); + CrossCoreWaitFlag(constInfo.syncC2V1); + // add nUpdate to mm2ResGm + if (info.actualSingleProcessSInnerSize != 0) { + ProcessAmlaNupdate(info, mSplitInfo); + CrossCoreSetFlag(constInfo.syncV1NupdateC2); + } + // move lse for flash decode + if (info.s2Idx == info.curSInnerLoopTimes - 1) { + if (info.tndIsS2SplitCore) { + if constexpr (FLASH_DECODE) { + uint32_t outIdx = info.loop % (constInfo.preLoadNum); + auto sumTensor = softmaxSumUb[outIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + auto maxTensor = softmaxMaxUb[outIdx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T)]; + ComputeLogSumExpAndCopyToGm(info, mSplitInfo, sumTensor, maxTensor); + } + } + } + } +} + +template +__aicore__ inline uint64_t SFAVectorService::CalcAccumOffset(uint32_t bN2Idx, uint32_t gS1Idx) +{ + return 0; +} + +template +__aicore__ inline void SFAVectorService::ProcessVec2SingleBuf(const RunInfo &info, + const MSplitInfo &mSplitInfo) +{ + if (info.s2Idx + 1 != info.curSInnerLoopTimes) { + return; + } + if (mSplitInfo.vecDealM == 0) { + return; + } + + ProcessVec2Inner(info, mSplitInfo, 0, mSplitInfo.vecDealM); +} + +template __aicore__ inline void SFAVectorService::ProcessVec2L(const RunInfo &info) +{ + uint32_t nBufferLoopTimes = (info.actMBaseSize + constInfo.nBufferMBaseSize - 1) / constInfo.nBufferMBaseSize; + uint32_t nBufferTail = info.actMBaseSize - (nBufferLoopTimes - 1) * constInfo.nBufferMBaseSize; + for (uint32_t i = 0; i < nBufferLoopTimes; i++) { + MSplitInfo mSplitInfo; + mSplitInfo.nBufferIdx = i; + mSplitInfo.nBufferStartM = i * constInfo.nBufferMBaseSize; + mSplitInfo.nBufferDealM = (i + 1 != nBufferLoopTimes) ? constInfo.nBufferMBaseSize : nBufferTail; + + mSplitInfo.vecDealM = (mSplitInfo.nBufferDealM <= 16) ? mSplitInfo.nBufferDealM : + (((mSplitInfo.nBufferDealM + 15) / 16 + 1) / 2 * 16); + mSplitInfo.vecStartM = 0; + if (GetBlockIdx() % 2 == 1) { + mSplitInfo.vecStartM = mSplitInfo.vecDealM; + mSplitInfo.vecDealM = mSplitInfo.nBufferDealM - mSplitInfo.vecDealM; + } + CrossCoreWaitFlag(constInfo.syncC2V2); + ProcessVec2SingleBuf(info, mSplitInfo); + } +} + +template +__aicore__ inline void SFAVectorService::ProcessVec2Inner(const RunInfo &info, + const MSplitInfo &mSplitInfo, + uint32_t mStartRow, uint32_t mDealSize) +{ + uint32_t mSplitSize = BASE_BLOCK_MAX_ELEMENT_NUM / constInfo.headDim; + if (mSplitSize > mDealSize) { + mSplitSize = mDealSize; + } + + uint32_t loopCount = (mDealSize + mSplitSize - 1) / mSplitSize; + uint32_t tailSplitSize = mDealSize - (loopCount - 1) * mSplitSize; + for (uint32_t i = 0, dealSize = mSplitSize; i < loopCount; i++) { + if (i == (loopCount - 1)) { + dealSize = tailSplitSize; + } + DealBmm2ResBaseBlock(info, mSplitInfo, i * mSplitSize + mStartRow, dealSize, + constInfo.headDim, constInfo.headDim); + pingpongFlag ^= 1; // pingpong 0 1切换 + } +} + + +template +__aicore__ inline void SFAVectorService::GetConfusionTransposeTiling( + int64_t numR, int64_t numC, const uint32_t stackBufferSize, const uint32_t typeSize, + ConfusionTransposeTiling &tiling) +{ + (void)stackBufferSize; + uint32_t blockSize = ONE_BLK_SIZE / typeSize; + uint32_t height = numC; + uint32_t width = numR; + uint32_t highBlock = height / BLOCK_CUBE; + uint32_t stride = height * blockSize * typeSize / ONE_BLK_SIZE; + uint32_t repeat = width / blockSize; + + tiling.param0 = blockSize; + tiling.param1 = height; + tiling.param2 = width; + tiling.param3 = highBlock; + tiling.param4 = stride; + tiling.param5 = repeat; +} + +template +__aicore__ inline void +SFAVectorService::Bmm2FDDataCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, + uint32_t wsMStart, uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount) +{ + LocalTensor tmp = outputBuff1.Get(); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + DataCopy(tmp, bmm2ResUb, columnCount * dealRowCount); + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + uint64_t accumTmpOutNum = CalcAccumOffset(info.bIdx, info.gS1Idx); + uint64_t offset = accumTmpOutNum * constInfo.kvHeadNum * constInfo.mBaseSize * constInfo.headDim + // taskoffset + info.tndCoreStartKVSplitPos * constInfo.kvHeadNum * constInfo.mBaseSize * constInfo.headDim + // 份数offset + wsMStart * actualColumnCount; // m轴offset + GlobalTensor dst = accumOutGm[offset]; + if (info.actualSingleProcessSInnerSize== 0) { + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = dealRowCount; + dataCopyParams.blockLen = actualColumnCount * sizeof(T); + dataCopyParams.srcStride = (columnCount - actualColumnCount) / (BYTE_BLOCK / sizeof(T)); + dataCopyParams.dstStride = 0; + DataCopyPad(dst, tmp, dataCopyParams); + } else { + matmul::InitOutput(dst, dealRowCount * actualColumnCount, ConstInfo::FLOAT_ZERO); + } + SetFlag(SYNC_OUTPUT_BUF1_FLAG); +} + +template +__aicore__ inline void +SFAVectorService::Bmm2DataCopyOutTrans(const RunInfo &info, LocalTensor &attenOutUb, + uint32_t wsMStart, uint32_t dealRowCount, + uint32_t columnCount, uint32_t actualColumnCount) +{ + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = dealRowCount; + dataCopyParams.blockLen = actualColumnCount * sizeof(OUT_T); + dataCopyParams.srcStride = (columnCount - actualColumnCount) / (BYTE_BLOCK / sizeof(OUT_T)); + dataCopyParams.dstStride = 0; + DataCopyPad(attentionOutGm[info.attenOutOffset + wsMStart * actualColumnCount], attenOutUb, dataCopyParams); + return; +} + +template +__aicore__ inline void +SFAVectorService::Bmm2CastAndCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, + uint32_t wsMStart, uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount) +{ + LocalTensor tmpBmm2ResCastTensor = outputBuff1.Get(); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + if constexpr (IsSameType::value) { // bf16 采取四舍六入五成双模式 + Cast(tmpBmm2ResCastTensor, bmm2ResUb, AscendC::RoundMode::CAST_RINT, dealRowCount * columnCount); + } else { + Cast(tmpBmm2ResCastTensor, bmm2ResUb, AscendC::RoundMode::CAST_ROUND, dealRowCount * columnCount); + } + + SetFlag(SYNC_OUTPUT_BUF1_FLAG); + WaitFlag(SYNC_OUTPUT_BUF1_FLAG); + Bmm2DataCopyOutTrans(info, tmpBmm2ResCastTensor, wsMStart, dealRowCount, columnCount, actualColumnCount); + SetFlag(SYNC_OUTPUT_BUF1_FLAG); +} + +template +__aicore__ inline void +SFAVectorService::Bmm2ResCopyOut(const RunInfo &info, LocalTensor &bmm2ResUb, uint32_t wsMStart, + uint32_t dealRowCount, uint32_t columnCount, + uint32_t actualColumnCount) +{ + if constexpr (FLASH_DECODE) { + if (info.tndIsS2SplitCore) { + Bmm2FDDataCopyOut(info, bmm2ResUb, wsMStart, dealRowCount, columnCount, actualColumnCount); + } else { + Bmm2CastAndCopyOut(info, bmm2ResUb, wsMStart, dealRowCount, columnCount, actualColumnCount); + } + } else { + Bmm2CastAndCopyOut(info, bmm2ResUb, wsMStart, dealRowCount, columnCount, actualColumnCount); + } +} + +template +__aicore__ inline void +SFAVectorService::DealBmm2ResBaseBlock(const RunInfo &info, const MSplitInfo &mSplitInfo, + uint32_t startRow, uint32_t dealRowCount, + uint32_t columnCount, uint32_t actualColumnCount) +{ + uint32_t vec2ComputeSize = dealRowCount * columnCount; + uint32_t mStart = mSplitInfo.nBufferStartM + mSplitInfo.vecStartM + startRow; + uint64_t srcGmOffset = (info.bn2IdxInCurCore % constInfo.preLoadNum) * constInfo.bmm2ResUbSize + + mStart * columnCount; + LocalTensor tmpBmm2ResUb = inputBuff1.Get(); + tmpBmm2ResUb = tmpBmm2ResUb[pingpongFlag * INPUT1_BUFFER_OFFSET / sizeof(MM2_OUT_T)]; + WaitFlag(SYNC_INPUT_BUF1_FLAG + pingpongFlag); + DataCopy(tmpBmm2ResUb, mm2ResGm[srcGmOffset], vec2ComputeSize); + + SetFlag(SYNC_INPUT_BUF1_FLAG); + WaitFlag(SYNC_INPUT_BUF1_FLAG); + + // 将绝对值大于1e10的数置为0 + LocalTensor bmm2ResUb = tmpBuff1.Get(); + bmm2ResUb.SetSize(vec2ComputeSize); + LocalTensor absBmm2ResUb = bmm2ResUb.template ReinterpretCast(); + Abs(absBmm2ResUb, tmpBmm2ResUb, vec2ComputeSize); + pipe_barrier(PIPE_V); + LocalTensor cmpMaskUb = absBmm2ResUb.template ReinterpretCast(); + CompareScalar(cmpMaskUb, absBmm2ResUb, (T)1e10, CMPMODE::LE, vec2ComputeSize); + pipe_barrier(PIPE_V); + Select(tmpBmm2ResUb, cmpMaskUb, tmpBmm2ResUb, ConstInfo::FLOAT_ZERO, + SELMODE::VSEL_TENSOR_SCALAR_MODE, vec2ComputeSize); + pipe_barrier(PIPE_V); + uint32_t baseOffset = mSplitInfo.nBufferStartM / 2 + startRow; + uint32_t idx = info.loop % (constInfo.preLoadNum); + LocalTensor tmpSumUb = v0ValidSizeBuff.Get()[384]; // sumUb用临时内存 16 * 32B = 512B + Brcb(tmpSumUb, aMlaSumUb[idx * SOFTMAX_TMP_BUFFER_OFFSET / sizeof(T) + baseOffset], (dealRowCount + 7) / 8, {1, 8}); + pipe_barrier(PIPE_V); + RowDivs(bmm2ResUb, tmpBmm2ResUb, tmpSumUb, dealRowCount, columnCount, actualColumnCount); + pipe_barrier(PIPE_V); + SetFlag(SYNC_INPUT_BUF1_FLAG + pingpongFlag); + Bmm2ResCopyOut(info, bmm2ResUb, mStart, dealRowCount, columnCount, actualColumnCount); +} + +template +__aicore__ inline void +SFAVectorService::RowDivs(LocalTensor dstUb, LocalTensor src0Ub, LocalTensor src1Ub, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount) +{ + // divs by row, 每行的元素除以相同的元素 + // dstUb[i, (j * 8) : (j * 8 + 7)] = src0Ub[i, (j * 8) : (j * 8 + 7)] / src1Ub[i, 0 : 7] + // src0Ub:[dealRowCount, columnCount], src1Ub:[dealRowCount, FP32_BLOCK_ELEMENT_NUM] dstUb:[dealRowCount, + // columnCount] + uint32_t dtypeMask = FP32_REPEAT_ELEMENT_NUM; + uint32_t dLoop = actualColumnCount / dtypeMask; + uint32_t dRemain = actualColumnCount % dtypeMask; + + BinaryRepeatParams repeatParamsDiv; + repeatParamsDiv.src0BlkStride = 1; + repeatParamsDiv.src1BlkStride = 0; + repeatParamsDiv.dstBlkStride = 1; + repeatParamsDiv.src0RepStride = columnCount / FP32_BLOCK_ELEMENT_NUM; + repeatParamsDiv.src1RepStride = 1; + repeatParamsDiv.dstRepStride = columnCount / FP32_BLOCK_ELEMENT_NUM; + uint32_t columnRepeatCount = dLoop; + if (columnRepeatCount <= dealRowCount) { + uint32_t offset = 0; + for (uint32_t i = 0; i < dLoop; i++) { + Div(dstUb[offset], src0Ub[offset], src1Ub, dtypeMask, dealRowCount, repeatParamsDiv); + offset += dtypeMask; + } + } else { + BinaryRepeatParams columnRepeatParams; + columnRepeatParams.src0BlkStride = 1; + columnRepeatParams.src1BlkStride = 0; + columnRepeatParams.dstBlkStride = 1; + columnRepeatParams.src0RepStride = 8; // 列方向上两次repeat起始地址间隔dtypeMask=64个元素,即8个block + columnRepeatParams.src1RepStride = 0; + columnRepeatParams.dstRepStride = 8; // 列方向上两次repeat起始地址间隔dtypeMask=64个元素,即8个block + uint32_t offset = 0; + for (uint32_t i = 0; i < dealRowCount; i++) { + Div(dstUb[offset], src0Ub[offset], src1Ub[i * FP32_BLOCK_ELEMENT_NUM], dtypeMask, columnRepeatCount, + columnRepeatParams); + offset += columnCount; + } + } + if (dRemain > 0) { + Div(dstUb[dLoop * dtypeMask], src0Ub[dLoop * dtypeMask], src1Ub, dRemain, dealRowCount, repeatParamsDiv); + } +} + +template +__aicore__ inline void +SFAVectorService::RowMuls(LocalTensor dstUb, LocalTensor src0Ub, LocalTensor src1Ub, + uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount) +{ + // muls by row, 每行的元素乘以相同的元素 + // dstUb[i, (j * 8) : (j * 8 + 7)] = src0Ub[i, (j * 8) : (j * 8 + 7)] * src1Ub[i, 0 : 7] + // src0Ub:[dealRowCount, columnCount] src1Ub:[dealRowCount, FP32_BLOCK_ELEMENT_NUM] dstUb:[dealRowCount, + // columnCount] + // dealRowCount is repeat times, must be less 256 + uint32_t repeatElementNum = FP32_REPEAT_ELEMENT_NUM; + uint32_t blockElementNum = FP32_BLOCK_ELEMENT_NUM; + + if constexpr (std::is_same::value) { + // 此限制由于每个repeat至多连续读取256B数据 + repeatElementNum = FP32_REPEAT_ELEMENT_NUM * 2; // 256/4 * 2=128 + blockElementNum = FP32_BLOCK_ELEMENT_NUM * 2; // 32/4 * 2 = 16 + } + + // 每次只能连续读取256B的数据进行计算,故每次只能处理256B/sizeof(dType)= + // 列方向分dLoop次,每次处理8列数据 + uint32_t dLoop = actualColumnCount / repeatElementNum; + uint32_t dRemain = actualColumnCount % repeatElementNum; + // REPEATE_STRIDE_UP_BOUND=256, 此限制由于src0RepStride数据类型为uint8之多256个datablock间距 + if (columnCount < REPEATE_STRIDE_UP_BOUND * blockElementNum) { + BinaryRepeatParams repeatParams; + repeatParams.src0BlkStride = 1; + repeatParams.src1BlkStride = 0; + repeatParams.dstBlkStride = 1; + repeatParams.src0RepStride = columnCount / blockElementNum; + repeatParams.src1RepStride = 1; + repeatParams.dstRepStride = columnCount / blockElementNum; + + // 如果以列为repeat所处理的次数小于行处理次数,则以列方式处理。反之则以行进行repeat处理 + if (dLoop <= dealRowCount) { + uint32_t offset = 0; + for (uint32_t i = 0; i < dLoop; i++) { + Mul(dstUb[offset], src0Ub[offset], src1Ub, repeatElementNum, dealRowCount, repeatParams); + offset += repeatElementNum; + } + } else { + BinaryRepeatParams columnRepeatParams; + columnRepeatParams.src0BlkStride = 1; + columnRepeatParams.src1BlkStride = 0; + columnRepeatParams.dstBlkStride = 1; + columnRepeatParams.src0RepStride = 8; // 列方向上两次repeat起始地址间隔dtypeMask=64个元素,即8个block + columnRepeatParams.src1RepStride = 0; + columnRepeatParams.dstRepStride = 8; // 列方向上两次repeat起始地址间隔dtypeMask=64个元素,即8个block + for (uint32_t i = 0; i < dealRowCount; i++) { + Mul(dstUb[i * columnCount], src0Ub[i * columnCount], src1Ub[i * blockElementNum], repeatElementNum, + dLoop, columnRepeatParams); + } + } + + // 最后一次完成[dealRowCount, dRemain] * [dealRowCount, blockElementNum] 只计算有效部分 + if (dRemain > 0) { + Mul(dstUb[dLoop * repeatElementNum], src0Ub[dLoop * repeatElementNum], src1Ub, dRemain, dealRowCount, + repeatParams); + } + } else { + BinaryRepeatParams repeatParams; + repeatParams.src0RepStride = 8; // 每个repeat为256B数据,正好8个datablock + repeatParams.src0BlkStride = 1; + repeatParams.src1RepStride = 0; + repeatParams.src1BlkStride = 0; + repeatParams.dstRepStride = 8; + repeatParams.dstBlkStride = 1; + // 每次计算一行,共计算dealRowCount行 + for (uint32_t i = 0; i < dealRowCount; i++) { + // 计算一行中的dLoop个repeat, 每个repeat计算256/block_size 个data_block + Mul(dstUb[i * columnCount], src0Ub[i * columnCount], src1Ub[i * blockElementNum], repeatElementNum, dLoop, + repeatParams); + // 计算一行中的尾块 + if (dRemain > 0) { + Mul(dstUb[i * columnCount + dLoop * repeatElementNum], + src0Ub[i * columnCount + dLoop * repeatElementNum], src1Ub[i * blockElementNum], dRemain, 1, + repeatParams); + } + } + } +} + +#endif // SPARSE_FLASH_ATTENTION_SERVICE_VECTOR_MLA_H diff --git a/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_template_tiling_key.h b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_template_tiling_key.h new file mode 100644 index 00000000000..29b542095c8 --- /dev/null +++ b/csrc/sparse_flash_attention/op_kernel/sparse_flash_attention_template_tiling_key.h @@ -0,0 +1,57 @@ +/** + * This program is free software, you can redistribute it and/or modify it. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file sparse_flash_attention_template_tiling_key.h + * \brief + */ + +#ifndef SPARSE_FLASH_ATTENTION_TEMPLATE_TILING_KEY_H +#define SPARSE_FLASH_ATTENTION_TEMPLATE_TILING_KEY_H + +#include "ascendc/host_api/tiling/template_argument.h" + +#define SFA_LAYOUT_BSND 0 +#define SFA_LAYOUT_TND 1 +#define SFA_LAYOUT_PA_BSND 2 + +#define ASCENDC_TPL_4_BW 4 + +#define C_TEMPLATE 0 +#define V_TEMPLATE 1 + +// 模板参数支持的范围定义 +ASCENDC_TPL_ARGS_DECL(SparseFlashAttention, // 算子OpType +ASCENDC_TPL_BOOL_DECL(FLASH_DECODE, 0, 1), +ASCENDC_TPL_UINT_DECL(LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND), +ASCENDC_TPL_UINT_DECL(KV_LAYOUT_T, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND, + SFA_LAYOUT_PA_BSND), +ASCENDC_TPL_UINT_DECL(TEMPLATE_MODE, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, C_TEMPLATE, V_TEMPLATE), +); + +// 支持的模板参数组合 +// 用于调用GET_TPL_TILING_KEY获取TilingKey时,接口内部校验TilingKey是否合法 +ASCENDC_TPL_SEL( + ASCENDC_TPL_ARGS_SEL( + ASCENDC_TPL_BOOL_SEL(FLASH_DECODE, 0), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(KV_LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_PA_BSND, SFA_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(TEMPLATE_MODE, ASCENDC_TPL_UI_LIST, C_TEMPLATE), + ), + + ASCENDC_TPL_ARGS_SEL( + ASCENDC_TPL_BOOL_SEL(FLASH_DECODE, 0), + ASCENDC_TPL_UINT_SEL(LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_BSND, SFA_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(KV_LAYOUT_T, ASCENDC_TPL_UI_LIST, SFA_LAYOUT_PA_BSND, SFA_LAYOUT_BSND, SFA_LAYOUT_TND), + ASCENDC_TPL_UINT_SEL(TEMPLATE_MODE, ASCENDC_TPL_UI_LIST, V_TEMPLATE), // V模板不支持非PA + ), +); + +#endif // TEMPLATE_TILING_KEY \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 06338e4f475..c7aa9d16815 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -27,6 +27,7 @@ #include "ops.h" #include "utils.h" #include "mla_preprocess/op_host/mla_preprocess.h" +#include "batch_matmul_transpose/op_host/batch_matmul_transpose.h" #include "aclnn_torch_adapter/op_api_common.h" #include @@ -587,6 +588,134 @@ std::tuple grouped_matmul_swiglu_quant_weigh return std::tuple(output, output_scale, output_offset); } +at::Tensor npu_lightning_indexer( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_key, + const c10::optional &block_table, c10::string_view layout_query, + c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode) +{ + // npu tensor max size + constexpr int32_t SIZE = 8; + constexpr int32_t DIM_0 = 0; + constexpr int32_t DIM_1 = 1; + constexpr int32_t DIM_2 = 2; + constexpr int32_t DIM_3 = 3; + + TORCH_CHECK(query.numel() > 0, "Query is empty."); + TORCH_CHECK(key.numel() > 0, "Key is empty."); + TORCH_CHECK(weights.numel() > 0, "Weights is empty."); + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count); + + at::SmallVector output_size; + std::string query_layout_str = std::string(layout_query); + std::string key_layout_str = std::string(layout_key); + if (query_layout_str == "BSND") { + output_size = {query.size(DIM_0), query.size(DIM_1), key.size(DIM_2), sparse_count}; + } else { + int n_dim_index = 0; + n_dim_index = (key_layout_str == "TND") ? DIM_1 : DIM_2; + output_size = {query.size(DIM_0), key.size(n_dim_index), sparse_count}; + } + at::Tensor lightning_indexer_output = at::empty(output_size, query.options().dtype(at::kInt)); + // convert str + char *query_layout_ptr = const_cast(query_layout_str.c_str()); + char *key_layout_ptr = const_cast(key_layout_str.c_str()); + EXEC_NPU_CMD( + aclnnLightningIndexer, + query, + key, + weights, + actual_seq_lengths_query, + actual_seq_lengths_key, + block_table, + query_layout_ptr, + key_layout_ptr, + sparse_count, + sparse_mode, + lightning_indexer_output); + return lightning_indexer_output; +} + +at::Tensor npu_sparse_flash_attention( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &value, + const at::Tensor &sparse_indices, double scale_value, int64_t sparse_block_size, + const c10::optional &block_table, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_kv, + const c10::optional &query_rope, + const c10::optional &key_rope, c10::string_view layout_query, + c10::string_view layout_kv, int64_t sparse_mode) +{ + std::string layout_query_str = std::string(layout_query); + std::string layout_kv_str = std::string(layout_kv); + + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + // construct the output tensor + at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype())); + // convert str + char *layout_query_ptr = const_cast(layout_query_str.c_str()); + char *layout_kv_ptr = const_cast(layout_kv_str.c_str()); + + EXEC_NPU_CMD( + aclnnSparseFlashAttention, + query, + key, + value, + sparse_indices, + block_table, + actual_seq_lengths_query, + actual_seq_lengths_kv, + query_rope, + key_rope, + scale_value, + sparse_block_size, + layout_query_ptr, + layout_kv_ptr, + sparse_mode, + output); + return output; +} + +void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, + c10::optional format_mode, + c10::optional quant_mode) +{ + auto [tiling_tensor, block_dim] = bmm_trans::batch_matmul_transpose_tiling( + tensor_a, + tensor_b, + tensor_c, + format_mode, + quant_mode + ); + + void *gm_a = tensor_a.data_ptr(); + void *gm_b = tensor_b.data_ptr(); + void *gm_c = tensor_c.data_ptr(); + void *gm_tiling_data = tiling_tensor.data_ptr(); + + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + at_npu::native::OpCommand cmd; + cmd.Name("batch_matmul_transpose"); + + cmd.SetCustomHandler([stream, gm_a, gm_b, gm_c, gm_tiling_data, + block_dim]() -> int { + batch_matmul_transpose_impl(stream, gm_a, gm_b, gm_c, gm_tiling_data, + block_dim); + return 0; + }); + cmd.Run(); + return; + +} + } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -641,6 +770,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ); ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess); + //batch_matmul ops refer to sgl-kernel-npu + ops.def( + "batch_matmul_transpose(Tensor tensor_a, Tensor tensor_b, Tensor tensor_c, str? format_mode=None, str? quant_mode=None) -> ()"); + ops.impl("batch_matmul_transpose", torch::kPrivateUse1, &vllm_ascend::batch_matmul_transpose); + ops.def("swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()"); ops.impl("swap_blocks", torch::kPrivateUse1, &vllm_ascend::swap_blocks); @@ -657,4 +791,22 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " (Tensor output, Tensor output_scale, Tensor output_offset)" ); ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant_weight_nz_tensor_list); + + ops.def( + "npu_lightning_indexer(Tensor query, Tensor key, Tensor weights, *," + " Tensor? actual_seq_lengths_query=None, Tensor? actual_seq_lengths_key=None," + " Tensor? block_table=None, str layout_query='BSND', str layout_key='BSND'," + " int sparse_count=2048, int sparse_mode=3) -> Tensor" + ); + ops.impl("npu_lightning_indexer", torch::kPrivateUse1, &vllm_ascend::npu_lightning_indexer); + + ops.def( + "npu_sparse_flash_attention(Tensor query, Tensor key, Tensor value," + " Tensor sparse_indices, float scale_value, int sparse_block_size, *," + " Tensor? block_table=None, Tensor? actual_seq_lengths_query=None," + " Tensor? actual_seq_lengths_kv=None, Tensor? query_rope=None," + " Tensor? key_rope=None, str layout_query='BSND', str layout_kv='BSND'," + " int sparse_mode=3) -> Tensor" + ); + ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 26b3d66de03..149db15862d 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -151,6 +151,71 @@ std::tuple grouped_matmul_swiglu_quant_weigh return std::tuple(output, output_scale, output_offset); } +at::Tensor npu_lightning_indexer_meta( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_key, + const c10::optional &block_table, c10::string_view layout_query, + c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode) +{ + // npu tensor max size + constexpr int32_t SIZE = 8; + constexpr int32_t DIM_0 = 0; + constexpr int32_t DIM_1 = 1; + constexpr int32_t DIM_2 = 2; + constexpr int32_t DIM_3 = 3; + + TORCH_CHECK(query.numel() > 0, "Query is empty."); + TORCH_CHECK(key.numel() > 0, "Key is empty."); + TORCH_CHECK(weights.numel() > 0, "Weights is empty."); + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count); + + std::string query_layout_str = std::string(layout_query); + std::string key_layout_str = std::string(layout_key); + at::SmallVector output_size; + if (query_layout_str == "BSND") { + output_size = {query.size(DIM_0), query.size(DIM_1), key.size(DIM_2), sparse_count}; + } else { + int n_dim_index = 0; + n_dim_index = (key_layout_str == "TND") ? DIM_1 : DIM_2; + output_size = {query.size(DIM_0), key.size(n_dim_index), sparse_count}; + } + // construct the output tensor + at::Tensor lightning_indexer_output = at::empty(output_size, query.options().dtype(at::kInt)); + return lightning_indexer_output; +} + +at::Tensor npu_sparse_flash_attention_meta( + const at::Tensor &query, const at::Tensor &key, const at::Tensor &value, + const at::Tensor &sparse_indices, double scale_value, int64_t sparse_block_size, + const c10::optional &block_table, + const c10::optional &actual_seq_lengths_query, + const c10::optional &actual_seq_lengths_kv, + const c10::optional &query_rope, + const c10::optional &key_rope, c10::string_view layout_query, + c10::string_view layout_kv, int64_t sparse_mode) +{ + std::string layout_query_str = std::string(layout_query); + for (size_t i = 0; i < query.sizes().size(); i++) { + TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater " + "than 0, but shape[", i, "] is ", query.size(i)); + } + at::Tensor output = at::empty(query.sizes(), query.options().dtype(query.dtype())); + return output; +} + +void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c, + c10::optional format_mode, + c10::optional quant_mode) +{ + return; + +} + } // namespace meta } // namespace vllm_ascend @@ -172,5 +237,11 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant); // Grouped matmul swiglu quant weight nz tensor list ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta); + // Lightning indexer + ops.impl("npu_lightning_indexer", &vllm_ascend::meta::npu_lightning_indexer_meta); + // Sparse flash attention + ops.impl("npu_sparse_flash_attention", &vllm_ascend::meta::npu_sparse_flash_attention_meta); + // batch_matmul_transpose + ops.impl("batch_matmul_transpose", &vllm_ascend::meta::batch_matmul_transpose); } } diff --git a/docs/source/tutorials/DeepSeek-V3.2-Exp.md b/docs/source/tutorials/DeepSeek-V3.2-Exp.md index f00f8b40a65..97e6d1fddde 100644 --- a/docs/source/tutorials/DeepSeek-V3.2-Exp.md +++ b/docs/source/tutorials/DeepSeek-V3.2-Exp.md @@ -39,7 +39,44 @@ Only AArch64 architecture are supported currently due to extra operator's instal ::::{tab-item} A3 series :sync: A3 -1. Start the docker image on your node, refer to [using docker](../installation.md#set-up-using-docker). +1. Start the docker image on your each node. + +```{code-block} bash + :substitutions: + +export IMAGE=quay.io/ascend/vllm-ascend:|vllm_ascend_version|-a3 +docker run --rm \ + --name vllm-ascend \ + --shm-size=1g \ + --net=host \ + --device /dev/davinci0 \ + --device /dev/davinci1 \ + --device /dev/davinci2 \ + --device /dev/davinci3 \ + --device /dev/davinci4 \ + --device /dev/davinci5 \ + --device /dev/davinci6 \ + --device /dev/davinci7 \ + --device /dev/davinci8 \ + --device /dev/davinci9 \ + --device /dev/davinci10 \ + --device /dev/davinci11 \ + --device /dev/davinci12 \ + --device /dev/davinci13 \ + --device /dev/davinci14 \ + --device /dev/davinci15 \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ + -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -v /root/.cache:/root/.cache \ + -it $IMAGE bash +``` 2. Install the package `custom-ops` to make the kernels available. @@ -57,7 +94,36 @@ pip install custom_ops-1.0-cp311-cp311-linux_aarch64.whl ::::{tab-item} A2 series :sync: A2 -1. Start the docker image on your node, refer to [using docker](../installation.md#set-up-using-docker). +1. Start the docker image on your each node. + +```{code-block} bash + :substitutions: + +export IMAGE=quay.io/ascend/vllm-ascend:|vllm_ascend_version| +docker run --rm \ + --name vllm-ascend \ + --shm-size=1g \ + --net=host \ + --device /dev/davinci0 \ + --device /dev/davinci1 \ + --device /dev/davinci2 \ + --device /dev/davinci3 \ + --device /dev/davinci4 \ + --device /dev/davinci5 \ + --device /dev/davinci6 \ + --device /dev/davinci7 \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ + -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -v /root/.cache:/root/.cache \ + -it $IMAGE bash +``` 2. Install the package `custom-ops` to make the kernels available. diff --git a/docs/source/tutorials/Qwen2.5-Omni.md b/docs/source/tutorials/Qwen2.5-Omni.md new file mode 100644 index 00000000000..5ea5481dce3 --- /dev/null +++ b/docs/source/tutorials/Qwen2.5-Omni.md @@ -0,0 +1,206 @@ +# Qwen2.5-Omni-7B + +## Introduction + +Qwen2.5-Omni is an end-to-end multimodal model designed to perceive diverse modalities, including text, images, audio, and video, while simultaneously generating text and natural speech responses in a streaming manner. + +The `Qwen2.5-Omni` model was supported since `vllm-ascend:v0.11.0rc0`. This document will show the main verification steps of the model, including supported features, feature configuration, environment preparation, single-NPU and multi-NPU deployment, accuracy and performance evaluation. + +## Supported Features + +Refer to [supported features](../user_guide/support_matrix/supported_models.md) to get the model's supported feature matrix. + +Refer to [feature guide](../user_guide/feature_guide/index.md) to get the feature's configuration. + +## Environment Preparation + +### Model Weight + +- `Qwen2.5-Omni-3B`(BF16): [Download model weight](https://huggingface.co/Qwen/Qwen2.5-Omni-3B) +- `Qwen2.5-Omni-7B`(BF16): [Download model weight](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) + +Following examples use the 7B version deafultly. + +### Installation + +You can using our official docker image to run `Qwen2.5-Omni` directly. + +Select an image based on your machine type and start the docker image on your node, refer to [using docker](../installation.md#set-up-using-docker). + +```{code-block} bash + :substitutions: +# Update --device according to your device (Atlas A2: /dev/davinci[0-7] Atlas A3:/dev/davinci[0-15]). +# Update the vllm-ascend image according to your environment. +# Note you should download the weight to /root/.cache in advance. +# Update the vllm-ascend image +export IMAGE=m.daocloud.io/quay.io/ascend/vllm-ascend:|vllm_ascend_version| +export NAME=vllm-ascend +# Run the container using the defined variables +# Note: If you are running bridge network with docker, please expose available ports for multiple nodes communication in advance +docker run --rm \ +--name $NAME \ +--net=host \ +--shm-size=1g \ +--device /dev/davinci0 \ +--device /dev/davinci1 \ +--device /dev/davinci2 \ +--device /dev/davinci3 \ +--device /dev/davinci4 \ +--device /dev/davinci5 \ +--device /dev/davinci6 \ +--device /dev/davinci7 \ +--device /dev/davinci_manager \ +--device /dev/devmm_svm \ +--device /dev/hisi_hdc \ +-v /usr/local/dcmi:/usr/local/dcmi \ +-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \ +-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ +-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ +-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ +-v /etc/ascend_install.info:/etc/ascend_install.info \ +-v /mnt/sfs_turbo/.cache:/root/.cache \ +-it $IMAGE bash +``` + +## Deployment + +### Single-node Deployment + +#### Single NPU (Qwen2.5-Omni-7B) + +```bash +export VLLM_USE_MODELSCOPE=true +export MODEL_PATH=vllm-ascend/Qwen2.5-Omni-7B +export LOCAL_MEDIA_PATH=/local_path/to_media/ + +vllm serve ${MODEL_PATH}\ +--host 0.0.0.0 \ +--port 8000 \ +--served-model-name Qwen-Omni \ +--allowed-local-media-path ${LOCAL_MEDIA_PATH} \ +--trust-remote-code \ +--compilation-config {"full_cuda_graph": 1} \ +--no-enable-prefix-caching +``` + +:::{note} +Now vllm-ascend docker image should contain vllm[audio] build part, if you encounter *audio not supported issue* by any chance, please re-build vllm with [audio] flag. + +```bash +VLLM_TARGET_DEVICE=empty pip install -v ".[audio]" +``` + +::: + +`--allowed-local-media-path` is optional, only set it if you need infer model with local media file + +`--gpu-memory-utilization` should not be set manually only if yous know what this parameter aims to. + +#### Multiple NPU (Qwen2.5-Omni-7B) + +```bash +export VLLM_USE_MODELSCOPE=true +export MODEL_PATH=vllm-ascend/Qwen2.5-Omni-7B +export LOCAL_MEDIA_PATH=/local_path/to_media/ +export DP_SIZE=8 + +vllm serve ${MODEL_PATH}\ +--host 0.0.0.0 \ +--port 8000 \ +--served-model-name Qwen-Omni \ +--allowed-local-media-path ${LOCAL_MEDIA_PATH} \ +--trust-remote-code \ +--compilation-config {"full_cuda_graph": 1} \ +--data-parallel-size ${DP_SIZE} \ +--no-enable-prefix-caching +``` + +`--tensor_parallel_size` no need to set for this 7B model, but if you really need tensor parallel, tp size can be one of `1\2\4` + +### Prefill-Decode Disaggregation + +Not supported yet + +## Functional Verification + +If your service start successfully, you can see the info shown below: + +```bash +INFO: Started server process [2736] +INFO: Waiting for application startup. +INFO: Application startup complete. +``` + +Once your server is started, you can query the model with input prompts: + +```bash +curl http://127.0.0.1:8000/v1/chat/completions -H "Content-Type: application/json" -H "Authorization: Bearer EMPTY" -d '{ + "model": "Qwen-Omni", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is the text in the illustrate?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" + } + } + ] + } + ], + "max_tokens": 100, + "temperature": 0.7 + }' + +``` + +If you query the server successfully, you can see the info shown below (client): + +```bash +{"id":"chatcmpl-a70a719c12f7445c8204390a8d0d8c97","object":"chat.completion","created":1764056861,"model":"Qwen-Omni","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is \"TONGYI Qwen\".","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":73,"total_tokens":88,"completion_tokens":15,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} +``` + +## Accuracy Evaluation + +Qwen2.5-Omni on vllm-ascend has been test on AISBench. + +### Using AISBench + +1. Refer to [Using AISBench](../developer_guide/evaluation/using_ais_bench.md) for details. + +2. After execution, you can get the result, here is the result of `Qwen2.5-Omni-7B` with `vllm-ascend:0.11.0rc0` for reference only. + +| dataset | platform | metric | mode | vllm-api-stream-chat | +|----- | ----- | ----- | ----- | -----| +| textVQA | A2 | accuracy | gen_base64 | 83.47 | +| textVQA | A3 | accuracy | gen_base64 | 84.04 | + +## Performance Evaluation + +### Using AISBench + +Refer to [Using AISBench for performance evaluation](../developer_guide/evaluation/using_ais_bench.md#execute-performance-evaluation) for details. + +### Using vLLM Benchmark + +Run performance evaluation of `Qwen2.5-Omni-7B` as an example. + +Refer to [vllm benchmark](https://docs.vllm.ai/en/latest/contributing/benchmarks.html) for more details. + +There are three `vllm bench` subcommand: +- `latency`: Benchmark the latency of a single batch of requests. +- `serve`: Benchmark the online serving throughput. +- `throughput`: Benchmark offline inference throughput. + +Take the `serve` as an example. Run the code as follows. + +```shell +vllm bench serve --model vllm-ascend/Qwen2.5-Omni-7B --dataset-name random --random-input 1024 --num-prompt 200 --request-rate 1 --save-result --result-dir ./ +``` + +After about several minutes, you can get the performance evaluation result. diff --git a/docs/source/tutorials/index.md b/docs/source/tutorials/index.md index 334417d525d..71fa2815ddc 100644 --- a/docs/source/tutorials/index.md +++ b/docs/source/tutorials/index.md @@ -22,4 +22,5 @@ multi_node_kimi multi_node_qwen3vl multi_node_pd_disaggregation_mooncake multi_node_ray +Qwen2.5-Omni.md ::: diff --git a/docs/source/tutorials/multi_npu_qwen3_next.md b/docs/source/tutorials/multi_npu_qwen3_next.md index 325745ac3d4..eeb57f5a6fa 100644 --- a/docs/source/tutorials/multi_npu_qwen3_next.md +++ b/docs/source/tutorials/multi_npu_qwen3_next.md @@ -55,8 +55,8 @@ source /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh Install Triton Ascend: ```bash -wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27.whl -pip install triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27.whl +wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl +pip install triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl ``` :::: diff --git a/examples/offline_data_parallel.py b/examples/offline_data_parallel.py index b16d50ffd1d..b7193f583f2 100644 --- a/examples/offline_data_parallel.py +++ b/examples/offline_data_parallel.py @@ -111,6 +111,10 @@ def parse_args(): parser.add_argument("--enable-expert-parallel", action="store_true", help="Enable expert parallel, used in MOE models.") + parser.add_argument("--quantization", + type=str, + default="", + help="Use quantization models") return parser.parse_args() @@ -134,6 +138,7 @@ def main( enable_expert_parallel, enforce_eager, trust_remote_code, + quantization, ): # DP only support on V1 engine os.environ["VLLM_DP_RANK"] = str(global_dp_rank) @@ -185,6 +190,7 @@ def start(rank): enforce_eager=enforce_eager, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, + quantization=quantization, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -220,6 +226,8 @@ def start(rank): assert dp_size % node_size == 0, "dp_size should be divisible by node_size" dp_per_node = dp_size // node_size + quantization = args.quantization if args.quantization else None + from multiprocessing import Process procs = [] @@ -238,6 +246,7 @@ def start(rank): args.enable_expert_parallel, args.enforce_eager, args.trust_remote_code, + quantization, ), ) proc.start() diff --git a/pyproject.toml b/pyproject.toml index a10ff9a834d..66a5dc24578 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,8 @@ requires = [ "setuptools>=64", "setuptools-scm>=8", "transformers<=4.57.1", - "torch-npu==2.7.1", - "torch==2.7.1", + "torch-npu==2.8.0", + "torch==2.8.0", "torchvision", "wheel", "msgpack", diff --git a/requirements-dev.txt b/requirements-dev.txt index d3db952d14f..44bfc3c5aa2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -19,4 +19,5 @@ librosa soundfile pytest_mock msserviceprofiler>=1.2.2 -mindstudio-probe>=8.3.0 \ No newline at end of file +mindstudio-probe>=8.3.0 +arctic-inference==0.1.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2a176f84727..7dcd69d5358 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ scipy pandas setuptools>=64 setuptools-scm>=8 -torch==2.7.1 +torch==2.8.0 torchvision wheel pandas-stubs @@ -28,6 +28,6 @@ numba # Install torch_npu #--pre #--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi -torch-npu==2.7.1 +torch-npu==2.8.0 transformers<=4.57.1 diff --git a/setup.py b/setup.py index 1bf800813e4..3e88affaf79 100644 --- a/setup.py +++ b/setup.py @@ -137,6 +137,9 @@ def gen_build_info(): # TODO(zzzzwwjj): Add A5 case soc_to_device = { + "910b": "_910B", + "910c": "_910_93", + "310p": "_310P", "ascend910b1": "_910B", "ascend910b2": "_910B", "ascend910b2c": "_910B", @@ -307,7 +310,14 @@ def configure(self, ext: CMakeExtension) -> None: cmake_args += [f"-DCMAKE_PREFIX_PATH={pybind11_cmake_path}"] - cmake_args += [f"-DSOC_VERSION={envs.SOC_VERSION}"] + soc_version_map = { + "910b": "ascend910b1", + "910c": "ascend910_9392", + "310p": "ascend310p1", + } + CANN_SOC_VERSION = soc_version_map.get(envs.SOC_VERSION, + envs.SOC_VERSION) + cmake_args += [f"-DSOC_VERSION={CANN_SOC_VERSION}"] # Override the base directory for FetchContent downloads to $ROOT/.deps # This allows sharing dependencies between profiles, diff --git a/tests/e2e/310p/test_offline_inference_parallel_310p.py b/tests/e2e/310p/test_offline_inference_parallel_310p.py index 6bf335686d1..7ba7ef73763 100644 --- a/tests/e2e/310p/test_offline_inference_parallel_310p.py +++ b/tests/e2e/310p/test_offline_inference_parallel_310p.py @@ -29,9 +29,6 @@ "additional_config": { "torchair_graph_config": { "enabled": True - }, - "ascend_scheduler_config": { - "enabled": True, } } }] diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 4d2c8c5f8f0..5292673d46b 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -40,7 +40,7 @@ BatchEncoding, BatchFeature) from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm import LLM, SamplingParams -from vllm.config.model import TaskOption, _get_and_verify_dtype +from vllm.config.model import _get_and_verify_dtype from vllm.inputs import TextPrompt from vllm.outputs import RequestOutput from vllm.platforms import current_platform @@ -270,7 +270,7 @@ class VllmRunner: def __init__( self, model_name: str, - task: TaskOption = "auto", + runner: str = "auto", tokenizer_name: Optional[str] = None, tokenizer_mode: str = "auto", # Use smaller max model length, otherwise bigger model cannot run due @@ -280,7 +280,7 @@ def __init__( disable_log_stats: bool = True, tensor_parallel_size: int = 1, block_size: int = 16, - enable_chunked_prefill: bool = False, + enable_chunked_prefill: bool = True, swap_space: int = 4, enforce_eager: Optional[bool] = False, quantization: Optional[str] = None, @@ -288,7 +288,7 @@ def __init__( ) -> None: self.model = LLM( model=model_name, - task=task, + runner=runner, tokenizer=tokenizer_name, tokenizer_mode=tokenizer_mode, trust_remote_code=True, diff --git a/tests/e2e/models/configs/ERNIE-4.5-21B-A3B-PT.yaml b/tests/e2e/models/configs/ERNIE-4.5-21B-A3B-PT.yaml new file mode 100644 index 00000000000..ae39aab9988 --- /dev/null +++ b/tests/e2e/models/configs/ERNIE-4.5-21B-A3B-PT.yaml @@ -0,0 +1,9 @@ +model_name: "PaddlePaddle/ERNIE-4.5-21B-A3B-PT" +hardware: "Atlas A2 Series" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,flexible-extract" + value: 0.71 +num_fewshot: 5 +trust_remote_code: True diff --git a/tests/e2e/models/configs/InternVL3_5-8B.yaml b/tests/e2e/models/configs/InternVL3_5-8B-hf.yaml similarity index 100% rename from tests/e2e/models/configs/InternVL3_5-8B.yaml rename to tests/e2e/models/configs/InternVL3_5-8B-hf.yaml diff --git a/tests/e2e/models/configs/Molmo-7B-D-0924.yaml b/tests/e2e/models/configs/Molmo-7B-D-0924.yaml new file mode 100644 index 00000000000..68951a40a23 --- /dev/null +++ b/tests/e2e/models/configs/Molmo-7B-D-0924.yaml @@ -0,0 +1,13 @@ +model_name: "LLM-Research/Molmo-7B-D-0924" +hardware: "Atlas A2 Series" +model: "vllm-vlm" +tasks: +- name: "ceval-valid" + metrics: + - name: "acc,none" + value: 0.71 +max_model_len: 4096 +trust_remote_code: True +apply_chat_template: False +fewshot_as_multiturn: False +gpu_memory_utilization: 0.8 diff --git a/tests/e2e/models/configs/accuracy.txt b/tests/e2e/models/configs/accuracy.txt index daa23e97639..c15d7986476 100644 --- a/tests/e2e/models/configs/accuracy.txt +++ b/tests/e2e/models/configs/accuracy.txt @@ -9,4 +9,10 @@ Qwen3-VL-30B-A3B-Instruct.yaml Qwen3-VL-8B-Instruct.yaml Qwen2.5-Omni-7B.yaml Meta-Llama-3.1-8B-Instruct.yaml -InternVL3_5-8B.yaml \ No newline at end of file +InternVL3_5-8B.yaml +ERNIE-4.5-21B-A3B-PT.yaml +gemma-2-9b-it.yaml +gemma-3-4b-it.yaml +internlm-7b.yaml +Molmo-7B-D-0924.yaml +llava-1.5-7b-hf.yaml diff --git a/tests/e2e/models/configs/gemma-2-9b-it.yaml b/tests/e2e/models/configs/gemma-2-9b-it.yaml new file mode 100644 index 00000000000..050e2f03279 --- /dev/null +++ b/tests/e2e/models/configs/gemma-2-9b-it.yaml @@ -0,0 +1,11 @@ +model_name: "LLM-Research/gemma-2-9b-it" +hardware: "Atlas A2 Series" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.46 + - name: "exact_match,flexible-extract" + value: 0.79 +num_fewshot: 5 +gpu_memory_utilization: 0.8 diff --git a/tests/e2e/models/configs/gemma-3-4b-it.yaml b/tests/e2e/models/configs/gemma-3-4b-it.yaml new file mode 100644 index 00000000000..42366800db0 --- /dev/null +++ b/tests/e2e/models/configs/gemma-3-4b-it.yaml @@ -0,0 +1,13 @@ +model_name: "LLM-Research/gemma-3-4b-it" +hardware: "Atlas A2 Series" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.59 + - name: "exact_match,flexible-extract" + value: 0.59 +num_fewshot: 5 +apply_chat_template: False +fewshot_as_multiturn: False +gpu_memory_utilization: 0.7 diff --git a/tests/e2e/models/configs/internlm-7b.yaml b/tests/e2e/models/configs/internlm-7b.yaml new file mode 100644 index 00000000000..ceccc53d1d4 --- /dev/null +++ b/tests/e2e/models/configs/internlm-7b.yaml @@ -0,0 +1,13 @@ +model_name: "Shanghai_AI_Laboratory/internlm-7b" +hardware: "Atlas A2 Series" +tasks: +- name: "ceval-valid" + metrics: + - name: "acc,none" + value: 0.42 +num_fewshot: 5 +max_model_len: 2048 +trust_remote_code: True +dtype: "bfloat16" +apply_chat_template: False +fewshot_as_multiturn: False diff --git a/tests/e2e/models/configs/llava-1.5-7b-hf.yaml b/tests/e2e/models/configs/llava-1.5-7b-hf.yaml new file mode 100644 index 00000000000..7bd69de99f7 --- /dev/null +++ b/tests/e2e/models/configs/llava-1.5-7b-hf.yaml @@ -0,0 +1,11 @@ +model_name: "llava-hf/llava-1.5-7b-hf" +hardware: "Atlas A2 Series" +model: "vllm-vlm" +tasks: +- name: "ceval-valid" + metrics: + - name: "acc,none" + value: 0.30 +trust_remote_code: True +gpu_memory_utilization: 0.8 +dtype: "bfloat16" diff --git a/tests/e2e/models/test_lm_eval_correctness.py b/tests/e2e/models/test_lm_eval_correctness.py index a0862b8025d..3d0ce6be5e7 100644 --- a/tests/e2e/models/test_lm_eval_correctness.py +++ b/tests/e2e/models/test_lm_eval_correctness.py @@ -39,10 +39,11 @@ def env_config() -> EnvConfig: def build_model_args(eval_config, tp_size): trust_remote_code = eval_config.get("trust_remote_code", False) max_model_len = eval_config.get("max_model_len", 4096) + dtype = eval_config.get("dtype", "auto") model_args = { "pretrained": eval_config["model_name"], "tensor_parallel_size": tp_size, - "dtype": "auto", + "dtype": dtype, "trust_remote_code": trust_remote_code, "max_model_len": max_model_len, } diff --git a/tests/e2e/multicard/test_data_parallel.py b/tests/e2e/multicard/test_data_parallel.py index 2e8ba386fce..94c95887149 100644 --- a/tests/e2e/multicard/test_data_parallel.py +++ b/tests/e2e/multicard/test_data_parallel.py @@ -27,13 +27,17 @@ import pytest -MODELS = ["Qwen/Qwen3-0.6B", "Qwen/Qwen3-30B-A3B"] +MODELS = [ + "Qwen/Qwen3-0.6B", "Qwen/Qwen3-30B-A3B", "vllm-ascend/Qwen3-30B-A3B-W8A8" +] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [32]) @patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) def test_data_parallel_inference(model, max_tokens): + moe_models = ["Qwen/Qwen3-30B-A3B", "vllm-ascend/Qwen3-30B-A3B-W8A8"] + quantization_models = ["vllm-ascend/Qwen3-30B-A3B-W8A8"] script = "examples/offline_data_parallel.py" env = os.environ.copy() @@ -54,8 +58,11 @@ def test_data_parallel_inference(model, max_tokens): "--trust-remote-code", ] - if model == "Qwen/Qwen3-30B-A3B": + if model in moe_models: cmd.append("--enable-expert-parallel") + if model in quantization_models: + cmd.append("--quantization") + cmd.append("ascend") print(f"Running subprocess: {' '.join(cmd)}") proc = subprocess.run(cmd, @@ -63,7 +70,7 @@ def test_data_parallel_inference(model, max_tokens): stdout=subprocess.PIPE, stderr=subprocess.STDOUT, timeout=600) - output = proc.stdout.decode() + output = proc.stdout.decode(errors='ignore') print(output) diff --git a/tests/e2e/multicard/test_data_parallel_tp2.py b/tests/e2e/multicard/test_data_parallel_tp2.py new file mode 100644 index 00000000000..6b0bdabe8dd --- /dev/null +++ b/tests/e2e/multicard/test_data_parallel_tp2.py @@ -0,0 +1,52 @@ +""" +Run `pytest tests/e2e/multicard/test_data_parallel_tp2.py`. +""" + +import os +import subprocess +import sys +from unittest.mock import patch + +import pytest + +MODELS = ["Qwen/Qwen3-0.6B"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [32]) +@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3"}) +def test_data_parallel_inference(model, max_tokens): + script = "examples/offline_data_parallel.py" + + env = os.environ.copy() + + cmd = [ + sys.executable, + script, + "--model", + model, + "--dp-size", + "2", + "--tp-size", + "2", + "--node-size", + "1", + "--node-rank", + "0", + "--trust-remote-code", + ] + + print(f"Running subprocess: {' '.join(cmd)}") + proc = subprocess.run(cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=600) + output = proc.stdout.decode(errors='ignore') + + print(output) + + assert "DP rank 0 needs to process" in output + assert "DP rank 1 needs to process" in output + assert "Generated text:" in output + assert proc.returncode == 0 diff --git a/tests/e2e/multicard/test_expert_parallel.py b/tests/e2e/multicard/test_expert_parallel.py index f1076013967..b8f03d5f905 100644 --- a/tests/e2e/multicard/test_expert_parallel.py +++ b/tests/e2e/multicard/test_expert_parallel.py @@ -15,23 +15,14 @@ def test_e2e_ep_correctness(model_name): max_tokens = 5 # FIXME: Really strange that chunked prefill might lead to different results, investigate further - with VllmRunner( - model_name, - tensor_parallel_size=2, - additional_config={"ascend_scheduler_config": { - "enabled": True - }}, - enforce_eager=False) as vllm_model: + with VllmRunner(model_name, tensor_parallel_size=2, + enforce_eager=False) as vllm_model: tp_output = vllm_model.generate_greedy(example_prompts, max_tokens) - with VllmRunner( - model_name, - tensor_parallel_size=2, - enable_expert_parallel=True, - additional_config={"ascend_scheduler_config": { - "enabled": True - }}, - enforce_eager=False) as vllm_model: + with VllmRunner(model_name, + tensor_parallel_size=2, + enable_expert_parallel=True, + enforce_eager=False) as vllm_model: ep_output = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( diff --git a/tests/e2e/multicard/test_external_launcher.py b/tests/e2e/multicard/test_external_launcher.py index 05851db1d69..ece35def697 100644 --- a/tests/e2e/multicard/test_external_launcher.py +++ b/tests/e2e/multicard/test_external_launcher.py @@ -67,7 +67,7 @@ def test_external_launcher(model): stderr=subprocess.STDOUT, timeout=600, ) - output = proc.stdout.decode() + output = proc.stdout.decode(errors='ignore') print(output) @@ -99,7 +99,7 @@ def test_moe_external_launcher(model): stderr=subprocess.STDOUT, timeout=600, ) - output = proc.stdout.decode() + output = proc.stdout.decode(errors='ignore') print(output) @@ -144,7 +144,7 @@ def test_external_launcher_and_sleepmode(): stderr=subprocess.STDOUT, timeout=300, ) - output = proc.stdout.decode() + output = proc.stdout.decode(errors='ignore') print(output) @@ -192,7 +192,7 @@ def test_external_launcher_and_sleepmode_level2(): stderr=subprocess.STDOUT, timeout=300, ) - output = proc.stdout.decode() + output = proc.stdout.decode(errors='ignore') print(output) @@ -232,7 +232,7 @@ def test_mm_allreduce(model): timeout=600, ) - output = proc.stdout.decode() + output = proc.stdout.decode(errors='ignore') print(output) assert "Generated text:" in output diff --git a/tests/e2e/multicard/test_fused_moe_allgather_ep.py b/tests/e2e/multicard/test_fused_moe_allgather_ep.py index 9335e19af69..85d246e56ba 100644 --- a/tests/e2e/multicard/test_fused_moe_allgather_ep.py +++ b/tests/e2e/multicard/test_fused_moe_allgather_ep.py @@ -49,13 +49,7 @@ def test_generate_with_allgather(): tensor_parallel_size=2, max_model_len=1024, dtype="auto", - enable_expert_parallel=True, - additional_config={ - "ascend_scheduler_config": { - "enabled": True, - "chunked_prefill_enabled": False, - }, - }) as vllm_model: + enable_expert_parallel=True) as vllm_model: vllm_model.generate(example_prompts, sampling_params) @@ -76,11 +70,5 @@ def test_generate_with_alltoall(): tensor_parallel_size=2, max_model_len=1024, dtype="auto", - enable_expert_parallel=True, - additional_config={ - "ascend_scheduler_config": { - "enabled": True, - "chunked_prefill_enabled": False, - }, - }) as vllm_model: + enable_expert_parallel=True) as vllm_model: vllm_model.generate(example_prompts, sampling_params) diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 320c3bdf0b9..1380c49e3d2 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -82,9 +82,6 @@ def test_models_distributed_DeepSeek_multistream_moe(): "enabled": True, }, "enable_multistream_moe": True, - "ascend_scheduler_config": { - "enabled": True, - }, "refresh": True, }, ) as vllm_model: @@ -154,14 +151,9 @@ def test_models_distributed_DeepSeek_W4A8DYNAMIC(model): quantization="ascend", enforce_eager=True, enable_expert_parallel=True, - additional_config={ - "torchair_graph_config": { - "enabled": False, - }, - "ascend_scheduler_config": { - "enabled": True, - } - }, + additional_config={"torchair_graph_config": { + "enabled": False, + }}, ) as vllm_model: vllm_model.generate_greedy(prompts, max_tokens) diff --git a/tests/e2e/multicard/test_prefix_caching.py b/tests/e2e/multicard/test_prefix_caching.py index e29916623ba..114d5d72195 100644 --- a/tests/e2e/multicard/test_prefix_caching.py +++ b/tests/e2e/multicard/test_prefix_caching.py @@ -58,7 +58,6 @@ ] -@pytest.mark.skip(reason="Fix me, the accuracy is not correct") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [50]) def test_prefix_cache_with_v1_scheduler(model: str, max_tokens: int) -> None: @@ -84,67 +83,3 @@ def test_prefix_cache_with_v1_scheduler(model: str, max_tokens: int) -> None: name_0="vllm_output", name_1="prefix_cache_output", ) - - -@pytest.mark.skip(reason="Fix me, the accuracy is not correct") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("max_tokens", [50]) -def test_prefix_cache_with_ascend_scheduler(model: str, - max_tokens: int) -> None: - - with VllmRunner(model, - additional_config={ - 'ascend_scheduler_config': { - 'enabled': True, - }, - }, - enforce_eager=False, - max_model_len=2048, - tensor_parallel_size=2, - gpu_memory_utilization=0.7) as vllm_model: - vllm_output = vllm_model.generate_greedy(INPUT_PROMPTS, max_tokens) - - with VllmRunner(model, - additional_config={ - 'ascend_scheduler_config': { - 'enabled': True, - 'enable_prefix_caching': True, - }, - }, - enforce_eager=False, - max_model_len=2048, - tensor_parallel_size=2, - gpu_memory_utilization=0.7) as vllm_model: - prefix_cache_output = vllm_model.generate_greedy( - INPUT_PROMPTS, max_tokens) - - # TODO: enable apc and chunked prefill with ascend scheduler will lead accuracy problem. - # Disable it now. Fix it or drop the ascend scheduler in the future. - # with VllmRunner(model, - # additional_config={ - # 'ascend_scheduler_config': { - # 'enabled': True, - # 'enable_prefix_caching': True, - # "enable_chunked_prefill": True, - # }, - # }, - # enforce_eager=True, - # max_model_len=2048, - # tensor_parallel_size=2, - # gpu_memory_utilization=0.7) as vllm_model: - # chunk_prefill_prefix_cache_output = vllm_model.generate_greedy( - # INPUT_PROMPTS, max_tokens) - - check_outputs_equal( - outputs_0_lst=vllm_output, - outputs_1_lst=prefix_cache_output, - name_0="vllm_output", - name_1="prefix_cache_output", - ) - - # check_outputs_equal( - # outputs_0_lst=chunk_prefill_prefix_cache_output, - # outputs_1_lst=prefix_cache_output, - # name_0="chunk_prefill_prefix_cache_output", - # name_1="prefix_cache_output", - # ) diff --git a/tests/e2e/multicard/test_qwen3_next.py b/tests/e2e/multicard/test_qwen3_next.py index e51748ea1e2..eaacd838ccd 100644 --- a/tests/e2e/multicard/test_qwen3_next.py +++ b/tests/e2e/multicard/test_qwen3_next.py @@ -89,12 +89,6 @@ def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY(): gpu_memory_utilization=0.8, distributed_executor_backend="mp", enforce_eager=True, - additional_config={ - "ascend_scheduler_config": { - "enabled": True, - "enable_chunked_prefill": False - } - }, speculative_config={ "method": "qwen3_next_mtp", "num_speculative_tokens": 1 diff --git a/tests/e2e/multicard/test_shared_expert_dp.py b/tests/e2e/multicard/test_shared_expert_dp.py new file mode 100644 index 00000000000..867d3ab6eaa --- /dev/null +++ b/tests/e2e/multicard/test_shared_expert_dp.py @@ -0,0 +1,93 @@ +import os + +import pytest +from vllm import SamplingParams + +from tests.e2e.conftest import VllmRunner +from tests.e2e.model_utils import check_outputs_equal + +MODELS = [ + "vllm-ascend/DeepSeek-V2-Lite", +] +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +@pytest.mark.parametrize("model", MODELS) +def test_models_with_enable_shared_expert_dp(model: str) -> None: + + if 'HCCL_OP_EXPANSION_MODE' in os.environ: + del os.environ['HCCL_OP_EXPANSION_MODE'] + + prompts = [ + "Hello, my name is", "The capital of the United States is", + "The capital of France is", "The future of AI is" + ] + sampling_params = SamplingParams(max_tokens=32, temperature=0.0) + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + tensor_parallel_size=2, + enable_expert_parallel=True, + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, sampling_params) + + os.environ["VLLM_ASCEND_ENABLE_FLASHCOMM1"] = "1" + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + tensor_parallel_size=2, + enable_expert_parallel=True, + additional_config={ + "enable_shared_expert_dp": True, + }, + ) as runner: + shared_expert_dp_eager_outputs = runner.model.generate( + prompts, sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=False, + compilation_config={ + "cudagraph_capture_sizes": [1, 4, 8, 16], + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + additional_config={ + "enable_shared_expert_dp": True, + }, + ) as runner: + shared_expert_dp_aclgraph_outputs = runner.model.generate( + prompts, sampling_params) + + vllm_eager_outputs_list = [] + for output in vllm_eager_outputs: + vllm_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + shared_expert_dp_eager_outputs_list = [] + for output in shared_expert_dp_eager_outputs: + shared_expert_dp_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + shared_expert_dp_aclgraph_outputs_list = [] + for output in shared_expert_dp_aclgraph_outputs: + shared_expert_dp_aclgraph_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=shared_expert_dp_eager_outputs_list, + name_0="vllm_eager_outputs", + name_1="shared_expert_dp_eager_outputs", + ) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=shared_expert_dp_aclgraph_outputs_list, + name_0="vllm_eager_outputs", + name_1="shared_expert_dp_aclgraph_outputs", + ) diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py index a6f3f16d860..3472051e870 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/e2e/multicard/test_torchair_graph_mode.py @@ -44,9 +44,6 @@ def _deepseek_torchair_test_fixture( kwargs = {} if not use_v1_schduler: kwargs = { - "ascend_scheduler_config": { - "enabled": True, - }, "refresh": True, } additional_config.update(**kwargs) @@ -97,6 +94,7 @@ def test_e2e_deepseekv3_with_torchair_ms_mla(): _deepseek_torchair_test_fixture(additional_config) +@pytest.mark.skip("accuracy test failed. Fix me") def test_e2e_deepseekv3_with_torchair_v1scheduler(): additional_config = { "torchair_graph_config": { @@ -120,9 +118,6 @@ def _pangu_torchair_test_fixture( # torchair is only work without chunked-prefill now kwargs = { - "ascend_scheduler_config": { - "enabled": True, - }, "refresh": True, } additional_config.update(**kwargs) @@ -185,9 +180,6 @@ def _qwen_torchair_test_fixture( "torchair_graph_config": { "enabled": False, }, - "ascend_scheduler_config": { - "enabled": True, - }, "refresh": True, } @@ -244,9 +236,6 @@ def _deepseek_v2_lite_torchair_test_fixure( kwargs = {} if not use_v1_schduler: kwargs = { - "ascend_scheduler_config": { - "enable": True, - }, "refresh": True, } additional_config.update(**kwargs) diff --git a/tests/e2e/multicard/test_weight_loader.py b/tests/e2e/multicard/test_weight_loader.py index 2150a440751..6bb616dfc3f 100644 --- a/tests/e2e/multicard/test_weight_loader.py +++ b/tests/e2e/multicard/test_weight_loader.py @@ -61,7 +61,7 @@ def test_external_launcher(model): stderr=subprocess.STDOUT, timeout=600, ) - output = proc.stdout.decode() + output = proc.stdout.decode(errors='ignore') print(output) @@ -99,7 +99,7 @@ def test_external_launcher_dense(model): stderr=subprocess.STDOUT, timeout=600, ) - output = proc.stdout.decode() + output = proc.stdout.decode(errors='ignore') print(output) diff --git a/tests/e2e/nightly/features/test_mtpx_deepseek_r1_0528_w8a8.py b/tests/e2e/nightly/features/test_mtpx_deepseek_r1_0528_w8a8.py index 65d01b21240..880b44ae171 100644 --- a/tests/e2e/nightly/features/test_mtpx_deepseek_r1_0528_w8a8.py +++ b/tests/e2e/nightly/features/test_mtpx_deepseek_r1_0528_w8a8.py @@ -73,11 +73,7 @@ async def test_models(model: str, mode: str) -> None: "VLLM_RPC_TIMEOUT": "3600000", "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS": "3600000" } - additional_config: dict[str, Any] = { - "ascend_scheduler_config": { - "enabled": False - }, - } + additional_config: dict[str, Any] = {} speculative_config = { "num_speculative_tokens": 2, "method": "deepseek_mtp" diff --git a/tests/e2e/nightly/features/test_prefix_cache_deepseek_r1_0528_w8a8.py b/tests/e2e/nightly/features/test_prefix_cache_deepseek_r1_0528_w8a8.py index 8ac1883d1c1..80157588e71 100644 --- a/tests/e2e/nightly/features/test_prefix_cache_deepseek_r1_0528_w8a8.py +++ b/tests/e2e/nightly/features/test_prefix_cache_deepseek_r1_0528_w8a8.py @@ -74,9 +74,6 @@ async def test_models(model: str) -> None: "PYTORCH_NPU_ALLOC_CONF": "expandable_segments:True", } additional_config = { - "ascend_scheduler_config": { - "enabled": False - }, "torchair_graph_config": { "enabled": True, "enable_multistream_moe": False, diff --git a/tests/e2e/nightly/features/test_prefix_cache_qwen3_32b_int8.py b/tests/e2e/nightly/features/test_prefix_cache_qwen3_32b_int8.py index 3ee23287c6a..fdf7167b8ff 100644 --- a/tests/e2e/nightly/features/test_prefix_cache_qwen3_32b_int8.py +++ b/tests/e2e/nightly/features/test_prefix_cache_qwen3_32b_int8.py @@ -68,12 +68,7 @@ async def test_models(model: str) -> None: port = get_open_port() env_dict = {"TASK_QUEUE_ENABLE": "1", "HCCL_OP_EXPANSION_MODE": "AIV"} - additional_config = { - "ascend_scheduler_config": { - "enabled": False - }, - "enable_weight_nz_layout": True - } + additional_config = {"enable_weight_nz_layout": True} server_args = [ "--quantization", "ascend", "--reasoning-parser", "qwen3", "--tensor-parallel-size", "4", "--port", diff --git a/tests/e2e/nightly/features/test_qwen3_32b_int8_a3_feature_stack3.py b/tests/e2e/nightly/features/test_qwen3_32b_int8_a3_feature_stack3.py index 17a7f4b6e0b..9fa2d1e54d2 100644 --- a/tests/e2e/nightly/features/test_qwen3_32b_int8_a3_feature_stack3.py +++ b/tests/e2e/nightly/features/test_qwen3_32b_int8_a3_feature_stack3.py @@ -83,8 +83,7 @@ async def test_models(model: str, tp_size: int) -> None: "0.9", "--block-size", "128", "--max-num-seqs", "256", "--enforce-eager", "--max-model-len", "35840", "--max-num-batched-tokens", "35840", "--additional-config", - '{"ascend_scheduler_config":{"enabled":true},"enable_weight_nz_layout":true}', - "--compilation-config", + '{"enable_weight_nz_layout":true}', "--compilation-config", '{"cudagraph_mode":"FULL_DECODE_ONLY", "cudagraph_capture_sizes":[1,8,24,48,60]}' ] with RemoteOpenAIServer(model, diff --git a/tests/e2e/nightly/models/test_deepseek_r1_0528_w8a8.py b/tests/e2e/nightly/models/test_deepseek_r1_0528_w8a8.py index c912657784a..35082edb4be 100644 --- a/tests/e2e/nightly/models/test_deepseek_r1_0528_w8a8.py +++ b/tests/e2e/nightly/models/test_deepseek_r1_0528_w8a8.py @@ -33,7 +33,6 @@ "single", "aclgraph", "aclgraph_mlapo", - "no_chunkprefill", ] prompts = [ @@ -82,9 +81,6 @@ async def test_models(model: str, mode: str) -> None: "method": "deepseek_mtp" } additional_config = { - "ascend_scheduler_config": { - "enabled": False - }, "torchair_graph_config": { "enabled": True, "enable_multistream_moe": False, @@ -112,10 +108,6 @@ async def test_models(model: str, mode: str) -> None: if mode == "aclgraph_mlapo": env_dict["VLLM_ASCEND_ENABLE_MLAPO"] = "1" additional_config["torchair_graph_config"] = {"enabled": False} - if mode == "no_chunkprefill": - additional_config["ascend_scheduler_config"] = {"enabled": True} - i = server_args.index("--max-num-batched-tokens") + 1 - server_args[i] = "36864" server_args.extend(["--additional-config", json.dumps(additional_config)]) request_keyword_args: dict[str, Any] = { **api_keyword_args, @@ -134,7 +126,7 @@ async def test_models(model: str, mode: str) -> None: choices: list[openai.types.CompletionChoice] = batch.choices assert choices[0].text, "empty response" print(choices) - if mode in ["single", "no_chunkprefill"]: + if mode in ["single"]: return # aisbench test run_aisbench_cases(model, diff --git a/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py b/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py index bca2baf0dfd..6413aba0fcb 100644 --- a/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py +++ b/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py @@ -71,9 +71,6 @@ async def test_models(model: str) -> None: "cudagraph_mode": "FULL_DECODE_ONLY" } additional_config: dict[str, Any] = { - "ascend_scheduler_config": { - "enabled": False - }, "torchair_graph_config": { "enabled": True }, diff --git a/tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py b/tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py index 217b27866d9..9d5b78f051d 100644 --- a/tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py +++ b/tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py @@ -92,8 +92,7 @@ async def test_models(model: str, tp_size: int, dp_size: int, "--gpu-memory-utilization", "0.9", "--additional-config", - '{"ascend_scheduler_config":{"enabled":true},' - '"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}', + '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}', ] if full_graph: server_args += [ diff --git a/tests/e2e/nightly/models/test_glm4_5.py b/tests/e2e/nightly/models/test_glm4_5.py new file mode 100644 index 00000000000..aeb71f6802d --- /dev/null +++ b/tests/e2e/nightly/models/test_glm4_5.py @@ -0,0 +1,111 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from typing import Any + +import openai +import pytest +from vllm.utils import get_open_port + +from tests.e2e.conftest import RemoteOpenAIServer +from tools.aisbench import run_aisbench_cases + +MODELS = [ + "ZhipuAI/GLM-4.5", +] + +TENSOR_PARALLELS = [8] +DATA_PARALLELS = [2] + +prompts = [ + "San Francisco is a", +] + +api_keyword_args = { + "max_tokens": 10, +} + +aisbench_cases = [{ + "case_type": "accuracy", + "dataset_path": "vllm-ascend/gsm8k-lite", + "request_conf": "vllm_api_general_chat", + "dataset_conf": "gsm8k/gsm8k_gen_0_shot_cot_chat_prompt", + "max_out_len": 4096, + "batch_size": 8, + "baseline": 95, + "threshold": 5 +}, { + "case_type": "performance", + "dataset_path": "vllm-ascend/GSM8K-in3500-bs400", + "request_conf": "vllm_api_stream_chat", + "dataset_conf": "gsm8k/gsm8k_gen_0_shot_cot_str_perf", + "num_prompts": 16, + "max_out_len": 1500, + "batch_size": 8, + "request_rate": 0, + "baseline": 1, + "threshold": 0.97 +}] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS) +@pytest.mark.parametrize("dp_size", DATA_PARALLELS) +async def test_models( + model: str, + tp_size: int, + dp_size: int, +) -> None: + port = get_open_port() + env_dict = {"HCCL_BUFFSIZE": "1024"} + server_args = [ + "--no-enable-prefix-caching", + "--enable-expert-parallel", + "--tensor-parallel-size", + str(tp_size), + "--data-parallel-size", + str(dp_size), + "--port", + str(port), + "--max-model-len", + "8192", + "--max-num-batched-tokens", + "8192", + "--block-size", + "16", + "--trust-remote-code", + "--gpu-memory-utilization", + "0.9", + ] + request_keyword_args: dict[str, Any] = { + **api_keyword_args, + } + with RemoteOpenAIServer(model, + server_args, + server_port=port, + env_dict=env_dict, + auto_port=False) as server: + client = server.get_async_client() + batch = await client.completions.create( + model=model, + prompt=prompts, + **request_keyword_args, + ) + choices: list[openai.types.CompletionChoice] = batch.choices + assert choices[0].text, "empty response" + # aisbench test + run_aisbench_cases(model, port, aisbench_cases) diff --git a/tests/e2e/nightly/models/test_qwen2_5_vl_32b.py b/tests/e2e/nightly/models/test_qwen2_5_vl_32b.py index fe6bbedf2eb..77c1a7e1d73 100644 --- a/tests/e2e/nightly/models/test_qwen2_5_vl_32b.py +++ b/tests/e2e/nightly/models/test_qwen2_5_vl_32b.py @@ -85,9 +85,8 @@ async def test_models(model: str, tp_size: int) -> None: str(tp_size), "--port", str(port), "--max-model-len", "30000", "--max-num-batched-tokens", "40000", "--max-num-seqs", "400", "--trust-remote-code", - "--gpu-memory-utilization", "0.8", "--additional-config", - '{"ascend_scheduler_config":{"enabled":false}}', - "--compilation_config", '{"cudagraph_mode": "FULL_DECODE_ONLY"}' + "--gpu-memory-utilization", "0.8", "--compilation_config", + '{"cudagraph_mode": "FULL_DECODE_ONLY"}' ] request_keyword_args: dict[str, Any] = { **api_keyword_args, diff --git a/tests/e2e/nightly/models/test_qwen3_235b_a22b_w8a8_eplb.py b/tests/e2e/nightly/models/test_qwen3_235b_a22b_w8a8_eplb.py index 945d7cae3b1..efbf77d20f8 100644 --- a/tests/e2e/nightly/models/test_qwen3_235b_a22b_w8a8_eplb.py +++ b/tests/e2e/nightly/models/test_qwen3_235b_a22b_w8a8_eplb.py @@ -60,11 +60,7 @@ async def test_models(model: str) -> None: "PYTORCH_NPU_ALLOC_CONF": "expandable_segments:True", "VLLM_ASCEND_ENABLE_FLASHCOMM1": "1" } - additional_config: dict[str, Any] = { - "ascend_scheduler_config": { - "enabled": False - }, - } + additional_config: dict[str, Any] = {} compilation_config = {"cudagraph_mode": "FULL_DECODE_ONLY"} server_args = [ "--quantization", "ascend", "--async-scheduling", diff --git a/tests/e2e/nightly/models/test_qwen3_235b_w8a8.py b/tests/e2e/nightly/models/test_qwen3_235b_w8a8.py index 8220e4d59af..055a452e5b2 100644 --- a/tests/e2e/nightly/models/test_qwen3_235b_w8a8.py +++ b/tests/e2e/nightly/models/test_qwen3_235b_w8a8.py @@ -63,11 +63,6 @@ async def test_models(model: str, mode: str) -> None: "PYTORCH_NPU_ALLOC_CONF": "expandable_segments:True", "VLLM_ASCEND_ENABLE_FLASHCOMM1": "1" } - additional_config: dict[str, Any] = { - "ascend_scheduler_config": { - "enabled": False - }, - } compilation_config = {"cudagraph_mode": "FULL_DECODE_ONLY"} server_args = [ "--quantization", "ascend", "--async-scheduling", @@ -82,7 +77,6 @@ async def test_models(model: str, mode: str) -> None: server_args.extend( ["--compilation-config", json.dumps(compilation_config)]) - server_args.extend(["--additional-config", json.dumps(additional_config)]) request_keyword_args: dict[str, Any] = { **api_keyword_args, } diff --git a/tests/e2e/nightly/models/test_qwq_32b.py b/tests/e2e/nightly/models/test_qwq_32b.py index a60eff224b1..824651ba6c6 100644 --- a/tests/e2e/nightly/models/test_qwq_32b.py +++ b/tests/e2e/nightly/models/test_qwq_32b.py @@ -93,8 +93,6 @@ async def test_models(model: str, mode: str, tp_size: int) -> None: server_args.remove( '{"cudagraph_mode":"FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 8, 24, 48, 60]}' ) - server_args.append("--additional-config") - server_args.append('{"ascend_scheduler_config":{"enabled":true}}') server_args.append("--enforce-eager") request_keyword_args: dict[str, Any] = { **api_keyword_args, diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-A2-torchair.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-A2-torchair.yaml index 42b70f76456..7bfe3f5e99c 100644 --- a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-A2-torchair.yaml +++ b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-A2-torchair.yaml @@ -30,7 +30,7 @@ deployment: --quantization ascend --gpu-memory-utilization 0.9 --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' - --additional-config '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":true,"enable_multistream_moe":true},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}' + --additional-config '{"torchair_graph_config":{"enabled":true,"enable_multistream_moe":true},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}' - server_cmd: > @@ -51,7 +51,7 @@ deployment: --quantization ascend --gpu-memory-utilization 0.9 --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' - --additional-config '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":true,"enable_multistream_moe":true},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}' + --additional-config '{"torchair_graph_config":{"enabled":true,"enable_multistream_moe":true},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}' benchmarks: acc: case_type: accuracy diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-A2.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-A2.yaml index cf44bc8f5e6..01100f29481 100644 --- a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-A2.yaml +++ b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-A2.yaml @@ -31,7 +31,7 @@ deployment: --gpu-memory-utilization 0.9 --enforce-eager --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' - --additional-config '{"ascend_scheduler_config":{"enabled":false},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}' + --additional-config '{"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}' - server_cmd: > @@ -53,5 +53,5 @@ deployment: --gpu-memory-utilization 0.9 --enforce-eager --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' - --additional-config '{"ascend_scheduler_config":{"enabled":false},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}' + --additional-config '{"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}' benchmarks: diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-EPLB.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-EPLB.yaml index 9a4c3d94407..6ca189c4298 100644 --- a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-EPLB.yaml +++ b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-EPLB.yaml @@ -50,7 +50,7 @@ deployment: "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" }' --additional-config - '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' + '{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' - server_cmd: > @@ -80,7 +80,7 @@ deployment: "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" }' --additional-config - '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' + '{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' - server_cmd: > vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 @@ -111,7 +111,7 @@ deployment: "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" }' --additional-config - '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' + '{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' - server_cmd: > vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 @@ -141,7 +141,7 @@ deployment: "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" }' --additional-config - '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' + '{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' benchmarks: perf: case_type: performance diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8.yaml index a8e49290bd8..37a024b989a 100644 --- a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8.yaml +++ b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8.yaml @@ -49,7 +49,7 @@ deployment: "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" }' --additional-config - '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}' + '{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}' - server_cmd: > @@ -79,7 +79,7 @@ deployment: "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" }' --additional-config - '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}' + '{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}' - server_cmd: > vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 @@ -110,7 +110,7 @@ deployment: "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" }' --additional-config - '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true}' + '{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true}' - server_cmd: > vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 @@ -140,7 +140,7 @@ deployment: "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" }' --additional-config - '{"ascend_scheduler_config":{"enabled":false},"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true}' + '{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true}' benchmarks: perf: case_type: performance diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3_2-Exp-bf16.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3_2-Exp-bf16.yaml index 6dafd3ccd31..93e76ca5a2e 100644 --- a/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3_2-Exp-bf16.yaml +++ b/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3_2-Exp-bf16.yaml @@ -13,7 +13,7 @@ env_common: deployment: - server_cmd: > - vllm serve Yanguan/DeepSeek-V3.2-Exp-bf16 \ + vllm serve "Yanguan/DeepSeek-V3.2-Exp-bf16" --host 0.0.0.0 --port $SERVER_PORT --data-parallel-address $LOCAL_IP @@ -29,11 +29,11 @@ deployment: --trust-remote-code --no-enable-prefix-caching --gpu-memory-utilization 0.9 - --additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}' + --additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}' - server_cmd: > - vllm serve Yanguan/DeepSeek-V3.2-Exp-bf16 \ + vllm serve "Yanguan/DeepSeek-V3.2-Exp-bf16" --headless --data-parallel-size 2 --data-parallel-size-local 1 @@ -49,5 +49,5 @@ deployment: --trust-remote-code --no-enable-prefix-caching --gpu-memory-utilization 0.92 - --additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}' + --additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}' benchmarks: diff --git a/tests/e2e/nightly/multi_node/scripts/run.sh b/tests/e2e/nightly/multi_node/scripts/run.sh index 48d1c39dc79..0c134441c82 100644 --- a/tests/e2e/nightly/multi_node/scripts/run.sh +++ b/tests/e2e/nightly/multi_node/scripts/run.sh @@ -108,8 +108,8 @@ install_extra_components() { fi pip install custom_ops-1.0-cp311-cp311-linux_aarch64.whl - export ASCEND_CUSTOM_OPP_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize:${ASCEND_CUSTOM_OPP_PATH} - export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH} + export ASCEND_CUSTOM_OPP_PATH="/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize${ASCEND_CUSTOM_OPP_PATH:+:${ASCEND_CUSTOM_OPP_PATH}}" + export LD_LIBRARY_PATH="/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" source /usr/local/Ascend/ascend-toolkit/set_env.sh rm -f CANN-custom_ops-sfa-linux.aarch64.run \ diff --git a/tests/e2e/nightly/ops/test_batch_matmul_transpose.py b/tests/e2e/nightly/ops/test_batch_matmul_transpose.py new file mode 100644 index 00000000000..6c81b9ebce7 --- /dev/null +++ b/tests/e2e/nightly/ops/test_batch_matmul_transpose.py @@ -0,0 +1,141 @@ +import random +import unittest + +import torch + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + +torch.set_printoptions(threshold=float("inf")) + + +class TestMatrixMultiplication(unittest.TestCase): + + def compute_golden(self, a, b, res1, m, n): + """Compute reference result (golden)""" + torch.bmm(a.transpose(0, 1), + b, + out=res1.view(-1, m, n).transpose(0, 1)) + + def assert_tensors_almost_equal(self, actual, expected, dtype): + """Check if two tensors are approximately equal (considering floating point errors)""" + self.assertEqual(actual.shape, expected.shape, "Shape mismatch") + + # Check for NaN + self.assertFalse( + torch.isnan(actual).any(), "Actual result contains NaN") + self.assertFalse( + torch.isnan(expected).any(), "Expected result contains NaN") + + # Check for Inf + self.assertFalse( + torch.isinf(actual).any(), "Actual result contains Inf") + self.assertFalse( + torch.isinf(expected).any(), "Expected result contains Inf") + + # Set different tolerances based on data type + if dtype == torch.float16: + rtol, atol = 1e-5, 1e-5 + else: # bfloat16 + rtol, atol = 1.5e-5, 1.5e-5 + + # Compare values + diff = torch.abs(actual - expected) + max_diff = diff.max().item() + max_expected = torch.abs(expected).max().item() + + # Check relative and absolute errors + if max_expected > 0: + relative_diff = max_diff / max_expected + self.assertLessEqual( + relative_diff, + rtol, + f"Relative error too large: {relative_diff} > {rtol}. Max difference: {max_diff}", + ) + + self.assertLessEqual(max_diff, atol, + f"Absolute error too large: {max_diff} > {atol}") + + def test_boundary_conditions(self): + """Test boundary conditions""" + test_cases = [ + # (b, m, k, n) + (1, 1, 1, 1), # Minimum size + (1, 10, 1, 1), # b=1 + (10, 1, 1, 10), # m=1 + (5, 5, 1, 5), # k=1 + (2, 2, 2, 1), # n=1 + (100, 1, 1, 100), # Flat case + (1, 100, 100, 1), # Flat case + (2, 3, 4, 5), # Random small size + (10, 20, 30, 40), # Medium size + (36, 128, 512, 128), # target case + (8, 160, 512, 128), + ] + + dtypes = [torch.float16, torch.bfloat16] + + for dtype in dtypes: + for b, m, k, n in test_cases: + with self.subTest(dtype=dtype, shape=f"({b}, {m}, {k}, {n})"): + a = torch.randn(b, m, k, dtype=dtype, device="npu") + b_tensor = torch.randn(m, k, n, dtype=dtype, device="npu") + res1 = torch.empty((b, m * n), dtype=dtype, device="npu") + res2 = torch.empty((b, m, n), dtype=dtype, device="npu") + + self.compute_golden(a, b_tensor, res1, m, n) + torch.ops._C_ascend.batch_matmul_transpose( + a, b_tensor, res2) + + self.assert_tensors_almost_equal(res1.view(-1, m, n), res2, + dtype) + + def test_random_shapes(self): + """Test randomly generated shapes""" + num_tests = 1 + dtypes = [torch.float16, torch.bfloat16] + + for dtype in dtypes: + for _ in range(num_tests): + # Generate reasonable random sizes + b = random.randint(1, 500) + m = random.randint(1, 500) + k = random.randint(1, 500) + n = random.randint(1, 500) + + with self.subTest(dtype=dtype, + shape=f"Random ({b}, {m}, {k}, {n})"): + a = torch.randn(b, m, k, dtype=dtype, device="npu") + b_tensor = torch.randn(m, k, n, dtype=dtype, device="npu") + res1 = torch.empty((b, m * n), dtype=dtype, device="npu") + res2 = torch.empty((b, m, n), dtype=dtype, device="npu") + + self.compute_golden(a, b_tensor, res1, m, n) + torch.ops._C_ascend.batch_matmul_transpose( + a, b_tensor, res2) + self.assert_tensors_almost_equal(res1.view(-1, m, n), res2, + dtype) + + def test_zero_values(self): + """Test zero input values""" + dtypes = [torch.float16, torch.bfloat16] + b, m, k, n = 5, 4, 3, 2 + + for dtype in dtypes: + with self.subTest(dtype=dtype): + a = torch.zeros(b, m, k, dtype=dtype, device="npu") + b_tensor = torch.zeros(m, k, n, dtype=dtype, device="npu") + res1 = torch.empty((b, m * n), dtype=dtype, device="npu") + res2 = torch.empty((b, m, n), dtype=dtype, device="npu") + + self.compute_golden(a, b_tensor, res1, m, n) + torch.ops._C_ascend.batch_matmul_transpose(a, b_tensor, res2) + + self.assert_tensors_almost_equal(res1.view(-1, m, n), res2, + dtype) + self.assertTrue(torch.all(res2 == 0)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index 2f56d9d2ab4..6b90ec365ce 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -48,27 +48,26 @@ def mtp_correctness(sampling_config: SamplingParams, if graph_mode == CUDAGraphMode.FULL: graph_mode_str = "FULL_DECODE_ONLY" - with VllmRunner( - model_name, - tensor_parallel_size=1, - max_num_seqs=256, - gpu_memory_utilization=0.7, - distributed_executor_backend="mp", - enable_expert_parallel=True, - speculative_config={ - "method": "deepseek_mtp", - "num_speculative_tokens": num_speculative_tokens, - "disable_padded_drafter_batch": disable_padded_drafter_batch, - }, - enforce_eager=enforce_eager, - max_model_len=2000, - compilation_config=CompilationConfig( - cudagraph_mode=graph_mode_str, - cudagraph_capture_sizes=[12], - ), - additional_config={"ascend_scheduler_config": { - "enabled": False - }}) as spec_llm: + with VllmRunner(model_name, + tensor_parallel_size=1, + max_num_seqs=256, + gpu_memory_utilization=0.7, + distributed_executor_backend="mp", + enable_expert_parallel=True, + speculative_config={ + "method": + "deepseek_mtp", + "num_speculative_tokens": + num_speculative_tokens, + "disable_padded_drafter_batch": + disable_padded_drafter_batch, + }, + enforce_eager=enforce_eager, + max_model_len=2000, + compilation_config=CompilationConfig( + cudagraph_mode=graph_mode_str, + cudagraph_capture_sizes=[12], + )) as spec_llm: spec_outputs = spec_llm.generate(example_prompts, sampling_config) matches = 0 diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index b35de243977..0902fe6dd68 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -61,6 +61,7 @@ def eagle3_model_name(): return "vllm-ascend/EAGLE3-LLaMA3.1-Instruct-8B" +@pytest.mark.skip("TODO: Revert me after ngram oom issue on ci is fixed") def test_ngram_correctness( test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, @@ -117,7 +118,6 @@ def test_eagle_correctness( spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name() with VllmRunner( model_name, - enable_chunked_prefill=True, max_num_seqs=1, max_num_batched_tokens=2048, gpu_memory_utilization=0.6, @@ -145,3 +145,88 @@ def test_eagle_correctness( # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) + + +def test_suffix_correctness( + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, +): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using ngram speculative decoding. + ''' + ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=False) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + with VllmRunner(model_name, + speculative_config={ + "method": "suffix", + "num_speculative_tokens": 8, + }, + max_model_len=1024, + enforce_eager=False) as runner: + spec_outputs = runner.model.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) + + +def test_suffix_acceptance( + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, +): + ''' + Check that suffix decoding caching takes effect and improves acceptance + lengths and acceptance rates over multiple runs of the same prompts. + ''' + num_draft = [] + num_accept = [] + with VllmRunner(model_name, + speculative_config={ + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + "suffix_decoding_max_cached_requests": 1000, + "num_speculative_tokens": 10, + }, + max_model_len=1024, + disable_log_stats=False, + enforce_eager=False) as runner: + for i in range(10): + runner.model.chat(test_prompts[i], sampling_config) + metrics = runner.model.get_metrics() + for metric in metrics: + print(metric) + if metric.name == "vllm:spec_decode_num_draft_tokens": + num_draft.append(metric.value) + if metric.name == "vllm:spec_decode_num_accepted_tokens": + num_accept.append(metric.value) + # Calculate the acceptance rates for the first and last runs. + first_accept_tokens = num_accept[0] + first_draft_tokens = num_draft[0] + first_accept_rate = first_accept_tokens / first_draft_tokens + + # Take the diff since the stats are cumulative. + last_accept_tokens = num_accept[-1] - num_accept[-2] + last_draft_tokens = num_draft[-1] - num_draft[-2] + last_accept_rate = last_accept_tokens / last_draft_tokens + + # Expect the acceptance length to improve. + assert first_accept_tokens < last_accept_tokens + + # Expect the acceptance rate to improve. + assert first_accept_rate < last_accept_rate + + # Heuristic: expect at least 80% acceptance rate at the end. + assert last_accept_rate > 0.60 diff --git a/tests/e2e/singlecard/test_ascend_scheduler.py b/tests/e2e/singlecard/test_ascend_scheduler.py index 502a810376e..0c996e4eaaa 100644 --- a/tests/e2e/singlecard/test_ascend_scheduler.py +++ b/tests/e2e/singlecard/test_ascend_scheduler.py @@ -12,11 +12,6 @@ @pytest.mark.parametrize("enforce_eager", [True, False]) def test_concurrent_partial_prefill(enforce_eager): with VllmRunner(MODEL, - additional_config={ - 'ascend_scheduler_config': { - 'enabled': True, - }, - }, max_num_seqs=3, max_num_batched_tokens=8192, enforce_eager=enforce_eager, @@ -31,11 +26,6 @@ def test_concurrent_partial_prefill(enforce_eager): @pytest.mark.parametrize("enforce_eager", [True, False]) def test_prefix_cache_stats_is_recorded(enforce_eager): with VllmRunner(MODEL, - additional_config={ - 'ascend_scheduler_config': { - 'enabled': True, - }, - }, max_num_seqs=3, max_num_batched_tokens=8192, enforce_eager=enforce_eager, @@ -47,48 +37,6 @@ def test_prefix_cache_stats_is_recorded(enforce_eager): assert outputs[0].num_cached_tokens == 128 -@pytest.mark.parametrize("max_tokens", - [4]) # cannot align results when max_tokens > 4 -@pytest.mark.parametrize("chunked_prefill_token_size", [2048]) -def test_chunked_prefill_with_ascend_scheduler( - max_tokens: int, chunked_prefill_token_size: int) -> None: - example_prompts = [ - "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs." - ] - max_num_seqs = chunked_prefill_token_size - max_num_batched_tokens = chunked_prefill_token_size - with VllmRunner(MODEL, - additional_config={ - 'ascend_scheduler_config': { - 'enabled': True, - 'enable_chunked_prefill': True, - }, - }, - max_num_seqs=max_num_seqs, - max_num_batched_tokens=max_num_batched_tokens, - max_model_len=2048, - gpu_memory_utilization=0.7) as vllm_model: - chunked_prefill_output = vllm_model.generate_greedy( - example_prompts, max_tokens) - - with VllmRunner(MODEL, - additional_config={ - 'ascend_scheduler_config': { - 'enabled': True, - }, - }, - max_model_len=2048, - gpu_memory_utilization=0.7) as vllm_model: - vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=vllm_output, - outputs_1_lst=chunked_prefill_output, - name_0="vllm_output", - name_1="chunked_prefill_output", - ) - - @pytest.mark.parametrize("max_tokens", [4]) # cannot align results when max_tokens > 4 @pytest.mark.parametrize("chunked_prefill_token_size", [2048]) diff --git a/tests/e2e/singlecard/test_bge_model.py b/tests/e2e/singlecard/test_bge_model.py index 968bf1c7d43..48d4bf08539 100644 --- a/tests/e2e/singlecard/test_bge_model.py +++ b/tests/e2e/singlecard/test_bge_model.py @@ -28,7 +28,7 @@ def test_bge_model_correctness(): model_name = snapshot_download("BAAI/bge-m3") with VllmRunner( model_name, - task="embed", + runner="pooling", enforce_eager=True, ) as vllm_runner: vllm_outputs = vllm_runner.encode(queries) diff --git a/tests/e2e/singlecard/test_chunked.py b/tests/e2e/singlecard/test_chunked.py deleted file mode 100644 index f6eacb71dac..00000000000 --- a/tests/e2e/singlecard/test_chunked.py +++ /dev/null @@ -1,82 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -Compare the outputs of vLLM with and without aclgraph. - -Run `pytest tests/compile/test_aclgraph.py`. -""" -import gc - -import pytest -import torch -from vllm import SamplingParams - -from tests.e2e.conftest import VllmRunner - -MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("max_tokens", [1]) -def test_models( - model: str, - max_tokens: int, -) -> None: - prompts = ["The president of the United States is"] - - sampling_params = SamplingParams( - max_tokens=max_tokens, - temperature=0.0, - ) - - with VllmRunner(model, - long_prefill_token_threshold=20, - enforce_eager=False) as vllm_model: - output1 = vllm_model.generate(prompts, sampling_params) - - with VllmRunner(model, - enforce_eager=False, - additional_config={ - 'ascend_scheduler_config': { - 'enabled': True - }, - }) as vllm_model: - output2 = vllm_model.generate(prompts, sampling_params) - - # Extract the generated token IDs for comparison - token_ids1 = output1[0][0][0] - token_ids2 = output2[0][0][0] - - print(f"Token IDs 1: {token_ids1}") - print(f"Token IDs 2: {token_ids2}") - - # Convert token IDs to tensors and calculate cosine similarity - # Take the length of a shorter sequence to ensure consistent dimensions - min_len = min(len(token_ids1), len(token_ids2)) - - tensor1 = torch.tensor(token_ids1[:min_len], dtype=torch.float32) - tensor2 = torch.tensor(token_ids2[:min_len], dtype=torch.float32) - - # Calculate similarity using torch.cosine_similarity - similarity = torch.cosine_similarity(tensor1, tensor2, dim=0) - print(f"Token IDs cosine similarity: {similarity.item()}") - - assert similarity > 0.95 - - gc.collect() - torch.npu.empty_cache() - torch.npu.reset_peak_memory_stats() diff --git a/tests/e2e/singlecard/test_embedding.py b/tests/e2e/singlecard/test_embedding.py index 8c63a980e8a..3ff8d3416f3 100644 --- a/tests/e2e/singlecard/test_embedding.py +++ b/tests/e2e/singlecard/test_embedding.py @@ -28,7 +28,7 @@ def test_embed_models_correctness(): model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B") with VllmRunner( model_name, - task="embed", + runner="pooling", enforce_eager=False, ) as vllm_runner: vllm_outputs = vllm_runner.encode(queries) diff --git a/tests/e2e/singlecard/test_embedding_aclgraph.py b/tests/e2e/singlecard/test_embedding_aclgraph.py index e0851b06468..4c164900b41 100644 --- a/tests/e2e/singlecard/test_embedding_aclgraph.py +++ b/tests/e2e/singlecard/test_embedding_aclgraph.py @@ -34,14 +34,14 @@ def test_aclgrpah_embed_models_correctness(model_name): with VllmRunner( model_name, - task="embed", + runner="pooling", enforce_eager=False, ) as vllm_aclgraph_runner: vllm_aclgraph_outputs = vllm_aclgraph_runner.encode(queries) with VllmRunner( model_name, - task="embed", + runner="pooling", enforce_eager=True, ) as vllm_runner: vllm_outputs = vllm_runner.encode(queries) diff --git a/tests/e2e/singlecard/test_vlm.py b/tests/e2e/singlecard/test_vlm.py index cc3d50f8b3d..954566799c0 100644 --- a/tests/e2e/singlecard/test_vlm.py +++ b/tests/e2e/singlecard/test_vlm.py @@ -20,7 +20,6 @@ Run `pytest tests/test_offline_inference.py`. """ -import pytest from vllm import SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset @@ -55,40 +54,6 @@ def test_multimodal_vl(prompt_template): assert output_str, "Generated output should not be empty." -@pytest.mark.skip(reason="This e2e test will stuck in multi-batch scenario. " - "Add this back after fixing the issue.") -def test_multimodal_ascend_scheduler(prompt_template): - image = ImageAsset("cherry_blossom") \ - .pil_image.convert("RGB") - img_questions = [ - "What is the content of this image?", - "Describe the content of this image in detail.", - "What's in the image?", - "Where is this image taken?", - ] - images = [image] * len(img_questions) - prompts = prompt_template(img_questions) - with VllmRunner("Qwen/Qwen2.5-VL-3B-Instruct", - max_model_len=4096, - additional_config={ - 'ascend_scheduler_config': { - 'enabled': True, - }, - }, - mm_processor_kwargs={ - "min_pixels": 28 * 28, - "max_pixels": 1280 * 28 * 28, - "fps": 1, - }, - enforce_eager=True) as vllm_model: - outputs = vllm_model.generate_greedy(prompts=prompts, - images=images, - max_tokens=64) - assert len(outputs) == len(prompts) - for _, output_str in outputs: - assert output_str, "Generated output should not be empty." - - def test_multimodal_audio(): audio_prompt = "".join([ f"Audio {idx+1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n" diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 129b54102e3..3a94e9e8d56 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -25,12 +25,6 @@ def test_get_builder_cls(self): self.assertEqual(AscendAttentionBackend.get_builder_cls(), AscendAttentionMetadataBuilder) - @patch('vllm_ascend.attention.attention_v1.get_ascend_device_type', - return_value=AscendDeviceType._310P) - def test_get_kv_cache_shape_310p(self, mock_soc_version): - result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40) - self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16)) - @patch('vllm_ascend.utils.get_ascend_device_type', return_value=AscendDeviceType._910_93) def test_get_kv_cache_shape_not_310p(self, mock_soc_version): @@ -95,76 +89,6 @@ def test_reorder_batch(self): self.assertFalse(result) - @patch('vllm_ascend.attention.attention_v1.AscendMetadata') - @patch('torch_npu.npu_format_cast') - @patch('vllm_ascend.utils.nd_to_nz_2d') - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType._310P) - def test_build_prefill_no_cache(self, mock_soc_version, mock_nd_to_nz_2d, - mock_npu_format_cast, - mock_ascend_metadata): - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=torch.tensor([0, 3, 7]), - query_start_loc_cpu=torch.tensor([0, 3, 7]), - seq_lens_cpu=torch.tensor([5, 6]), - num_reqs=2, - num_actual_tokens=10, - max_query_len=5, - decode_token_per_req=torch.tensor([1, 1]), - block_table_tensor=torch.zeros((10, 10)), - slot_mapping=torch.tensor(range(20)), - actual_seq_lengths_q=torch.tensor([0, 1]), - positions=torch.tensor([10, 10]), - attn_mask=torch.ones((10, 10)), - spec_attn_mask=None, - attn_state=AscendAttentionState.PrefillNoCache, - num_computed_tokens_cpu=None, - seq_lens=None) - - mock_nz_tensor = MagicMock() - mock_model = MagicMock() - mock_nd_to_nz_2d.return_value = mock_nz_tensor - mock_npu_format_cast.return_value = mock_nz_tensor - - self.builder.build(1, common_attn_metadata, mock_model) - - @patch('vllm_ascend.attention.attention_v1.AscendMetadata') - @patch('torch_npu.npu_format_cast') - @patch('vllm_ascend.utils.nd_to_nz_spec') - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType._310P) - @patch('vllm_ascend.attention.attention_v1.AscendAttentionState') - def test_build_chunked_prefill(self, mock_ascend_attention_state, - mock_soc_version, mock_nd_to_nz_spec, - mock_npu_format_cast, mock_ascend_metadata): - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=torch.tensor([0, 2, 5, 9]), - query_start_loc_cpu=torch.tensor([0, 2, 5, 9]), - seq_lens_cpu=torch.tensor([4, 5, 6]), - num_reqs=3, - num_actual_tokens=15, - max_query_len=6, - decode_token_per_req=torch.tensor([1, 1, 1]), - block_table_tensor=torch.zeros((10, 10)), - slot_mapping=torch.tensor(range(20)), - actual_seq_lengths_q=torch.tensor([0, 1, 2]), - positions=torch.tensor([10, 10]), - attn_mask=torch.ones((15, 15)), - spec_attn_mask=None, - attn_state=AscendAttentionState.ChunkedPrefill, - num_computed_tokens_cpu=None, - seq_lens=None) - - mock_ascend_attention_state = MagicMock() - mock_ascend_attention_state.PrefillNoCache = 0 - - mock_nz_tensor = MagicMock() - mock_model = MagicMock() - mock_nd_to_nz_spec.return_value = mock_nz_tensor - mock_npu_format_cast.return_value = mock_nz_tensor - - self.builder.build(1, common_attn_metadata, mock_model) - @patch('vllm_ascend.attention.attention_v1.AscendMetadata') @patch('vllm_ascend.utils.get_ascend_device_type', return_value=AscendDeviceType._910_93) @@ -286,73 +210,40 @@ def test_forward_no_attn_metadata(self): assert output.shape == (10, 8 * 64) - @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('torch_npu._npu_reshape_and_cache') - @patch('torch_npu._npu_flash_attention') - def test_forward_prefill_no_cache(self, mock_flash_attention, - mock_reshape_cache, - mock_get_forward_context): - """Test forward pass in PrefillNoCache state""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - output = torch.empty_like(query) - - mock_get_forward_context.return_value = MagicMock(capturing=False) - - metadata = self.attn_metadata - metadata.attn_state = AscendAttentionState.PrefillNoCache - metadata.attn_mask = torch.randn(1, 1, 10, 10) - metadata.seq_lens = torch.tensor([10]) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 0 - metadata.num_prefills = 10 - layer = self.layer_no_quant - - output = self.impl.forward(layer, query, key, value, kv_cache, - metadata, output) - - mock_reshape_cache.assert_called_once() - mock_flash_attention.assert_called_once() - assert output.shape == (10, 8 * 64) - @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') @patch('vllm_ascend.attention.attention_v1.get_forward_context') - def test_forward_prefill_cache_hit(self, mock_get_forward_context, - mock_npu_fused_infer_attention_score, - mock_npu_reshape_and_cache): + def test_forward_prefill(self, mock_get_forward_context, + mock_npu_fused_infer_attention_score, + mock_npu_reshape_and_cache): """Test forward pass in PrefillCacheHit state""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) + query = torch.randn(10, 8, 64) + key = torch.randn(10, 8, 64) + value = torch.randn(10, 8, 64) kv_cache = torch.empty(2, 5, 128, 8, 64) output = torch.empty_like(query) - metadata = self.attn_metadata metadata.attn_state = AscendAttentionState.PrefillCacheHit metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.query_lens = torch.tensor([10]) metadata.seq_lens = torch.tensor([10]) + metadata.actual_seq_lengths_q = [10] metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + metadata.num_decode_tokens = 0 metadata.num_decodes = 0 metadata.num_prefills = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) layer = self.layer_no_quant mock_get_forward_context.return_value = MagicMock(capturing=False) - mock_npu_fused_infer_attention_score.return_value = (output, - torch.ones( - 10, 8, 64)) - + mock_npu_fused_infer_attention_score.return_value = (torch.ones( + 10, 8, 64), torch.ones(10, 8, 64)) output = self.impl.forward(layer, query, key, value, kv_cache, metadata, output) mock_npu_fused_infer_attention_score.assert_called_once() - assert output.shape == (10, 8 * 64) + assert output.shape == (10, 8, 64) @patch('torch_npu._npu_paged_attention') @patch('torch_npu._npu_reshape_and_cache') @@ -454,119 +345,6 @@ def test_forward_decode_only_swa_seq_len_mismatch( assert output.shape == (10, 8 * 64) - @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType._910_93) - @patch('torch_npu._npu_reshape_and_cache') - @patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill') - def test_forward_head_size_192(self, mock_vanilla_prefill, - mock_npu_reshape_and_cache, - mock_soc_version, mock_get_forward_context): - """Test forward pass when head_size is 192""" - - self.impl.head_size = 192 - query = torch.randn(10, 8 * 192) - key = torch.randn(10, 8 * 192) - value = torch.randn(10, 8 * 192) - kv_cache = torch.empty(2, 5, 128, 8, 192) - output = torch.empty_like(query) - - mock_get_forward_context.return_value = MagicMock(capturing=False) - - metadata = self.attn_metadata - metadata.attn_mask = torch.randn(1, 1, 10, 10) - metadata.query_lens = torch.tensor([10]) - metadata.seq_lens = torch.tensor([10]) - metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 10 - metadata.num_prefills = 0 - layer = self.layer_no_quant - mock_vanilla_prefill.return_value = MagicMock() - - output = self.impl_192.forward(layer, query, key, value, kv_cache, - metadata, output) - - mock_vanilla_prefill.assert_called_once() - assert output.shape == (10, 8 * 192) - - @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('torch_npu.npu_fused_infer_attention_score') - @patch('torch_npu._npu_reshape_and_cache') - def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache, - mock_npu_fused_infer_attention_score, - mock_get_forward_context): - """Test forward pass in normal V1 situation""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - output = torch.empty_like(query) - - metadata = self.attn_metadata - metadata.attn_mask = torch.randn(1, 1, 10, 10) - metadata.query_lens = torch.tensor([10]) - metadata.seq_lens = torch.tensor([10]) - metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 0 - metadata.num_prefills = 10 - layer = self.layer_no_quant - mock_get_forward_context.return_value = MagicMock(capturing=False) - mock_npu_fused_infer_attention_score.return_value = (output, - torch.ones( - 10, 8, 64)) - - output = self.impl.forward(layer, query, key, value, kv_cache, - metadata, output) - - mock_npu_fused_infer_attention_score.assert_called_once() - assert output.shape == (10, 8 * 64) - - @patch('torch_npu.npu_format_cast') - @patch('torch_npu._npu_reshape_and_cache') - @patch('torch_npu.npu_fused_infer_attention_score') - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType._310P) - @patch('vllm_ascend.attention.attention_v1.get_forward_context') - def test_forward_310p_device(self, mock_get_forward_context, - mock_soc_version, - mock_npu_fused_infer_attention_score, - mock_npu_reshape_and_cache, - mock_npu_format_cast): - """Test forward pass on 310P device""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - output = torch.empty_like(query) - - metadata = self.attn_metadata - metadata.attn_mask = torch.randn(1, 1, 10, 10) - metadata.query_lens = torch.tensor([10]) - metadata.seq_lens = torch.tensor([10]) - metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - metadata.num_decodes = 0 - metadata.num_prefills = 10 - layer = self.layer_no_quant - - mock_npu_format_cast.return_value = metadata.attn_mask - - mock_get_forward_context.return_value = MagicMock(capturing=False) - mock_npu_fused_infer_attention_score.return_value = (output, - torch.ones( - 10, 8, 64)) - - output = self.impl.forward(layer, query, key, value, kv_cache, - metadata, output) - - mock_npu_fused_infer_attention_score.assert_called_once() - assert output.shape == (10, 8 * 64) - @patch('torch_npu._npu_reshape_and_cache') def test_forward_raise_error(self, mock_paged_attention): query = torch.randn(10, 8 * 64) diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index 15a5c50986b..4a9109166ed 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -24,11 +24,13 @@ def mock_distributed(): patch('torch.distributed.get_backend', return_value='nccl'), \ patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \ patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \ - patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group: + patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group, \ + patch('vllm_ascend.distributed.parallel_state.get_pp_group') as mock_pp_group: mock_group.return_value.local_rank = 0 mock_group.return_value.device_group = MagicMock() mock_tp_group.return_value.world_size = 4 mock_dp_group.return_value.world_size = 2 + mock_pp_group.return_value.world_size = 2 yield diff --git a/tests/ut/models/test_mla.py b/tests/ut/models/test_mla.py index 6b03b05be6f..28363450dca 100644 --- a/tests/ut/models/test_mla.py +++ b/tests/ut/models/test_mla.py @@ -7,8 +7,7 @@ from vllm.model_executor.layers.mla import MLAModules from tests.ut.base import TestBase -from vllm_ascend.models.layers.mla import (AscendMultiHeadLatentAttention, - IndexerWrapper) +from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention, IndexerWrapper class TestIndexerWrapper(TestBase): @@ -78,15 +77,13 @@ def setUp(self): self.mock_cache_config = MagicMock(spec=CacheConfig) self.mock_quant_config = MagicMock() - @patch("vllm_ascend.models.layers.mla.get_current_vllm_config") - @patch("vllm_ascend.models.layers.mla.get_ascend_config") - @patch( - "vllm_ascend.models.layers.mla.get_tensor_model_parallel_world_size") + @patch("vllm_ascend.ops.mla.get_current_vllm_config") + @patch("vllm_ascend.ops.mla.get_ascend_config") + @patch("vllm_ascend.ops.mla.get_tensor_model_parallel_world_size") def test_initialization(self, mock_tp_size, mock_ascend_config, mock_get_vllm_config): - with patch("vllm_ascend.models.layers.mla.MLAAttention", - return_value=True): + with patch("vllm_ascend.ops.mla.MLAAttention", return_value=True): mock_tp_size.return_value = 2 mock_ascend_config.return_value.enable_shared_expert_dp = True mock_vllm_config = MagicMock(spec=VllmConfig) @@ -114,12 +111,11 @@ def test_initialization(self, mock_tp_size, mock_ascend_config, self.assertTrue(attn.enable_shared_expert_dp) self.assertIsNotNone(attn.mla_attn) - @patch("vllm_ascend.models.layers.mla.torch.ops.vllm.mla_forward") - @patch("vllm_ascend.models.layers.mla.get_current_vllm_config") - @patch("vllm_ascend.models.layers.mla.get_ascend_config") - @patch( - "vllm_ascend.models.layers.mla.get_tensor_model_parallel_world_size") - @patch("vllm_ascend.models.layers.mla.get_forward_context") + @patch("vllm_ascend.ops.mla.torch.ops.vllm.mla_forward") + @patch("vllm_ascend.ops.mla.get_current_vllm_config") + @patch("vllm_ascend.ops.mla.get_ascend_config") + @patch("vllm_ascend.ops.mla.get_tensor_model_parallel_world_size") + @patch("vllm_ascend.ops.mla.get_forward_context") def test_forward(self, mock_get_forward_context, mock_tp_size, mock_ascend_config, mock_get_vllm_config, mock_mla_forward): @@ -130,8 +126,7 @@ def test_forward(self, mock_get_forward_context, mock_tp_size, num_hidden_layers=32, first_k_dense_replace=False) mock_get_vllm_config.return_value = mock_vllm_config mock_vllm_config.compilation_config = CompilationConfig() - with patch("vllm_ascend.models.layers.mla.MLAAttention", - return_value=True): + with patch("vllm_ascend.ops.mla.MLAAttention", return_value=True): attn = AscendMultiHeadLatentAttention( hidden_size=self.hidden_size, num_heads=self.num_heads, diff --git a/tests/ut/models/test_qwen2_vl.py b/tests/ut/models/test_qwen2_vl.py deleted file mode 100644 index d62b8594bae..00000000000 --- a/tests/ut/models/test_qwen2_vl.py +++ /dev/null @@ -1,200 +0,0 @@ -import pytest -import torch -from pytest_mock import MockerFixture -from vllm.model_executor.layers.activation import QuickGELU - -from tests.ut.base import PytestBase -from vllm_ascend.models.qwen2_vl import (AscendQwen2VisionAttention, - AscendQwen2VisionBlock) - - -class TestAscendQwen2VisionAttention(PytestBase): - - def init_attention( - self, - mocker, - embed_dim=1000, - num_heads=10, - projection_size=100, - quant_config=None, - prefix="", - ): - mocker_attn = mocker.patch( - "vllm_ascend.models.qwen2_vl.Qwen2VisionAttention.__init__") - - attention = AscendQwen2VisionAttention( - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - ) - args, kwargs = mocker_attn.call_args - assert args == (embed_dim, num_heads, projection_size, None, "") - assert not kwargs - attention.num_attention_heads_per_partition = num_heads - return attention - - def test_attn_init_should_normal(self, mocker: MockerFixture): - embed_dim = 1000 - num_heads = 10 - projection_size = 100 - quant_config = None - prefix = "" - vit = self.init_attention( - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - mocker=mocker, - ) - assert vit.hidden_size_per_attention_head == 10 - - def test_attn_init_should_raise_error(self, mocker: MockerFixture): - embed_dim = 1000 - num_heads = 7 - projection_size = 100 - quant_config = None - prefix = "" - with pytest.raises(AssertionError): - # projection_size should divided by num heads - self.init_attention( - mocker=mocker, - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - ) - - def test_attn_forward(self, mocker: MockerFixture): - attention = self.init_attention(mocker=mocker) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim - cu_seqlens = torch.tensor([10, 50, 100]) - cos = torch.rand((1, 100, 1, 128)) - sin = torch.rand((1, 100, 1, 128)) - - qkv = lambda x: (x, 0) # noqa - split_qkv = lambda x: [ #noqa - torch.rand((100, 3, 10, 128)) for i in range(3) - ] # noqa - npu_rotary_mul = lambda q, cos, sin: q # noqa - _npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa - proj = lambda x: (x, 0) # noqa - - mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv) - mocker_split_qkv = mocker.patch.object( - attention, - "split_qkv", - side_effect=split_qkv, - ) - mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul", - side_effect=npu_rotary_mul) - mocker_npu_flash_attention_unpad = mocker.patch( - "torch_npu._npu_flash_attention_unpad", - side_effect=_npu_flash_attention_unpad, - ) - mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj) - attention.__dict__["qkv"] = mocker_qkv - attention.__dict__["split_qkv"] = mocker_split_qkv - attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul - attention.__dict__["_npu_flash_attention_unpad"] = ( - mocker_npu_flash_attention_unpad) - attention.__dict__["proj"] = mocker_proj - - output = attention.forward( - x=x, - cu_seqlens=cu_seqlens, - cos=cos, - sin=sin, - ) - qkv_args, qkv_kwargs = mocker_qkv.call_args - assert qkv_args == (x, ) - assert not qkv_kwargs - - split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args - assert split_qkv_args == (x, ) - assert not split_qkv_kwargs - - npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args - assert npu_rotary_mul_args[1:] == (cos, sin) - assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128]) - assert not npu_rotary_mul_kwargs - - assert output.shape == torch.Size([100, 3, 1280]) - - -class TestAscendQwen2VisionBlock(PytestBase): - - def init_vision_block( - self, - mocker, - dim=100, - num_heads=10, - mlp_ratio=0.5, - ): - mocker_vit = mocker.patch( - "vllm.model_executor.models.qwen2_vl.Qwen2VisionBlock.__init__", - return_value=None, - ) - - mocker_attn = mocker.patch( - "vllm_ascend.models.qwen2_vl.AscendQwen2VisionAttention.__init__", - return_value=None, - ) - - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - vision_block = AscendQwen2VisionBlock( - dim=dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - ) - args, kwargs = mocker_vit.call_args - assert args == (dim, num_heads, mlp_ratio, QuickGELU, None, None, "") - assert not kwargs - - args1, kwargs1 = mocker_attn.call_args - assert not args1 - assert kwargs1 == { - "embed_dim": dim, - "num_heads": num_heads, - "projection_size": dim, - "quant_config": None, - "prefix": ".attn", - } - return vision_block - - def test_init_vision_block_should_normal( - self, - mocker: MockerFixture, - ): - vision_block = self.init_vision_block(mocker) - assert isinstance(vision_block, AscendQwen2VisionBlock) - - def test_vision_block_forward(self, mocker: MockerFixture): - x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d - cu_seqlens = torch.tensor([10, 50, 100]) - cos = torch.rand((1, 100, 1, 128)) - sin = torch.rand((1, 100, 1, 128)) - vision_block = self.init_vision_block(mocker) - mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x) - mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x) - vision_block.__dict__["attn"] = mocker_attn - vision_block.__dict__["mlp"] = mocker_mlp - - output = vision_block.forward(x.clone(), cu_seqlens, cos, sin) - - _, attn_kwargs = mocker_attn.call_args - assert attn_kwargs == { - "cu_seqlens": cu_seqlens, - "cos": cos, - "sin": sin, - } - - assert torch.all(x * 3 == output) diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index 314775f883c..77af2649aae 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import patch import pytest import torch @@ -42,7 +43,9 @@ def context(self, mocker: MockerFixture): # Test case for the most common and basic scenario @pytest.mark.parametrize( "residual", [None, torch.randn(4, 8, dtype=torch.float16)]) - def test_forward_oot_basic(self, residual): + @patch("torch.ops.vllm.maybe_chunk_residual") + def test_forward_oot_basic(self, mock_maybe_chunk_residual, residual): + mock_maybe_chunk_residual.side_effect = lambda x, residual: residual layer = RMSNorm(hidden_size=8, eps=1e-05) x = torch.randn(4, 8, dtype=torch.float16) if residual is not None: @@ -107,6 +110,8 @@ def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture): mock_forward_context.num_hidden_layers = num_hidden_layers mock_forward_context.fusion_linear = "gate_up_dense" mock_forward_context.weight_prefetch_method = None + mocker.patch("torch.ops.vllm.maybe_chunk_residual", + lambda x, residual: residual) # Ensure fusion and layer_idx increment are handled correctly x = torch.randn(4, 8, dtype=torch.float16) diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index f258f8e709e..8adde876a0b 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -226,8 +226,8 @@ def test_fused_experts_method(self, mock_unified_apply_mlp, w2 = w2.contiguous() result = comm_impl.fused_experts(hidden_states=hidden_states, - w1=w1, - w2=w2, + w1=[w1], + w2=[w2], topk_weights=topk_weights, topk_ids=topk_ids, activation="silu") diff --git a/typos.toml b/typos.toml index bd75b50aa0b..d15e113700d 100644 --- a/typos.toml +++ b/typos.toml @@ -19,7 +19,7 @@ locale = "en" extend-ignore-identifiers-re = [".*Unc.*", ".*_thw", ".*UE8M0.*", ".*[UE4M3|ue4m3].*", ".*eles.*", ".*fo.*", ".*ba.*", ".*ot.*", ".*[Tt]h[rR].*"] -extend-ignore-words-re = ["CANN", "cann"] +extend-ignore-words-re = ["CANN", "cann","ND"] extend-ignore-re = [] [default.extend-identifiers] diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 16d16a4d7c8..115dbef1209 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -72,6 +72,10 @@ def __init__(self, vllm_config): self.enable_shared_expert_dp = additional_config.get( "enable_shared_expert_dp", False ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel + if self.enable_shared_expert_dp: + from vllm_ascend.utils import enable_sp + assert enable_sp(vllm_config=vllm_config, + enable_shared_expert_dp=True) self.multistream_overlap_shared_expert = additional_config.get( "multistream_overlap_shared_expert", False) self.recompute_scheduler_enable = additional_config.get( diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 1d9139c5113..0cb2b75cdc6 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -41,11 +41,7 @@ split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) -from vllm_ascend.ops.attention import vanilla_chunked_prefill -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, - aligned_16, get_ascend_device_type, nd_to_nz_2d, - nd_to_nz_spec, prefill_context_parallel_enable, - weak_ref_tensors) +from vllm_ascend.utils import prefill_context_parallel_enable, weak_ref_tensors # isort: off if prefill_context_parallel_enable(): @@ -83,9 +79,6 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - if get_ascend_device_type() == AscendDeviceType._310P: - return (2, num_blocks, num_kv_heads * head_size // 16, block_size, - 16) return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod @@ -351,16 +344,6 @@ def build( query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) - if get_ascend_device_type() == AscendDeviceType._310P: - if attn_state == AscendAttentionState.PrefillNoCache: - mask_nz = nd_to_nz_2d(attn_mask) - attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - elif attn_state == AscendAttentionState.ChunkedPrefill: - mask_nz = nd_to_nz_spec(attn_mask) - attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata prefill_metadata = None decode_metadata = None @@ -585,9 +568,9 @@ def full_graph_attention(self, output: torch.Tensor, num_tokens=0): if self.pcp_size * self.dcp_size > 1: - intermediate_output = self._forward_pcp_dcp( - query, key, value, kv_cache, attn_metadata, output) - return intermediate_output, query.shape[0] + attn_output = self._forward_pcp_dcp(query, key, value, kv_cache, + attn_metadata, output) + return attn_output, query.shape[0] elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: block_size = 128 block_table = None @@ -688,93 +671,58 @@ def full_graph_attention(self, graph_params.handles[num_tokens].append(handle) return output, num_tokens - def _forward_prefill_no_cache( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, - num_tokens=0, - ) -> torch.Tensor: - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - - mask = attn_metadata.attn_mask - - if get_ascend_device_type() == AscendDeviceType._310P: - # align q k v output tensors - query = aligned_16(query) - key = aligned_16(key) - value = aligned_16(value) - output = aligned_16(output) - # do reformat in case of broadcasted tensors - mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1) - mask = torch_npu.npu_format_cast(mask.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - - torch_npu._npu_flash_attention(query=query, - key=key, - value=value, - mask=mask, - seq_len=attn_metadata.seq_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output) - assert output is not None - return output[:num_tokens] - - def _forward_prefill_cache_hit( - self, - query: torch.Tensor, - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - - compress_mask = attn_metadata.attn_mask - batch_size = attn_metadata.query_lens.shape[0] - block_table = attn_metadata.block_tables[:batch_size, :] - num_block, block_size, _, _ = self.key_cache.shape # type: ignore - - if block_size == 128: - # TODO:The npu_fused_infer_attention_score op is planned to - # be utilized in a wider range in upcoming versions. + def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, attn_metadata: AscendMetadata, + output: torch.Tensor): + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + block_size = 128 + block_table = None + actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] + num_block, block_size, _, _ = self.key_cache.shape # type: ignore key = self.key_cache.view( # type: ignore num_block, block_size, -1) value = self.value_cache.view( # type: ignore num_block, block_size, -1) - - output, _ = torch_npu.npu_fused_infer_attention_score( - query=query, - key=key, - value=value, - atten_mask=compress_mask, - block_table=block_table, - input_layout="TND", - block_size=block_size, - actual_seq_lengths=attn_metadata.actual_seq_lengths_q, - actual_seq_lengths_kv=attn_metadata.seq_lens_list, - num_key_value_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale=self.scale, - sparse_mode=3, - ) + actual_seq_lengths_kv = attn_metadata.seq_lens_list + # chunked_prefill. else: - torch_npu._npu_flash_attention_qlens( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - block_table=block_table, - mask=compress_mask, - seq_len=attn_metadata.query_lens, - context_lens=attn_metadata.seq_lens, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - out=output) + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + + num_tokens = attn_metadata.actual_seq_lengths_q[-1] + query = query[:num_tokens] + # Prepare tensors for attention output + # TODO: Refactor this to step-level instead of layer-level + + # Get workspace from cache or calculate it if not present. + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) + + attn_output = attn_output.view(num_tokens, self.num_heads, + self.head_size) + output[:num_tokens] = attn_output[:num_tokens] return output def _forward_decode_only( @@ -783,10 +731,6 @@ def _forward_decode_only( attn_metadata: AscendMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if get_ascend_device_type() == AscendDeviceType._310P: - # seq_lens_tensor needs to be transferred to the device for 310P. - attn_metadata.seq_lens = \ - attn_metadata.seq_lens.to(device=query.device) if self.sliding_window is not None and attn_metadata.seq_lens.shape[ 0] == query.size(0): batch_size = attn_metadata.seq_lens.shape[0] @@ -827,69 +771,6 @@ def _forward_decode_only( out=output) return output - def _forward_v1_style( - self, - query: torch.Tensor, - attn_metadata: AscendMetadata, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # Use chunked prefill for head size 192 scenario, like deepseek - # paged_attention_splitfuse maybe crash at such scenario. - # TODO: vanilla path will be removed after the kernel support - # head_size 192 scenario. - if self.head_size == 192: - cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() - cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist() - cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device) - cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device) - cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) - cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0) - max_seqlen_q = torch.max(attn_metadata.query_lens) - max_seqlen_k = torch.max(attn_metadata.seq_lens) - vanilla_chunked_prefill(output, query, self.key_cache, - self.value_cache, - attn_metadata.block_tables, cu_seqlen_q, - cu_seqlen_k, max_seqlen_q, max_seqlen_k, - self.scale, None, True) - return output - - # Use paged attention. - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - - if get_ascend_device_type() == AscendDeviceType._310P: - # Do reformat in case of broadcasted tensors. - attn_metadata.attn_mask = \ - torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - attn_metadata.seq_lens = \ - attn_metadata.seq_lens.to(device=query.device) - - # TODO:The npu_fused_infer_attention_score op is planned to - # be utilized in a wider range in upcoming versions. - num_block, block_size, _, _ = self.key_cache.shape # type: ignore - key = self.key_cache.view( # type: ignore - num_block, block_size, -1) - value = self.value_cache.view( # type: ignore - num_block, block_size, -1) - - output, _ = torch_npu.npu_fused_infer_attention_score( - query=query, - key=key, - value=value, - atten_mask=attn_metadata.attn_mask, - block_table=attn_metadata.block_tables, - input_layout="TND", - block_size=block_size, - actual_seq_lengths=attn_metadata.actual_seq_lengths_q, - actual_seq_lengths_kv=attn_metadata.seq_lens_list, - num_key_value_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale=self.scale, - sparse_mode=3, - ) - return output - def _attention_with_nomask_and_mask(self, q: torch.Tensor, q_seqlens: List[int], k_nomask: torch.Tensor, @@ -1464,6 +1345,31 @@ def _load_kv_for_chunk(self, attn_metadata, kv_cache, ) return key, value + def _forward_encode( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + ) -> torch.Tensor: + cum_seq_len = attn_metadata.query_start_loc[1:].tolist() + output = torch_npu.npu_fusion_attention( + query, + key, + value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=4, + atten_mask=attn_metadata.attn_mask, + pre_tockens=attn_metadata.max_query_len, + next_tockens=attn_metadata.max_query_len, + actual_seq_qlen=cum_seq_len, + actual_seq_kvlen=cum_seq_len, + )[0] + return output + def forward( self, layer: AttentionLayer, @@ -1494,24 +1400,16 @@ def forward( "fused output quantization is not yet supported" " for AscendAttentionBackendImpl") - num_tokens = query.shape[0] - if attn_metadata is None: - return output - - # NOTE: Currently, we have various attention paths for different - # scenarios, and not all of them are in-place operations. Therefore, - # we need to create a separate tensor to hold the attention result. - # In the future, we may consolidate them into fewer paths, which will - # hopefully allow us to use in-place operation by default. - intermediate_output: torch.Tensor - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - attn_type = self.attn_type - if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY: + if self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_ONLY: raise NotImplementedError("Encoder/decoder cross-attention " "are not implemented for " "PallasAttentionBackendImpl") + num_tokens = query.shape[0] + if attn_metadata is None: + return output.fill_(0) + num_decode_tokens = attn_metadata.num_decode_tokens has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 @@ -1558,48 +1456,25 @@ def forward( forward_context: ForwardContext = get_forward_context() if not forward_context.capturing: if self.pcp_size * self.dcp_size > 1: - intermediate_output = self._forward_pcp_dcp( - query, key, value, kv_cache, attn_metadata, output) - elif attn_type == AttentionType.ENCODER_ONLY: - # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. - cum_seq_len = attn_metadata.query_start_loc[1:].tolist() - intermediate_output = torch_npu.npu_fusion_attention( - query, - key, - value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - sparse_mode=4, - atten_mask=attn_metadata.attn_mask, - pre_tockens=attn_metadata.max_query_len, - next_tockens=attn_metadata.max_query_len, - actual_seq_qlen=cum_seq_len, - actual_seq_kvlen=cum_seq_len, - )[0] - # V0-Style scheduler situation. - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - intermediate_output = self._forward_prefill_no_cache( - query, key, value, attn_metadata, output, num_tokens) - elif attn_metadata.attn_state == \ - AscendAttentionState.PrefillCacheHit: - intermediate_output = self._forward_prefill_cache_hit( - query, attn_metadata, output) - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - intermediate_output = self._forward_decode_only( - query, attn_metadata, output) - # Normal V1 situation. + attn_output = self._forward_pcp_dcp(query, key, value, + kv_cache, attn_metadata, + output) + output[:num_tokens] = attn_output[:num_tokens] + return output + if self.attn_type == AttentionType.ENCODER_ONLY: + attn_output = self._forward_encode(query, key, value, + attn_metadata, output) + output[:num_tokens] = attn_output[:num_tokens] + return output + if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + output = self._forward_decode_only(query, attn_metadata, + output) else: - # npu_fused_infer_attention_score does not support cases - # where query.shape[0] != attn_metadata.query_start_loc[-1]. - # Thus we need unpad it here. - num_tokens = attn_metadata.query_start_loc[-1] - query = query[:num_tokens] - intermediate_output = self._forward_v1_style( - query, attn_metadata, output) + output = self._forward_prefill(query, key, value, + attn_metadata, output) else: - intermediate_output, num_tokens = self.full_graph_attention( + attn_output, num_tokens = self.full_graph_attention( query, key, value, kv_cache, attn_metadata, output) - output[:num_tokens] = intermediate_output[:num_tokens] + output[:num_tokens] = attn_output[:num_tokens] return output diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 188e66a5948..5d341d032a2 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -887,15 +887,16 @@ def __init__( ).device_group if self.tp_size > 1 else None def _v_up_proj(self, x): - if self.W_UV.shape[0] * self.W_UV.shape[ - 1] < 65536 and not self.dcp_size * self.pcp_size > 1: + if x.dtype in [torch.float16, torch.bfloat16] \ + and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \ + and not self.dcp_size * self.pcp_size > 1: x = x.view(-1, self.num_heads, self.kv_lora_rank) - x = torch_npu.npu_transpose_batchmatmul(x, - self.W_UV, - perm_x1=[1, 0, 2], - perm_x2=[0, 1, 2], - perm_y=[1, 0, 2]) - x = x.reshape(-1, self.num_heads * self.v_head_dim) + b, _, _ = x.shape + res = torch.empty((b, self.num_heads, self.v_head_dim), + dtype=x.dtype, + device=x.device) + torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) + x = res.reshape(-1, self.num_heads * self.v_head_dim) else: # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -923,8 +924,10 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: - if hasattr(layer, attr): + try: return getattr(layer, attr) + except AttributeError: + pass raise AttributeError( f"Layer '{layer}' has no recognized weight attribute:" f" {WEIGHT_NAMES}.") @@ -1674,6 +1677,8 @@ def forward( forward_context = get_forward_context() if (self.enable_mlapo and (attn_metadata is None or not forward_context.with_prefill)): + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states.contiguous(), need_gather_q_kv) decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( hidden_states, kv_cache, attn_metadata) else: diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 874ee39286e..1b979a89dc7 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -273,8 +273,10 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: - if hasattr(layer, attr): + try: return getattr(layer, attr) + except AttributeError: + pass raise AttributeError( f"Layer '{layer}' has no recognized weight attribute:" f" {WEIGHT_NAMES}.") @@ -455,7 +457,7 @@ def forward( kv_cache=kv_cache, attn_metadata=attn_metadata, need_gather_q_kv=need_gather_q_kv) - attn_output = torch.ops.custom.npu_sparse_flash_attention( + attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( query=ql_nope, key=k_nope, value=k_nope, @@ -535,7 +537,7 @@ def indexer_select( seq_lens = attn_metadata.seq_lens cum_query_lens = attn_metadata.cum_query_lens - topk_indices = torch.ops.custom.npu_lightning_indexer( + topk_indices = torch.ops._C_ascend.npu_lightning_indexer( query=q, key=kv_cache[2], weights=weights, diff --git a/vllm_ascend/distributed/kvpool/ascend_store_connector.py b/vllm_ascend/distributed/kvpool/ascend_store_connector.py index 9f4833555db..4107afdfab5 100644 --- a/vllm_ascend/distributed/kvpool/ascend_store_connector.py +++ b/vllm_ascend/distributed/kvpool/ascend_store_connector.py @@ -43,8 +43,6 @@ def __init__(self, self.kv_caches: dict[str, torch.Tensor] = {} - self._block_size = vllm_config.cache_config.block_size - self.sended_but_unfinished_reqs: set[str] = set() if role == KVConnectorRole.SCHEDULER: diff --git a/vllm_ascend/distributed/kvpool/config_data.py b/vllm_ascend/distributed/kvpool/config_data.py index e3b0873d686..0d89021bb3a 100644 --- a/vllm_ascend/distributed/kvpool/config_data.py +++ b/vllm_ascend/distributed/kvpool/config_data.py @@ -17,6 +17,10 @@ class KeyMetadata: model_name: str """ worker id when running under a distributed setting """ head_or_tp_rank: int + """ Initialize the current prefill context model parallel rank """ + pcp_rank: int + """ Initialize the current decode context model parallel rank """ + dcp_rank: int @dataclass(order=True) @@ -28,12 +32,15 @@ def __hash__(self): return hash(( self.key_metadata.model_name, self.key_metadata.head_or_tp_rank, + self.key_metadata.pcp_rank, + self.key_metadata.dcp_rank, self.chunk_hash, )) def to_string(self): return ( f"{self.key_metadata.model_name}" + f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}" f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}" ) @@ -60,6 +67,8 @@ def __hash__(self): return hash(( self.key_metadata.model_name, self.key_metadata.head_or_tp_rank, + self.key_metadata.pcp_rank, + self.key_metadata.dcp_rank, self.chunk_hash, self.layer_id, )) @@ -67,6 +76,7 @@ def __hash__(self): def to_string(self): return ( f"{self.key_metadata.model_name}" + f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}" f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}" ) diff --git a/vllm_ascend/distributed/kvpool/kv_transfer.py b/vllm_ascend/distributed/kvpool/kv_transfer.py index b30158ae8c2..0265d6a320c 100644 --- a/vllm_ascend/distributed/kvpool/kv_transfer.py +++ b/vllm_ascend/distributed/kvpool/kv_transfer.py @@ -19,11 +19,13 @@ class KVTransferThread(threading.Thread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, ready_event: threading.Event, name: str): + tp_rank: int, dcp_size: int, ready_event: threading.Event, + name: str): super().__init__(daemon=True, name=name) self.m_store = m_store self.ready_event = ready_event self.tp_rank = tp_rank + self.dcp_size = dcp_size self.token_database = token_database self.done_task_lock = threading.Lock() self.request_queue: queue.Queue[Any] = queue.Queue() @@ -87,10 +89,12 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreSendingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, put_step: int, ready_event: threading.Event): + tp_rank: int, dcp_size: int, put_step: int, + ready_event: threading.Event): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheSendingThread") self.put_step = put_step @@ -112,12 +116,16 @@ def _handle_request(self, req_meta: dict[str, Any]): key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) - key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] - addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] - size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] - if key_list_tp: - torch.npu.current_stream().synchronize() - self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) + if self.dcp_size > 1: + self.m_store.put(key_list, addr_list, size_list) + else: + key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % + self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % + self.put_step::self.put_step] + if key_list_tp: + self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) if is_last_chunk: self.set_finished_request(req_id) self.request_queue.task_done() @@ -126,10 +134,11 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreRecvingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, ready_event: threading.Event): + tp_rank: int, dcp_size: int, ready_event: threading.Event): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheStoreRecvingThread") @@ -166,11 +175,12 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreLayerSendingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, put_step: int, ready_event: threading.Event, - num_layers: int): + tp_rank: int, dcp_size: int, put_step: int, + ready_event: threading.Event, num_layers: int): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheStoreLayerSendingThread") self.final_layer_id = num_layers - 1 @@ -192,12 +202,16 @@ def _handle_request( # type: ignore[override] key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) - key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] - addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] - size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] - if key_list_tp: - torch.npu.current_stream().synchronize() - self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) + if self.dcp_size > 1: + self.m_store.put(key_list, addr_list, size_list) + else: + key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % + self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % + self.put_step::self.put_step] + if key_list_tp: + self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk: self.set_finished_request(req_meta.req_id) self.request_queue.task_done() @@ -206,11 +220,12 @@ def _handle_request( # type: ignore[override] class KVCacheStoreLayerRecvingThread(KVTransferThread): def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - tp_rank: int, ready_event: threading.Event, + tp_rank: int, dcp_size: int, ready_event: threading.Event, get_event: threading.Event): super().__init__(m_store, token_database, tp_rank, + dcp_size, ready_event, name="KVCacheStoreLayerRecvingThread") self.get_event = get_event diff --git a/vllm_ascend/distributed/kvpool/pool_scheduler.py b/vllm_ascend/distributed/kvpool/pool_scheduler.py index 06041b5a6e5..e4274becf07 100644 --- a/vllm_ascend/distributed/kvpool/pool_scheduler.py +++ b/vllm_ascend/distributed/kvpool/pool_scheduler.py @@ -29,7 +29,16 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise): "load_async", False) # request_id -> (vllm cached tokes, kvpool cached tokens) self.load_specs: dict[str, LoadSpec] = {} + self.pcp_size = getattr(vllm_config.parallel_config, + "prefill_context_parallel_size", 1) + self.dcp_size = getattr(vllm_config.parallel_config, + "decode_context_parallel_size", 1) + self._block_size = vllm_config.cache_config.block_size + if self.pcp_size > 1: + self._block_size *= self.pcp_size + if self.dcp_size > 1: + self._block_size *= self.dcp_size # request_id -> full_token_ids self._request_trackers: dict[str, RequestTracker] = {} # Whether to discard partial chunks diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index b03d2808928..25322c5f75d 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -1,11 +1,13 @@ -# Standard import math import threading from typing import Dict, Generator, Optional, Type -# Third Party import torch from vllm.config import VllmConfig +from vllm.distributed import (get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.utils import logger from vllm.v1.core.kv_cache_utils import BlockHash @@ -20,6 +22,14 @@ from vllm_ascend.distributed.kvpool.kv_transfer import ( KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) +from vllm_ascend.utils import prefill_context_parallel_enable + +if prefill_context_parallel_enable(): + # isort: off + from vllm.distributed import (get_prefill_context_model_parallel_rank, + get_prefill_context_model_parallel_world_size + ) + # isort: on backend_map: Dict[str, Type[Backend]] = { "mooncake": MooncakeBackend, @@ -44,17 +54,30 @@ def __init__( and model_config.use_mla): self.use_mla = True self.use_layerwise = use_layerwize - self.tp_rank = parallel_config.rank - self.tp_size = parallel_config.tensor_parallel_size + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + self.pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + self.pcp_rank = get_prefill_context_model_parallel_rank( + ) if self.pcp_size > 1 else 0 + self.dcp_size = get_decode_context_model_parallel_world_size() + self.dcp_rank = get_decode_context_model_parallel_rank( + ) if self.dcp_size > 1 else 0 + self.kv_role = vllm_config.kv_transfer_config.kv_role self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "backend", "mooncake") self.block_size = vllm_config.cache_config.block_size + + if self.pcp_size > 1: + self.block_size *= self.pcp_size + if self.dcp_size > 1: + self.block_size *= self.dcp_size self.current_layer = 0 self.num_layers = model_config.get_num_layers(parallel_config) - self.block_size = vllm_config.cache_config.block_size if self.use_mla: self.num_kv_head = 1 @@ -69,8 +92,10 @@ def __init__( self.put_step = 1 self.metadata = KeyMetadata( - model_config.model, + model_config.model.split('/')[-1], self.head_or_tp_rank, + self.pcp_rank, + self.dcp_rank, ) self.token_database = ChunkedTokenDatabase(self.metadata, @@ -147,12 +172,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreLayerSendingThread( self.m_store, self.token_database, self.tp_rank, - self.put_step, ready_event_sending, self.num_layers) + self.dcp_size, self.put_step, ready_event_sending, + self.num_layers) self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreLayerRecvingThread( - self.m_store, self.token_database, self.tp_rank, ready_event, - self.get_event) + self.m_store, self.token_database, self.tp_rank, self.dcp_size, + ready_event, self.get_event) self.kv_recv_thread.start() ready_event.wait() else: @@ -160,13 +186,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreSendingThread( self.m_store, self.token_database, self.tp_rank, - self.put_step, ready_event_sending) + self.dcp_size, self.put_step, ready_event_sending) self.kv_send_thread.start() if self.load_async: ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreRecvingThread( self.m_store, self.token_database, self.tp_rank, - ready_event) + self.dcp_size, ready_event) self.kv_recv_thread.start() ready_event.wait() diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 5c5a0a5bef3..61f5d7a1164 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -107,6 +107,7 @@ def __init__(self, kv_cache_config: Optional[KVCacheConfig] = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id + self._connector_metadata = LLMDataDistCMgrConnectorMetadata() if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: Optional[ LLMDataDistCMgrConnectorScheduler] = LLMDataDistCMgrConnectorScheduler( diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 9b5dde0fee1..00de0627b4c 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -3,7 +3,8 @@ import torch from vllm.config import ParallelConfig, get_current_vllm_config from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group, - get_tp_group, get_world_group, + get_pp_group, get_tp_group, + get_world_group, init_model_parallel_group) import vllm_ascend.envs as envs_ascend @@ -185,6 +186,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): ).flashcomm2_oproj_tensor_parallel_size global_tp_size = get_tp_group().world_size global_dp_size = get_dp_group().world_size + global_pp_size = get_pp_group().world_size num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size // flashcomm2_otp_size) @@ -197,18 +199,27 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): if flashcomm2_otp_size > 1: otp_group_ranks = [] odp_group_ranks: list[list[int]] = [ - [] for _ in range(flashcomm2_otp_size * global_dp_size) + [] for _ in range(flashcomm2_otp_size * global_dp_size * + global_pp_size) ] - for dp_group_index in range(global_dp_size): - for i in range(num_fc2_oproj_tensor_parallel_groups): - ranks = [] - for j in range(flashcomm2_otp_size): - rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups - ranks.append(rank_idx) - odp_group_index = dp_group_index * flashcomm2_otp_size + j - odp_group_ranks[odp_group_index].append(rank_idx) - otp_group_ranks.append(ranks) + for pp_group_index in range(global_pp_size): + dp_pp_serial_index = dp_group_index * global_pp_size + pp_group_index + tp_base_rank = dp_pp_serial_index * global_tp_size + odp_base_index = dp_pp_serial_index * flashcomm2_otp_size + + for i in range(num_fc2_oproj_tensor_parallel_groups): + ranks = [] + for j in range(flashcomm2_otp_size): + tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups + assert tp_local_rank < global_tp_size + global_rank = tp_base_rank + tp_local_rank + ranks.append(global_rank) + + odp_group_index = odp_base_index + j + odp_group_ranks[odp_group_index].append( + global_rank) + otp_group_ranks.append(ranks) _FLASHCOMM2_OTP = init_model_parallel_group( otp_group_ranks, diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 726763013f4..47a99d1b10d 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -44,11 +44,22 @@ def __init__(self, model, **args): self.init_redundancy_expert = get_ascend_config( ).init_redundancy_expert + for i in range(self.num_dense_layers, + self.model.config.num_hidden_layers): + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = \ + self.model.model.layers[i].mlp.experts.w13_weight_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = \ + self.model.model.layers[i].mlp.experts.w2_weight_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = \ + self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list + self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \ + self.model.model.layers[i].mlp.experts.w2_weight_scale_list # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here if self.model.quant_config is not None: self.expert_weight_names = [ - "w13_weight", "w2_weight", "w13_weight_scale", - "w13_weight_offset", "w2_weight_scale", "w2_weight_offset" + "w13_weight_list", "w2_weight_list", + "w13_weight_scale_fp32_list", "w13_weight_offset", + "w2_weight_scale_list", "w2_weight_offset" ] else: self.expert_weight_names = ["w13_weight", "w2_weight"] @@ -84,9 +95,14 @@ def init_buffer_tensor(self, num_buffer_tensor): for name in self.expert_weight_names: complete_name = "model.layers." + str( self.num_dense_layers) + ".mlp.experts." + name - expert_tensor = self.param_dict[complete_name].data[0] - if name in ["w13_weight", "w2_weight"]: + if name in [ + "w13_weight_list", "w2_weight_list", + "w13_weight_scale_fp32_list", "w2_weight_scale_list" + ]: + expert_tensor = self.param_dict[complete_name][0] expert_tensor = expert_tensor.clone() + else: + expert_tensor = self.param_dict[complete_name][0].data[0] buffer_tensor = torch.empty_like(expert_tensor) self.buffer_tensor_list[buffer_id].append(buffer_tensor) @@ -97,12 +113,23 @@ def init_expert_param_per_layer(self): layer_idx = self.num_dense_layers + moe_layer_id self.expert_param_per_layer[layer_idx] = list() for local_expert_id in range(num_local_expert): - self.expert_param_per_layer[layer_idx].append([ - self.param_dict["model.layers." + str(layer_idx) + - ".mlp.experts." + - name].data[local_expert_id] - for name in self.expert_weight_names - ]) + per_expert_param = list() + for name in self.expert_weight_names: + if name in [ + "w13_weight_list", "w2_weight_list", + "w13_weight_scale_fp32_list", + "w2_weight_scale_list" + ]: + per_expert_param.append( + self.param_dict["model.layers." + str(layer_idx) + + ".mlp.experts." + + name][local_expert_id]) + else: + per_expert_param.append( + self.param_dict["model.layers." + str(layer_idx) + + ".mlp.experts." + + name][0].data[local_expert_id]) + self.expert_param_per_layer[layer_idx].append(per_expert_param) def get_rank_expert_workload(self) -> torch.Tensor: self.moe_load = self.model.get_all_moe_loads() @@ -194,15 +221,15 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str): json.dump(record, f, indent=4) def do_update_expert_map(self, layer_id, updated_expert_map): - self.expert_map_per_layer[layer_id] = updated_expert_map.clone() - self.expert_map_per_layer_cpu[layer_id] = updated_expert_map.clone() + self.expert_map_per_layer[layer_id].copy_(updated_expert_map) + self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id): for expert_tensor, buffer_tensor in zip( self.expert_param_per_layer[layer_id][local_expert_to_replace], self.buffer_tensor_list[buffer_tensor_id]): - expert_tensor = buffer_tensor.clone() + expert_tensor.copy_(buffer_tensor) logger.debug(f"Expert tensor shape is :{expert_tensor.shape}") def do_update_log2phy_map(self, layer_id, updated_log2phy_map): diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 31eae8d7cbe..b1957fe8f04 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -2,28 +2,9 @@ def register_model(): - ModelRegistry.register_model( - "Qwen2VLForConditionalGeneration", - "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") - - ModelRegistry.register_model( - "Qwen3VLMoeForConditionalGeneration", - "vllm_ascend.models.qwen3_vl:AscendQwen3VLMoeForConditionalGeneration") - - ModelRegistry.register_model( - "Qwen3VLForConditionalGeneration", - "vllm_ascend.models.qwen3_vl:AscendQwen3VLForConditionalGeneration") - # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. ModelRegistry.register_model( "PanguProMoEForCausalLM", "vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM" ) - - ModelRegistry.register_model( - "Qwen3NextForCausalLM", - "vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM") - - ModelRegistry.register_model( - "Qwen3NextMTP", "vllm_ascend.models.qwen3_next_mtp:CustomQwen3NextMTP") diff --git a/vllm_ascend/models/layers/__init__.py b/vllm_ascend/models/layers/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py deleted file mode 100644 index f24f9823648..00000000000 --- a/vllm_ascend/models/qwen2_vl.py +++ /dev/null @@ -1,373 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from vllm/model_executor/models/qwen2_vl.py -# This file is a part of the vllm-ascend project. - -from collections.abc import Iterable -from functools import partial -from typing import Callable, Optional, Set, Tuple, Type - -import torch -import torch.nn as nn -import torch_npu -from einops import rearrange -from transformers.models.qwen2_vl.configuration_qwen2_vl import \ - Qwen2VLVisionConfig -from vllm.config import VllmConfig -from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import QuickGELU -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.qwen2_vl import ( - Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionPatchEmbed, - Qwen2VisionTransformer, Qwen2VLDummyInputsBuilder, - Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor, - Qwen2VLProcessingInfo) -from vllm.model_executor.models.utils import maybe_prefix -from vllm.model_executor.models.vision import conv3d_to_linear_weight -from vllm.multimodal import MULTIMODAL_REGISTRY - -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, is_enable_nz - -MIN_PAD_SIZE = 64 # min_size to pad weight -MAX_PAD_SIZE = 128 # max_size to pad weight - - -class AscendQwen2VisionAttention(Qwen2VisionAttention): - - def __init__( - self, - embed_dim: int, - num_heads: int, - projection_size: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__( - embed_dim, - num_heads, - projection_size, - quant_config, - prefix, - ) - self.cu_seqlens = None - self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) - self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head - if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: - self.hidden_size_per_attention_head = MAX_PAD_SIZE - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - - self.cu_seqlens = cu_seqlens - - # [s, b, c] --> [s, b, 3 * head * head_dim] - x, _ = self.qkv(x) - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] - q, k, v = self.split_qkv(x) - batch_size = q.shape[1] - - q, k, v = [ - rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) - ] - q = torch_npu.npu_rotary_mul(q, cos, sin) - k = torch_npu.npu_rotary_mul(k, cos, sin) - q, k, v = [ - rearrange(x, "b s h d -> (b s) h d").contiguous() - for x in (q, k, v) - ] - - context_layer = torch.empty_like(q) - - # operator requires pta version >= 2.5.1 - torch_npu._npu_flash_attention_unpad( - query=q, - key=k, - value=v, - seq_len=self.cu_seqlens, - scale_value=self.origin_hidden_size_per_attention_head**-0.5, - num_heads=self.num_attention_heads_per_partition, - num_kv_heads=self.num_attention_heads_per_partition, - out=context_layer) - context_layer = rearrange(context_layer, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() - - output, _ = self.proj(context_layer) - return output - - -class AscendQwen2VisionBlock(Qwen2VisionBlock): - - def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float, - act_layer: Type[nn.Module] = QuickGELU, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(dim, num_heads, mlp_ratio, act_layer, norm_layer, - quant_config, prefix) - self.attn = AscendQwen2VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), - cu_seqlens=cu_seqlens, - cos=cos, - sin=sin, - ) - - x = x + self.mlp(self.norm2(x)) - return x - - -class AscendQwen2VisionPatchEmbed(Qwen2VisionPatchEmbed): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.matmul( - self.proj.weight.data.view(self.embed_dim, -1).transpose(0, 1)) - return x - - -class AscendQwen2VisionTransformer(Qwen2VisionTransformer): - - def __init__( - self, - vision_config: Qwen2VLVisionConfig, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - interleaved=False, - ) -> None: - super().__init__(vision_config, norm_eps, quant_config, prefix) - - self.interleaved = interleaved - self.enable_pad = False - self.depth = vision_config.depth - self.hidden_size = vision_config.embed_dim - self.num_heads = vision_config.num_heads - self.patch_embed = AscendQwen2VisionPatchEmbed( - patch_size=vision_config.patch_size, - temporal_patch_size=vision_config.temporal_patch_size, - in_channels=vision_config.in_channels, - embed_dim=vision_config.embed_dim, - ) - - self.blocks = nn.ModuleList([ - AscendQwen2VisionBlock(dim=self.embed_dim, - num_heads=self.num_heads, - mlp_ratio=vision_config.mlp_ratio, - norm_layer=partial(nn.LayerNorm, - eps=norm_eps), - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(vision_config.depth) - ]) - - self.hidden_size_per_attention_head = dist_utils.divide( - self.hidden_size, self.num_heads) - - if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: - self.enable_pad = True - self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head - self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2 - self.half_pad_hidden_size_per_attention_head = ( - MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2 - self.hidden_size_per_attention_head = MAX_PAD_SIZE - - def cal_cos_sin(self, rotary_pos_emb): - cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] - sin = rotary_pos_emb.sin() - if self.enable_pad: - cos = torch.nn.functional.pad( - cos, (0, self.half_pad_hidden_size_per_attention_head)) - sin = torch.nn.functional.pad( - sin, (0, self.half_pad_hidden_size_per_attention_head)) - - if not self.interleaved: - cos_new = torch.cat((cos, cos), dim=-1) - sin_new = torch.cat((sin, sin), dim=-1) - else: - cos_new = rearrange(torch.stack((cos, cos), dim=-1), - "... d two -> ...(d two)", - two=2) - sin_new = rearrange(torch.stack((sin, sin), dim=-1), - "... d two -> ...(d two)", - two=2) - cos_new = cos_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - sin_new = sin_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - return cos_new, sin_new - - def pad_qkv_bias(self, bias): - first_half = bias.reshape( - -1, 3, self.origin_hidden_size_per_attention_head - )[:, :, :self.half_origin_hidden_size_per_attention_head] - second_half = bias.reshape( - -1, 3, self.origin_hidden_size_per_attention_head - )[:, :, self.half_origin_hidden_size_per_attention_head:] - first_half_padded = torch.nn.functional.pad( - first_half, (0, self.half_pad_hidden_size_per_attention_head)) - second_half_padded = torch.nn.functional.pad( - second_half, (0, self.half_pad_hidden_size_per_attention_head)) - bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2) - bias_final = bias_padded.reshape(-1) - return bias_final - - def pad_qkv_weight(self, data): - qkv_weight_first_half = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size - )[:, :, :self.half_origin_hidden_size_per_attention_head, :] - qkv_weight_second_half = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size - )[:, :, self.half_origin_hidden_size_per_attention_head:, :] - - qkv_weight_first_half_padded = torch.nn.functional.pad( - qkv_weight_first_half, - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) - qkv_weight_second_half_padded = torch.nn.functional.pad( - qkv_weight_second_half, - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) - qkv_weight_padded = torch.cat( - [qkv_weight_first_half_padded, qkv_weight_second_half_padded], - dim=2) - qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) - - if is_enable_nz(): - qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( - qkv_weight_final) - qkv_weight_final_copy = torch_npu.npu_format_cast( - qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND) - return qkv_weight_final_copy - - return qkv_weight_final - - def pad_proj_weight(self, data): - out_weight = torch.nn.functional.pad( - data.reshape(self.hidden_size, -1, - self.half_origin_hidden_size_per_attention_head), - (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( - self.hidden_size, -1) - - if is_enable_nz(): - out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) - out_weight_copy = torch_npu.npu_format_cast( - out_weight_copy, ACL_FORMAT_FRACTAL_ND) - return out_weight_copy - - return out_weight - - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - - for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - if ("attn.proj.weight" in name) and self.enable_pad: - param.data = self.pad_proj_weight(param.data) - if ("attn.qkv.weight" in name) and self.enable_pad: - param.data = self.pad_qkv_weight(param.data) - if ("attn.qkv.bias" in name) and self.enable_pad: - param.data = self.pad_qkv_bias(param.data) - loaded_params.add(name) - return loaded_params - - def forward( - self, - x: torch.Tensor, - grid_thw: torch.Tensor, - ) -> torch.Tensor: - grid_thw = torch.tensor(grid_thw, dtype=torch.int32) - # compute cu_seqlens and avoid cumsum to fit operator unpadFA - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, - 0]).cpu().to(torch.int32) - - # patchify - x = x.to(device=self.device, dtype=self.dtype) - x = self.patch_embed(x) - - # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) - cos, sin = self.cal_cos_sin(rotary_pos_emb) - - x = x.unsqueeze(1) - for blk in self.blocks: - x = blk(x, cu_seqlens=cu_seqlens, cos=cos, sin=sin) - - # adapter - x = self.merger(x) - return x - - -@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, - info=Qwen2VLProcessingInfo, - dummy_inputs=Qwen2VLDummyInputsBuilder) -class AscendQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - self.visual = AscendQwen2VisionTransformer( - self.config.vision_config, - norm_eps=getattr(self.config, "rms_norm_eps", 1e-6), - quant_config=vllm_config.quant_config, - prefix=maybe_prefix(prefix, "visual"), - ) diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py deleted file mode 100644 index b1d7b5444a9..00000000000 --- a/vllm_ascend/models/qwen3_next.py +++ /dev/null @@ -1,981 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# mypy: ignore-errors -"""Inference-only Qwen3Next model.""" -from collections.abc import Iterable -from typing import Optional - -import torch -from einops import rearrange -from torch import nn -from transformers.activations import ACT2FN -from vllm.attention import AttentionBackend, AttentionMetadata -from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, - VllmConfig, get_current_vllm_config) -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fla.ops import chunk -from vllm.model_executor.layers.fla.ops.fused_recurrent import \ - fused_recurrent_gated_delta_rule -from vllm.model_executor.layers.fused_moe import FusedMoE -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.layers.layernorm import \ - GemmaRMSNorm as Qwen3NextRMSNorm -from vllm.model_executor.layers.layernorm import RMSNormGated -# yapf: enable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba_mixer2 import \ - mamba_v2_sharded_weight_loader -from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) -from vllm.model_executor.layers.mamba.ops import causal_conv1d -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP -from vllm.model_executor.models.utils import ( - PPMissingLayer, extract_layer_index, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.utils import set_weight_attrs -from vllm.transformers_utils.configs import Qwen3NextConfig -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata - -from vllm.model_executor.models.qwen3_next import ( # isort: skip - Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM, - Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock, - fused_gdn_gating) - - -class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase): - - @property - def mamba_type(self) -> str: - return "linear_attention" - - def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend - return GDNAttentionBackend - - def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - self.model_config.dtype, self.cache_config.mamba_cache_dtype) - - def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: - return MambaStateShapeCalculator.gated_delta_net_state_shape( - self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim, - self.head_v_dim, self.conv_kernel_size, self.num_spec) - - def __init__( - self, - config: Qwen3NextConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prefix: str = "", - ) -> None: - nn.Module.__init__(self) - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.hidden_size = config.hidden_size - self.num_v_heads = config.linear_num_value_heads - self.num_k_heads = config.linear_num_key_heads - self.head_k_dim = config.linear_key_head_dim - self.head_v_dim = config.linear_value_head_dim - self.key_dim = self.head_k_dim * self.num_k_heads - self.value_dim = self.head_v_dim * self.num_v_heads - - self.conv_kernel_size = config.linear_conv_kernel_dim - self.layer_idx = extract_layer_index(prefix) - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] - self.layer_norm_epsilon = config.rms_norm_eps - self.prefix = prefix - - self.config = config - self.model_config = model_config - self.cache_config = cache_config - self.quant_config = quant_config - self.speculative_config = speculative_config - self.num_spec = (self.speculative_config.num_speculative_tokens - if self.speculative_config else 0) - - # QKV - self.conv_dim = self.key_dim * 2 + self.value_dim - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.conv_dim, - bias=False, - prefix=f"{prefix}.conv1d", - ) - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - # projection of the input hidden states - self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 - self.projection_size_ba = self.num_v_heads * 2 - self.in_proj = MergedColumnParallelLinear( - input_size=self.hidden_size, - output_sizes=[self.projection_size_qkvz, self.projection_size_ba], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj", - ) - - query_key_settings = (self.key_dim, 0, False) - value_settings = (self.value_dim, 0, False) - - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, { - "weight_loader": - mamba_v2_sharded_weight_loader([ - query_key_settings, - query_key_settings, - value_settings, - ], self.tp_size, self.tp_rank) - }) - - # selective projection used to make dt, B and C input dependent - - # time step projection (discretization) - # instantiate once and copy inv_dt in init_weights of PretrainedModel - self.dt_bias = nn.Parameter( - torch.ones(self.num_v_heads // self.tp_size), ) - self.A_log = nn.Parameter( - torch.empty( - divide(self.num_v_heads, self.tp_size), - dtype=torch.float32, - )) - - set_weight_attrs(self.A_log, - {"weight_loader": sharded_weight_loader(0)}) - set_weight_attrs(self.dt_bias, - {"weight_loader": sharded_weight_loader(0)}) - - self.norm = RMSNormGated( - self.head_v_dim, - eps=self.layer_norm_epsilon, - norm_before_gate=True, - device="npu", - ) - - self.out_proj = RowParallelLinear(self.value_dim, - self.hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj") - - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - - def forward( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - ): - """ - Forward pass with three parts: - 1. Input projection - 2. Core attention (custom op) - 3. Output projection - """ - num_tokens = hidden_states.size(0) - - # ============================================================ - # Part 1: Input Projection - # ============================================================ - - forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata - - if attn_metadata is None: - # V1 profile run - return - - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, GDNAttentionMetadata) - - num_actual_tokens = (attn_metadata.num_prefill_tokens + - attn_metadata.num_decode_tokens + - attn_metadata.num_spec_decode_tokens) - - # 1. Set up dimensions for reshapes later - projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) - projected_states_qkvz, projected_states_ba = torch.split( - projected_states, - [ - self.projection_size_qkvz // self.tp_size, - self.projection_size_ba // self.tp_size - ], - dim=-1, - ) - query, key, value, z, b, a = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba) - query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), - (query, key, value)) - mixed_qkv = torch.cat((query, key, value), dim=-1) - - # ============================================================ - # Part 2: Core Attention (Custom Op) - # ============================================================ - core_attn_out = torch.zeros( - (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - torch.ops.vllm.gdn_attention_core( - mixed_qkv, - b, - a, - core_attn_out, - self.prefix, - ) - - # ============================================================ - # Part 3: Output Projection - # ============================================================ - z_shape_og = z.shape - # Reshape input data into 2D tensor - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") - output[:num_tokens], _ = self.out_proj(core_attn_out) - - def _forward( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - ): - forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata - - if attn_metadata is None: - # V1 profile run - return - - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, GDNAttentionMetadata) - has_initial_state = attn_metadata.has_initial_state - spec_query_start_loc = attn_metadata.spec_query_start_loc - non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc - spec_sequence_masks = attn_metadata.spec_sequence_masks - spec_token_indx = attn_metadata.spec_token_indx - non_spec_token_indx = attn_metadata.non_spec_token_indx - spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 - non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - - num_actual_tokens = (attn_metadata.num_prefill_tokens + - attn_metadata.num_decode_tokens + - attn_metadata.num_spec_decode_tokens) - num_accepted_tokens = attn_metadata.num_accepted_tokens - - # 1. Set up dimensions for reshapes later - projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens]) - projected_states_qkvz, projected_states_ba = torch.split( - projected_states, - [ - self.projection_size_qkvz // self.tp_size, - self.projection_size_ba // self.tp_size - ], - dim=-1, - ) - query, key, value, z, b, a = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba) - query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), - (query, key, value)) - mixed_qkv = torch.cat((query, key, value), dim=-1) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - - if spec_sequence_masks is not None: - if (attn_metadata.num_prefills == 0 - and attn_metadata.num_decodes == 0): - mixed_qkv_spec = mixed_qkv - mixed_qkv_non_spec = None - else: - mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) - mixed_qkv_non_spec = mixed_qkv.index_select( - 0, non_spec_token_indx) - else: - mixed_qkv_spec = None - mixed_qkv_non_spec = mixed_qkv - - # 2.1: process the mutli-query part - if spec_sequence_masks is not None: - mixed_qkv_spec = mixed_qkv_spec.view( - attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) - mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') - mixed_qkv_spec = causal_conv1d.causal_conv1d_update( - mixed_qkv_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=spec_state_indices_tensor[:, 0] - [:attn_metadata.num_spec_decodes], - num_accepted_tokens=num_accepted_tokens, - validate_data=False, - ) - mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d') - - # 2.2: process the remaining part - if attn_metadata.num_prefills > 0: - # - "cache_indices" updates the conv_state cache in positions - # pointed to by "mamba_cache_params.state_indices_tensor" - mixed_qkv_non_spec = causal_conv1d.causal_conv1d_fn( - mixed_qkv_non_spec.transpose(0, 1), - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=conv_state, - has_initial_state=has_initial_state, - cache_indices=non_spec_state_indices_tensor, - query_start_loc=non_spec_query_start_loc, - ).transpose(0, 1) - elif attn_metadata.num_decodes > 0: - mixed_qkv_non_spec = causal_conv1d.causal_conv1d_update( - mixed_qkv_non_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=non_spec_state_indices_tensor[:attn_metadata - .num_decodes], - # validate_data=True, - ) - else: - mixed_qkv_non_spec = None - - query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( - mixed_qkv_spec) - query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( - mixed_qkv_non_spec) - - beta = b.sigmoid() - g = fused_gdn_gating(self.A_log, a, self.dt_bias) - g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta)) - - if spec_sequence_masks is not None: - if (attn_metadata.num_prefills == 0 - and attn_metadata.num_decodes == 0): - g_spec = g - beta_spec = beta - g_non_spec = None - beta_non_spec = None - else: - g_spec = g.index_select(1, spec_token_indx) - beta_spec = beta.index_select(1, spec_token_indx) - g_non_spec = g.index_select(1, non_spec_token_indx) - beta_non_spec = beta.index_select(1, non_spec_token_indx) - else: - g_spec = None - beta_spec = None - g_non_spec = g - beta_non_spec = beta - - # 3. Recurrent attention - # 3.1: process the mutlti-query part - if spec_sequence_masks is not None: - core_attn_out_spec, last_recurrent_state = ( - fused_recurrent_gated_delta_rule( - q=query_spec, - k=key_spec, - v=value_spec, - g=g_spec, - beta=beta_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=spec_query_start_loc[:attn_metadata. - num_spec_decodes + 1], - ssm_state_indices=spec_state_indices_tensor, - num_accepted_tokens=num_accepted_tokens, - use_qk_l2norm_in_kernel=True, - )) - else: - core_attn_out_spec, last_recurrent_state = None, None - - # 3.2: process the remaining part - if attn_metadata.num_prefills > 0: - initial_state = ssm_state[ - non_spec_state_indices_tensor].contiguous() - initial_state[~has_initial_state, ...] = 0 - - ( - core_attn_out_non_spec, - last_recurrent_state, - ) = chunk.chunk_gated_delta_rule( - q=query_non_spec, - k=key_non_spec, - v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, - initial_state=initial_state, - output_final_state=True, - cu_seqlens=non_spec_query_start_loc, - head_first=False, - use_qk_l2norm_in_kernel=True) - - # Init cache - ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( - ssm_state.dtype) - elif attn_metadata.num_decodes > 0: - core_attn_out_non_spec, last_recurrent_state = ( - fused_recurrent_gated_delta_rule( - q=query_non_spec, - k=key_non_spec, - v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=non_spec_query_start_loc[:attn_metadata. - num_decodes + 1], - ssm_state_indices=non_spec_state_indices_tensor, - use_qk_l2norm_in_kernel=True, - )) - else: - core_attn_out_non_spec, last_recurrent_state = None, None - - # Merge core attention output - if (spec_sequence_masks is not None - and core_attn_out_non_spec is not None): - core_attn_out = torch.empty( - (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), - dtype=core_attn_out_non_spec.dtype, - device=core_attn_out_non_spec.device, - ) - core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec) - core_attn_out.index_copy_(1, non_spec_token_indx, - core_attn_out_non_spec) - elif spec_sequence_masks is not None: - core_attn_out = core_attn_out_spec - else: - core_attn_out = core_attn_out_non_spec - - z_shape_og = z.shape - # reshape input data into 2D tensor - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)') - - output[:num_actual_tokens], _ = self.out_proj(core_attn_out) - - def _forward_core( - self, - mixed_qkv: torch.Tensor, - b: torch.Tensor, - a: torch.Tensor, - core_attn_out: torch.Tensor, - ): - """ - Core attention computation (called by custom op). - """ - forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata - - if attn_metadata is None: - # V1 profile run - return - - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, GDNAttentionMetadata) - has_initial_state = attn_metadata.has_initial_state - spec_query_start_loc = attn_metadata.spec_query_start_loc - non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc - spec_sequence_masks = attn_metadata.spec_sequence_masks - spec_token_indx = attn_metadata.spec_token_indx - non_spec_token_indx = attn_metadata.non_spec_token_indx - spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 - non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - - num_actual_tokens = (attn_metadata.num_prefill_tokens + - attn_metadata.num_decode_tokens + - attn_metadata.num_spec_decode_tokens) - num_accepted_tokens = attn_metadata.num_accepted_tokens - - mixed_qkv = mixed_qkv[:num_actual_tokens] - b = b[:num_actual_tokens] - a = a[:num_actual_tokens] - - # 1. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - - if spec_sequence_masks is not None: - if (attn_metadata.num_prefills == 0 - and attn_metadata.num_decodes == 0): - mixed_qkv_spec = mixed_qkv - mixed_qkv_non_spec = None - else: - mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) - mixed_qkv_non_spec = mixed_qkv.index_select( - 0, non_spec_token_indx) - else: - mixed_qkv_spec = None - mixed_qkv_non_spec = mixed_qkv - - # 1.1: Process the multi-query part - if spec_sequence_masks is not None: - mixed_qkv_spec = mixed_qkv_spec.view( - attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) - mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') - mixed_qkv_spec = causal_conv1d.causal_conv1d_update( - mixed_qkv_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=spec_state_indices_tensor[:, 0] - [:attn_metadata.num_spec_decodes], - num_accepted_tokens=num_accepted_tokens, - validate_data=False, - ) - mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d') - - # 1.2: Process the remaining part - if attn_metadata.num_prefills > 0: - # - "cache_indices" updates the conv_state cache in positions - # pointed to by "mamba_cache_params.state_indices_tensor" - mixed_qkv_non_spec = causal_conv1d.causal_conv1d_fn( - mixed_qkv_non_spec.transpose(0, 1), - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=conv_state, - has_initial_state=has_initial_state, - cache_indices=non_spec_state_indices_tensor, - query_start_loc=non_spec_query_start_loc, - ).transpose(0, 1) - elif attn_metadata.num_decodes > 0: - mixed_qkv_non_spec = causal_conv1d.causal_conv1d_update( - mixed_qkv_non_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=non_spec_state_indices_tensor[:attn_metadata - .num_decodes], - # validate_data=True, - ) - else: - mixed_qkv_non_spec = None - - query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( - mixed_qkv_spec) - query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( - mixed_qkv_non_spec) - - beta = b.sigmoid() - g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) - - if spec_sequence_masks is not None: - if (attn_metadata.num_prefills == 0 - and attn_metadata.num_decodes == 0): - g_spec = g - beta_spec = beta - g_non_spec = None - beta_non_spec = None - else: - g_spec = g.index_select(1, spec_token_indx) - beta_spec = beta.index_select(1, spec_token_indx) - g_non_spec = g.index_select(1, non_spec_token_indx) - beta_non_spec = beta.index_select(1, non_spec_token_indx) - else: - g_spec = None - beta_spec = None - g_non_spec = g - beta_non_spec = beta - - # 2. Recurrent attention - - # 2.1: Process the multi-query part - if spec_sequence_masks is not None: - core_attn_out_spec, last_recurrent_state = ( - fused_recurrent_gated_delta_rule( - q=query_spec, - k=key_spec, - v=value_spec, - g=g_spec, - beta=beta_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=spec_query_start_loc[:attn_metadata. - num_spec_decodes + 1], - ssm_state_indices=spec_state_indices_tensor, - num_accepted_tokens=num_accepted_tokens, - use_qk_l2norm_in_kernel=True, - )) - else: - core_attn_out_spec, last_recurrent_state = None, None - - # 3.2: process the remaining part - if attn_metadata.num_prefills > 0: - initial_state = ssm_state[ - non_spec_state_indices_tensor].contiguous() - initial_state[~has_initial_state, ...] = 0 - - batch_size = initial_state.shape[0] - temp_core_attn_out = [] - last_recurrent_state = [] - - for b_idx in range(batch_size): - start, end = non_spec_query_start_loc[ - b_idx], non_spec_query_start_loc[b_idx + 1] - cur_q = query_non_spec[:, start:end, ...] - cur_k = key_non_spec[:, start:end, ...] - cur_v = value_non_spec[:, start:end, ...] - cur_g = g_non_spec[:, start:end, ...] - cur_b = beta_non_spec[:, start:end, ...] - cur_state = initial_state[b_idx].unsqueeze(0) - - ( - cur_core_attn_out_non_spec, - cur_last_recurrent_state, - ) = chunk.chunk_gated_delta_rule( - query=cur_q, - key=cur_k, - value=cur_v, - g=cur_g, - beta=cur_b, - initial_state=cur_state, - output_final_state=True, - use_qk_l2norm_in_kernel=True, - ) - - temp_core_attn_out.append(cur_core_attn_out_non_spec) - last_recurrent_state.append(cur_last_recurrent_state) - - tar_dtype = temp_core_attn_out[0].dtype - tar_device = temp_core_attn_out[0].device - tar_shape = list(temp_core_attn_out[0].shape) - tar_shape[1] = non_spec_query_start_loc[-1] - core_attn_out_non_spec = torch.empty(tar_shape, - dtype=tar_dtype, - device=tar_device) - for b_idx in range(batch_size): - cur_core_attn_out = temp_core_attn_out[b_idx] - start, end = non_spec_query_start_loc[ - b_idx], non_spec_query_start_loc[b_idx + 1] - core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out - last_recurrent_state = torch.cat(last_recurrent_state, dim=0) - - # Init cache - ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( - ssm_state.dtype) - elif attn_metadata.num_decodes > 0: - core_attn_out_non_spec, last_recurrent_state = ( - fused_recurrent_gated_delta_rule( - q=query_non_spec, - k=key_non_spec, - v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=non_spec_query_start_loc[:attn_metadata. - num_decodes + 1], - ssm_state_indices=non_spec_state_indices_tensor, - use_qk_l2norm_in_kernel=True, - )) - else: - core_attn_out_non_spec, last_recurrent_state = None, None - - # 3. Merge core attention output - if spec_sequence_masks is not None and core_attn_out_non_spec is not None: - merged_out = torch.empty( - (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), - dtype=core_attn_out_non_spec.dtype, - device=core_attn_out_non_spec.device, - ) - merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) - merged_out.index_copy_(1, non_spec_token_indx, - core_attn_out_non_spec) - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) - elif spec_sequence_masks is not None: - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze( - 0) - - -class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer): - - def __init__( - self, - vllm_config: VllmConfig, - layer_type: str, - prefix: str = "", - ) -> None: - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - speculative_config = vllm_config.speculative_config - - self.layer_type = layer_type - self.layer_idx = extract_layer_index(prefix) - - if self.layer_type == "linear_attention": - self.linear_attn = CustomQwen3NextGatedDeltaNet( - config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - speculative_config=speculative_config, - prefix=f'{prefix}.linear_attn') - elif self.layer_type == "full_attention": - self.self_attn = Qwen3NextAttention( - config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f'{prefix}.self_attn', - ) - else: - raise ValueError(f"Invalid layer_type {self.layer_type}") - - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) - if (self.layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 and - (self.layer_idx + 1) % config.decoder_sparse_step == 0): - self.mlp = Qwen3NextSparseMoeBlock(vllm_config=vllm_config, - prefix=f"{prefix}.mlp") - else: - self.mlp = Qwen3NextMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - ) - - self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen3NextRMSNorm( - config.hidden_size, eps=config.rms_norm_eps) - - self.layer_scale = getattr(config, "layer_scale", False) - if self.layer_scale: - self.attn_layer_scale = torch.nn.Parameter( - torch.zeros( - 1, - 1, - config.hidden_size, - dtype=config.torch_dtype, - ), ) - self.ffn_layer_scale = torch.nn.Parameter( - torch.zeros( - 1, - 1, - config.hidden_size, - dtype=config.torch_dtype, - ), ) - - -@support_torch_compile -class CustomQwen3NextModel(Qwen3NextModel): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - config: Qwen3NextConfig = vllm_config.model_config.hf_config - parallel_config = vllm_config.parallel_config - lora_config = vllm_config.lora_config - eplb_config = parallel_config.eplb_config - self.num_redundant_experts = eplb_config.num_redundant_experts - - self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab - - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - - def get_layer(prefix: str): - return CustomQwen3NextDecoderLayer( - vllm_config, - layer_type=config.layer_types[extract_layer_index(prefix)], - prefix=prefix, - ) - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - self.norm = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ("in_proj", "in_proj_qkvz", 0), - ("in_proj", "in_proj_ba", 1), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - expert_params_mapping = self.get_expert_mapping() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - if name.startswith("mtp."): - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - - if "mlp.experts" in name: - continue - - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # name = apply_attn_prefix(name, params_dict) - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class CustomQwen3NextForCausalLM(Qwen3NextForCausalLM): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - lora_config = vllm_config.lora_config - scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Qwen3Next currently does not support prefix caching" - self.quant_config = vllm_config.quant_config - self.config = config - self.scheduler_config = scheduler_config - self.model = CustomQwen3NextModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - # Set MoE hyperparameters - self.expert_weights = [] - - self.moe_layers: list[FusedMoE] = [] - example_layer = None - for layer in self.model.layers: - if isinstance(layer, PPMissingLayer): - continue - - assert isinstance(layer, Qwen3NextDecoderLayer) - if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): - example_layer = layer.mlp - self.moe_layers.append(layer.mlp.experts) - - if example_layer is None: - raise RuntimeError("No Qwen3Next layer found in the model.layers.") - - self.num_moe_layers = len(self.moe_layers) - self.num_expert_groups = 1 - self.num_shared_experts = 0 - self.num_logical_experts = example_layer.n_logical_experts - self.num_physical_experts = example_layer.n_physical_experts - self.num_local_physical_experts = example_layer.n_local_physical_experts - self.num_routed_experts = example_layer.n_routed_experts - self.num_redundant_experts = example_layer.n_redundant_experts diff --git a/vllm_ascend/models/qwen3_next_mtp.py b/vllm_ascend/models/qwen3_next_mtp.py deleted file mode 100644 index c17d969cb29..00000000000 --- a/vllm_ascend/models/qwen3_next_mtp.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Inference-only Qwen3Next MTP model.""" -import torch -from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig -from vllm.model_executor.layers.linear import ColumnParallelLinear -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.models.interfaces import SupportsPP -from vllm.model_executor.models.qwen3_next_mtp import ( - Qwen3NextMTP, Qwen3NextMultiTokenPredictor) -from vllm.model_executor.models.utils import ( - make_empty_intermediate_tensors_factory, maybe_prefix) -from vllm.transformers_utils.configs import Qwen3NextConfig - -from vllm_ascend.models.qwen3_next import (CustomQwen3NextDecoderLayer, - Qwen3NextRMSNorm) - - -@support_torch_compile -class CustomQwen3NextMultiTokenPredictor(Qwen3NextMultiTokenPredictor): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super(Qwen3NextMultiTokenPredictor, self).__init__() - - model_config = vllm_config.model_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - config: Qwen3NextConfig = model_config.hf_config - - self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.mtp_start_layer_idx = config.num_hidden_layers - self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) - - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - - self.fc = ColumnParallelLinear(self.config.hidden_size * 2, - self.config.hidden_size, - gather_output=True, - bias=False, - return_bias=False, - quant_config=quant_config, - prefix=f'{prefix}.fc') - - # use old version mtp layer name to avoid a exception in vllm - self.layers = torch.nn.ModuleList( - CustomQwen3NextDecoderLayer( - vllm_config, - layer_type="full_attention", - prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}', - ) for idx in range(self.num_mtp_layers)) - - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - self.norm = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - -@support_torch_compile -class CustomQwen3NextMTP(Qwen3NextMTP, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": ["up_proj", "down_proj"] - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - self.vllm_config = vllm_config - cache_config = vllm_config.cache_config - assert not cache_config.enable_prefix_caching, \ - "Qwen3NextMTP currently does not support prefix caching" - - self.quant_config = vllm_config.quant_config - - super(Qwen3NextMTP, self).__init__() - self.config = config - self.model = CustomQwen3NextMultiTokenPredictor( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.unpadded_vocab_size = config.vocab_size - self.lm_head = ParallelLMHead(self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - prefix=maybe_prefix(prefix, "lm_head")) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) diff --git a/vllm_ascend/models/qwen3_vl.py b/vllm_ascend/models/qwen3_vl.py deleted file mode 100644 index c79e71e7197..00000000000 --- a/vllm_ascend/models/qwen3_vl.py +++ /dev/null @@ -1,264 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from typing import Callable, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - -try: - from transformers.models.qwen3_vl.configuration_qwen3_vl import \ - Qwen3VLConfig - from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \ - Qwen3VLMoeConfig -except ImportError: - pass -from vllm.config import VllmConfig -from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention - -try: - from vllm.model_executor.models.qwen3_vl import ( - Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer, - Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration, - Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) - from vllm.model_executor.models.qwen3_vl_moe import ( - Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo) -except ImportError: - Qwen3_VisionBlock = object - Qwen3_VisionPatchEmbed = object - Qwen3_VisionTransformer = object - Qwen3VLDummyInputsBuilder = object - Qwen3VLForConditionalGeneration = object - Qwen3VLMultiModalProcessor = object - Qwen3VLProcessingInfo = object - Qwen3VLMoeForConditionalGeneration = object - Qwen3VLMoeProcessingInfo = object -from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix -from vllm.multimodal import MULTIMODAL_REGISTRY - - -class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.matmul( - self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) - x = x + self.proj.bias - return x - - -class AscendQwen3_VisionBlock(Qwen3_VisionBlock): - - def __init__( - self, - dim: int, - num_heads: int, - mlp_hidden_dim: int, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False, - ) -> None: - super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, - quant_config, prefix, use_data_parallel) - self.attn = Qwen2_5_VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, - cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) - - x = x + self.mlp(self.norm2(x)) - return x - - -class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer): - - def __init__( - self, - vision_config, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False, - ) -> None: - super().__init__(vision_config, norm_eps, quant_config, prefix, - use_data_parallel) - norm_layer = partial(nn.LayerNorm, eps=norm_eps) - self.patch_embed = AscendQwen3_VisionPatchEmbed( - patch_size=self.patch_size, - temporal_patch_size=self.temporal_patch_size, - in_channels=vision_config.in_channels, - hidden_size=self.hidden_size, - ) - self.blocks = nn.ModuleList([ - AscendQwen3_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(vision_config.depth) - ]) - self.hidden_size_per_attention_head = dist_utils.divide( - self.hidden_size, self.num_heads) - - def cal_cos_sin(self, rotary_pos_emb): - cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] - sin = rotary_pos_emb.sin() - cos_new = torch.cat((cos, cos), dim=-1) - sin_new = torch.cat((sin, sin), dim=-1) - cos_new = cos_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - sin_new = sin_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - return cos_new, sin_new - - def forward( - self, - x: torch.Tensor, - grid_thw: list[list[int]], - ) -> torch.Tensor: - hidden_states = x.to(device=self.device, dtype=self.dtype) - hidden_states = self.patch_embed(hidden_states) - - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) - grid_thw_tensor = torch.tensor(grid_thw, - device=self.device, - dtype=torch.int32) - cu_seqlens = torch.repeat_interleave( - grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], - grid_thw_tensor[:, 0]).cpu().to(torch.int32) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - - hidden_states = hidden_states.unsqueeze(1) - rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) - - cos, sin = self.cal_cos_sin(rotary_pos_emb) - - deepstack_feature_lists = [] - for layer_num, blk in enumerate(self.blocks): - hidden_states = blk(hidden_states, - cu_seqlens=cu_seqlens, - cos=cos, - sin=sin) - if layer_num in self.deepstack_visual_indexes: - deepstack_merger_idx = self.deepstack_visual_indexes.index( - layer_num) - deepstack_feature = self.deepstack_merger_list[ - deepstack_merger_idx](hidden_states) - deepstack_feature_lists.append(deepstack_feature) - hidden_states = self.merger(hidden_states) - hidden_states = torch.cat( - [hidden_states] + deepstack_feature_lists, - dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] - return hidden_states - - -@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, - info=Qwen3VLProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder) -class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - supports_encoder_tp_data = True - - # To ensure correct weight loading and mapping. - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.visual.": "visual.", - "lm_head.": "language_model.lm_head.", - "model.language_model.": "language_model.model.", - }) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen3VLConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.visual = AscendQwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel) - - -@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, - info=Qwen3VLMoeProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder) -class AscendQwen3VLMoeForConditionalGeneration( - Qwen3VLMoeForConditionalGeneration): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - supports_encoder_tp_data = True - - # To ensure correct weight loading and mapping. - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.visual.": "visual.", - "lm_head.": "language_model.lm_head.", - "model.language_model.": "language_model.model.", - }) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config - self.multimodal_config = multimodal_config - self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.visual = AscendQwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - ) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index b9667abbccb..945ea19743c 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Optional import torch +import torch.nn.functional as F import torch_npu from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, @@ -183,6 +184,9 @@ def __init__(self, *args, **kwargs): # init moe. self.local_num_experts, self.expert_map, _ = determine_expert_map( self.ep_size, self.ep_rank, self.global_num_experts) + # TODO: Temporary flag to indicate if static EPLB is enabled. This is a + # workaround to bypass a quantization check that fails with float weights. + init_eplb_enable = False # static eplb initializing with expert_map_path if self.expert_map_path and os.path.exists( self.expert_map_path) and os.access(self.expert_map_path, @@ -199,6 +203,7 @@ def __init__(self, *args, **kwargs): self.moe_instance_id, self.ep_rank)) self.log2phy = self.expert_load_balancer.get_rank_log2phy_map( self.moe_instance_id, self.ep_rank).npu() + init_eplb_enable = True except Exception as e: logger.warning( f"Init expert map of mtp/eagle when using sample.{e}") @@ -224,10 +229,10 @@ def __init__(self, *args, **kwargs): self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64).npu() - eplb_enable = self.dynamic_eplb or (self.expert_map_path is not None) - if eplb_enable and (not hasattr(self.quant_method, "quant_method") or - not isinstance(self.quant_method.quant_method, - AscendW8A8DynamicFusedMoEMethod)): + if init_eplb_enable and ( + not hasattr(self.quant_method, "quant_method") + or not isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod)): raise ValueError("Eplb supports only w8a8_dynamic quantization.") self.moe_config.num_experts = self.global_num_experts @@ -275,7 +280,7 @@ def get_map(self): return self.expert_map def get_log2phy_map(self): - return self.logical_to_physical_map + return self.log2phy def clear_moe_load(self): if self.moe_load is not None: @@ -292,6 +297,32 @@ def maybe_all_reduce_tensor_model_parallel( return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( final_hidden_states) + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + og_hidden_states = hidden_states.shape[-1] + if self.hidden_size != og_hidden_states: + hidden_states = F.pad( + hidden_states, + (0, self.hidden_size - og_hidden_states), + mode="constant", + value=0.0, + ) + if self.shared_experts is None: + fused_output = torch.ops.vllm.moe_forward(hidden_states, + router_logits, + self.layer_name) + return fused_output[..., :og_hidden_states] + else: + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name) + return ( + shared_output[..., :og_hidden_states], + fused_output[..., :og_hidden_states], + ) + def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None @@ -360,7 +391,7 @@ def forward_impl(self, hidden_states: torch.Tensor, def transpose_weight(self, loaded_weight, expert_data, shard_dim): # Ensure training and inference weight shapes match during RL weight updates - if ( + if (len(loaded_weight.shape) >= 2 and len(expert_data.shape) >= 2 and \ loaded_weight.shape[1] != expert_data.shape[1] and \ loaded_weight.shape[0] != expert_data.shape[0] ): diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index c48ce1a49be..802dbe5d66e 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -83,8 +83,8 @@ def finalize(self, def fused_experts( self, hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", @@ -93,8 +93,8 @@ def fused_experts( use_int4_w4a8: bool = False, global_num_experts: Optional[int] = None, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, + w1_scale: Optional[list[torch.Tensor]] = None, + w2_scale: Optional[list[torch.Tensor]] = None, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, # For TorchAir graph diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 07ba732f199..13e1efc0acd 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -23,7 +23,11 @@ from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.utils import (AscendDeviceType, dispose_tensor, - get_ascend_device_type) + enable_custom_op, get_ascend_device_type) + + +def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): + return fusion and dynamic_eplb and enable_custom_op() def cumsum_group_list(group_list: torch.Tensor, @@ -55,10 +59,10 @@ def cumsum_group_list(group_list: torch.Tensor, def quant_apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, + w1: list[torch.Tensor], + w1_scale: list[torch.Tensor], + w2: list[torch.Tensor], + w2_scale: list[torch.Tensor], group_list: torch.Tensor, group_list_type: int = 1, dynamic_scale: torch.Tensor = None, @@ -79,7 +83,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, quantized_hidden_states = hidden_states bias1, bias2 = None, None - _output_dtype = w2_scale.dtype + _output_dtype = w2_scale[0].dtype weight_prefetch_method = get_forward_context().weight_prefetch_method if weight_prefetch_method: @@ -87,23 +91,34 @@ def quant_apply_mlp(hidden_states: torch.Tensor, hidden_states) is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 if w1_scale_bias is None and is_mc2: - if fusion and not dynamic_eplb: + if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = ( + torch.ops._C_ascend. + grouped_matmul_swiglu_quant_weight_nz_tensor_list( + x=hidden_states, + weight=w1, + weight_scale=w1_scale, + x_scale=pertoken_scale, + group_list=cumsum_group_list(group_list, group_list_type), + )) + elif fusion and not dynamic_eplb: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( x=hidden_states, - weight=w1, + weight=w1[0], group_list=cumsum_group_list(group_list, group_list_type), - weight_scale=w1_scale, + weight_scale=w1_scale[0], x_scale=pertoken_scale) if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) else: - if w1_scale.dtype != torch.float32: - w1_scale = w1_scale.to(torch.float32) + if w1_scale[0].dtype != torch.float32: + w1_scale[0] = w1_scale[0].to(torch.float32) # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w1], + weight=w1, split_item=3, group_list_type=group_list_type, group_type=0, @@ -126,14 +141,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor, # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w2], - scale=[w2_scale], + weight=w2, + scale=w2_scale, per_token_scale=[swiglu_out_scale], split_item=2, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=w2_scale.dtype)[0] + output_dtype=w2_scale[0].dtype)[0] else: if w1_scale_bias is not None: if group_list_type == 0: @@ -146,23 +161,36 @@ def quant_apply_mlp(hidden_states: torch.Tensor, # TODO w4a8 scene: dynamic acquisition of dtype in the future _output_dtype = torch.bfloat16 - if fusion and not dynamic_eplb: + if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = ( + torch.ops._C_ascend. + grouped_matmul_swiglu_quant_weight_nz_tensor_list( + x=hidden_states, + weight=w1, + weight_scale=w1_scale, + x_scale=pertoken_scale, + group_list=cumsum_group_list(group_list, group_list_type), + bias=bias1, + )) + elif fusion and not dynamic_eplb: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( x=hidden_states, - weight=w1, + weight=w1[0], bias=bias1, group_list=cumsum_group_list(group_list, group_list_type), - weight_scale=w1_scale, + weight_scale=w1_scale[0], x_scale=pertoken_scale) if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) else: + w1_scale[0] = w1_scale[0].to(w2_scale[0].dtype) # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w1], - scale=[w1_scale.to(w2_scale.dtype)], + weight=w1, + scale=w1_scale, bias=bias1, per_token_scale=[pertoken_scale], split_item=2, @@ -179,8 +207,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w2], - scale=[w2_scale], + weight=w2, + scale=w2_scale, bias=bias2, per_token_scale=[swiglu_out_scale], split_item=2, @@ -232,11 +260,11 @@ def unquant_apply_mlp(hidden_states: torch.Tensor, def unified_apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], group_list: torch.Tensor, + w1_scale: Optional[list[torch.Tensor]] = None, + w2_scale: Optional[list[torch.Tensor]] = None, dynamic_scale: torch.Tensor = None, group_list_type: int = 1, w1_scale_bias: torch.Tensor = None, @@ -247,6 +275,7 @@ def unified_apply_mlp(hidden_states: torch.Tensor, need_trans: bool = True, dynamic_eplb: bool = False) -> torch.Tensor: if with_quant: + assert w1_scale is not None and w2_scale is not None return quant_apply_mlp(hidden_states=hidden_states, w1=w1, w1_scale=w1_scale, diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 8c395b54fd4..8dad11c2512 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -108,12 +108,13 @@ def forward_oot( residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu - if residual is not None: + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) assert x.size(0) == residual.size(0) + next_need_quant_fusion_linear = getattr( + self, 'next_need_quant_fusion_linear', None) x, residual = _addrmsnorm_forward_oot( - self, x, residual, self.next_need_quant_fusion_linear, - self.bias) + self, x, residual, next_need_quant_fusion_linear, self.bias) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 844cdcbde72..53a3b26b74f 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -297,7 +297,7 @@ def __init__( def forward( self, input_, - is_prefill: bool = True, + **kwargs, ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if self.custom_op is not None: return self.custom_op.apply(input_) diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/ops/mla.py similarity index 100% rename from vllm_ascend/models/layers/mla.py rename to vllm_ascend/ops/mla.py diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index bb16bc006a4..7c7fd6f08a9 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -2,6 +2,7 @@ import torch.nn.functional as F import torch_npu from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -15,6 +16,27 @@ from vllm_ascend.utils import npu_stream_switch, prefetch_stream +def _maybe_chunk_residual_impl(x: torch.Tensor, + residual: torch.Tensor) -> torch.Tensor: + try: + forward_context = get_forward_context() + except AssertionError: + return residual + + if x.size(0) != residual.size(0): + sp_enabled = forward_context.sp_enabled + assert sp_enabled is True, ("Currently, this situation only occurs " + "when sp is enabled") + pad_size = forward_context.pad_size + if pad_size > 0: + residual = F.pad(residual, (0, 0, 0, pad_size)) + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + residual = torch.chunk(residual, tp_size, dim=0)[tp_rank] + + return residual + + def _maybe_all_gather_and_maybe_unpad_impl( x: torch.Tensor, label: bool, @@ -151,7 +173,9 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None: except AssertionError: return - if not forward_context.prefetch_mlp_enabled: + prefetch_mlp_enabled = getattr(forward_context, 'prefetch_mlp_enabled', + False) + if not prefetch_mlp_enabled: return forward_context.prefetch_mlp_down_proj = True model_instance = forward_context.model_instance @@ -180,7 +204,9 @@ def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None: except AssertionError: return - if not forward_context.prefetch_mlp_enabled: + prefetch_mlp_enabled = getattr(forward_context, 'prefetch_mlp_enabled', + False) + if not prefetch_mlp_enabled: return if forward_context.prefetch_mlp_gate_up_proj or \ forward_context.prefetch_mlp_down_proj: @@ -259,6 +285,12 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, return output +direct_register_custom_op(op_name="maybe_chunk_residual", + op_func=_maybe_chunk_residual_impl, + fake_impl=lambda x, residual: x, + mutates_args=[], + dispatch_key="PrivateUse1") + direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", op_func=_maybe_all_gather_and_maybe_unpad_impl, fake_impl=_maybe_all_gather_and_maybe_unpad_fake, diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 8e0a71ab667..ca24083f04b 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -18,7 +18,6 @@ import vllm_ascend.patch.platform.patch_config # noqa import vllm_ascend.patch.platform.patch_distributed # noqa -import vllm_ascend.patch.platform.patch_dynamo_vllm_backend # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa diff --git a/vllm_ascend/patch/platform/patch_config.py b/vllm_ascend/patch/platform/patch_config.py index 0e8642d1cea..b798fda3bc7 100644 --- a/vllm_ascend/patch/platform/patch_config.py +++ b/vllm_ascend/patch/platform/patch_config.py @@ -28,6 +28,8 @@ def __post_init__(self): self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "suffix": + self.model = "suffix" else: raise ValueError("num_speculative_tokens was provided but without " "speculative model.") @@ -70,6 +72,10 @@ def __post_init__(self): # draft related config as None here. self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config + elif self.method == "suffix": + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + self._validate_suffix_decoding() else: self.prompt_lookup_max = 0 self.prompt_lookup_min = 0 diff --git a/vllm_ascend/patch/platform/patch_distributed.py b/vllm_ascend/patch/platform/patch_distributed.py index 467cc0450be..f4f342d245c 100644 --- a/vllm_ascend/patch/platform/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_distributed.py @@ -18,32 +18,10 @@ # This file is a part of the vllm-ascend project. import torch -import vllm.envs as envs_vllm -from vllm.config import ParallelConfig from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type -def parallel_config_get_dp_port(self) -> int: - """ - We might need to initialize process groups in multiple - processes that is related to data parallelism, - e.g. both in the worker and in the engine, which - can live in different processes. To avoid port conflicts, we - increment the port number each time we need to initialize a - new process group related to data parallelism. - """ - answer = self.data_parallel_master_port - self.data_parallel_master_port += 1 - - # NOTE: Get port from envs directly when using torchrun - port = envs_vllm.VLLM_DP_MASTER_PORT if envs_vllm.VLLM_DP_MASTER_PORT else answer - return port - - -ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port - - class NullHandle: def __init__(self): diff --git a/vllm_ascend/patch/platform/patch_dynamo_vllm_backend.py b/vllm_ascend/patch/platform/patch_dynamo_vllm_backend.py deleted file mode 100644 index 9b753622f4f..00000000000 --- a/vllm_ascend/patch/platform/patch_dynamo_vllm_backend.py +++ /dev/null @@ -1,16 +0,0 @@ -# mypy: ignore-errors -from typing import Any, Dict - -import torch.fx as fx -from vllm.compilation.backends import VllmBackend -from vllm.compilation.caching import VllmSerializableFunction - -_original_vllmbackend_call = VllmBackend.__call__ - - -def __patch_call__(self, graph: fx.GraphModule, example_inputs, - options: Dict[str, Any]) -> VllmSerializableFunction: - return _original_vllmbackend_call(self, graph, example_inputs) - - -VllmBackend.__call__ = __patch_call__ diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index faa57b6140f..0d1dd559880 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -28,4 +28,6 @@ import vllm_ascend.patch.worker.patch_multimodal_merge # noqa import vllm_ascend.patch.worker.patch_minicpm # noqa import vllm_ascend.patch.worker.patch_qwen2_5_vl # noqa +import vllm_ascend.patch.worker.patch_qwen2_5_omni # noqa +import vllm_ascend.patch.worker.patch_qwen3_vl # noqa import vllm_ascend.patch.worker.patch_rope # noqa diff --git a/vllm_ascend/patch/worker/patch_qwen2_5_omni.py b/vllm_ascend/patch/worker/patch_qwen2_5_omni.py new file mode 100644 index 00000000000..bd91a33ee42 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_qwen2_5_omni.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import torch.nn as nn +from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs, + Qwen2_5OmniThinkerForConditionalGeneration) + +from vllm_ascend.ascend_forward_context import set_ascend_forward_context + + +class AscendQwen2_5OmniThinkerForConditionalGeneration(nn.Module): + + def _process_image_input( + self, + image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"].type(self.visual.dtype) + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + with set_ascend_forward_context(None, self.vllm_config): + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs, + video_hashes: list[str] | None = None, + cached_video_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if video_input["type"] == "video_embeds": + return video_input["video_embeds"].type(self.visual.dtype) + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + with set_ascend_forward_context(None, self.vllm_config): + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + +# NOTE: These will be removed after https://github.com/vllm-project/vllm/pull/29388 is merged. +Qwen2_5OmniThinkerForConditionalGeneration._process_image_input = AscendQwen2_5OmniThinkerForConditionalGeneration._process_image_input +Qwen2_5OmniThinkerForConditionalGeneration._process_video_input = AscendQwen2_5OmniThinkerForConditionalGeneration._process_video_input diff --git a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py index 27f08751bff..bb22acf3f17 100644 --- a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py +++ b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py @@ -24,17 +24,27 @@ import torch_npu from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import \ Qwen2_5_VLVisionConfig +from transformers.models.qwen2_vl.configuration_qwen2_vl import \ + Qwen2VLVisionConfig from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import maybe_get_vit_flash_attn_backend +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding.common import ( + apply_rotary_emb_torch, dispatch_rotary_emb_function) from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionPatchMerger, Qwen2_5_VisionTransformer, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs) +from vllm.model_executor.models.qwen2_vl import (Qwen2VisionAttention, + Qwen2VisionBlock, + Qwen2VisionPatchEmbed, + Qwen2VisionPatchMerger, + Qwen2VisionTransformer) from vllm.model_executor.models.utils import cast_overflow_tensors from vllm.model_executor.models.vision import ( get_vit_attn_backend, run_dp_sharded_mrope_vision_model) @@ -55,7 +65,7 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, - seqlens: torch.Tensor, + seqlens: torch.Tensor = None, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -74,18 +84,8 @@ def forward( # Convert cumulative tensor to intervals and move it to cpu. cu_seqlens = torch.diff(cu_seqlens).to("cpu") - cos = rotary_pos_emb_cos - sin = rotary_pos_emb_sin - cos = einops.rearrange( - torch.stack((cos, cos), dim=-1), - "... d two -> ...(d two)", - two=2, - ) - sin = einops.rearrange( - torch.stack((sin, sin), dim=-1), - "... d two -> ...(d two)", - two=2, - ) + cos = torch.cat((rotary_pos_emb_cos, rotary_pos_emb_cos), dim=-1) + sin = torch.cat((rotary_pos_emb_sin, rotary_pos_emb_sin), dim=-1) cos = cos.reshape(1, -1, 1, self.hidden_size_per_attention_head) sin = sin.reshape(1, -1, 1, self.hidden_size_per_attention_head) q = torch_npu.npu_rotary_mul(q, cos, sin) @@ -132,6 +132,191 @@ def forward( return output +class AscendQwen2VisionBlock(nn.Module): + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + x = x + self.mlp(self.norm2(x)) + return x + + +class AscendQwen2VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen2VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + attn_backend_override: AttentionBackendEnum | None = None, + ) -> None: + nn.Module.__init__(self) + + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio + + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.hidden_size + + self.spatial_merge_size = spatial_merge_size + self.num_heads = num_heads + self.embed_dim = embed_dim + + self.patch_embed = Qwen2VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) + + self.blocks = nn.ModuleList([ + Qwen2VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, + ) for layer_idx in range(depth) + ]) + self.merger = Qwen2VisionPatchMerger( + d_model=hidden_size, + context_dim=embed_dim, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, + ) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + + if (self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype())): + self.attn_backend = AttentionBackendEnum.FLASH_ATTN + + def rot_pos_emb( + self, + grid_thw: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]: + pos_ids = [] + max_grid_size = 0 + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = (hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten()) + wpos_ids = (wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten()) + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + max_grid_size = max(max_grid_size, h, w) + pos_ids = torch.cat(pos_ids, dim=0) + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + + # (num_tokens, rotary_dim // 2) + cos_h = cos[pos_ids[:, 0]] # type: ignore + cos_w = cos[pos_ids[:, 1]] # type: ignore + sin_h = sin[pos_ids[:, 0]] # type: ignore + sin_w = sin[pos_ids[:, 1]] # type: ignore + + cos_combined = torch.cat([cos_h, cos_w], dim=-1) + sin_combined = torch.cat([sin_h, sin_w], dim=-1) + return cos_combined, sin_combined + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor | list[list[int]], + ) -> torch.Tensor: + # patchify + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + else: + grid_thw_list = grid_thw.tolist() + + # compute position embedding + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb( + grid_thw_list) + + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) + + # transformers + x = x.unsqueeze(1) + + # pre-compute seqlens for attn mask to reduce cuMemcpy operations + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + for blk in self.blocks: + x = blk( + x, + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + # adapter + x = self.merger(x) + + return x + + class AscendQwen2_5_VisionBlock(nn.Module): def forward( @@ -486,7 +671,16 @@ def _process_video_input( return video_embeds.split(sizes) +def _apply_rotary_pos_emb_vision(t: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor) -> torch.Tensor: + rotary_emb_function = dispatch_rotary_emb_function( + default=partial(apply_rotary_emb_torch, is_neox_style=True)) + output = rotary_emb_function(t, cos, sin).type_as(t) + return output + + # NOTE: This will be removed after MMEncoderAttention has been extract as a CustomOp in vllm. +Qwen2VisionAttention.forward = AscendQwen2_5_VisionAttention.forward Qwen2_5_VisionAttention.forward = AscendQwen2_5_VisionAttention.forward # NOTE: These will be removed after https://github.com/vllm-project/vllm/pull/29388 is merged. @@ -494,8 +688,13 @@ def _process_video_input( Qwen2_5_VLForConditionalGeneration._process_video_input = AscendQwen2_5_VLForConditionalGeneration._process_video_input # NOTE: These will be removed after vllm-ascend is aligned with vllm latest main. +Qwen2VisionBlock.forward = AscendQwen2VisionBlock.forward +Qwen2VisionTransformer.__init__ = AscendQwen2VisionTransformer.__init__ +Qwen2VisionTransformer.rot_pos_emb = AscendQwen2VisionTransformer.rot_pos_emb +Qwen2VisionTransformer.forward = AscendQwen2VisionTransformer.forward Qwen2_5_VisionBlock.forward = AscendQwen2_5_VisionBlock.forward Qwen2_5_VisionTransformer.__init__ = AscendQwen2_5_VisionTransformer.__init__ Qwen2_5_VisionTransformer.rotary_pos_emb_thw = AscendQwen2_5_VisionTransformer.rotary_pos_emb_thw Qwen2_5_VisionTransformer.get_rope_by_thw = AscendQwen2_5_VisionTransformer.get_rope_by_thw Qwen2_5_VisionTransformer.forward = AscendQwen2_5_VisionTransformer.forward +apply_rotary_pos_emb_vision = _apply_rotary_pos_emb_vision diff --git a/vllm_ascend/patch/worker/patch_qwen3_vl.py b/vllm_ascend/patch/worker/patch_qwen3_vl.py new file mode 100644 index 00000000000..1b80bbdcfa1 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_qwen3_vl.py @@ -0,0 +1,251 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +from transformers.models.qwen3_vl.configuration_qwen3_vl import \ + Qwen3VLVisionConfig +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.layer import check_upstream_fa_availability +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionBlock, + Qwen3_VisionPatchEmbed, + Qwen3_VisionPatchMerger, + Qwen3_VisionTransformer) +from vllm.model_executor.models.vision import get_vit_attn_backend + + +class AscendQwen3_VisionBlock(nn.Module): + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + max_seqlen: torch.Tensor, # Only used for Flash Attention + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + ) + + x = x + self.mlp(self.norm2(x)) + return x + + +class AscendQwen3_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen3VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + attn_backend_override: AttentionBackendEnum | None = None, + ) -> None: + nn.Module.__init__(self) + + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + self.num_position_embeddings = vision_config.num_position_embeddings + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.temporal_patch_size = vision_config.temporal_patch_size + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + self.use_data_parallel = use_data_parallel + self.num_grid_per_side = int(self.num_position_embeddings**0.5) + + # NOTE: This is used for creating empty tensor for all_gather for + # DP ViT. Here out_hidden_size is enlarged due to deepstack + self.out_hidden_size = vision_config.out_hidden_size * ( + 1 + len(self.deepstack_visual_indexes)) + + self.patch_embed = Qwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + + self.pos_embed = nn.Embedding(self.num_position_embeddings, + self.hidden_size) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) + + self.merger = Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, + ) + + self.deepstack_merger_list = nn.ModuleList([ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", + use_data_parallel=use_data_parallel, + ) for layer_idx in range(len(self.deepstack_visual_indexes)) + ]) + + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + use_upstream_fa = False + if (self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA + and check_upstream_fa_availability(torch.get_default_dtype())): + self.attn_backend = AttentionBackendEnum.FLASH_ATTN + use_upstream_fa = True + + if self.attn_backend not in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, + }: + raise RuntimeError( + f"Qwen3-VL does not support {self.attn_backend} backend now.") + self.blocks = nn.ModuleList([ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa, + ) for layer_idx in range(vision_config.depth) + ]) + + def rot_pos_emb(self, grid_thw: list[list[int]]): + max_grid_size = max(max(h, w) for _, h, w in grid_thw) + pos_ids = [ + self.rot_pos_ids(h, w, self.spatial_merge_size) if t == 1 else + self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1) + for t, h, w in grid_thw + ] + pos_ids = torch.cat(pos_ids, dim=0) + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + + # (num_tokens, rotary_dim // 2) + cos_h = cos[pos_ids[:, 0]] # type: ignore + cos_w = cos[pos_ids[:, 1]] # type: ignore + sin_h = sin[pos_ids[:, 0]] # type: ignore + sin_w = sin[pos_ids[:, 1]] # type: ignore + + cos_combined = torch.cat([cos_h, cos_w], dim=-1) + sin_combined = torch.cat([sin_h, sin_w], dim=-1) + + return cos_combined, sin_combined + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor | list[list[int]], + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, + dtype=self.dtype, + non_blocking=True) + hidden_states = self.patch_embed(hidden_states) + + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = np.array(grid_thw, dtype=np.int32) + else: + grid_thw = grid_thw.to("cpu") + grid_thw_list = grid_thw.tolist() + grid_thw = grid_thw.numpy() + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb( + grid_thw_list) + rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device, + non_blocking=True) + rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device, + non_blocking=True) + + cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum(axis=0, dtype=np.int32) + cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) + cu_seqlens = torch.from_numpy(cu_seqlens) + + hidden_states = hidden_states.unsqueeze(1) + max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index( + layer_num) + deepstack_feature = self.deepstack_merger_list[ + deepstack_merger_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat( + [hidden_states] + deepstack_feature_lists, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + return hidden_states + + +# NOTE: These will be removed after vllm-ascend is aligned with vllm latest main. +Qwen3_VisionBlock.forward = AscendQwen3_VisionBlock.forward +Qwen3_VisionTransformer.__init__ = AscendQwen3_VisionTransformer.__init__ +Qwen3_VisionTransformer.rot_pos_emb = AscendQwen3_VisionTransformer.rot_pos_emb +Qwen3_VisionTransformer.forward = AscendQwen3_VisionTransformer.forward diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 7cc84fc6ae3..5ff66926aa7 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -283,7 +283,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if parallel_config and parallel_config.worker_cls == "auto": # TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm. parallel_config.all2all_backend = "flashinfer_all2allv" - if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp: + if ascend_config.torchair_graph_config.enabled: parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker" else: parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker" @@ -379,8 +379,6 @@ def get_attn_backend_cls( ascend_config = get_ascend_config() if use_mla and ascend_config.enable_shared_expert_dp: - if use_mla and not use_sparse: - return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend" if use_mla and use_sparse: return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend" diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 72c04e50b70..e9d0c97f942 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -408,11 +408,10 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): quant_config: The Ascend quantization config. """ - def __init__(self, - quant_config: AscendQuantConfig, - prefix: str, - packed_modules_mapping: Dict[str, Any], - layer: torch.nn.Module = None): + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, + Any], layer: torch.nn.Module): + super().__init__(layer.moe_config) self.quant_method = get_quant_method(quant_config.quant_description, prefix, "moe", diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index eaaaee86702..be43726e8d9 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -12,6 +12,8 @@ AscendW8A8LinearMethod) from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod) +from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod, + AscendW8A8PDMixLinearMethod) ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { "W4A8_DYNAMIC": { @@ -30,6 +32,10 @@ "linear": AscendW8A8DynamicLinearMethod, "moe": AscendW8A8DynamicFusedMoEMethod, }, + "W8A8_MIX": { + "linear": AscendW8A8PDMixLinearMethod, + "moe": AscendW8A8PDMixFusedMoeMethod, + }, "C8": { "attention": AscendC8KVCacheMethod, }, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index c7f1dfabb86..a73050c3123 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -379,10 +379,10 @@ def apply( moe_comm_method = get_forward_context().moe_comm_method return moe_comm_method.fused_experts( hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, + w1=[layer.w13_weight], + w2=[layer.w2_weight], + w1_scale=[layer.w13_weight_scale], + w2_scale=[layer.w2_weight_scale], w1_scale_bias=layer.w13_scale_bias, w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 8a7bbfe7263..bfa39e69b39 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -118,8 +118,10 @@ def apply( weight=layer.weight, start_flag=x, ) - - quant_comm_config = getattr(layer, "_quant_comm_config", {}) + try: + quant_comm_config = getattr(layer, "_quant_comm_config") + except AttributeError: + quant_comm_config = {} comm_fn = quant_comm_config.get("communication_fn") enable_flashcomm2_quant_comm = comm_fn is not None and ( "o_proj" in layer.prefix or "out_proj" in layer.prefix) @@ -150,8 +152,12 @@ def apply( ) quant_bias = layer.quant_bias if tp_rank == 0 else None - if getattr(layer, "ascend_quant_method", - "") == COMPRESSED_TENSORS_METHOD: + + try: + ascend_quant_method = getattr(layer, "ascend_quant_method") + except AttributeError: + ascend_quant_method = "" + if ascend_quant_method == COMPRESSED_TENSORS_METHOD: quant_bias = bias if get_ascend_device_type() == AscendDeviceType._310P: @@ -192,8 +198,8 @@ def process_weights_after_loading(self, layer): layer.weight.data, ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) - if getattr(layer, "ascend_quant_method", - "") == COMPRESSED_TENSORS_METHOD: + ascend_quant_method = getattr(layer, "ascend_quant_method", "") + if ascend_quant_method == COMPRESSED_TENSORS_METHOD: deq_scale = layer.input_scale.data * layer.weight_scale.data layer.deq_scale = torch.nn.Parameter(deq_scale, requires_grad=False) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 6b7d6b0875c..2901d17504e 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional import torch import torch_npu @@ -72,33 +72,20 @@ def get_pergroup_param(self, @staticmethod def apply( layer: torch.nn.Module, - x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + x: torch.Tensor, bias: Optional[torch.Tensor] = None, tp_rank: Optional[int] = 0, ) -> torch.Tensor: - config = getattr(layer, "_ascend_quant_config", {}) - if not isinstance(x, tuple): - output_dtype = config.get("output_dtype", x.dtype) - quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x) - else: - assert "output_dtype" in config.keys(), ( - f"DynamicLinearMethod needs explicitly specified `output_dtype`" - f"for pre-quantized input, got config [{config}]") - output_dtype = config["output_dtype"] - quantized_x, dynamic_scale = x - pertoken_scale = (dynamic_scale - if config.get("pertoken_scale", True) else None) - + quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x) output = torch_npu.npu_quant_matmul( quantized_x, layer.weight, layer.weight_scale, pertoken_scale=pertoken_scale, bias=bias, - output_dtype=output_dtype, + output_dtype=x.dtype, ) - return ((output, dynamic_scale) - if config.get("return_scale", False) else output) + return output def process_weights_after_loading(self, layer): if self.transpose_weight: @@ -234,13 +221,24 @@ def apply( topk_weights = topk_weights.to(self.in_dtype) moe_comm_method = get_forward_context().moe_comm_method + if self.dynamic_eplb: + w1 = layer.w13_weight_list + w1_scale = layer.w13_weight_scale_fp32_list + w2 = layer.w2_weight_list + w2_scale = layer.w2_weight_scale_list + else: + w1 = [layer.w13_weight] + w1_scale = [layer.w13_weight_scale_fp32] + w2 = [layer.w2_weight] + w2_scale = [layer.w2_weight_scale] + return moe_comm_method.fused_experts( hidden_states=x, pertoken_scale=pertoken_scale, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale_fp32, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, topk_weights=topk_weights, topk_ids=topk_ids, use_int8_w8a8=True, @@ -272,3 +270,25 @@ def process_weights_after_loading(self, layer): layer.w2_weight_scale.data.shape[0], -1) layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( layer.w2_weight_offset.data.shape[0], -1) + if self.dynamic_eplb: + layer.w13_weight_list = [ + weight.clone() + for weight in layer.w13_weight.data.unbind(dim=0) + ] + layer.w2_weight_list = [ + weight.clone() for weight in layer.w2_weight.data.unbind(dim=0) + ] + layer.w13_weight_scale_fp32_list = [ + weight.clone() + for weight in layer.w13_weight_scale.data.unbind(dim=0) + ] + layer.w2_weight_scale_list = [ + weight.clone() + for weight in layer.w2_weight_scale.data.unbind(dim=0) + ] + del layer.w13_weight + del layer.w2_weight + del layer.w13_weight_scale + del layer.w13_weight_scale_fp32 + del layer.w2_weight_scale + torch.npu.empty_cache() diff --git a/vllm_ascend/quantization/w8a8_pdmix.py b/vllm_ascend/quantization/w8a8_pdmix.py new file mode 100644 index 00000000000..0fa74f7e9a0 --- /dev/null +++ b/vllm_ascend/quantization/w8a8_pdmix.py @@ -0,0 +1,70 @@ +from typing import Any, Dict, cast + +import torch +from vllm.config import get_current_vllm_config + +from .w8a8 import AscendW8A8LinearMethod +from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, + AscendW8A8DynamicLinearMethod) + + +class AscendW8A8PDMixLinearMethod(AscendW8A8DynamicLinearMethod): + + def __init__(self): + self.kv_transfer_config = get_current_vllm_config().kv_transfer_config + super().__init__() + + @staticmethod + def apply(layer, x, bias=None, tp_rank=0): + if layer.is_kv_consumer: + return AscendW8A8LinearMethod.apply(layer, x, bias, tp_rank) + else: + return AscendW8A8DynamicLinearMethod.apply(layer, x, bias, tp_rank) + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return AscendW8A8LinearMethod.get_pertensor_param(params_dtype) + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + return AscendW8A8LinearMethod.get_perchannel_param( + output_size, params_dtype) + + def process_weights_after_loading(self, layer): + AscendW8A8LinearMethod.process_weights_after_loading( + cast(AscendW8A8LinearMethod, self), layer) + layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) + layer.is_kv_consumer = self.kv_transfer_config is not None and self.kv_transfer_config.is_kv_consumer + + +class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod): + + def __init__(self): + super().__init__() + + @staticmethod + def get_dynamic_quant_param(num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = AscendW8A8DynamicFusedMoEMethod.get_dynamic_quant_param( + num_experts, intermediate_size_per_partition, hidden_sizes, + params_dtype) + param_dict["w2_deq_scale"] = torch.empty(num_experts, + hidden_sizes, + dtype=torch.float32) + param_dict["w13_deq_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32) + param_dict["w2_input_offset"] = torch.empty(num_experts, + 1, + dtype=torch.int8) + param_dict["w13_input_offset"] = torch.empty(num_experts, + 1, + dtype=torch.int8) + + return param_dict diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index a17f534045e..9bd941fc731 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import vllm.v1.sample.rejection_sampler as rs +from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (RejectionSampler, apply_sampling_constraints, @@ -149,25 +150,36 @@ def rejection_sample( if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) - if min(num_draft_tokens) == 1 and max( - num_draft_tokens) == 1 and sampling_metadata.all_greedy: - rejection_greedy_sample_spec_len_1_pytorch( - output_token_ids, - draft_token_ids, - target_argmax, - bonus_token_ids, - ) - else: - rejection_greedy_sample_pytorch( + if HAS_TRITON: + rejection_greedy_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, target_argmax, bonus_token_ids, - num_draft_tokens, - max_spec_len, is_greedy, + max_spec_len, ) + else: + if min(num_draft_tokens) == 1 and max( + num_draft_tokens) == 1 and sampling_metadata.all_greedy: + rejection_greedy_sample_spec_len_1_pytorch( + output_token_ids, + draft_token_ids, + target_argmax, + bonus_token_ids, + ) + else: + rejection_greedy_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + num_draft_tokens, + max_spec_len, + is_greedy, + ) if sampling_metadata.all_greedy: return output_token_ids @@ -194,21 +206,37 @@ def rejection_sample( ) # Rejection sampling for random sampling requests. - rejection_random_sample_pytorch( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs, - is_greedy, - max_spec_len, - vocab_size, - IS_NGRAM=draft_probs is None, - # num_warps=1, - ) + if HAS_TRITON: + rejection_random_sample_kernel[(batch_size, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + NO_DRAFT_PROBS=draft_probs is None, + ) + else: + rejection_random_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + # num_warps=1, + ) return output_token_ids @@ -241,14 +269,24 @@ def expand_batch_to_tokens( batch_size = x.shape[0] assert cu_num_tokens.shape[0] == batch_size expanded_x = x.new_empty(num_tokens) - expand_pytorch( - expanded_x, - x, - cu_num_tokens, - replace_from, - replace_to, - MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. - ) + if HAS_TRITON: + expand_kernel[(batch_size, )]( + expanded_x, + x, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + ) + else: + expand_pytorch( + expanded_x, + x, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + ) return expanded_x @@ -282,16 +320,29 @@ def sample_recovered_tokens( q[i].exponential_(generator=generator) recovered_token_ids = torch.empty_like(draft_token_ids) - sample_recovered_tokens_pytorch( - recovered_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - q, - vocab_size, - IS_NGRAM=draft_probs is None, - ) + if HAS_TRITON: + sample_recovered_tokens_kernel[(batch_size, max_spec_len)]( + recovered_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + triton.next_power_of_2(vocab_size), + NO_DRAFT_PROBS=draft_probs is None, + ) + else: + sample_recovered_tokens_pytorch( + recovered_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + IS_NGRAM=draft_probs is None, + ) return recovered_token_ids @@ -504,4 +555,192 @@ def sample_recovered_tokens_pytorch( target_probs[token_idx, draft_token_id] = orig_prob +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_greedy_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + target_argmax_ptr, # [num_tokens] + bonus_token_ids_ptr, # [batch_size] + is_greedy_ptr, # [batch_size] or None + max_spec_len, +): + req_idx = tl.program_id(0) + # Because is_greedy_ptr is not Nonr at profiling run, + # re-comilation may happen during runtime when is_greedy_ptr is None. + is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + + req_idx) + if not is_greedy: + # Early exit for non-greedy sampling requests + return + + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + target_argmax_id, + ) + if draft_token_id != target_argmax_id: + # Reject + rejected = True + + if not rejected: + # If all tokens are accepted, append the bonus token + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, + bonus_token_id, + ) + + +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_random_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + NO_DRAFT_PROBS: tl.constexpr, +): + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + if is_greedy: + # Early exost for greedy sampling requests + return + + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + if NO_DRAFT_PROBS: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + # Accept + token_id = draft_token_id + else: + # Reject. Use recovered token + rejected = True + token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + token_id) + + if not rejected: + # If all tokens are accepted, append the bonus token + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, + bonus_token_id, + ) + + +@triton.jit(do_not_specialize=["replace_from", "replace_to"]) +def expand_kernel( + output_ptr, # [num_tokens] + input_ptr, # [batch_size] + cu_num_tokens_ptr, # [batch_size] + replace_from, + replace_to, + MAX_NUM_TOKENS: tl.constexpr, +): + req_idx = tl.program_id(0) + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_tokens_ptr + req_idx) + num_tokens = end_idx - start_idx + + src_val = tl.load(input_ptr + req_idx) + src_val = tl.where(src_val == replace_from, replace_to, src_val) + offset = tl.arange(0, MAX_NUM_TOKENS) + tl.store(output_ptr + start_idx + offset, + src_val, + mask=offset < num_tokens) + + +@triton.jit +def sample_recovered_tokens_kernel( + output_token_ids_ptr, # [num_tokens] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + q_ptr, # [batch_size, vocab_size] + vocab_size, + PADDED_VOCAB_SIZE: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, +): + req_idx = tl.program_id(0) + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + # Early exit for out-of-range positions + pos = tl.program_id(1) + if pos >= num_draft_tokens: + return + + vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) + if NO_DRAFT_PROBS: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=((vocab_offset < vocab_size) & + (vocab_offset != draft_token_id)), + other=0, + ) + else: + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) + prob = tl.maximum(target_prob - draft_prob, 0) + # We don't need `prob = prob / tl.sum(prob)` here because + # `tl.argmax` will select the maximum value. + + q = tl.load( + q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf"), + ) + recovered_id = tl.argmax(prob / q, axis=-1) + tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) + + rs.expand_batch_to_tokens = expand_batch_to_tokens diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 6abe8777cd3..a8d448750b8 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -19,6 +19,7 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.ngram_proposer import NgramProposer +from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer @@ -35,6 +36,8 @@ def get_spec_decode_method(method, if is_torchair_graph: return TorchairMtpProposer(vllm_config, device, runner) return MtpProposer(vllm_config, device, runner) + elif method == 'suffix': + return SuffixDecodingProposer(vllm_config, device, runner) else: raise ValueError("Unknown speculative decoding method: " f"{method}") diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 75f01ee9bdb..791c487ddb8 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -123,7 +123,8 @@ def dummy_run(self, num_reqs: int = 0, num_tokens_across_dp: Optional[torch.Tensor] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None): + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None): moe_comm_type = self.runner._select_moe_comm_method(num_tokens) with set_ascend_forward_context(None, self.vllm_config, @@ -134,6 +135,7 @@ def dummy_run(self, positions=self.positions[:num_tokens], hidden_states=self.hidden_states[:num_tokens], ) + dummy_compute_logits(self.hidden_states) def generate_token_ids(self, valid_sampled_token_ids: list[np.ndarray], diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py index 5fdb494515f..098f171fbe4 100644 --- a/vllm_ascend/spec_decode/interface.py +++ b/vllm_ascend/spec_decode/interface.py @@ -14,6 +14,7 @@ class SpecDcodeType(enum.Enum): EAGLE = 1 EAGLE3 = 2 MTP = 4 + SUFFIX = 5 class Proposer: @@ -51,4 +52,4 @@ def generate_token_ids(self, attn_metadata=None, aux_hidden_states: torch.Tensor = None): """Called by execute_model in model_runner""" - raise NotImplementedError + raise NotImplementedError \ No newline at end of file diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 73b65aedfd9..cacc2bdf0ee 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -32,7 +32,8 @@ update_mla_attn_params) from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, - prefill_context_parallel_enable) + prefill_context_parallel_enable, + shared_expert_dp_enabled) if prefill_context_parallel_enable(): from vllm.distributed import get_pcp_group @@ -46,9 +47,7 @@ _MTP_MODELS = { "DeepseekV3ForCausalLM": - ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"), - "Qwen3NextForCausalLM": - ("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP") + ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP") } _DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn' @@ -96,6 +95,7 @@ def __init__( # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() + self.enable_shared_expert_dp = shared_expert_dp_enabled() self.pcp_size = self.runner.pcp_size self.dcp_size = self.runner.dcp_size @@ -215,7 +215,8 @@ def dummy_run(self, num_reqs: int = 0, num_tokens_across_dp=None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None) -> None: + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None) -> None: ( num_tokens, @@ -287,6 +288,12 @@ def dummy_run(self, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, is_mtp_model=True): + if self.enable_shared_expert_dp: + positions = positions.unsqueeze(-1) + positions = torch.ops.vllm.maybe_pad_and_reduce(positions) + positions = positions.squeeze(-1) + previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + previous_hidden_states) self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) @@ -295,9 +302,14 @@ def dummy_run(self, not forward_context.capturing: if self.vllm_config.model_config.use_mla: update_mla_attn_params( - self.update_stream, forward_context, - positions.shape[0], + self.update_stream, forward_context, num_tokens, self.vllm_config.speculative_config) + if self.enable_shared_expert_dp: + positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + positions, True) + previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + previous_hidden_states, True) + dummy_compute_logits(previous_hidden_states) if with_prefill: break @@ -675,7 +687,8 @@ def _propose( moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens) - if scheduler_output: + # Enable shared_expert_dp and MTP FULL graph may cause accuracy issues. + if scheduler_output and not self.enable_shared_expert_dp: max_query_len = common_attn_metadata.max_query_len uniform_decode = (max_query_len in list( range(1, self.num_speculative_tokens + @@ -725,11 +738,22 @@ def _propose( with ProfileExecuteDuration().capture_async('mtp_forward'): model_kwargs = {} model_kwargs["attn_metadata"] = attn_metadata - - hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens]) + input_ids = self.input_ids[:num_input_tokens] + positions = self.positions[:num_input_tokens] + hidden_states = self.hidden_states[:num_input_tokens] + + if self.enable_shared_expert_dp: + # positions [N] -> [N, 1] for padding + positions = positions.unsqueeze(-1) + positions = torch.ops.vllm.maybe_pad_and_reduce( + positions) + positions = positions.squeeze(-1) + hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + hidden_states) + + hidden_states = self.model(input_ids=input_ids, + positions=positions, + hidden_states=hidden_states) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: if self.vllm_config.model_config.use_mla: @@ -738,6 +762,12 @@ def _propose( num_input_tokens, self.vllm_config.speculative_config) + if self.enable_shared_expert_dp: + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states.contiguous(), True) + positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + positions.contiguous(), True) + num_indices = last_token_indices.shape[0] if lmhead_tp_enable(): if not self.runner.with_prefill: @@ -758,6 +788,7 @@ def _propose( logits = self.model.compute_logits(sample_hidden_states) if lmhead_tp_enable() and num_indices < logits.shape[0]: logits = logits[:num_indices] + last_token_indices = last_token_indices[:num_indices] draft_token_ids = logits.argmax(dim=-1) if self.num_speculative_tokens == 1: @@ -804,26 +835,27 @@ def _propose( batch_size, attn_metadata_i.decode.actual_seq_lengths_q) attn_metadata_i.decode.cos = builder.cos_cache[ - positions].unsqueeze(1).unsqueeze(2) + positions[:batch_size]].unsqueeze(1).unsqueeze(2) attn_metadata_i.decode.sin = builder.sin_cache[ - positions].unsqueeze(1).unsqueeze(2) + positions[:batch_size]].unsqueeze(1).unsqueeze(2) # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch # but adjust the position ids and slot mappings to avoid the # out-of-range access during the model execution. The draft tokens # generated with this adjustment should be ignored. - exceeds_max_model_len = positions >= self.runner.model_config.max_model_len + exceeds_max_model_len = positions[: + batch_size] >= self.runner.model_config.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + positions[:batch_size]) # Increment the sequence lengths. attn_metadata_i.seq_lens[:batch_size] += 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. exceeds_max_model_len_cpu = exceeds_max_model_len.to( - attn_metadata_i.seq_lens.device, non_blocking=True) + attn_metadata_i.seq_lens.device, non_blocking=False) attn_metadata_i.seq_lens[:batch_size].masked_fill_( exceeds_max_model_len_cpu, 1) # Mask out the slot mappings that exceed the max model length. diff --git a/vllm_ascend/spec_decode/ngram_proposer.py b/vllm_ascend/spec_decode/ngram_proposer.py index 065d290fa44..43f94c8e2ba 100644 --- a/vllm_ascend/spec_decode/ngram_proposer.py +++ b/vllm_ascend/spec_decode/ngram_proposer.py @@ -27,7 +27,8 @@ def dummy_run(self, num_reqs=None, num_tokens_across_dp=None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None): + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None): pass def generate_token_ids(self, diff --git a/vllm_ascend/spec_decode/suffix_proposer.py b/vllm_ascend/spec_decode/suffix_proposer.py new file mode 100644 index 00000000000..e607044906e --- /dev/null +++ b/vllm_ascend/spec_decode/suffix_proposer.py @@ -0,0 +1,45 @@ +import torch +from vllm.config import CUDAGraphMode +from vllm.v1.spec_decode.suffix_decoding import \ + SuffixDecodingProposer as VllmSuffixDecodingProposer + +from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType + + +class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer): + + def __init__(self, vllm_config, device, runner): + super().__init__(vllm_config) + self.name = SpecDcodeType.SUFFIX + self.device = device + self.runner = runner + + def load_model(self, *args, **kwargs): + # No model to load. + pass + + @torch.inference_mode() + def dummy_run(self, + num_tokens, + with_prefill=None, + skip_attn=None, + num_reqs=None, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None): + pass + + def generate_token_ids(self, + valid_sampled_token_ids, + sampling_metadata=None, + scheduler_output=None, + spec_decode_metadata=None, + positions=None, + num_scheduled_tokens=None, + hidden_states=None, + attn_metadata=None, + aux_hidden_states=None) -> list[list[int]]: + draft_token_ids = self.propose(self.runner.input_batch, + valid_sampled_token_ids) + return draft_token_ids diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 4408b310c70..87f23b9b3bf 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -16,7 +16,7 @@ # Adapted from vllm/tests/kernels/test_moe.py import os -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch import torch.distributed as dist @@ -45,7 +45,9 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod +from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod, + AscendQuantConfig) +from vllm_ascend.quantization.utils import get_quant_method from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding from vllm_ascend.torchair.utils import (get_all_reduce_merge_state, get_rm_router_logits_state, @@ -936,6 +938,15 @@ def apply( ep_group=get_ep_group()) +class TorchairAscendFusedMoEMethod(AscendFusedMoEMethod): + + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, Any]): + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "moe", + packed_modules_mapping) + + class TorchairAscendFusedMoE(FusedMoE): # The moe_counter parameter is required during the initialization of EPLB @@ -1115,7 +1126,7 @@ def __init__( self.quant_method = TorchairAscendUnquantizedFusedMoEMethod( self.moe) else: - self.quant_method = AscendFusedMoEMethod( + self.quant_method = TorchairAscendFusedMoEMethod( quant_config, prefix, quant_config.packed_modules_mapping) assert self.quant_method is not None @@ -1385,7 +1396,7 @@ def get_map(self): return self.expert_map def get_log2phy_map(self): - return self.logical_to_physical_map + return self.log2phy def clear_moe_load(self): if self.moe_load is not None: diff --git a/vllm_ascend/torchair/torchair_mtp_proposer.py b/vllm_ascend/torchair/torchair_mtp_proposer.py index 476ff479966..bcbf7dc3d9b 100644 --- a/vllm_ascend/torchair/torchair_mtp_proposer.py +++ b/vllm_ascend/torchair/torchair_mtp_proposer.py @@ -81,7 +81,8 @@ def dummy_run(self, num_reqs: int = 0, num_tokens_across_dp=None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None) -> None: + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None) -> None: moe_comm_type = self.runner._select_moe_comm_method(num_tokens) if not with_prefill: @@ -143,6 +144,7 @@ def dummy_run(self, self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) + dummy_compute_logits(previous_hidden_states) if with_prefill: break diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 0a74bcbfdcf..e9441e28681 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -648,7 +648,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): return from vllm.model_executor.custom_op import CustomOp - from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE, AscendSharedFusedMoE) @@ -658,6 +657,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): AscendQKVParallelLinear, AscendReplicatedLinear, AscendRowParallelLinear) + from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding, AscendRotaryEmbedding, AscendYaRNRotaryEmbedding) @@ -758,7 +758,7 @@ def dense_optim_enable() -> bool: return envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE -def enable_sp(vllm_config=None) -> bool: +def enable_sp(vllm_config=None, enable_shared_expert_dp: bool = False) -> bool: global _ENABLE_SP if _ENABLE_SP is None: if vllm_config is None: @@ -772,6 +772,12 @@ def enable_sp(vllm_config=None) -> bool: # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')))) + if not _ENABLE_SP and enable_shared_expert_dp: + _ENABLE_SP = True + logger.info( + "shared_expert_dp requires enable_sp = True. has set enable_sp to True" + ) + if not _ENABLE_SP: return _ENABLE_SP @@ -948,7 +954,7 @@ def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config, global_tp_size = vllm_config.parallel_config.tensor_parallel_size if not flashcomm2_enable(): - logger.info("FLASHCOMM2 not enable.") + logger.debug("FLASHCOMM2 not enable.") return flashcomm2_oproj_tp_size logger.info( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 2e7c4ea299b..37fb4381e6a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -96,6 +96,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -630,7 +631,8 @@ def _set_up_drafter(self): # Set up speculative decoding. self.spec_attn_mask = None self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer, - TorchairMtpProposer]] = None + TorchairMtpProposer, + SuffixDecodingProposer]] = None self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 if self.speculative_config: @@ -977,25 +979,21 @@ def _make_attention_mask(self, seq_lens, position, # dcp situation. if self.dcp_size > 1: return self.attn_mask_builder.get_splitfuse_attn_mask() + if self.vllm_config.model_config.use_mla: + return None # Pooling situation. if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": return self.attn_mask_builder.get_pooling_mask(self.device) - # Chunk Prefill situation. - elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse: + # fia prefill situation. + if attn_state in [ + AscendAttentionState.PrefillNoCache, + AscendAttentionState.PrefillCacheHit, + AscendAttentionState.ChunkedPrefill + ]: return self.attn_mask_builder.get_splitfuse_attn_mask() - # Prefill without cache situation. - elif attn_state == AscendAttentionState.PrefillNoCache: - max_seq_len = max(seq_lens.max().item(), 0) - return self.attn_mask_builder.get_attn_mask( - max_seq_len, self.dtype, self.device) - # Prefill with cache hit. - elif attn_state == AscendAttentionState.PrefillCacheHit: - return self.attn_mask_builder.get_splitfuse_attn_mask().to( - torch.bool) # Decode-only situation. - else: - return None + return None def _make_fia_attention_mask(self) -> torch.Tensor: # pcp situation. @@ -2339,7 +2337,6 @@ def execute_model( attn_metadata, self.with_prefill, maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, inputs_embeds) - self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( scheduler_output) @@ -2603,7 +2600,7 @@ def propose_draft_token_ids(sampled_token_ids): # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) - + self.maybe_wait_for_kv_save() if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() @@ -3004,14 +3001,21 @@ def _dummy_run( need_dummy_logits = (not self.in_profile_run and lmhead_tp_enable()) - - if need_dummy_logits: - max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs - dummy_indices = torch.zeros(max_num_reqs_across_dp, - dtype=torch.int32) - - def dummy_compute_logits(hidden_states): - return self.model.compute_logits( + max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs + dummy_indices = torch.zeros(max_num_reqs_across_dp, + dtype=torch.int32) + + def dummy_compute_logits(hidden_states): + if not need_dummy_logits: + return None + return self.model.compute_logits(hidden_states[dummy_indices]) + + def dummy_drafter_compute_logits(hidden_states): + if not need_dummy_logits or self.drafter is None: + return + if hasattr(self.drafter, "model") and hasattr( + self.drafter.model, "compute_logits"): + return self.drafter.model.compute_logits( hidden_states[dummy_indices]) with set_ascend_forward_context( @@ -3033,8 +3037,7 @@ def dummy_compute_logits(hidden_states): with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) - if need_dummy_logits: - dummy_compute_logits(hidden_states) + dummy_compute_logits(hidden_states) if self.drafter: self.drafter.dummy_run( @@ -3043,10 +3046,8 @@ def dummy_compute_logits(hidden_states): num_reqs=num_reqs, num_tokens_across_dp=num_tokens_across_dp, aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor) - if need_dummy_logits: - self.drafter.model.compute_logits( - hidden_states[dummy_indices]) + batch_descriptor=batch_descriptor, + dummy_compute_logits=dummy_drafter_compute_logits) if self.in_profile_run and self.dynamic_eplb: self.model.clear_all_moe_loads() if not self.in_profile_run and self.dynamic_eplb: diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index df7fec602d0..f64d06475c0 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -92,21 +92,6 @@ def __init__( # init ascend config and soc version init_ascend_config(vllm_config) check_ascend_device_type() - use_sparse = False - if vllm_config.model_config is not None: - use_sparse = hasattr(vllm_config.model_config.hf_config, - "index_topk") - if use_sparse: - # Direct import instead of using try_register_lib to ensure proper error handling when - # custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments) - # yapf: disable - import custom_ops # type: ignore # noqa - - # yapf: enable - logger.info( - "custom_ops module loaded successfully. Custom operators like " - "torch.ops.custom.npu_sparse_flash_attention are now available." - ) super().__init__(vllm_config=vllm_config, local_rank=local_rank, @@ -369,6 +354,9 @@ def _warm_up_atb(self): def get_model(self) -> nn.Module: return self.model_runner.get_model() + def get_kv_connector_handshake_metadata(self) -> Optional[dict]: + return None + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec()