Skip to content

Commit 2e0834b

Browse files
author
FengHao
committed
GDS-fixed
1 parent baa2e53 commit 2e0834b

File tree

13 files changed

+180
-141
lines changed

13 files changed

+180
-141
lines changed

ucm/store/device/cuda/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ set_target_properties(Cuda::cudart PROPERTIES
1414
IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcudart.so"
1515
IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcufile.so"
1616
)
17+
target_link_libraries(storedevice PUBLIC Cuda::cudart)

ucm/store/device/cuda/cuda_device.cu

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
#include <unistd.h>
3131
#include <cerrno>
3232
#include <cstring>
33-
#include "sharded_handle_recorder.h"
33+
#include <unordered_map>
34+
#include <cstdlib>
35+
#include "infra/template/sharded_handle_recorder.h"
3436

3537
#define CUDA_TRANS_UNIT_SIZE (sizeof(uint64_t) * 2)
3638
#define CUDA_TRANS_BLOCK_NUMBER (32)
@@ -97,11 +99,10 @@ struct fmt::formatter<cudaError_t> : formatter<int32_t> {
9799

98100
namespace UC {
99101

100-
static Status CreateCuFileHandle(const std::string& path, int flags, CUfileHandle_t& cuFileHandle, int& fd)
102+
static Status CreateCuFileHandle(int fd, CUfileHandle_t& cuFileHandle)
101103
{
102-
fd = open(path.c_str(), flags, 0644);
103104
if (fd < 0) {
104-
UC_ERROR("Failed to open file {}: {}", path, strerror(errno));
105+
UC_ERROR("Invalid file descriptor: {}", fd);
105106
return Status::Error();
106107
}
107108

@@ -110,10 +111,8 @@ static Status CreateCuFileHandle(const std::string& path, int flags, CUfileHandl
110111
cfDescr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
111112
CUfileError_t err = cuFileHandleRegister(&cuFileHandle, &cfDescr);
112113
if (err.err != CU_FILE_SUCCESS) {
113-
UC_ERROR("Failed to register cuFile handle for {}: error {}",
114-
path, static_cast<int>(err.err));
115-
close(fd);
116-
fd = -1;
114+
UC_ERROR("Failed to register cuFile handle for fd {}: error {}",
115+
fd, static_cast<int>(err.err));
117116
return Status::Error();
118117
}
119118

@@ -163,31 +162,27 @@ class CudaDevice : public IBufferedDevice {
163162
}
164163
static void ReleaseDeviceArray(void* deviceArray) { CUDA_API(cudaFree, deviceArray); }
165164
static std::once_flag gdsOnce_;
166-
static void InitGdsOnce();
167165

168166
public:
167+
static void InitGdsOnce();
169168
CudaDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
170169
: IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr}
171170
{
172171
}
173172
~CudaDevice() {
174-
CuFileHandleRecorder::Instance().ClearAll([](CUfileHandle_t h, int fd) {
173+
HandlePool<int, CUfileHandle_t>::Instance().ClearAll([](CUfileHandle_t h) {
175174
cuFileHandleDeregister(h);
176-
if (fd >= 0) {
177-
close(fd);
178-
}
179175
});
180176

181177
if (stream_ != nullptr) {
182178
cudaStreamDestroy((cudaStream_t)stream_);
183179
}
184180
}
185-
Status Setup(bool transferUseDirect) override
181+
Status Setup() override
186182
{
187-
if(transferUseDirect) {InitGdsOnce();}
188183
auto status = Status::OK();
189184
if ((status = CUDA_API(cudaSetDevice, this->deviceId)).Failure()) { return status; }
190-
if ((status = IBufferedDevice::Setup(transferUseDirect)).Failure()) { return status; }
185+
if ((status = IBufferedDevice::Setup()).Failure()) { return status; }
191186
if ((status = CUDA_API(cudaStreamCreate, (cudaStream_t*)&this->stream_)).Failure()) {
192187
return status;
193188
}
@@ -209,36 +204,48 @@ public:
209204
{
210205
return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost, this->stream_);
211206
}
212-
Status S2DSync(const std::string& path, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
207+
Status S2DSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
213208
{
214209
CUfileHandle_t cuFileHandle = nullptr;
215-
auto status = CuFileHandleRecorder::Instance().Get(path, cuFileHandle,
216-
[&path](CUfileHandle_t& handle, int& fd) -> Status {
217-
return CreateCuFileHandle(path, O_RDONLY | O_DIRECT, handle, fd);
210+
auto status = HandlePool<int, CUfileHandle_t>::Instance().Get(fd, cuFileHandle,
211+
[fd](CUfileHandle_t& handle) -> Status {
212+
return CreateCuFileHandle(fd, handle);
218213
});
219214
if (status.Failure()) {
220215
return status;
221216
}
222217
ssize_t bytesRead = cuFileRead(cuFileHandle, address, length, fileOffset, devOffset);
218+
HandlePool<int, CUfileHandle_t>::Instance().Put(fd, [](CUfileHandle_t h) {
219+
if (h != nullptr) {
220+
cuFileHandleDeregister(h);
221+
}
222+
});
223+
223224
if (bytesRead < 0 || (size_t)bytesRead != length) {
224-
UC_ERROR("cuFileRead failed for {}: expected {}, got {}", path, length, bytesRead);
225+
UC_ERROR("cuFileRead failed for fd {}: expected {}, got {}", fd, length, bytesRead);
225226
return Status::Error();
226227
}
227228
return Status::OK();
228229
}
229-
Status D2SSync(const std::string& path, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
230+
Status D2SSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
230231
{
231232
CUfileHandle_t cuFileHandle = nullptr;
232-
auto status = CuFileHandleRecorder::Instance().Get(path, cuFileHandle,
233-
[&path](CUfileHandle_t& handle, int& fd) -> Status {
234-
return CreateCuFileHandle(path, O_WRONLY | O_CREAT | O_DIRECT, handle, fd);
233+
auto status = HandlePool<int, CUfileHandle_t>::Instance().Get(fd, cuFileHandle,
234+
[fd](CUfileHandle_t& handle) -> Status {
235+
return CreateCuFileHandle(fd, handle);
235236
});
236237
if (status.Failure()) {
237238
return status;
238239
}
239240
ssize_t bytesWrite = cuFileWrite(cuFileHandle, address, length, fileOffset, devOffset);
241+
HandlePool<int, CUfileHandle_t>::Instance().Put(fd, [](CUfileHandle_t h) {
242+
if (h != nullptr) {
243+
cuFileHandleDeregister(h);
244+
}
245+
});
246+
240247
if (bytesWrite < 0 || (size_t)bytesWrite != length) {
241-
UC_ERROR("cuFileWrite failed for {}: expected {}, got {}", path, length, bytesWrite);
248+
UC_ERROR("cuFileWrite failed for fd {}: expected {}, got {}", fd, length, bytesWrite);
242249
return Status::Error();
243250
}
244251
return Status::OK();
@@ -305,9 +312,12 @@ private:
305312
};
306313

307314
std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_t bufferSize,
308-
const size_t bufferNumber)
315+
const size_t bufferNumber, bool transferUseDirect)
309316
{
310317
try {
318+
if (transferUseDirect) {
319+
CudaDevice::InitGdsOnce();
320+
}
311321
return std::make_unique<CudaDevice>(deviceId, bufferSize, bufferNumber);
312322
} catch (const std::exception& e) {
313323
UC_ERROR("Failed({}) to make cuda device({},{},{}).", e.what(), deviceId, bufferSize,

ucm/store/device/cuda/sharded_handle_recorder.h

Lines changed: 0 additions & 84 deletions
This file was deleted.

ucm/store/device/ibuffered_device.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@ class IBufferedDevice : public IDevice {
3535
: IDevice{deviceId, bufferSize, bufferNumber}
3636
{
3737
}
38-
Status Setup(bool transferUseDirect) override
38+
Status Setup() override
3939
{
40-
if(transferUseDirect) {return Status::OK();}
4140
auto totalSize = this->bufferSize * this->bufferNumber;
4241
if (totalSize == 0) { return Status::OK(); }
4342
this->_addr = this->MakeBuffer(totalSize);

ucm/store/device/idevice.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class IDevice {
3737
{
3838
}
3939
virtual ~IDevice() = default;
40-
virtual Status Setup(bool transferUseDirect) = 0;
40+
virtual Status Setup() = 0;
4141
virtual std::shared_ptr<std::byte> GetBuffer(const size_t size) = 0;
4242
virtual Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) = 0;
4343
virtual Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) = 0;
@@ -49,8 +49,8 @@ class IDevice {
4949
const size_t count) = 0;
5050
virtual Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number,
5151
const size_t count) = 0;
52-
virtual Status S2DSync(const std::string& path, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;
53-
virtual Status D2SSync(const std::string& path, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;
52+
virtual Status S2DSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;
53+
virtual Status D2SSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;
5454

5555
protected:
5656
virtual std::shared_ptr<std::byte> MakeBuffer(const size_t size) = 0;
@@ -62,7 +62,7 @@ class IDevice {
6262
class DeviceFactory {
6363
public:
6464
static std::unique_ptr<IDevice> Make(const int32_t deviceId, const size_t bufferSize,
65-
const size_t bufferNumber);
65+
const size_t bufferNumber, bool transferUseDirect = false);
6666
};
6767

6868
} // namespace UC
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#ifndef UC_INFRA_HANDLE_POOL_H
2+
#define UC_INFRA_HANDLE_POOL_H
3+
4+
#include <functional>
5+
#include "status/status.h"
6+
#include "hashmap.h"
7+
8+
namespace UC {
9+
10+
template <typename KeyType, typename HandleType>
11+
class HandlePool {
12+
private:
13+
struct PoolEntry {
14+
HandleType handle;
15+
uint64_t refCount;
16+
};
17+
using PoolMap = HashMap<KeyType, PoolEntry, std::hash<KeyType>, 10>;
18+
PoolMap pool_;
19+
HandlePool() = default;
20+
HandlePool(const HandlePool&) = delete;
21+
HandlePool& operator=(const HandlePool&) = delete;
22+
23+
public:
24+
static HandlePool& Instance()
25+
{
26+
static HandlePool instance;
27+
return instance;
28+
}
29+
30+
Status Get(const KeyType& key, HandleType& handle,
31+
std::function<Status(HandleType&)> instantiate)
32+
{
33+
auto result = pool_.GetOrCreate(key, [&instantiate](PoolEntry& entry) -> bool {
34+
HandleType h{};
35+
36+
auto status = instantiate(h);
37+
if (status.Failure()) {
38+
return false;
39+
}
40+
41+
entry.handle = h;
42+
entry.refCount = 1;
43+
return true;
44+
});
45+
46+
if (!result.has_value()) {
47+
return Status::Error();
48+
}
49+
50+
auto& entry = result.value().get();
51+
entry.refCount++;
52+
handle = entry.handle;
53+
return Status::OK();
54+
}
55+
56+
void Put(const KeyType& key,
57+
std::function<void(HandleType)> cleanup)
58+
{
59+
pool_.Upsert(key, [&cleanup](PoolEntry& entry) -> bool {
60+
entry.refCount--;
61+
if (entry.refCount > 0) {
62+
return false;
63+
}
64+
cleanup(entry.handle);
65+
return true;
66+
});
67+
}
68+
69+
void ClearAll(std::function<void(HandleType)> cleanup)
70+
{
71+
pool_.ForEach([&cleanup](const KeyType& key, PoolEntry& entry) {
72+
(void)key;
73+
cleanup(entry.handle);
74+
});
75+
pool_.Clear();
76+
}
77+
};
78+
79+
} // namespace UC
80+
81+
#endif
82+

0 commit comments

Comments
 (0)