2424#include < cuda_runtime.h>
2525#include " ibuffered_device.h"
2626#include " logger/logger.h"
27+ #include < cufile.h>
28+ #include < mutex>
29+ #include < fcntl.h>
30+ #include < unistd.h>
31+ #include < cerrno>
32+ #include < cstring>
33+ #include " sharded_handle_recorder.h"
2734
2835template <>
2936struct fmt ::formatter<cudaError_t> : formatter<int32_t > {
@@ -35,6 +42,28 @@ struct fmt::formatter<cudaError_t> : formatter<int32_t> {
3542
3643namespace UC {
3744
45+ static Status CreateCuFileHandle (const std::string& path, int flags, CUfileHandle_t& cuFileHandle, int & fd)
46+ {
47+ fd = open (path.c_str (), flags, 0644 );
48+ if (fd < 0 ) {
49+ UC_ERROR (" Failed to open file {}: {}" , path, strerror (errno));
50+ return Status::Error ();
51+ }
52+
53+ CUfileDescr_t cfDescr{};
54+ cfDescr.handle .fd = fd;
55+ cfDescr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
56+ CUfileError_t err = cuFileHandleRegister (&cuFileHandle, &cfDescr);
57+ if (err.err != CU_FILE_SUCCESS) {
58+ UC_ERROR (" Failed to register cuFile handle for {}: error {}" ,
59+ path, static_cast <int >(err.err ));
60+ close (fd);
61+ fd = -1 ;
62+ return Status::Error ();
63+ }
64+
65+ return Status::OK ();
66+ }
3867template <typename Api, typename ... Args>
3968Status CudaApi (const char * caller, const char * file, const size_t line, const char * name, Api&& api,
4069 Args&&... args)
@@ -62,17 +91,32 @@ class CudaDevice : public IBufferedDevice {
6291 c->cb (ret == cudaSuccess);
6392 delete c;
6493 }
94+ static std::once_flag gdsOnce_;
95+ static void InitGdsOnce ();
6596
6697public:
6798 CudaDevice (const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
6899 : IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr }
69100 {
70101 }
71- Status Setup () override
102+ ~CudaDevice () {
103+ CuFileHandleRecorder::Instance ().ClearAll ([](CUfileHandle_t h, int fd) {
104+ cuFileHandleDeregister (h);
105+ if (fd >= 0 ) {
106+ close (fd);
107+ }
108+ });
109+
110+ if (stream_ != nullptr ) {
111+ cudaStreamDestroy ((cudaStream_t)stream_);
112+ }
113+ }
114+ Status Setup (bool transferUseDirect) override
72115 {
116+ if (transferUseDirect) {InitGdsOnce ();}
73117 auto status = Status::OK ();
74118 if ((status = CUDA_API (cudaSetDevice, this ->deviceId )).Failure ()) { return status; }
75- if ((status = IBufferedDevice::Setup ()).Failure ()) { return status; }
119+ if ((status = IBufferedDevice::Setup (transferUseDirect )).Failure ()) { return status; }
76120 if ((status = CUDA_API (cudaStreamCreate, (cudaStream_t*)&this ->stream_ )).Failure ()) {
77121 return status;
78122 }
@@ -96,6 +140,40 @@ class CudaDevice : public IBufferedDevice {
96140 return CUDA_API (cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost,
97141 (cudaStream_t)this ->stream_ );
98142 }
143+ Status S2DSync (const std::string& path, void * address, const size_t length, const size_t file_offset, const size_t dev_offset) override
144+ {
145+ CUfileHandle_t cuFileHandle = nullptr ;
146+ auto status = CuFileHandleRecorder::Instance ().Get (path, cuFileHandle,
147+ [&path](CUfileHandle_t& handle, int & fd) -> Status {
148+ return CreateCuFileHandle (path, O_RDONLY | O_DIRECT, handle, fd);
149+ });
150+ if (status.Failure ()) {
151+ return status;
152+ }
153+ ssize_t bytesRead = cuFileRead (cuFileHandle, address, length, file_offset, dev_offset);
154+ if (bytesRead < 0 || (size_t )bytesRead != length) {
155+ UC_ERROR (" cuFileRead failed for {}: expected {}, got {}" , path, length, bytesRead);
156+ return Status::Error ();
157+ }
158+ return Status::OK ();
159+ }
160+ Status D2SSync (const std::string& path, void * address, const size_t length, const size_t file_offset, const size_t dev_offset) override
161+ {
162+ CUfileHandle_t cuFileHandle = nullptr ;
163+ auto status = CuFileHandleRecorder::Instance ().Get (path, cuFileHandle,
164+ [&path](CUfileHandle_t& handle, int & fd) -> Status {
165+ return CreateCuFileHandle (path, O_WRONLY | O_CREAT | O_DIRECT, handle, fd);
166+ });
167+ if (status.Failure ()) {
168+ return status;
169+ }
170+ ssize_t bytesWrite = cuFileWrite (cuFileHandle, address, length, file_offset, dev_offset);
171+ if (bytesWrite < 0 || (size_t )bytesWrite != length) {
172+ UC_ERROR (" cuFileWrite failed for {}: expected {}, got {}" , path, length, bytesWrite);
173+ return Status::Error ();
174+ }
175+ return Status::OK ();
176+ }
99177 Status AppendCallback (std::function<void (bool )> cb) override
100178 {
101179 auto * c = new (std::nothrow) Closure (cb);
@@ -140,5 +218,17 @@ std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_
140218 return nullptr ;
141219 }
142220}
221+ std::once_flag CudaDevice::gdsOnce_{};
222+ void CudaDevice::InitGdsOnce ()
223+ {
224+ std::call_once (gdsOnce_, [] (){
225+ CUfileError_t ret = cuFileDriverOpen ();
226+ if (ret.err == CU_FILE_SUCCESS) {
227+ UC_INFO (" GDS driver initialized successfully" );
228+ } else {
229+ UC_ERROR (" GDS driver initialized unsuccessfully" );
230+ }
231+ });
232+ }
143233
144234} // namespace UC
0 commit comments