Skip to content

Commit 302ea32

Browse files
committed
Issue/259 softmax_cuda 算子dispatch抽象以及格式规范化
1 parent 3d63b2c commit 302ea32

File tree

8 files changed

+147
-107
lines changed

8 files changed

+147
-107
lines changed

include/infiniop/ops/softmax.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ __C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t d
1616
__C infiniStatus_t infiniopSoftmax(infiniopSoftmaxDescriptor_t desc, void *workspace, size_t workspace_size, void *y, const void *x, void *stream);
1717

1818
__C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc);
19+
1920
#endif

src/infiniop/ops/softmax/cpu/softmax_cpu.cc

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,48 +38,35 @@ infiniStatus_t Descriptor::create(
3838
template <typename T>
3939
void softmax_cpu(const SoftmaxInfo &info,
4040
const void *x, void *y, int axis) {
41-
int dimsize = info.dimsize;
41+
int dim_size = info.dim_size;
4242
int stride = info.stride;
43-
int othersize = info.otherdim_size;
44-
auto to_float = [](const T &val) -> float {
45-
if constexpr (std::is_same_v<T, fp16_t>) {
46-
return utils::cast<float>(val);
47-
} else {
48-
return val;
49-
}
50-
};
51-
52-
auto from_float = [](float val) -> T {
53-
if constexpr (std::is_same_v<T, fp16_t>) {
54-
return utils::cast<fp16_t>(val);
55-
} else {
56-
return val;
57-
}
58-
};
59-
43+
int other_size = info.other_size;
6044
auto input = reinterpret_cast<const T *>(x);
6145
auto output = reinterpret_cast<T *>(y);
6246

6347
auto compute_softmax = [&](int i) {
64-
int tid = i % stride + (i - i % stride) * dimsize;
48+
int tid = i % stride + (i - i % stride) * dim_size;
49+
6550
float max_data = -INFINITY;
66-
for (int j = 0; j < dimsize; j++) {
51+
for (int j = 0; j < dim_size; j++) {
6752
int index = tid + j * stride;
68-
max_data = fmax(max_data, to_float(input[index]));
53+
max_data = fmax(max_data, utils::cast<float>(input[index]));
6954
}
55+
7056
float sum_data = 0.0f;
71-
for (int j = 0; j < dimsize; j++) {
57+
for (int j = 0; j < dim_size; j++) {
7258
int index = tid + j * stride;
73-
sum_data += std::exp(to_float(input[index]) - max_data);
59+
sum_data += std::exp(utils::cast<float>(input[index]) - max_data);
7460
}
75-
for (int j = 0; j < dimsize; j++) {
61+
62+
for (int j = 0; j < dim_size; j++) {
7663
int index = tid + j * stride;
77-
float result = std::exp(to_float(input[index]) - max_data) / sum_data;
78-
output[index] = from_float(result);
64+
float result = std::exp(utils::cast<float>(input[index]) - max_data) / sum_data;
65+
output[index] = utils::cast<T>(result);
7966
}
8067
};
8168
#pragma omp parallel for
82-
for (int i = 0; i < othersize; i++) {
69+
for (int i = 0; i < other_size; i++) {
8370
compute_softmax(i);
8471
}
8572
}

src/infiniop/ops/softmax/cuda/softmax_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ infiniStatus_t Descriptor::calculate(
4949
return INFINI_STATUS_BAD_TENSOR_DTYPE;
5050
}
5151
}
52-
} // namespace op::softmax::cuda
52+
} // namespace op::softmax::cuda

src/infiniop/ops/softmax/cuda/softmax_kernel.cuh

Lines changed: 83 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ struct MaxOp<half> {
4747
}
4848
};
4949

50-
template <typename T, template <typename> class ReduceOp, int thread_group_width = 32>
50+
template <typename T, template <typename> class ReduceOp, int THREAD_GROUP_WIDTH = 32>
5151
__device__ __forceinline__ T warpReduce(T value) {
52-
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
52+
for (int mask = THREAD_GROUP_WIDTH / 2; mask > 0; mask /= 2) {
5353
value = ReduceOp<T>()(value, __shfl_xor_sync(0xffffffff, value, mask));
5454
}
5555
return value;
@@ -73,23 +73,23 @@ __device__ __forceinline__ T warpReduce(T value) {
7373
也就是我们还需要i 和 j
7474
i 也就是 (blockIdx.x * blockDim.y + threadIdx.y) / stride
7575
j 也就是 (blockIdx.x * blockDim.y + threadIdx.y) % stride
76-
然后i转化为线性也就是 i * stride * dimsize
76+
然后i转化为线性也就是 i * stride * dim_size
7777
j直接加上就好
7878
*/
79-
template <int elemPerThread, int BLOCK_DIM_Y, int BLOCK_DIM_X, typename T>
80-
__global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dimsize, int otherdim_size) {
81-
float dataPerThread[elemPerThread];
79+
template <int ELEM_PER_THREAD, int BLOCK_DIM_Y, int BLOCK_DIM_X, typename T>
80+
__global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dim_size, int other_size) {
81+
float dataPerThread[ELEM_PER_THREAD];
8282
int global_warp_id = blockIdx.x * blockDim.y + threadIdx.y;
83-
int group_offset = global_warp_id % stride + (global_warp_id - global_warp_id % stride) * dimsize;
83+
int group_offset = global_warp_id % stride + (global_warp_id - global_warp_id % stride) * dim_size;
8484
int tid = threadIdx.x;
85-
if (global_warp_id >= otherdim_size) {
85+
if (global_warp_id >= other_size) {
8686
return;
8787
}
8888
__shared__ float group_max[BLOCK_DIM_X];
8989
__shared__ float group_sum[BLOCK_DIM_X];
9090
float thread_max = -INFINITY;
9191
float thread_sum = 0.0f;
92-
for (int i = 0; tid + i * BLOCK_DIM_X < dimsize; i++) {
92+
for (int i = 0; tid + i * BLOCK_DIM_X < dim_size; i++) {
9393
dataPerThread[i] = static_cast<float>(x[(tid + i * BLOCK_DIM_X) * stride + group_offset]);
9494
thread_max = max(thread_max, dataPerThread[i]);
9595
}
@@ -99,7 +99,7 @@ __global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dimsize, int
9999
group_max[threadIdx.y] = thread_max;
100100
}
101101

102-
for (int i = 0; tid + i * BLOCK_DIM_X < dimsize; i++) {
102+
for (int i = 0; tid + i * BLOCK_DIM_X < dim_size; i++) {
103103
dataPerThread[i] = __expf(dataPerThread[i] - group_max[threadIdx.y]);
104104
thread_sum += dataPerThread[i];
105105
}
@@ -109,18 +109,18 @@ __global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dimsize, int
109109
group_sum[threadIdx.y] = thread_sum;
110110
}
111111

112-
for (int i = 0; tid + i * BLOCK_DIM_X < dimsize; i++) {
112+
for (int i = 0; tid + i * BLOCK_DIM_X < dim_size; i++) {
113113
y[(tid + i * BLOCK_DIM_X) * stride + group_offset] = static_cast<T>(dataPerThread[i] * __fdividef(1.0f, group_sum[threadIdx.y]));
114114
}
115115
}
116116

117-
template <int elemPerThread, int BLOCK_DIM, typename T>
117+
template <int ELEM_PER_THREAD, int BLOCK_DIM, typename T>
118118
__launch_bounds__(BLOCK_DIM)
119-
__global__ void Softmax_block_impl(const T *x, T *y, int stride, int dimsize, int otherdim_size) {
120-
// remain = dimsize - BLOCK_DIM * elemPerThread
119+
__global__ void Softmax_block_impl(const T *x, T *y, int stride, int dim_size, int other_size) {
120+
// remain = dim_size - BLOCK_DIM * ELEM_PER_THREAD
121121
int tid = threadIdx.x;
122-
int block_offset = (blockIdx.x - blockIdx.x % stride) * dimsize + blockIdx.x % stride;
123-
int remain = dimsize - (BLOCK_DIM - 1) * elemPerThread; // 🔧 修正:最后线程处理的元素数
122+
int block_offset = (blockIdx.x - blockIdx.x % stride) * dim_size + blockIdx.x % stride;
123+
int remain = dim_size - (BLOCK_DIM - 1) * ELEM_PER_THREAD;
124124

125125
MD md_partial;
126126
md_partial.max = -INFINITY;
@@ -129,16 +129,16 @@ __launch_bounds__(BLOCK_DIM)
129129
// tid = [0, BLOCK_DIM - 1], 所以最后一个线程处理余数部分
130130
if (tid < BLOCK_DIM - 1) {
131131
#pragma unroll
132-
for (int i = 0; i < elemPerThread; i++) {
133-
int index = (tid * elemPerThread + i) * stride + block_offset;
132+
for (int i = 0; i < ELEM_PER_THREAD; i++) {
133+
int index = (tid * ELEM_PER_THREAD + i) * stride + block_offset;
134134
input.max = static_cast<float>(x[index]);
135135
input.sum = 1.0f;
136136
md_partial = reduce_for_md(md_partial, input);
137137
}
138138
} else {
139139
#pragma unroll
140140
for (int i = 0; i < remain; i++) {
141-
int index = ((BLOCK_DIM - 1) * elemPerThread + i) * stride + block_offset;
141+
int index = ((BLOCK_DIM - 1) * ELEM_PER_THREAD + i) * stride + block_offset;
142142
input.max = static_cast<float>(x[index]);
143143
input.sum = 1.0f;
144144
md_partial = reduce_for_md(md_partial, input);
@@ -153,85 +153,94 @@ __launch_bounds__(BLOCK_DIM)
153153
}
154154
__syncthreads();
155155
if (tid < BLOCK_DIM - 1) {
156-
for (int i = 0; i < elemPerThread; i++) {
157-
int index = (tid * elemPerThread + i) * stride + block_offset;
156+
for (int i = 0; i < ELEM_PER_THREAD; i++) {
157+
int index = (tid * ELEM_PER_THREAD + i) * stride + block_offset;
158158
y[index] = static_cast<T>(__expf(static_cast<float>(x[index]) - md_total.max) * __fdividef(1.0f, md_total.sum));
159159
}
160160
} else {
161161
for (int i = 0; i < remain; i++) {
162-
int index = ((BLOCK_DIM - 1) * elemPerThread + i) * stride + block_offset;
162+
int index = ((BLOCK_DIM - 1) * ELEM_PER_THREAD + i) * stride + block_offset;
163163
y[index] = static_cast<T>(__expf(static_cast<float>(x[index]) - md_total.max) * __fdividef(1.0f, md_total.sum));
164164
}
165165
}
166166
}
167167

168168
template <typename T>
169-
infiniStatus_t softmax_dispatch(const op::softmax::SoftmaxInfo &info, void *y, const void *x, void *stream) {
170-
int dimsize = info.dimsize;
171-
int stride = info.stride;
172-
int otherdim_size = info.otherdim_size;
173-
if (dimsize <= 1024) {
174-
dim3 block(32, 32); // BLOCK_DIM_X=32, BLOCK_DIM_Y=4
175-
int num_blocks = (otherdim_size + block.y - 1) / block.y;
176-
dim3 grid(num_blocks, 1, 1);
177-
int elemPerThread = (dimsize + 31) / 32; // 计算每个线程需要处理的元素数
178-
elemPerThread = min(elemPerThread, 32); // 限制最大值
169+
void dispatchSoftmaxKernel(
170+
const void *x, void *y,
171+
int stride, int dim_size, int other_size,
172+
void *stream, bool use_warp_impl) {
173+
174+
int elemPerThread;
175+
dim3 grid, block;
176+
177+
if (use_warp_impl) {
178+
block = dim3(32, 32);
179+
grid = dim3((other_size + block.y - 1) / block.y, 1, 1);
180+
elemPerThread = min((dim_size + 31) / 32, 32);
181+
182+
#define LAUNCH_WARP_KERNEL(ELEM_PER_THREAD) \
183+
Softmax_warp_impl<ELEM_PER_THREAD, 32, 32, T> \
184+
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>( \
185+
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), \
186+
stride, dim_size, other_size)
187+
179188
if (elemPerThread <= 1) {
180-
Softmax_warp_impl<1, 32, 32, T>
181-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
182-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
189+
LAUNCH_WARP_KERNEL(1);
183190
} else if (elemPerThread <= 2) {
184-
Softmax_warp_impl<2, 32, 32, T>
185-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
186-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
191+
LAUNCH_WARP_KERNEL(2);
187192
} else if (elemPerThread <= 4) {
188-
Softmax_warp_impl<4, 32, 32, T>
189-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
190-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
193+
LAUNCH_WARP_KERNEL(4);
191194
} else if (elemPerThread <= 8) {
192-
Softmax_warp_impl<8, 32, 32, T>
193-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
194-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
195+
LAUNCH_WARP_KERNEL(8);
195196
} else if (elemPerThread <= 16) {
196-
Softmax_warp_impl<16, 32, 32, T>
197-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
198-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
197+
LAUNCH_WARP_KERNEL(16);
199198
} else {
200-
Softmax_warp_impl<32, 32, 32, T>
201-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
202-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
199+
LAUNCH_WARP_KERNEL(32);
203200
}
204-
} else if (dimsize > 1024) {
205-
int block_size = 1024;
206-
int elemPerThread = (dimsize + block_size - 1) / block_size; // 每个线程需要处理的元素数
207-
elemPerThread = min(elemPerThread, 32); // 限制最大值为32
208-
dim3 block(block_size);
209-
dim3 grid(otherdim_size);
201+
202+
#undef LAUNCH_WARP_KERNEL
203+
204+
} else {
205+
// Block implementation for dim_size > 1024
206+
constexpr int BLOCK_SIZE = 1024;
207+
block = dim3(BLOCK_SIZE);
208+
grid = dim3(other_size);
209+
elemPerThread = min((dim_size + BLOCK_SIZE - 1) / BLOCK_SIZE, 32);
210+
211+
#define LAUNCH_BLOCK_KERNEL(ELEM_PER_THREAD) \
212+
Softmax_block_impl<ELEM_PER_THREAD, BLOCK_SIZE, T> \
213+
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>( \
214+
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), \
215+
stride, dim_size, other_size)
216+
210217
if (elemPerThread <= 1) {
211-
Softmax_block_impl<1, 1024, T>
212-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
213-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
218+
LAUNCH_BLOCK_KERNEL(1);
214219
} else if (elemPerThread <= 2) {
215-
Softmax_block_impl<2, 1024, T>
216-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
217-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
220+
LAUNCH_BLOCK_KERNEL(2);
218221
} else if (elemPerThread <= 4) {
219-
Softmax_block_impl<4, 1024, T>
220-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
221-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
222+
LAUNCH_BLOCK_KERNEL(4);
222223
} else if (elemPerThread <= 8) {
223-
Softmax_block_impl<8, 1024, T>
224-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
225-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
224+
LAUNCH_BLOCK_KERNEL(8);
226225
} else if (elemPerThread <= 16) {
227-
Softmax_block_impl<16, 1024, T>
228-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
229-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
226+
LAUNCH_BLOCK_KERNEL(16);
230227
} else {
231-
Softmax_block_impl<32, 1024, T>
232-
<<<grid, block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
233-
reinterpret_cast<const T *>(x), reinterpret_cast<T *>(y), stride, dimsize, otherdim_size);
228+
LAUNCH_BLOCK_KERNEL(32);
234229
}
230+
231+
#undef LAUNCH_BLOCK_KERNEL
232+
}
233+
}
234+
235+
template <typename T>
236+
infiniStatus_t softmax_dispatch(const op::softmax::SoftmaxInfo &info, void *y, const void *x, void *stream) {
237+
int dim_size = info.dim_size;
238+
int stride = info.stride;
239+
int other_size = info.other_size;
240+
if (dim_size <= 1024) {
241+
dispatchSoftmaxKernel<T>(x, y, stride, dim_size, other_size, stream, true);
242+
} else if (dim_size > 1024) {
243+
dispatchSoftmaxKernel<T>(x, y, stride, dim_size, other_size, stream, false);
235244
}
236245
return INFINI_STATUS_SUCCESS;
237246
}

src/infiniop/ops/softmax/info.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ namespace op::softmax {
1010
class SoftmaxInfo {
1111
public:
1212
int axis;
13-
int otherdim_size;
13+
int other_size;
1414
int stride;
1515
int size;
16-
int dimsize;
16+
int dim_size;
1717

1818
static utils::Result<SoftmaxInfo> create(
1919
infiniopTensorDescriptor_t y_desc,
@@ -27,9 +27,9 @@ class SoftmaxInfo {
2727
SoftmaxInfo info;
2828
info.axis = axis;
2929
info.size = 1;
30-
info.otherdim_size = 1;
30+
info.other_size = 1;
3131
info.stride = 1;
32-
info.dimsize = static_cast<int>(x_desc->dim(axis));
32+
info.dim_size = static_cast<int>(x_desc->dim(axis));
3333
int ndim = static_cast<int>(y_desc->ndim());
3434
for (int i = ndim - 1; i >= 0; i--) {
3535
info.size *= static_cast<int>(y_desc->dim(i));
@@ -38,7 +38,7 @@ class SoftmaxInfo {
3838
for (int i = axis + 1; i < ndim; i++) {
3939
info.stride *= static_cast<int>(x_desc->dim(i));
4040
}
41-
info.otherdim_size = info.size / info.dimsize;
41+
info.other_size = info.size / info.dim_size;
4242
return utils::Result<SoftmaxInfo>(info);
4343
}
4444
};

src/infiniop/ops/softmax/softmax.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,5 @@
4646
void *stream) const; \
4747
}; \
4848
}
49+
4950
#endif // __SOFTMAX_H__

0 commit comments

Comments
 (0)