Skip to content

Commit c68ddc1

Browse files
authored
[OPS] add bmm_transpose ops (#3990)
### What this PR does / why we need it? Add a new fusion ops to custom_op, which can cobime the torch.bmm() and transpsose to achieve better peformance. This ops is used in mla_v1 to replace the bmm and transpose ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.11.2 --------- Signed-off-by: hust17yixuan <303660421@qq.com>
1 parent bc67696 commit c68ddc1

File tree

15 files changed

+1737
-14
lines changed

15 files changed

+1737
-14
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ repos:
1313
args: [
1414
--toml, pyproject.toml,
1515
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
16-
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND'
16+
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND,ND'
1717
]
1818
additional_dependencies:
1919
- tomli

CMakeLists.txt

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,36 @@ include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
5555
file(GLOB KERNEL_FILES
5656
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/*.cpp)
5757

58-
ascendc_library(vllm_ascend_kernels SHARED
58+
set(VLLM_ASCEND_CUSTOM_OP
5959
${KERNEL_FILES}
6060
${CMAKE_CURRENT_SOURCE_DIR}/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
61+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp
62+
)
63+
64+
set(VLLM_ASCEND_CUSTOM_OP_EXCLUDE
65+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp
66+
)
67+
68+
if(SOC_VERSION STREQUAL "ASCEND310P3")
69+
list(REMOVE_ITEM VLLM_ASCEND_CUSTOM_OP ${VLLM_ASCEND_CUSTOM_OP_EXCLUDE})
70+
endif()
71+
72+
ascendc_library(vllm_ascend_kernels SHARED
73+
${VLLM_ASCEND_CUSTOM_OP}
6174
)
6275

6376
message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}")
6477

65-
file(GLOB VLLM_ASCEND_SRC
66-
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp
67-
${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp)
78+
if(SOC_VERSION STREQUAL "ASCEND310P3")
79+
file(GLOB VLLM_ASCEND_SRC
80+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp
81+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp)
82+
else()
83+
file(GLOB VLLM_ASCEND_SRC
84+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp
85+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/aclnn_torch_adapter/*.cpp
86+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp)
87+
endif()
6888

6989
include_directories(
7090
${pybind11_INCLUDE_DIRS}
@@ -74,6 +94,7 @@ include_directories(
7494
${ASCEND_HOME_PATH}/include
7595
${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform
7696
${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform
97+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_host
7798
)
7899

79100
set(
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#include <iostream>
2+
#include <string>
3+
#include "acl/acl.h"
4+
#include "kernel_tiling/kernel_tiling.h"
5+
#include "tiling/platform/platform_ascendc.h"
6+
#include "tiling/tiling_data.h"
7+
#include "common_tiling.h"
8+
9+
10+
namespace bmm_trans {
11+
using namespace pp_matmul;
12+
13+
std::unordered_map<c10::string_view, uint16_t> quantModeMap = {
14+
{"per_channel_symm", 0},
15+
{"per_channel_asymm", 1},
16+
{"per_token_symm", 2},
17+
};
18+
19+
std::unordered_map<c10::string_view, uint16_t> formatModeMap = {
20+
{"ND", 0},
21+
{"NZ", 1},
22+
};
23+
24+
std::unordered_map<c10::ScalarType, TensorDType> atType2tensorDType = {
25+
{at::ScalarType::BFloat16, TensorDType::TENSOR_DTYPE_BF16},
26+
{at::ScalarType::Half, TensorDType::TENSOR_DTYPE_FLOAT16}};
27+
28+
// batch size -> memory index
29+
constexpr uint32_t MAX_CAPTURE_NUM = 1024;
30+
31+
template <typename MapType>
32+
inline int GetModeVal(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
33+
const char *mode_name)
34+
{
35+
std::string modeStr(mode_name);
36+
c10::string_view mode_str = mode_opt.value_or(default_mode);
37+
auto it = mode_map.find(mode_str);
38+
// if input mode is unsupported, use default value
39+
TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str));
40+
return it->second;
41+
}
42+
43+
std::tuple<at::Tensor, uint32_t> batch_matmul_transpose_tiling(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
44+
c10::optional<c10::string_view> format_mode,
45+
c10::optional<c10::string_view> quant_mode)
46+
{
47+
auto tensorAShape = tensor_a.sizes();
48+
auto tensorBShape = tensor_b.sizes();
49+
auto tensorCShape = tensor_c.sizes();
50+
uint32_t n;
51+
uint32_t block_dim;
52+
53+
//auto &platform = PlatformInfo::Instance();
54+
HardwareInfo hwInfo;
55+
std::map<c10::ScalarType, float> dTypeMap = {{at::ScalarType::Half, 2.0}, {at::ScalarType::BFloat16, 2.0}};
56+
57+
at::ScalarType aType = tensor_a.scalar_type();
58+
at::ScalarType bType = tensor_b.scalar_type();
59+
at::ScalarType cType = tensor_c.scalar_type();
60+
TORCH_CHECK(aType == bType && bType == cType, "tensor type is not the same");
61+
TORCH_CHECK((aType == at::ScalarType::BFloat16) || (aType == at::ScalarType::Half),
62+
"tensor type only support half or bf16");
63+
64+
TensorFormat formatMode = static_cast<TensorFormat>(GetModeVal(formatModeMap, format_mode, "ND", "format_mode"));
65+
MatMul::QuantMode quantMode =
66+
static_cast<MatMul::QuantMode>(GetModeVal(quantModeMap, quant_mode, "per_channel_symm", "quant_mode"));
67+
68+
TORCH_CHECK(tensorAShape.size() == 3, "batch size is not same between srcTensor and dstTensor");
69+
if (formatMode == TensorFormat::TENSOR_FORMAT_ND) {
70+
TORCH_CHECK(tensorBShape.size() == 3, "tensor shape should be dim3 in ND format");
71+
TORCH_CHECK(tensorAShape[2] == tensorBShape[1], "tensor shape is wrong");
72+
n = tensorBShape[2];
73+
} else {
74+
TORCH_CHECK(tensorBShape.size() == 4, "tensor shape should be dim4 in nz format");
75+
TORCH_CHECK(tensorAShape[2] == tensorBShape[2], "tensor shape is wrong");
76+
n = tensorBShape[1] * tensorBShape[3];
77+
}
78+
TORCH_CHECK(tensorAShape[1] == tensorBShape[0], "tensor shape is wrong");
79+
80+
OpShape opShape = {.batchSize = static_cast<uint32_t>(tensorAShape[1]),
81+
.m = static_cast<uint32_t>(tensorAShape[0]),
82+
.k = static_cast<uint32_t>(tensorAShape[2]),
83+
.n = n};
84+
pp_matmul::PpMatmulTilingData matmulTilingData = {
85+
.opShape = opShape,
86+
};
87+
auto dType = atType2tensorDType[aType];
88+
MatMulInfo mmInfo = {.batchSize = opShape.batchSize,
89+
.m = opShape.m,
90+
.k = opShape.k,
91+
.n = opShape.n,
92+
.dtypeA = dType,
93+
.dtypeB = dType,
94+
.dtypeC = dType,
95+
.formatB = formatMode,
96+
.mmType = MatMul::MatMulType::MATMUL_EIN_SUM,
97+
.inDtype = dTypeMap[aType],
98+
.outDtype = dTypeMap[cType],
99+
.quantMode = quantMode};
100+
GetPpMatmulTiling(mmInfo, hwInfo, block_dim, matmulTilingData);
101+
host_utils::PpMatmulTilingCheck(matmulTilingData);
102+
103+
// tiling
104+
int32_t batchIdx = opShape.m - 1;
105+
uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData);
106+
static auto global_tiling_data = at::empty(
107+
{tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));
108+
if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) {
109+
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, &matmulTilingData,
110+
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
111+
} else {
112+
// Handle the case where batchIdx is out of range
113+
TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx);
114+
}
115+
at::Tensor tiling_tensor =
116+
at::from_blob(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, at::kByte);
117+
118+
return std::make_tuple(tiling_tensor, block_dim);
119+
120+
}
121+
122+
}
123+
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
// Licensed under the BSD 3-Clause License (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// Unless required by applicable law or agreed to in writing, software
7+
// distributed under the License is distributed on an "AS IS" BASIS,
8+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
// See the License for the specific language governing permissions and
10+
// limitations under the License.
11+
12+
#ifndef UTILS_COMMON_H
13+
#define UTILS_COMMON_H
14+
15+
namespace host_utils {
16+
17+
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT64 = 4;
18+
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT32 = 8;
19+
20+
inline uint64_t alinInt64Count(uint64_t count)
21+
{
22+
return (count + BLK_SIZE_ALIN_FOR_INT64 - 1) / BLK_SIZE_ALIN_FOR_INT64 * BLK_SIZE_ALIN_FOR_INT64;
23+
}
24+
25+
inline uint64_t alinInt32Count(uint64_t count)
26+
{
27+
return (count + BLK_SIZE_ALIN_FOR_INT32 - 1) / BLK_SIZE_ALIN_FOR_INT32 * BLK_SIZE_ALIN_FOR_INT32;
28+
}
29+
30+
template <typename T>
31+
inline T CeilDiv(const T dividend, const T divisor)
32+
{
33+
if (divisor == 0) {
34+
return UINT32_MAX;
35+
}
36+
return (dividend + divisor - 1) / divisor;
37+
}
38+
39+
template <typename T>
40+
inline T RoundUp(const T val, const T align = 16)
41+
{
42+
if (align == 0 || val + align - 1 < val) {
43+
return 0;
44+
}
45+
return (val + align - 1) / align * align;
46+
}
47+
48+
template <typename T>
49+
inline T RoundDown(const T val, const T align = 16)
50+
{
51+
if (align == 0) {
52+
return 0;
53+
}
54+
return val / align * align;
55+
}
56+
} // namespace host_utils
57+
#endif // UTILS_COMMON_H

0 commit comments

Comments
 (0)