Skip to content

Commit 7c8c9a3

Browse files
authored
[Feat] add batch interface for device ops and implement ScatterGather with CUDA (#305)
add batch interface for device ops and implement ScatterGather with CUDA
1 parent b53b23a commit 7c8c9a3

File tree

6 files changed

+164
-28
lines changed

6 files changed

+164
-28
lines changed

ucm/store/device/ascend/ascend_device.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,24 @@ class AscendDevice : public IBufferedDevice {
120120
this->stream_);
121121
}
122122
Status Synchronized() override { return ASCEND_API(aclrtSynchronizeStream, this->stream_); }
123+
Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number,
124+
const size_t count) override
125+
{
126+
for (size_t i = 0; i < number; i++) {
127+
auto status = this->H2DSync(dArr[i], hArr[i], count);
128+
if (status.Failure()) { return status; }
129+
}
130+
return Status::OK();
131+
}
132+
Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number,
133+
const size_t count) override
134+
{
135+
for (size_t i = 0; i < number; i++) {
136+
auto status = this->D2HSync(hArr[i], dArr[i], count);
137+
if (status.Failure()) { return status; }
138+
}
139+
return Status::OK();
140+
}
123141

124142
protected:
125143
std::shared_ptr<std::byte> MakeBuffer(const size_t size) override

ucm/store/device/cuda/CMakeLists.txt

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
set(CUDA_ROOT "/usr/local/cuda/" CACHE PATH "Path to CUDA root directory")
2-
add_library(Cuda::cudart UNKNOWN IMPORTED)
3-
set_target_properties(Cuda::cudart PROPERTIES
4-
INTERFACE_INCLUDE_DIRECTORIES "${CUDA_ROOT}/include"
5-
IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcudart.so"
2+
set(CMAKE_CUDA_COMPILER ${CUDA_ROOT}/bin/nvcc)
3+
set(CMAKE_CUDA_ARCHITECTURES 75 80 86 89 90)
4+
enable_language(CUDA)
5+
add_library(storedevice STATIC cuda_device.cu)
6+
target_link_libraries(storedevice PUBLIC storeinfra)
7+
target_compile_options(storedevice PRIVATE
8+
--diag-suppress=128 --diag-suppress=2417 --diag-suppress=2597
9+
-Wall -fPIC
610
)
7-
8-
add_library(storedevice STATIC cuda_device.cc)
9-
target_link_libraries(storedevice PUBLIC storeinfra Cuda::cudart)

ucm/store/device/cuda/cuda_device.cc renamed to ucm/store/device/cuda/cuda_device.cu

Lines changed: 107 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,61 @@
2525
#include "ibuffered_device.h"
2626
#include "logger/logger.h"
2727

28+
#define CUDA_TRANS_UNIT_SIZE (sizeof(uint64_t) * 2)
29+
#define CUDA_TRANS_BLOCK_NUMBER (32)
30+
#define CUDA_TRANS_BLOCK_SIZE (256)
31+
#define CUDA_TRANS_THREAD_NUMBER (CUDA_TRANS_BLOCK_NUMBER * CUDA_TRANS_BLOCK_SIZE)
32+
33+
inline __device__ void H2DUnit(uint8_t* __restrict__ dst, const volatile uint8_t* __restrict__ src)
34+
{
35+
uint64_t a, b;
36+
asm volatile("ld.global.cs.v2.u64 {%0, %1}, [%2];" : "=l"(a), "=l"(b) : "l"(src));
37+
asm volatile("st.global.cg.v2.u64 [%0], {%1, %2};" ::"l"(dst), "l"(a), "l"(b));
38+
}
39+
40+
inline __device__ void D2HUnit(volatile uint8_t* __restrict__ dst, const uint8_t* __restrict__ src)
41+
{
42+
uint64_t a, b;
43+
asm volatile("ld.global.cs.v2.u64 {%0, %1}, [%2];" : "=l"(a), "=l"(b) : "l"(src));
44+
asm volatile("st.volatile.global.v2.u64 [%0], {%1, %2};" ::"l"(dst), "l"(a), "l"(b));
45+
}
46+
47+
__global__ void H2DKernel(uintptr_t* dst, const volatile uintptr_t* src, size_t num, size_t size)
48+
{
49+
auto length = num * size;
50+
auto offset = (blockIdx.x * blockDim.x + threadIdx.x) * CUDA_TRANS_UNIT_SIZE;
51+
while (offset + CUDA_TRANS_UNIT_SIZE <= length) {
52+
auto idx = offset / size;
53+
auto off = offset % size;
54+
H2DUnit(((uint8_t*)dst[idx]) + off, ((const uint8_t*)src[idx]) + off);
55+
offset += CUDA_TRANS_THREAD_NUMBER * CUDA_TRANS_UNIT_SIZE;
56+
}
57+
}
58+
59+
__global__ void D2HKernel(volatile uintptr_t* dst, const uintptr_t* src, size_t num, size_t size)
60+
{
61+
auto length = num * size;
62+
auto offset = (blockIdx.x * blockDim.x + threadIdx.x) * CUDA_TRANS_UNIT_SIZE;
63+
while (offset + CUDA_TRANS_UNIT_SIZE <= length) {
64+
auto idx = offset / size;
65+
auto off = offset % size;
66+
D2HUnit(((uint8_t*)dst[idx]) + off, ((const uint8_t*)src[idx]) + off);
67+
offset += CUDA_TRANS_THREAD_NUMBER * CUDA_TRANS_UNIT_SIZE;
68+
}
69+
}
70+
71+
inline __host__ void H2DBatch(uintptr_t* dst, const volatile uintptr_t* src, size_t num,
72+
size_t size, cudaStream_t stream)
73+
{
74+
H2DKernel<<<CUDA_TRANS_BLOCK_NUMBER, CUDA_TRANS_BLOCK_SIZE, 0, stream>>>(dst, src, num, size);
75+
}
76+
77+
inline __host__ void D2HBatch(volatile uintptr_t* dst, const uintptr_t* src, size_t num,
78+
size_t size, cudaStream_t stream)
79+
{
80+
D2HKernel<<<CUDA_TRANS_BLOCK_NUMBER, CUDA_TRANS_BLOCK_SIZE, 0, stream>>>(dst, src, num, size);
81+
}
82+
2883
template <>
2984
struct fmt::formatter<cudaError_t> : formatter<int32_t> {
3085
auto format(cudaError_t err, format_context& ctx) const -> format_context::iterator
@@ -39,7 +94,7 @@ template <typename Api, typename... Args>
3994
Status CudaApi(const char* caller, const char* file, const size_t line, const char* name, Api&& api,
4095
Args&&... args)
4196
{
42-
auto ret = api(args...);
97+
auto ret = std::invoke(api, args...);
4398
if (ret != cudaSuccess) {
4499
UC_ERROR("CUDA ERROR: api={}, code={}, err={}, caller={},{}:{}.", name, ret,
45100
cudaGetErrorString(ret), caller, basename(file), line);
@@ -62,6 +117,22 @@ class CudaDevice : public IBufferedDevice {
62117
c->cb(ret == cudaSuccess);
63118
delete c;
64119
}
120+
static void* MakeDeviceArray(const void* hostArray[], const size_t number)
121+
{
122+
auto size = sizeof(void*) * number;
123+
void* deviceArray = nullptr;
124+
auto ret = cudaMalloc(&deviceArray, size);
125+
if (ret != cudaSuccess) {
126+
UC_ERROR("Failed({},{}) to alloc({}) on device.", ret, cudaGetErrorString(ret), size);
127+
return nullptr;
128+
}
129+
if (CUDA_API(cudaMemcpy, deviceArray, hostArray, size, cudaMemcpyHostToDevice).Success()) {
130+
return deviceArray;
131+
}
132+
ReleaseDeviceArray(deviceArray);
133+
return nullptr;
134+
}
135+
static void ReleaseDeviceArray(void* deviceArray) { CUDA_API(cudaFree, deviceArray); }
65136

66137
public:
67138
CudaDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
@@ -88,13 +159,11 @@ class CudaDevice : public IBufferedDevice {
88159
}
89160
Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) override
90161
{
91-
return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyHostToDevice,
92-
(cudaStream_t)this->stream_);
162+
return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyHostToDevice, this->stream_);
93163
}
94164
Status D2HAsync(std::byte* dst, const std::byte* src, const size_t count) override
95165
{
96-
return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost,
97-
(cudaStream_t)this->stream_);
166+
return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost, this->stream_);
98167
}
99168
Status AppendCallback(std::function<void(bool)> cb) override
100169
{
@@ -103,14 +172,42 @@ class CudaDevice : public IBufferedDevice {
103172
UC_ERROR("Failed to make closure for append cb.");
104173
return Status::OutOfMemory();
105174
}
106-
auto status =
107-
CUDA_API(cudaStreamAddCallback, (cudaStream_t)this->stream_, Trampoline, (void*)c, 0);
175+
auto status = CUDA_API(cudaStreamAddCallback, this->stream_, Trampoline, (void*)c, 0);
108176
if (status.Failure()) { delete c; }
109177
return status;
110178
}
111-
Status Synchronized() override
179+
Status Synchronized() override { return CUDA_API(cudaStreamSynchronize, this->stream_); }
180+
Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number,
181+
const size_t count) override
112182
{
113-
return CUDA_API(cudaStreamSynchronize, (cudaStream_t)this->stream_);
183+
auto src = MakeDeviceArray((const void**)hArr, number);
184+
if (!src) { return Status::OutOfMemory(); }
185+
auto dst = MakeDeviceArray((const void**)dArr, number);
186+
if (!dst) {
187+
ReleaseDeviceArray(src);
188+
return Status::OutOfMemory();
189+
}
190+
H2DBatch((uintptr_t*)dst, (const volatile uintptr_t*)src, number, count, this->stream_);
191+
auto status = this->Synchronized();
192+
ReleaseDeviceArray(src);
193+
ReleaseDeviceArray(dst);
194+
return status;
195+
}
196+
Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number,
197+
const size_t count) override
198+
{
199+
auto src = MakeDeviceArray((const void**)dArr, number);
200+
if (!src) { return Status::OutOfMemory(); }
201+
auto dst = MakeDeviceArray((const void**)hArr, number);
202+
if (!dst) {
203+
ReleaseDeviceArray(src);
204+
return Status::OutOfMemory();
205+
}
206+
D2HBatch((volatile uintptr_t*)dst, (const uintptr_t*)src, number, count, this->stream_);
207+
auto status = this->Synchronized();
208+
ReleaseDeviceArray(src);
209+
ReleaseDeviceArray(dst);
210+
return status;
114211
}
115212

116213
protected:
@@ -126,7 +223,7 @@ class CudaDevice : public IBufferedDevice {
126223
}
127224

128225
private:
129-
void* stream_;
226+
cudaStream_t stream_;
130227
};
131228

132229
std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_t bufferSize,

ucm/store/device/ibuffered_device.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,28 +38,26 @@ class IBufferedDevice : public IDevice {
3838
Status Setup() override
3939
{
4040
auto totalSize = this->bufferSize * this->bufferNumber;
41+
if (totalSize == 0) { return Status::OK(); }
4142
this->_addr = this->MakeBuffer(totalSize);
4243
if (!this->_addr) { return Status::OutOfMemory(); }
4344
this->_indexPool.Setup(this->bufferNumber);
4445
return Status::OK();
4546
}
4647
virtual std::shared_ptr<std::byte> GetBuffer(const size_t size) override
4748
{
48-
auto idx = IndexPool::npos;
49-
if (size <= this->bufferSize && (idx = this->_indexPool.Acquire()) != IndexPool::npos) {
49+
if (!this->_addr || size > this->bufferSize) { return this->MakeBuffer(size); }
50+
auto idx = this->_indexPool.Acquire();
51+
if (idx != IndexPool::npos) {
5052
auto ptr = this->_addr.get() + this->bufferSize * idx;
51-
return std::shared_ptr<std::byte>(
52-
ptr, [this, idx](std::byte*) { this->_indexPool.Release(idx); });
53+
return std::shared_ptr<std::byte>(ptr,
54+
[this, idx](auto) { this->_indexPool.Release(idx); });
5355
}
54-
auto buffer = this->MakeBuffer(size);
55-
if (buffer) { return buffer; }
56-
auto host = (std::byte*)malloc(size);
57-
if (host) { return std::shared_ptr<std::byte>(host, free); }
58-
return nullptr;
56+
return this->MakeBuffer(size);
5957
}
6058

6159
private:
62-
std::shared_ptr<std::byte> _addr;
60+
std::shared_ptr<std::byte> _addr{nullptr};
6361
IndexPool _indexPool;
6462
};
6563

ucm/store/device/idevice.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class IDevice {
4545
virtual Status D2HAsync(std::byte* dst, const std::byte* src, const size_t count) = 0;
4646
virtual Status AppendCallback(std::function<void(bool)> cb) = 0;
4747
virtual Status Synchronized() = 0;
48+
virtual Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number,
49+
const size_t count) = 0;
50+
virtual Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number,
51+
const size_t count) = 0;
4852

4953
protected:
5054
virtual std::shared_ptr<std::byte> MakeBuffer(const size_t size) = 0;

ucm/store/device/simu/simu_device.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,31 @@ class SimuDevice : public IBufferedDevice {
7777
this->backend_.Push([=] { cb(true); });
7878
return Status::OK();
7979
}
80-
virtual Status Synchronized()
80+
Status Synchronized() override
8181
{
8282
Latch waiter{1};
8383
this->backend_.Push([&] { waiter.Done(nullptr); });
8484
waiter.Wait();
8585
return Status::OK();
8686
}
87+
Status H2DBatchSync(std::byte* dArr[], const std::byte* hArr[], const size_t number,
88+
const size_t count) override
89+
{
90+
for (size_t i = 0; i < number; i++) {
91+
auto status = this->H2DSync(dArr[i], hArr[i], count);
92+
if (status.Failure()) { return status; }
93+
}
94+
return Status::OK();
95+
}
96+
Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number,
97+
const size_t count) override
98+
{
99+
for (size_t i = 0; i < number; i++) {
100+
auto status = this->D2HSync(hArr[i], dArr[i], count);
101+
if (status.Failure()) { return status; }
102+
}
103+
return Status::OK();
104+
}
87105

88106
protected:
89107
std::shared_ptr<std::byte> MakeBuffer(const size_t size) override

0 commit comments

Comments
 (0)