@@ -47,9 +47,9 @@ struct MaxOp<half> {
4747 }
4848};
4949
50- template <typename T, template <typename > class ReduceOp , int thread_group_width = 32 >
50+ template <typename T, template <typename > class ReduceOp , int THREAD_GROUP_WIDTH = 32 >
5151__device__ __forceinline__ T warpReduce (T value) {
52- for (int mask = thread_group_width / 2 ; mask > 0 ; mask /= 2 ) {
52+ for (int mask = THREAD_GROUP_WIDTH / 2 ; mask > 0 ; mask /= 2 ) {
5353 value = ReduceOp<T>()(value, __shfl_xor_sync (0xffffffff , value, mask));
5454 }
5555 return value;
@@ -73,23 +73,23 @@ __device__ __forceinline__ T warpReduce(T value) {
7373也就是我们还需要i 和 j
7474i 也就是 (blockIdx.x * blockDim.y + threadIdx.y) / stride
7575j 也就是 (blockIdx.x * blockDim.y + threadIdx.y) % stride
76- 然后i转化为线性也就是 i * stride * dimsize
76+ 然后i转化为线性也就是 i * stride * dim_size
7777j直接加上就好
7878*/
79- template <int elemPerThread , int BLOCK_DIM_Y, int BLOCK_DIM_X, typename T>
80- __global__ void Softmax_warp_impl (const T *x, T *y, int stride, int dimsize , int otherdim_size ) {
81- float dataPerThread[elemPerThread ];
79+ template <int ELEM_PER_THREAD , int BLOCK_DIM_Y, int BLOCK_DIM_X, typename T>
80+ __global__ void Softmax_warp_impl (const T *x, T *y, int stride, int dim_size , int other_size ) {
81+ float dataPerThread[ELEM_PER_THREAD ];
8282 int global_warp_id = blockIdx .x * blockDim .y + threadIdx .y ;
83- int group_offset = global_warp_id % stride + (global_warp_id - global_warp_id % stride) * dimsize ;
83+ int group_offset = global_warp_id % stride + (global_warp_id - global_warp_id % stride) * dim_size ;
8484 int tid = threadIdx .x ;
85- if (global_warp_id >= otherdim_size ) {
85+ if (global_warp_id >= other_size ) {
8686 return ;
8787 }
8888 __shared__ float group_max[BLOCK_DIM_X];
8989 __shared__ float group_sum[BLOCK_DIM_X];
9090 float thread_max = -INFINITY;
9191 float thread_sum = 0 .0f ;
92- for (int i = 0 ; tid + i * BLOCK_DIM_X < dimsize ; i++) {
92+ for (int i = 0 ; tid + i * BLOCK_DIM_X < dim_size ; i++) {
9393 dataPerThread[i] = static_cast <float >(x[(tid + i * BLOCK_DIM_X) * stride + group_offset]);
9494 thread_max = max (thread_max, dataPerThread[i]);
9595 }
@@ -99,7 +99,7 @@ __global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dimsize, int
9999 group_max[threadIdx .y ] = thread_max;
100100 }
101101
102- for (int i = 0 ; tid + i * BLOCK_DIM_X < dimsize ; i++) {
102+ for (int i = 0 ; tid + i * BLOCK_DIM_X < dim_size ; i++) {
103103 dataPerThread[i] = __expf (dataPerThread[i] - group_max[threadIdx .y ]);
104104 thread_sum += dataPerThread[i];
105105 }
@@ -109,18 +109,18 @@ __global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dimsize, int
109109 group_sum[threadIdx .y ] = thread_sum;
110110 }
111111
112- for (int i = 0 ; tid + i * BLOCK_DIM_X < dimsize ; i++) {
112+ for (int i = 0 ; tid + i * BLOCK_DIM_X < dim_size ; i++) {
113113 y[(tid + i * BLOCK_DIM_X) * stride + group_offset] = static_cast <T>(dataPerThread[i] * __fdividef (1 .0f , group_sum[threadIdx .y ]));
114114 }
115115}
116116
117- template <int elemPerThread , int BLOCK_DIM, typename T>
117+ template <int ELEM_PER_THREAD , int BLOCK_DIM, typename T>
118118__launch_bounds__ (BLOCK_DIM)
119- __global__ void Softmax_block_impl(const T *x, T *y, int stride, int dimsize , int otherdim_size ) {
120- // remain = dimsize - BLOCK_DIM * elemPerThread
119+ __global__ void Softmax_block_impl(const T *x, T *y, int stride, int dim_size , int other_size ) {
120+ // remain = dim_size - BLOCK_DIM * ELEM_PER_THREAD
121121 int tid = threadIdx .x ;
122- int block_offset = (blockIdx .x - blockIdx .x % stride) * dimsize + blockIdx .x % stride;
123- int remain = dimsize - (BLOCK_DIM - 1 ) * elemPerThread; // 🔧 修正:最后线程处理的元素数
122+ int block_offset = (blockIdx .x - blockIdx .x % stride) * dim_size + blockIdx .x % stride;
123+ int remain = dim_size - (BLOCK_DIM - 1 ) * ELEM_PER_THREAD;
124124
125125 MD md_partial;
126126 md_partial.max = -INFINITY;
@@ -129,16 +129,16 @@ __launch_bounds__(BLOCK_DIM)
129129 // tid = [0, BLOCK_DIM - 1], 所以最后一个线程处理余数部分
130130 if (tid < BLOCK_DIM - 1 ) {
131131#pragma unroll
132- for (int i = 0 ; i < elemPerThread ; i++) {
133- int index = (tid * elemPerThread + i) * stride + block_offset;
132+ for (int i = 0 ; i < ELEM_PER_THREAD ; i++) {
133+ int index = (tid * ELEM_PER_THREAD + i) * stride + block_offset;
134134 input.max = static_cast <float >(x[index]);
135135 input.sum = 1 .0f ;
136136 md_partial = reduce_for_md (md_partial, input);
137137 }
138138 } else {
139139#pragma unroll
140140 for (int i = 0 ; i < remain; i++) {
141- int index = ((BLOCK_DIM - 1 ) * elemPerThread + i) * stride + block_offset;
141+ int index = ((BLOCK_DIM - 1 ) * ELEM_PER_THREAD + i) * stride + block_offset;
142142 input.max = static_cast <float >(x[index]);
143143 input.sum = 1 .0f ;
144144 md_partial = reduce_for_md (md_partial, input);
@@ -153,85 +153,94 @@ __launch_bounds__(BLOCK_DIM)
153153 }
154154 __syncthreads ();
155155 if (tid < BLOCK_DIM - 1 ) {
156- for (int i = 0 ; i < elemPerThread ; i++) {
157- int index = (tid * elemPerThread + i) * stride + block_offset;
156+ for (int i = 0 ; i < ELEM_PER_THREAD ; i++) {
157+ int index = (tid * ELEM_PER_THREAD + i) * stride + block_offset;
158158 y[index] = static_cast <T>(__expf (static_cast <float >(x[index]) - md_total.max ) * __fdividef (1 .0f , md_total.sum ));
159159 }
160160 } else {
161161 for (int i = 0 ; i < remain; i++) {
162- int index = ((BLOCK_DIM - 1 ) * elemPerThread + i) * stride + block_offset;
162+ int index = ((BLOCK_DIM - 1 ) * ELEM_PER_THREAD + i) * stride + block_offset;
163163 y[index] = static_cast <T>(__expf (static_cast <float >(x[index]) - md_total.max ) * __fdividef (1 .0f , md_total.sum ));
164164 }
165165 }
166166}
167167
168168template <typename T>
169- infiniStatus_t softmax_dispatch (const op::softmax::SoftmaxInfo &info, void *y, const void *x, void *stream) {
170- int dimsize = info.dimsize ;
171- int stride = info.stride ;
172- int otherdim_size = info.otherdim_size ;
173- if (dimsize <= 1024 ) {
174- dim3 block (32 , 32 ); // BLOCK_DIM_X=32, BLOCK_DIM_Y=4
175- int num_blocks = (otherdim_size + block.y - 1 ) / block.y ;
176- dim3 grid (num_blocks, 1 , 1 );
177- int elemPerThread = (dimsize + 31 ) / 32 ; // 计算每个线程需要处理的元素数
178- elemPerThread = min (elemPerThread, 32 ); // 限制最大值
169+ void dispatchSoftmaxKernel (
170+ const void *x, void *y,
171+ int stride, int dim_size, int other_size,
172+ void *stream, bool use_warp_impl) {
173+
174+ int elemPerThread;
175+ dim3 grid, block;
176+
177+ if (use_warp_impl) {
178+ block = dim3 (32 , 32 );
179+ grid = dim3 ((other_size + block.y - 1 ) / block.y , 1 , 1 );
180+ elemPerThread = min ((dim_size + 31 ) / 32 , 32 );
181+
182+ #define LAUNCH_WARP_KERNEL (ELEM_PER_THREAD ) \
183+ Softmax_warp_impl<ELEM_PER_THREAD, 32 , 32 , T> \
184+ <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> ( \
185+ reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), \
186+ stride, dim_size, other_size)
187+
179188 if (elemPerThread <= 1 ) {
180- Softmax_warp_impl<1 , 32 , 32 , T>
181- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
182- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
189+ LAUNCH_WARP_KERNEL (1 );
183190 } else if (elemPerThread <= 2 ) {
184- Softmax_warp_impl<2 , 32 , 32 , T>
185- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
186- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
191+ LAUNCH_WARP_KERNEL (2 );
187192 } else if (elemPerThread <= 4 ) {
188- Softmax_warp_impl<4 , 32 , 32 , T>
189- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
190- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
193+ LAUNCH_WARP_KERNEL (4 );
191194 } else if (elemPerThread <= 8 ) {
192- Softmax_warp_impl<8 , 32 , 32 , T>
193- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
194- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
195+ LAUNCH_WARP_KERNEL (8 );
195196 } else if (elemPerThread <= 16 ) {
196- Softmax_warp_impl<16 , 32 , 32 , T>
197- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
198- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
197+ LAUNCH_WARP_KERNEL (16 );
199198 } else {
200- Softmax_warp_impl<32 , 32 , 32 , T>
201- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
202- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
199+ LAUNCH_WARP_KERNEL (32 );
203200 }
204- } else if (dimsize > 1024 ) {
205- int block_size = 1024 ;
206- int elemPerThread = (dimsize + block_size - 1 ) / block_size; // 每个线程需要处理的元素数
207- elemPerThread = min (elemPerThread, 32 ); // 限制最大值为32
208- dim3 block (block_size);
209- dim3 grid (otherdim_size);
201+
202+ #undef LAUNCH_WARP_KERNEL
203+
204+ } else {
205+ // Block implementation for dim_size > 1024
206+ constexpr int BLOCK_SIZE = 1024 ;
207+ block = dim3 (BLOCK_SIZE);
208+ grid = dim3 (other_size);
209+ elemPerThread = min ((dim_size + BLOCK_SIZE - 1 ) / BLOCK_SIZE, 32 );
210+
211+ #define LAUNCH_BLOCK_KERNEL (ELEM_PER_THREAD ) \
212+ Softmax_block_impl<ELEM_PER_THREAD, BLOCK_SIZE, T> \
213+ <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> ( \
214+ reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), \
215+ stride, dim_size, other_size)
216+
210217 if (elemPerThread <= 1 ) {
211- Softmax_block_impl<1 , 1024 , T>
212- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
213- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
218+ LAUNCH_BLOCK_KERNEL (1 );
214219 } else if (elemPerThread <= 2 ) {
215- Softmax_block_impl<2 , 1024 , T>
216- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
217- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
220+ LAUNCH_BLOCK_KERNEL (2 );
218221 } else if (elemPerThread <= 4 ) {
219- Softmax_block_impl<4 , 1024 , T>
220- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
221- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
222+ LAUNCH_BLOCK_KERNEL (4 );
222223 } else if (elemPerThread <= 8 ) {
223- Softmax_block_impl<8 , 1024 , T>
224- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
225- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
224+ LAUNCH_BLOCK_KERNEL (8 );
226225 } else if (elemPerThread <= 16 ) {
227- Softmax_block_impl<16 , 1024 , T>
228- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
229- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
226+ LAUNCH_BLOCK_KERNEL (16 );
230227 } else {
231- Softmax_block_impl<32 , 1024 , T>
232- <<<grid, block, 0 , reinterpret_cast <cudaStream_t>(stream)>>> (
233- reinterpret_cast <const T *>(x), reinterpret_cast <T *>(y), stride, dimsize, otherdim_size);
228+ LAUNCH_BLOCK_KERNEL (32 );
234229 }
230+
231+ #undef LAUNCH_BLOCK_KERNEL
232+ }
233+ }
234+
235+ template <typename T>
236+ infiniStatus_t softmax_dispatch (const op::softmax::SoftmaxInfo &info, void *y, const void *x, void *stream) {
237+ int dim_size = info.dim_size ;
238+ int stride = info.stride ;
239+ int other_size = info.other_size ;
240+ if (dim_size <= 1024 ) {
241+ dispatchSoftmaxKernel<T>(x, y, stride, dim_size, other_size, stream, true );
242+ } else if (dim_size > 1024 ) {
243+ dispatchSoftmaxKernel<T>(x, y, stride, dim_size, other_size, stream, false );
235244 }
236245 return INFINI_STATUS_SUCCESS;
237246}
0 commit comments