Skip to content

Commit c828536

Browse files
committed
issue/207 operator: sigmoid op on cpu and cuda
1 parent 65bed07 commit c828536

File tree

13 files changed

+782
-0
lines changed

13 files changed

+782
-0
lines changed

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "infiniop/ops/relu.h"
1515
#include "infiniop/ops/rms_norm.h"
1616
#include "infiniop/ops/rope.h"
17+
#include "infiniop/ops/sigmoid.h"
1718
#include "infiniop/ops/sub.h"
1819
#include "infiniop/ops/swiglu.h"
1920
#include "infiniop/tensor_descriptor.h"

include/infiniop/ops/sigmoid.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef __INFINIOP_SIGMOID_API_H__
2+
#define __INFINIOP_SIGMOID_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopSigmoidDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateSigmoidDescriptor(infiniopHandle_t handle,
9+
infiniopSigmoidDescriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t y,
11+
infiniopTensorDescriptor_t x);
12+
13+
__C __export infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t desc, size_t *size);
14+
15+
__C __export infiniStatus_t infiniopSigmoid(infiniopSigmoidDescriptor_t desc,
16+
void *workspace,
17+
size_t workspace_size,
18+
void *y,
19+
const void *x,
20+
void *stream);
21+
22+
__C __export infiniStatus_t infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc);
23+
24+
#endif

scripts/python_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def run_tests(args):
2222
"rearrange.py",
2323
"rms_norm.py",
2424
"rope.py",
25+
"sigmoid.py",
2526
"sub.py",
2627
"swiglu.py",
2728
]:

src/infiniop-test/include/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ DECLARE_INFINIOP_TEST(add)
1616
DECLARE_INFINIOP_TEST(causal_softmax)
1717
DECLARE_INFINIOP_TEST(rearrange)
1818
DECLARE_INFINIOP_TEST(sub)
19+
DECLARE_INFINIOP_TEST(sigmoid)
1920

2021
#define REGISTER_INFINIOP_TEST(name) \
2122
{ \
@@ -43,6 +44,7 @@ DECLARE_INFINIOP_TEST(sub)
4344
REGISTER_INFINIOP_TEST(causal_softmax) \
4445
REGISTER_INFINIOP_TEST(rearrange) \
4546
REGISTER_INFINIOP_TEST(sub) \
47+
REGISTER_INFINIOP_TEST(sigmoid) \
4648
}
4749

4850
namespace infiniop_test {
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include "ops.hpp"
2+
#include "utils.hpp"
3+
#include <infinirt.h>
4+
#include <iomanip>
5+
#include <iostream>
6+
7+
namespace infiniop_test::sigmoid {
8+
struct Test::Attributes {
9+
std::shared_ptr<Tensor> x;
10+
std::shared_ptr<Tensor> y;
11+
std::shared_ptr<Tensor> ans;
12+
};
13+
14+
std::shared_ptr<Test> Test::build(
15+
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
16+
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
17+
double rtol, double atol) {
18+
auto test = std::shared_ptr<Test>(new Test(rtol, atol));
19+
test->_attributes = new Attributes();
20+
if (tensors.find("x") == tensors.end()
21+
|| tensors.find("y") == tensors.end()
22+
|| tensors.find("ans") == tensors.end()) {
23+
throw std::runtime_error("Invalid Test");
24+
}
25+
26+
test->_attributes->x = tensors["x"];
27+
test->_attributes->y = tensors["y"];
28+
test->_attributes->ans = tensors["ans"];
29+
30+
return test;
31+
}
32+
33+
std::shared_ptr<infiniop_test::Result> Test::run(
34+
infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) {
35+
infiniopSigmoidDescriptor_t op_desc;
36+
auto x = _attributes->x->to(device, device_id);
37+
auto y = _attributes->y->to(device, device_id);
38+
CHECK_OR(infiniopCreateSigmoidDescriptor(handle, &op_desc,
39+
y->desc(),
40+
x->desc()),
41+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor."));
42+
size_t workspace_size;
43+
CHECK_OR(infiniopGetSigmoidWorkspaceSize(op_desc, &workspace_size),
44+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size."));
45+
void *workspace;
46+
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
47+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace."));
48+
CHECK_OR(infiniopSigmoid(op_desc, workspace, workspace_size,
49+
y->data(),
50+
x->data(),
51+
nullptr),
52+
return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution."));
53+
54+
try {
55+
allClose(y, _attributes->ans, _rtol, _atol);
56+
} catch (const std::exception &e) {
57+
return TEST_FAILED(RESULT_INCORRECT, e.what());
58+
}
59+
60+
double elapsed_time = 0.;
61+
62+
elapsed_time = benchmark(
63+
[=]() {
64+
infiniopSigmoid(
65+
op_desc, workspace, workspace_size,
66+
y->data(),
67+
x->data(),
68+
nullptr);
69+
},
70+
warm_ups, iterations);
71+
72+
infiniopDestroySigmoidDescriptor(op_desc);
73+
infinirtFree(workspace);
74+
return TEST_PASSED(elapsed_time);
75+
}
76+
77+
std::vector<std::string> Test::attribute_names() {
78+
return {};
79+
}
80+
81+
std::vector<std::string> Test::tensor_names() {
82+
return {"x", "y", "ans"};
83+
}
84+
85+
std::vector<std::string> Test::output_names() {
86+
return {"y"};
87+
}
88+
89+
std::string Test::toString() const {
90+
std::ostringstream oss;
91+
oss << op_name() << std::endl;
92+
oss << "- x: " << _attributes->x->info() << std::endl;
93+
oss << "- y: " << _attributes->y->info() << std::endl;
94+
oss << std::scientific << std::setprecision(2);
95+
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
96+
return oss.str();
97+
}
98+
99+
Test::~Test() {
100+
delete _attributes;
101+
}
102+
103+
} // namespace infiniop_test::sigmoid
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "sigmoid_cpu.h"
2+
3+
namespace op::sigmoid::cpu {
4+
5+
Descriptor::~Descriptor() = default;
6+
7+
infiniStatus_t Descriptor::create(
8+
infiniopHandle_t handle_,
9+
Descriptor **desc_ptr,
10+
infiniopTensorDescriptor_t out_desc,
11+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
12+
13+
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
14+
auto dtype = out_desc->dtype();
15+
16+
const auto &x_desc = input_desc_vec.at(0);
17+
const auto &y_shape = out_desc->shape();
18+
const auto &x_shape = x_desc->shape();
19+
20+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
21+
CHECK_SAME_SHAPE(y_shape, x_shape);
22+
23+
// create CPU elementwise descriptor
24+
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
25+
26+
return INFINI_STATUS_SUCCESS;
27+
}
28+
29+
infiniStatus_t Descriptor::calculate(
30+
void *workspace,
31+
size_t workspace_size,
32+
void *output,
33+
std::vector<const void *> inputs,
34+
void *stream) const {
35+
36+
switch (_dtype) {
37+
case INFINI_DTYPE_F16:
38+
return _device_info->calculate<SigmoidOp, fp16_t>(_info, output, inputs, stream);
39+
case INFINI_DTYPE_F32:
40+
return _device_info->calculate<SigmoidOp, float>(_info, output, inputs, stream);
41+
case INFINI_DTYPE_F64:
42+
return _device_info->calculate<SigmoidOp, double>(_info, output, inputs, stream);
43+
default:
44+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
45+
}
46+
47+
return INFINI_STATUS_SUCCESS;
48+
}
49+
} // namespace op::sigmoid::cpu
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef __SIGMOID_CPU_H__
2+
#define __SIGMOID_CPU_H__
3+
4+
#include "../../../elementwise/cpu/elementwise_cpu.h"
5+
6+
ELEMENTWISE_DESCRIPTOR(sigmoid, cpu)
7+
8+
namespace op::sigmoid::cpu {
9+
typedef struct SigmoidOp {
10+
public:
11+
static constexpr size_t num_inputs = 1;
12+
template <typename T>
13+
T operator()(const T &x) const {
14+
return T(1) / (T(1) + std::exp(-x));
15+
}
16+
} SigmoidOp;
17+
} // namespace op::sigmoid::cpu
18+
19+
#endif // __SIGMOID_CPU_H__
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "sigmoid_cuda.cuh"
2+
#include "sigmoid_cuda_internal.cuh"
3+
4+
namespace op::sigmoid::cuda {
5+
6+
Descriptor::~Descriptor() = default;
7+
8+
infiniStatus_t Descriptor::create(
9+
infiniopHandle_t handle_,
10+
Descriptor **desc_ptr,
11+
infiniopTensorDescriptor_t out_desc,
12+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
13+
14+
auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
15+
auto dtype = out_desc->dtype();
16+
17+
const auto &x_desc = input_desc_vec.at(0);
18+
const auto &y_shape = out_desc->shape();
19+
const auto &x_shape = x_desc->shape();
20+
21+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
22+
23+
CHECK_SAME_SHAPE(y_shape, x_shape);
24+
25+
// create CUDA elementwise descriptor
26+
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
27+
28+
return INFINI_STATUS_SUCCESS;
29+
}
30+
31+
infiniStatus_t Descriptor::calculate(
32+
void *workspace,
33+
size_t workspace_size,
34+
void *output,
35+
std::vector<const void *> inputs,
36+
void *stream) const {
37+
38+
if (workspace_size < _workspace_size) {
39+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
40+
}
41+
42+
switch (_dtype) {
43+
case INFINI_DTYPE_F16:
44+
return _device_info->calculate<256, SigmoidOp, half>(_info, workspace, output, inputs, stream);
45+
case INFINI_DTYPE_F32:
46+
return _device_info->calculate<256, SigmoidOp, float>(_info, workspace, output, inputs, stream);
47+
case INFINI_DTYPE_F64:
48+
return _device_info->calculate<256, SigmoidOp, double>(_info, workspace, output, inputs, stream);
49+
default:
50+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
51+
}
52+
53+
return INFINI_STATUS_SUCCESS;
54+
}
55+
} // namespace op::sigmoid::cuda
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __SIGMOID_CUDA_API_H__
2+
#define __SIGMOID_CUDA_API_H__
3+
4+
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
5+
6+
ELEMENTWISE_DESCRIPTOR(sigmoid, cuda)
7+
8+
#endif // __SIGMOID_CUDA_API_H__
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#ifndef __SIDMOID_CUDA_H__
2+
#define __SIDMOID_CUDA_H__
3+
4+
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
5+
#include <cuda_fp16.h>
6+
7+
namespace op::sigmoid::cuda {
8+
typedef struct SigmoidOp {
9+
public:
10+
static constexpr size_t num_inputs = 1;
11+
template <typename T>
12+
__device__ __forceinline__ T operator()(const T &x) const {
13+
// sigmoid(x) = 1 / (1 + exp(-x))
14+
if constexpr (std::is_same_v<T, half2>) {
15+
half2 denominator = __hadd2(make_half2(1, 1), h2exp(__hneg2(x)));
16+
return h2rcp(denominator);
17+
} else if constexpr (std::is_same_v<T, half>) {
18+
half denominator = __hadd(__float2half(1.0f), hexp(__hneg(x)));
19+
return hrcp(denominator);
20+
} else if constexpr (std::is_same_v<T, float>) {
21+
float denominator = __fadd_rn(1.0f, __expf(-x));
22+
return __frcp_rn(denominator);
23+
} else { // double
24+
return 1.0 / (1.0 + exp(-x));
25+
}
26+
}
27+
} SigmoidOp;
28+
} // namespace op::sigmoid::cuda
29+
30+
#endif // __SIDMOID_CUDA_H__

0 commit comments

Comments
 (0)