Skip to content

Commit f53154d

Browse files
authored
issue/383: Add logsoftmax ops
1 parent 37411f6 commit f53154d

File tree

13 files changed

+1065
-0
lines changed

13 files changed

+1065
-0
lines changed

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "infiniop/ops/attention.h"
77
#include "infiniop/ops/causal_softmax.h"
88
#include "infiniop/ops/clip.h"
9+
#include "infiniop/ops/logsoftmax.h"
910
#include "infiniop/ops/conv.h"
1011
#include "infiniop/ops/dequantize_awq.h"
1112
#include "infiniop/ops/gemm.h"

include/infiniop/ops/logsoftmax.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef __INFINIOP_LOGSOFTMAX_API_H__
2+
#define __INFINIOP_LOGSOFTMAX_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopLogSoftmaxDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateLogSoftmaxDescriptor(infiniopHandle_t handle,
9+
infiniopLogSoftmaxDescriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t y_desc,
11+
infiniopTensorDescriptor_t x_desc);
12+
13+
__C __export infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescriptor_t desc, size_t *size);
14+
15+
__C __export infiniStatus_t infiniopLogSoftmax(infiniopLogSoftmaxDescriptor_t desc,
16+
void *workspace,
17+
size_t workspace_size,
18+
void *y,
19+
const void *x,
20+
void *stream);
21+
22+
__C __export infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescriptor_t desc);
23+
24+
#endif

scripts/python_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def run_tests(args):
1717
"causal_softmax.py",
1818
"clip.py",
1919
"gemm.py",
20+
"logsoftmax.py",
2021
"mul.py",
2122
"random_sample.py",
2223
"rearrange.py",
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#include "logsoftmax_cpu.h"
2+
#include "../../../devices/cpu/common_cpu.h"
3+
#include "../../../reduce/cpu/reduce.h"
4+
#include <algorithm>
5+
#include <cmath>
6+
7+
namespace op::logsoftmax::cpu {
8+
9+
Descriptor::~Descriptor() {}
10+
11+
infiniStatus_t Descriptor::create(
12+
infiniopHandle_t handle,
13+
Descriptor **desc_ptr,
14+
infiniopTensorDescriptor_t y_desc,
15+
infiniopTensorDescriptor_t x_desc) {
16+
auto result = LogSoftmaxInfo::create(y_desc, x_desc);
17+
CHECK_RESULT(result);
18+
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
19+
return INFINI_STATUS_SUCCESS;
20+
}
21+
22+
template <typename Tx, typename Ty>
23+
infiniStatus_t logsoftmax(const LogSoftmaxInfo *info, Ty *y, const Tx *x) {
24+
#pragma omp parallel for
25+
for (ptrdiff_t batch = 0; batch < ptrdiff_t(info->batch_size); batch++) {
26+
ptrdiff_t y_offset, x_offset;
27+
28+
if (info->ndim == 3) {
29+
// For 3D tensors, convert linear batch index back to 2D indices
30+
ptrdiff_t batch_idx = batch / info->seq_len;
31+
ptrdiff_t seq_idx = batch % info->seq_len;
32+
y_offset = batch_idx * info->y_stride_0 + seq_idx * info->y_stride_1;
33+
x_offset = batch_idx * info->x_stride_0 + seq_idx * info->x_stride_1;
34+
} else {
35+
// For 2D tensors, use the flattened strides
36+
y_offset = batch * info->y_stride_b;
37+
x_offset = batch * info->x_stride_b;
38+
}
39+
40+
Ty *y_ = y + y_offset;
41+
const Tx *x_ = x + x_offset;
42+
43+
// Find max value for numerical stability
44+
float max_val;
45+
if constexpr (std::is_same<Tx, fp16_t>::value || std::is_same<Tx, bf16_t>::value) {
46+
max_val = op::common_cpu::reduce_op::max(x_, info->probs_size, info->x_stride_p);
47+
} else {
48+
max_val = op::common_cpu::reduce_op::max(x_, info->probs_size, info->x_stride_p);
49+
}
50+
51+
// Compute exp(x - max) and sum
52+
float sum = 0.0f;
53+
for (size_t i = 0; i < info->probs_size; i++) {
54+
float x_val;
55+
if constexpr (std::is_same<Tx, fp16_t>::value || std::is_same<Tx, bf16_t>::value) {
56+
x_val = utils::cast<float>(x_[i * info->x_stride_p]);
57+
} else {
58+
x_val = x_[i * info->x_stride_p];
59+
}
60+
sum += std::exp(x_val - max_val);
61+
}
62+
63+
// Compute log(sum)
64+
float log_sum = std::log(sum);
65+
66+
// Compute log_softmax = x - max - log(sum)
67+
for (size_t i = 0; i < info->probs_size; i++) {
68+
float x_val;
69+
if constexpr (std::is_same<Tx, fp16_t>::value || std::is_same<Tx, bf16_t>::value) {
70+
x_val = utils::cast<float>(x_[i * info->x_stride_p]);
71+
} else {
72+
x_val = x_[i * info->x_stride_p];
73+
}
74+
75+
float result = x_val - max_val - log_sum;
76+
77+
if constexpr (std::is_same<Ty, fp16_t>::value || std::is_same<Ty, bf16_t>::value) {
78+
y_[i * info->y_stride_p] = utils::cast<Ty>(result);
79+
} else {
80+
y_[i * info->y_stride_p] = result;
81+
}
82+
}
83+
}
84+
85+
return INFINI_STATUS_SUCCESS;
86+
}
87+
88+
infiniStatus_t Descriptor::calculate(
89+
void *workspace, size_t workspace_size,
90+
void *y,
91+
const void *x,
92+
void *stream) const {
93+
94+
// Handle different input/output dtype combinations
95+
if (_info.x_dtype == INFINI_DTYPE_F16) {
96+
if (_info.y_dtype == INFINI_DTYPE_F16) {
97+
return logsoftmax<fp16_t, fp16_t>(&_info, (fp16_t *)y, (const fp16_t *)x);
98+
} else if (_info.y_dtype == INFINI_DTYPE_BF16) {
99+
return logsoftmax<fp16_t, bf16_t>(&_info, (bf16_t *)y, (const fp16_t *)x);
100+
} else if (_info.y_dtype == INFINI_DTYPE_F32) {
101+
return logsoftmax<fp16_t, float>(&_info, (float *)y, (const fp16_t *)x);
102+
} else {
103+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
104+
}
105+
} else if (_info.x_dtype == INFINI_DTYPE_BF16) {
106+
if (_info.y_dtype == INFINI_DTYPE_F16) {
107+
return logsoftmax<bf16_t, fp16_t>(&_info, (fp16_t *)y, (const bf16_t *)x);
108+
} else if (_info.y_dtype == INFINI_DTYPE_BF16) {
109+
return logsoftmax<bf16_t, bf16_t>(&_info, (bf16_t *)y, (const bf16_t *)x);
110+
} else if (_info.y_dtype == INFINI_DTYPE_F32) {
111+
return logsoftmax<bf16_t, float>(&_info, (float *)y, (const bf16_t *)x);
112+
} else {
113+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
114+
}
115+
} else if (_info.x_dtype == INFINI_DTYPE_F32) {
116+
if (_info.y_dtype == INFINI_DTYPE_F16) {
117+
return logsoftmax<float, fp16_t>(&_info, (fp16_t *)y, (const float *)x);
118+
} else if (_info.y_dtype == INFINI_DTYPE_BF16) {
119+
return logsoftmax<float, bf16_t>(&_info, (bf16_t *)y, (const float *)x);
120+
} else if (_info.y_dtype == INFINI_DTYPE_F32) {
121+
return logsoftmax<float, float>(&_info, (float *)y, (const float *)x);
122+
} else {
123+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
124+
}
125+
} else {
126+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
127+
}
128+
}
129+
130+
} // namespace op::logsoftmax::cpu
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#ifndef __LOGSOFTMAX_CPU_H__
2+
#define __LOGSOFTMAX_CPU_H__
3+
#include "../logsoftmax.h"
4+
5+
DESCRIPTOR(cpu)
6+
7+
#endif
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#ifndef __LOGSOFTMAX_KERNEL_CUH__
2+
#define __LOGSOFTMAX_KERNEL_CUH__
3+
4+
#include <cub/block/block_reduce.cuh>
5+
#include <type_traits>
6+
7+
template <unsigned int BLOCK_SIZE, typename Tdata_out, typename Tdata_in, typename Tcompute>
8+
__device__ void logSoftmaxKernel(
9+
Tdata_out *y, const Tdata_in *x,
10+
size_t batch_size, size_t probs_size, size_t ndim, size_t seq_len,
11+
ptrdiff_t y_stride_b, ptrdiff_t y_stride_p,
12+
ptrdiff_t x_stride_b, ptrdiff_t x_stride_p,
13+
ptrdiff_t y_stride_0, ptrdiff_t y_stride_1,
14+
ptrdiff_t x_stride_0, ptrdiff_t x_stride_1) {
15+
16+
typedef cub::BlockReduce<Tcompute, BLOCK_SIZE> BlockReduce;
17+
__shared__ typename BlockReduce::TempStorage temp_storage;
18+
__shared__ Tcompute shared_max_val;
19+
__shared__ Tcompute shared_sum_exp;
20+
21+
int batch_idx = blockIdx.x;
22+
int tid = threadIdx.x;
23+
24+
if (batch_idx >= batch_size) {
25+
return;
26+
}
27+
28+
// Calculate correct memory offsets for 3D tensors
29+
ptrdiff_t y_offset, x_offset;
30+
if (ndim == 3) {
31+
// For 3D tensors, convert linear batch index back to 2D indices
32+
ptrdiff_t batch_dim_idx = batch_idx / seq_len;
33+
ptrdiff_t seq_dim_idx = batch_idx % seq_len;
34+
y_offset = batch_dim_idx * y_stride_0 + seq_dim_idx * y_stride_1;
35+
x_offset = batch_dim_idx * x_stride_0 + seq_dim_idx * x_stride_1;
36+
} else {
37+
// For 2D tensors, use the flattened strides
38+
y_offset = batch_idx * y_stride_b;
39+
x_offset = batch_idx * x_stride_b;
40+
}
41+
42+
const Tdata_in *x_batch = x + x_offset;
43+
Tdata_out *y_batch = y + y_offset;
44+
45+
// Find maximum value for numerical stability
46+
Tcompute max_val = static_cast<Tcompute>(-INFINITY);
47+
for (int i = tid; i < probs_size; i += BLOCK_SIZE) {
48+
if (i < probs_size) { // Add boundary check
49+
Tcompute val = static_cast<Tcompute>(x_batch[i * x_stride_p]);
50+
if constexpr (std::is_same_v<Tcompute, float>) {
51+
max_val = fmaxf(max_val, val);
52+
} else {
53+
max_val = fmax(max_val, val);
54+
}
55+
}
56+
}
57+
max_val = BlockReduce(temp_storage).Reduce(max_val, cub::Max());
58+
if (tid == 0) {
59+
shared_max_val = max_val;
60+
}
61+
__syncthreads();
62+
63+
// Compute sum of exp(x - max)
64+
Tcompute sum_exp = static_cast<Tcompute>(0.0);
65+
for (int i = tid; i < probs_size; i += BLOCK_SIZE) {
66+
if (i < probs_size) { // Add boundary check
67+
Tcompute val = static_cast<Tcompute>(x_batch[i * x_stride_p]);
68+
if constexpr (std::is_same_v<Tcompute, float>) {
69+
sum_exp += expf(val - shared_max_val);
70+
} else {
71+
sum_exp += exp(val - shared_max_val);
72+
}
73+
}
74+
}
75+
sum_exp = BlockReduce(temp_storage).Sum(sum_exp);
76+
if (tid == 0) {
77+
shared_sum_exp = sum_exp;
78+
}
79+
__syncthreads();
80+
81+
// Compute log_softmax = x - max - log(sum_exp)
82+
Tcompute log_sum_exp;
83+
if constexpr (std::is_same_v<Tcompute, float>) {
84+
log_sum_exp = logf(shared_sum_exp);
85+
} else {
86+
log_sum_exp = log(shared_sum_exp);
87+
}
88+
for (int i = tid; i < probs_size; i += BLOCK_SIZE) {
89+
if (i < probs_size) { // Add boundary check
90+
Tcompute val = static_cast<Tcompute>(x_batch[i * x_stride_p]);
91+
Tcompute result = val - shared_max_val - log_sum_exp;
92+
y_batch[i * y_stride_p] = static_cast<Tdata_out>(result);
93+
}
94+
}
95+
}
96+
97+
template <unsigned int BLOCK_SIZE, typename Tdata_out, typename Tdata_in, typename Tcompute>
98+
__global__ void logSoftmax(
99+
Tdata_out *y, const Tdata_in *x,
100+
size_t batch_size, size_t probs_size, size_t ndim, size_t seq_len,
101+
ptrdiff_t y_stride_b, ptrdiff_t y_stride_p,
102+
ptrdiff_t x_stride_b, ptrdiff_t x_stride_p,
103+
ptrdiff_t y_stride_0, ptrdiff_t y_stride_1,
104+
ptrdiff_t x_stride_0, ptrdiff_t x_stride_1) {
105+
logSoftmaxKernel<BLOCK_SIZE, Tdata_out, Tdata_in, Tcompute>(y, x, batch_size, probs_size, ndim, seq_len, y_stride_b, y_stride_p, x_stride_b, x_stride_p, y_stride_0, y_stride_1, x_stride_0, x_stride_1);
106+
}
107+
108+
#endif // __LOGSOFTMAX_KERNEL_CUH__

0 commit comments

Comments
 (0)