Skip to content

Commit ac75e2d

Browse files
committed
add lightning_indexer and sparse_flash_attention
1 parent 9af3475 commit ac75e2d

28 files changed

+10127
-15
lines changed

csrc/build_aclnn.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
1111
exit 0
1212
elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
1313
# ASCEND910B (A2) series
14-
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list"
14+
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
1515
SOC_ARG="ascend910b"
1616
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
1717
# ASCEND910C (A3) series
18-
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list"
18+
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
1919
SOC_ARG="ascend910_93"
2020
else
2121
# others
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# This program is free software, you can redistribute it and/or modify it.
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
3+
# This file is a part of the CANN Open Software.
4+
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
5+
# Please refer to the License for details. You may not use this file except in compliance with the License.
6+
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
7+
# See LICENSE in the root of the software repository for the full text of the License.
8+
# ======================================================================================================================
9+
10+
add_ops_compile_options(
11+
OP_NAME LightningIndexer
12+
OPTIONS --cce-auto-sync=off
13+
-Wno-deprecated-declarations
14+
-Werror
15+
-mllvm -cce-aicore-hoist-movemask=false
16+
--op_relocatable_kernel_binary=true
17+
)
18+
19+
set(lightning_indexer_depends transformer/attention/lightning_indexer PARENT_SCOPE)
20+
21+
target_sources(op_host_aclnn PRIVATE
22+
lightning_indexer_def.cpp
23+
)
24+
25+
target_sources(optiling PRIVATE
26+
lightning_indexer_tiling.cpp
27+
)
28+
29+
if (NOT BUILD_OPEN_PROJECT)
30+
target_sources(opmaster_ct PRIVATE
31+
lightning_indexer_tiling.cpp
32+
)
33+
endif ()
34+
35+
target_include_directories(optiling PRIVATE
36+
${CMAKE_CURRENT_SOURCE_DIR}
37+
)
38+
39+
target_sources(opsproto PRIVATE
40+
lightning_indexer_proto.cpp
41+
)
42+
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/**
2+
* This program is free software, you can redistribute it and/or modify it.
3+
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
4+
* This file is a part of the CANN Open Software.
5+
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
6+
* Please refer to the License for details. You may not use this file except in compliance with the License.
7+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
8+
* See LICENSE in the root of the software repository for the full text of the License.
9+
*/
10+
11+
/*!
12+
* \file lightning_indexer_def.cpp
13+
* \brief
14+
*/
15+
#include <cstdint>
16+
#include "register/op_def_registry.h"
17+
18+
namespace ops {
19+
class LightningIndexer : public OpDef {
20+
public:
21+
explicit LightningIndexer(const char *name) : OpDef(name)
22+
{
23+
this->Input("query")
24+
.ParamType(REQUIRED)
25+
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
26+
.FormatList({ge::FORMAT_ND})
27+
.AutoContiguous();
28+
this->Input("key")
29+
.ParamType(REQUIRED)
30+
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
31+
.FormatList({ge::FORMAT_ND})
32+
.AutoContiguous();
33+
this->Input("weights")
34+
.ParamType(REQUIRED)
35+
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
36+
.FormatList({ge::FORMAT_ND})
37+
.AutoContiguous();
38+
this->Input("actual_seq_lengths_query")
39+
.ParamType(OPTIONAL)
40+
.DataType({ge::DT_INT32, ge::DT_INT32})
41+
.FormatList({ge::FORMAT_ND})
42+
.AutoContiguous();
43+
this->Input("actual_seq_lengths_key")
44+
.ParamType(OPTIONAL)
45+
.DataType({ge::DT_INT32, ge::DT_INT32})
46+
.FormatList({ge::FORMAT_ND})
47+
.AutoContiguous();
48+
this->Input("block_table")
49+
.ParamType(OPTIONAL)
50+
.DataTypeList({ge::DT_INT32})
51+
.FormatList({ge::FORMAT_ND})
52+
.AutoContiguous();
53+
this->Output("sparse_indices").ParamType(REQUIRED).DataTypeList({ge::DT_INT32}).FormatList({ge::FORMAT_ND});
54+
this->Attr("layout_query").AttrType(OPTIONAL).String("BSND");
55+
this->Attr("layout_key").AttrType(OPTIONAL).String("PA_BSND");
56+
this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048); // 2048:默认值,筛选前2048
57+
this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3:默认值,只计算下三角
58+
OpAICoreConfig aicore_config;
59+
aicore_config.DynamicCompileStaticFlag(true)
60+
.DynamicFormatFlag(true)
61+
.DynamicRankSupportFlag(true)
62+
.DynamicShapeSupportFlag(true)
63+
.NeedCheckSupportFlag(false)
64+
.PrecisionReduceFlag(true)
65+
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
66+
.ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false");
67+
this->AICore().AddConfig("ascend910b", aicore_config);
68+
this->AICore().AddConfig("ascend910_93", aicore_config);
69+
}
70+
};
71+
OP_ADD(LightningIndexer);
72+
} // namespace ops
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/**
2+
* This program is free software, you can redistribute it and/or modify it.
3+
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
4+
* This file is a part of the CANN Open Software.
5+
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
6+
* Please refer to the License for details. You may not use this file except in compliance with the License.
7+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
8+
* See LICENSE in the root of the software repository for the full text of the License.
9+
*/
10+
11+
/*!
12+
* \file lightning_indexer_proto.cpp
13+
* \brief
14+
*/
15+
#include <graph/utils/type_utils.h>
16+
#include <register/op_impl_registry.h>
17+
#include "error/ops_error.h"
18+
19+
20+
using namespace ge;
21+
22+
namespace ops {
23+
constexpr uint32_t QUERY_INDEX = 0;
24+
constexpr uint32_t KEY_INDEX = 1;
25+
constexpr uint32_t ACTUAL_SEQ_K_INDEX = 4;
26+
constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 0;
27+
constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 1;
28+
constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 2;
29+
30+
static ge::graphStatus InferShapeLightningIndexer(gert::InferShapeContext *context)
31+
{
32+
OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferShapeContext is nullptr!"),
33+
return ge::GRAPH_FAILED);
34+
const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX);
35+
OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED);
36+
const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX);
37+
OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED);
38+
gert::Shape *outShape = context->GetOutputShape(0);
39+
40+
auto attrs = context->GetAttrs();
41+
OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED);
42+
const char *inputLayoutQueryPtr = attrs->GetAttrPointer<char>(ATTR_QUERY_LAYOUT_INDEX);
43+
OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED);
44+
const char *inputLayoutKeyPtr = attrs->GetAttrPointer<char>(ATTR_KEY_LAYOUT_INDEX);
45+
OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED);
46+
const int64_t *seleced_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX);
47+
OPS_LOG_E_IF_NULL(context, seleced_count, return ge::GRAPH_FAILED);
48+
std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr);
49+
std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr);
50+
OPS_ERR_IF(
51+
inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND",
52+
OPS_LOG_E(context, "The attr layout_query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str()),
53+
return ge::GRAPH_FAILED);
54+
55+
outShape->SetDimNum(queryShape->GetDimNum());
56+
if (inputLayoutQueryPtrStr == "BSND") {
57+
OPS_ERR_IF(
58+
queryShape->GetDimNum() != 4,
59+
OPS_LOG_E(context, "Layout BSND, queryDims (%zu) must be 4!", queryShape->GetDimNum()),
60+
return ge::GRAPH_FAILED);
61+
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B
62+
outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S
63+
outShape->SetDim(2, keyShape->GetDim(2)); // 2:Dim N
64+
outShape->SetDim(3, *seleced_count); // 3:Dim K
65+
} else {
66+
OPS_ERR_IF(
67+
queryShape->GetDimNum() != 3,
68+
OPS_LOG_E(context, "Layout TND, queryDims (%zu) must be 3!", queryShape->GetDimNum()),
69+
return ge::GRAPH_FAILED);
70+
outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T
71+
int32_t nDimIndex = (inputLayoutKeyPtrStr == "PA_BSND") ? 2 : 1; // 2:Key Dim N
72+
outShape->SetDim(1, keyShape->GetDim(nDimIndex)); // 1:Dim N
73+
outShape->SetDim(2, *seleced_count); // 2:Dim K
74+
}
75+
OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferShape end.");
76+
77+
return ge::GRAPH_SUCCESS;
78+
}
79+
80+
static ge::graphStatus InferDataTypeLightningIndexer(gert::InferDataTypeContext *context)
81+
{
82+
OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferDataTypeContext is nullptr!"),
83+
return ge::GRAPH_FAILED);
84+
OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexer InferDataType impl.");
85+
// default set q's dtype as fia's output type
86+
ge::DataType outputType = ge::DT_INT32;
87+
// attention_out, outidx:0
88+
context->SetOutputDataType(0, outputType);
89+
OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferDataType end.");
90+
return GRAPH_SUCCESS;
91+
}
92+
93+
IMPL_OP_INFERSHAPE(LightningIndexer)
94+
.InferShape(InferShapeLightningIndexer)
95+
.InferDataType(InferDataTypeLightningIndexer);
96+
} // namespace ops

0 commit comments

Comments
 (0)