|
| 1 | +/** |
| 2 | + * This program is free software, you can redistribute it and/or modify it. |
| 3 | + * Copyright (c) 2025 Huawei Technologies Co., Ltd. |
| 4 | + * This file is a part of the CANN Open Software. |
| 5 | + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). |
| 6 | + * Please refer to the License for details. You may not use this file except in compliance with the License. |
| 7 | + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. |
| 8 | + * See LICENSE in the root of the software repository for the full text of the License. |
| 9 | + */ |
| 10 | + |
| 11 | +/*! |
| 12 | + * \file lightning_indexer_proto.cpp |
| 13 | + * \brief |
| 14 | + */ |
| 15 | +#include <graph/utils/type_utils.h> |
| 16 | +#include <register/op_impl_registry.h> |
| 17 | +#include "error/ops_error.h" |
| 18 | + |
| 19 | + |
| 20 | +using namespace ge; |
| 21 | + |
| 22 | +namespace ops { |
| 23 | +constexpr uint32_t QUERY_INDEX = 0; |
| 24 | +constexpr uint32_t KEY_INDEX = 1; |
| 25 | +constexpr uint32_t ACTUAL_SEQ_K_INDEX = 4; |
| 26 | +constexpr uint32_t ATTR_QUERY_LAYOUT_INDEX = 0; |
| 27 | +constexpr uint32_t ATTR_KEY_LAYOUT_INDEX = 1; |
| 28 | +constexpr uint32_t ATTR_SPARSE_COUNT_INDEX = 2; |
| 29 | + |
| 30 | +static ge::graphStatus InferShapeLightningIndexer(gert::InferShapeContext *context) |
| 31 | +{ |
| 32 | + OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferShapeContext is nullptr!"), |
| 33 | + return ge::GRAPH_FAILED); |
| 34 | + const gert::Shape *queryShape = context->GetInputShape(QUERY_INDEX); |
| 35 | + OPS_LOG_E_IF_NULL(context, queryShape, return ge::GRAPH_FAILED); |
| 36 | + const gert::Shape *keyShape = context->GetInputShape(KEY_INDEX); |
| 37 | + OPS_LOG_E_IF_NULL(context, keyShape, return ge::GRAPH_FAILED); |
| 38 | + gert::Shape *outShape = context->GetOutputShape(0); |
| 39 | + |
| 40 | + auto attrs = context->GetAttrs(); |
| 41 | + OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED); |
| 42 | + const char *inputLayoutQueryPtr = attrs->GetAttrPointer<char>(ATTR_QUERY_LAYOUT_INDEX); |
| 43 | + OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED); |
| 44 | + const char *inputLayoutKeyPtr = attrs->GetAttrPointer<char>(ATTR_KEY_LAYOUT_INDEX); |
| 45 | + OPS_LOG_E_IF_NULL(context, inputLayoutKeyPtr, return ge::GRAPH_FAILED); |
| 46 | + const int64_t *seleced_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX); |
| 47 | + OPS_LOG_E_IF_NULL(context, seleced_count, return ge::GRAPH_FAILED); |
| 48 | + std::string inputLayoutQueryPtrStr = std::string(inputLayoutQueryPtr); |
| 49 | + std::string inputLayoutKeyPtrStr = std::string(inputLayoutKeyPtr); |
| 50 | + OPS_ERR_IF( |
| 51 | + inputLayoutQueryPtrStr != "TND" && inputLayoutQueryPtrStr != "BSND", |
| 52 | + OPS_LOG_E(context, "The attr layout_query should be TND or BSND, but got %s.", inputLayoutQueryPtrStr.c_str()), |
| 53 | + return ge::GRAPH_FAILED); |
| 54 | + |
| 55 | + outShape->SetDimNum(queryShape->GetDimNum()); |
| 56 | + if (inputLayoutQueryPtrStr == "BSND") { |
| 57 | + OPS_ERR_IF( |
| 58 | + queryShape->GetDimNum() != 4, |
| 59 | + OPS_LOG_E(context, "Layout BSND, queryDims (%zu) must be 4!", queryShape->GetDimNum()), |
| 60 | + return ge::GRAPH_FAILED); |
| 61 | + outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim B |
| 62 | + outShape->SetDim(1, queryShape->GetDim(1)); // 1:Dim S |
| 63 | + outShape->SetDim(2, keyShape->GetDim(2)); // 2:Dim N |
| 64 | + outShape->SetDim(3, *seleced_count); // 3:Dim K |
| 65 | + } else { |
| 66 | + OPS_ERR_IF( |
| 67 | + queryShape->GetDimNum() != 3, |
| 68 | + OPS_LOG_E(context, "Layout TND, queryDims (%zu) must be 3!", queryShape->GetDimNum()), |
| 69 | + return ge::GRAPH_FAILED); |
| 70 | + outShape->SetDim(0, queryShape->GetDim(0)); // 0:Dim T |
| 71 | + int32_t nDimIndex = (inputLayoutKeyPtrStr == "PA_BSND") ? 2 : 1; // 2:Key Dim N |
| 72 | + outShape->SetDim(1, keyShape->GetDim(nDimIndex)); // 1:Dim N |
| 73 | + outShape->SetDim(2, *seleced_count); // 2:Dim K |
| 74 | + } |
| 75 | + OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferShape end."); |
| 76 | + |
| 77 | + return ge::GRAPH_SUCCESS; |
| 78 | +} |
| 79 | + |
| 80 | +static ge::graphStatus InferDataTypeLightningIndexer(gert::InferDataTypeContext *context) |
| 81 | +{ |
| 82 | + OPS_ERR_IF(context == nullptr, OPS_LOG_E("LightningIndexer", "InferDataTypeContext is nullptr!"), |
| 83 | + return ge::GRAPH_FAILED); |
| 84 | + OPS_LOG_D(context->GetNodeName(), "Enter LightningIndexer InferDataType impl."); |
| 85 | + // default set q's dtype as fia's output type |
| 86 | + ge::DataType outputType = ge::DT_INT32; |
| 87 | + // attention_out, outidx:0 |
| 88 | + context->SetOutputDataType(0, outputType); |
| 89 | + OPS_LOG_D(context->GetNodeName(), "LightningIndexer InferDataType end."); |
| 90 | + return GRAPH_SUCCESS; |
| 91 | +} |
| 92 | + |
| 93 | +IMPL_OP_INFERSHAPE(LightningIndexer) |
| 94 | + .InferShape(InferShapeLightningIndexer) |
| 95 | + .InferDataType(InferDataTypeLightningIndexer); |
| 96 | +} // namespace ops |
0 commit comments