|
| 1 | +#include "../../../devices/ascend/ascend_kernel_common.h" |
| 2 | + |
| 3 | +using namespace AscendC; |
| 4 | + |
| 5 | +template <typename T> |
| 6 | +class SwigluKernel { |
| 7 | +public: |
| 8 | + __aicore__ inline SwigluKernel() {} |
| 9 | + __aicore__ inline void init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd, |
| 10 | + int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b, |
| 11 | + int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b); |
| 12 | + __aicore__ inline void process(); |
| 13 | + |
| 14 | +private: |
| 15 | + __aicore__ inline void copyIn(int64_t i); |
| 16 | + __aicore__ inline void compute(int64_t i); |
| 17 | + __aicore__ inline void copyOut(int64_t i); |
| 18 | + |
| 19 | +private: |
| 20 | + GlobalTensor<T> _c_gm, _a_gm, _b_gm; |
| 21 | + TQue<QuePosition::VECIN, BUFFER_NUM> _in_queue_a, _in_queue_b; |
| 22 | + TQue<QuePosition::VECOUT, BUFFER_NUM> _out_queue_c; |
| 23 | + |
| 24 | + TPipe _pipe; |
| 25 | + float _beta_value = 1.0f; |
| 26 | + int64_t _block_idx, _tile_len, _copy_len, |
| 27 | + _batch, _seq_len, _hidden_size, |
| 28 | + _stride_seq_a, _stride_seq_b, _stride_seq_c; |
| 29 | + int64_t _stride_batch_a = 1, _stride_batch_b = 1, _stride_batch_c = 1; |
| 30 | +}; |
| 31 | + |
| 32 | +template <typename T> |
| 33 | +__aicore__ inline void SwigluKernel<T>::init(GM_ADDR c, GM_ADDR a, GM_ADDR b, int64_t batch_, int64_t seq, int64_t hd, |
| 34 | + int64_t stride_batch_c, int64_t stride_batch_a, int64_t stride_batch_b, |
| 35 | + int64_t stride_seq_c, int64_t stride_seq_a, int64_t stride_seq_b) { |
| 36 | + // Init Shape & StrideVariables |
| 37 | + _batch = batch_; |
| 38 | + _seq_len = seq; |
| 39 | + _hidden_size = hd; |
| 40 | + _stride_batch_a = stride_batch_a; |
| 41 | + _stride_batch_b = stride_batch_b; |
| 42 | + _stride_batch_c = stride_batch_c; |
| 43 | + _stride_seq_a = stride_seq_a; |
| 44 | + _stride_seq_b = stride_seq_b; |
| 45 | + _stride_seq_c = stride_seq_c; |
| 46 | + |
| 47 | + _block_idx = GetBlockIdx(); |
| 48 | + _tile_len = _block_idx < (_hidden_size % BLOCK_NUM) ? (_hidden_size / BLOCK_NUM) + 1 : (_hidden_size / BLOCK_NUM); |
| 49 | + _copy_len = (_tile_len * sizeof(T)) % BYTE_ALIGN == 0 ? _tile_len : (_tile_len * sizeof(T) + (BYTE_ALIGN - _tile_len * sizeof(T) % BYTE_ALIGN)) / sizeof(T); |
| 50 | + |
| 51 | + // Set global tensor |
| 52 | + _a_gm.SetGlobalBuffer((__gm__ T *)a); |
| 53 | + _b_gm.SetGlobalBuffer((__gm__ T *)b); |
| 54 | + _c_gm.SetGlobalBuffer((__gm__ T *)c); |
| 55 | + |
| 56 | + // _pipe alloc memory to queue, the unit is bytes |
| 57 | + _pipe.InitBuffer(_in_queue_a, BUFFER_NUM, _copy_len * sizeof(T)); |
| 58 | + _pipe.InitBuffer(_in_queue_b, BUFFER_NUM, _copy_len * sizeof(T)); |
| 59 | + _pipe.InitBuffer(_out_queue_c, BUFFER_NUM, _copy_len * sizeof(T)); |
| 60 | +} |
| 61 | + |
| 62 | +template <typename T> |
| 63 | +__aicore__ inline void SwigluKernel<T>::copyIn(int64_t i) { |
| 64 | + // Alloc tensor from queue memory |
| 65 | + LocalTensor<T> aLocal = _in_queue_a.AllocTensor<T>(); |
| 66 | + LocalTensor<T> bLocal = _in_queue_b.AllocTensor<T>(); |
| 67 | + // Get idx of current tile |
| 68 | + auto batch_idx = _batch == 1 ? 0 : i / _seq_len; |
| 69 | + auto seq_idx = _batch == 1 ? i : i % _seq_len; |
| 70 | + |
| 71 | + int64_t idxa = batch_idx * _stride_batch_a + seq_idx * _stride_seq_a + _block_idx * _tile_len; |
| 72 | + int64_t idxb = batch_idx * _stride_batch_b + seq_idx * _stride_seq_b + _block_idx * _tile_len; |
| 73 | + // Copy process_th tile from global tensor to local tensor |
| 74 | + DataCopy(aLocal, _a_gm[idxa], _copy_len); |
| 75 | + DataCopy(bLocal, _b_gm[idxb], _copy_len); |
| 76 | + |
| 77 | + // Enque input tensor to VECIN queue |
| 78 | + _in_queue_a.EnQue(aLocal); |
| 79 | + _in_queue_b.EnQue(bLocal); |
| 80 | +} |
| 81 | + |
| 82 | +template <typename T> |
| 83 | +__aicore__ inline void SwigluKernel<T>::compute(int64_t i) { |
| 84 | + // Deque input tensors from VECIN queue |
| 85 | + LocalTensor<T> aLocal = _in_queue_a.DeQue<T>(); |
| 86 | + LocalTensor<T> bLocal = _in_queue_b.DeQue<T>(); |
| 87 | + LocalTensor<T> cLocal = _out_queue_c.AllocTensor<T>(); |
| 88 | + // Call SwiGLU ascend api |
| 89 | + SwiGLU<T, false>(cLocal, aLocal, bLocal, _beta_value, _copy_len); |
| 90 | + // Enque result and free input |
| 91 | + _out_queue_c.EnQue<T>(cLocal); |
| 92 | + _in_queue_a.FreeTensor(aLocal); |
| 93 | + _in_queue_b.FreeTensor(bLocal); |
| 94 | +} |
| 95 | + |
| 96 | +template <typename T> |
| 97 | +__aicore__ inline void SwigluKernel<T>::copyOut(int64_t i) { |
| 98 | + // Deque output tensor from VECOUT queue |
| 99 | + LocalTensor<T> cLocal = _out_queue_c.DeQue<T>(); |
| 100 | + auto batch_idx = _batch == 1 ? 0 : i / _seq_len; |
| 101 | + auto seq_idx = _batch == 1 ? i : i % _seq_len; |
| 102 | + int64_t idxc = batch_idx * _stride_batch_c + seq_idx * _stride_seq_c + _block_idx * _tile_len; |
| 103 | + // Copy progress_th tile from local tensor to global tensor |
| 104 | + if (_tile_len * sizeof(T) % BYTE_ALIGN != 0) { |
| 105 | + DataCopyExtParams dcep = {1, static_cast<uint32_t>(_tile_len * sizeof(T)), 0, 0, 0}; |
| 106 | + DataCopyPad(_c_gm[idxc], cLocal, dcep); |
| 107 | + } else { |
| 108 | + DataCopy(_c_gm[idxc], cLocal, _tile_len); |
| 109 | + } |
| 110 | + // Free output Local tensor |
| 111 | + _out_queue_c.FreeTensor(cLocal); |
| 112 | +} |
| 113 | + |
| 114 | +template <typename T> |
| 115 | +__aicore__ inline void SwigluKernel<T>::process() { |
| 116 | + for (int64_t i = 0; i < _batch * _seq_len; ++i) { |
| 117 | + copyIn(i); |
| 118 | + compute(i); |
| 119 | + copyOut(i); |
| 120 | + } |
| 121 | +} |
| 122 | + |
| 123 | +#define DEFINE_SWIGLU_KERNEL(KERNEL_NAME, TYPE) \ |
| 124 | + __global__ __aicore__ void KERNEL_NAME(GM_ADDR c, GM_ADDR a, GM_ADDR b, \ |
| 125 | + int64_t batch, int64_t seq, int64_t hd, \ |
| 126 | + int64_t stride_batch_c, \ |
| 127 | + int64_t stride_batch_a, \ |
| 128 | + int64_t stride_batch_b, \ |
| 129 | + int64_t stride_seq_c, \ |
| 130 | + int64_t stride_seq_a, \ |
| 131 | + int64_t stride_seq_b) { \ |
| 132 | + SwigluKernel<TYPE> op; \ |
| 133 | + op.init(c, a, b, \ |
| 134 | + batch, seq, hd, \ |
| 135 | + stride_batch_c, stride_batch_a, stride_batch_b, \ |
| 136 | + stride_seq_c, stride_seq_a, stride_seq_b); \ |
| 137 | + op.process(); \ |
| 138 | + } |
| 139 | + |
| 140 | +DEFINE_SWIGLU_KERNEL(swiglu_kernel_half, half) |
| 141 | +DEFINE_SWIGLU_KERNEL(swiglu_kernel_float, float) |
| 142 | + |
| 143 | +#undef DEFINE_SWIGLU_KERNEL |
| 144 | + |
| 145 | +extern "C" infiniStatus_t swiglu_kernel_launch( |
| 146 | + void *c, void *a, void *b, |
| 147 | + infiniDtype_t dtype, size_t batch, size_t seq, size_t hd, |
| 148 | + ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b, |
| 149 | + ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream) { |
| 150 | + |
| 151 | +#define LAUNCH_SWIGLU_KERNEL(DTYPE_ENUM, KERNEL_NAME) \ |
| 152 | + case DTYPE_ENUM: \ |
| 153 | + KERNEL_NAME<<<BLOCK_NUM, nullptr, stream>>>( \ |
| 154 | + c, a, b, \ |
| 155 | + static_cast<int64_t>(batch), \ |
| 156 | + static_cast<int64_t>(seq), \ |
| 157 | + static_cast<int64_t>(hd), \ |
| 158 | + stride_batch_c, stride_batch_a, stride_batch_b, \ |
| 159 | + stride_seq_c, stride_seq_a, stride_seq_b); \ |
| 160 | + return INFINI_STATUS_SUCCESS; |
| 161 | + |
| 162 | + switch (dtype) { |
| 163 | + LAUNCH_SWIGLU_KERNEL(INFINI_DTYPE_F16, swiglu_kernel_half) |
| 164 | + LAUNCH_SWIGLU_KERNEL(INFINI_DTYPE_F32, swiglu_kernel_float) |
| 165 | + default: |
| 166 | + return INFINI_STATUS_BAD_TENSOR_DTYPE; |
| 167 | + } |
| 168 | + |
| 169 | +#undef LAUNCH_SWIGLU_KERNEL |
| 170 | +} |
0 commit comments