3030#include < unistd.h>
3131#include < cerrno>
3232#include < cstring>
33- #include " sharded_handle_recorder.h"
33+ #include < unordered_map>
34+ #include < cstdlib>
35+ #include " infra/template/sharded_handle_recorder.h"
3436
3537#define CUDA_TRANS_UNIT_SIZE (sizeof (uint64_t ) * 2 )
3638#define CUDA_TRANS_BLOCK_NUMBER (32 )
@@ -97,11 +99,10 @@ struct fmt::formatter<cudaError_t> : formatter<int32_t> {
9799
98100namespace UC {
99101
100- static Status CreateCuFileHandle (const std::string& path, int flags , CUfileHandle_t& cuFileHandle, int & fd )
102+ static Status CreateCuFileHandle (int fd , CUfileHandle_t& cuFileHandle)
101103{
102- fd = open (path.c_str (), flags, 0644 );
103104 if (fd < 0 ) {
104- UC_ERROR (" Failed to open file {} : {}" , path, strerror (errno) );
105+ UC_ERROR (" Invalid file descriptor : {}" , fd );
105106 return Status::Error ();
106107 }
107108
@@ -110,10 +111,8 @@ static Status CreateCuFileHandle(const std::string& path, int flags, CUfileHandl
110111 cfDescr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
111112 CUfileError_t err = cuFileHandleRegister (&cuFileHandle, &cfDescr);
112113 if (err.err != CU_FILE_SUCCESS) {
113- UC_ERROR (" Failed to register cuFile handle for {}: error {}" ,
114- path, static_cast <int >(err.err ));
115- close (fd);
116- fd = -1 ;
114+ UC_ERROR (" Failed to register cuFile handle for fd {}: error {}" ,
115+ fd, static_cast <int >(err.err ));
117116 return Status::Error ();
118117 }
119118
@@ -163,31 +162,27 @@ class CudaDevice : public IBufferedDevice {
163162 }
164163 static void ReleaseDeviceArray (void * deviceArray) { CUDA_API (cudaFree, deviceArray); }
165164 static std::once_flag gdsOnce_;
166- static void InitGdsOnce ();
167165
168166public:
167+ static void InitGdsOnce ();
169168 CudaDevice (const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
170169 : IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr }
171170 {
172171 }
173172 ~CudaDevice () {
174- CuFileHandleRecorder ::Instance ().ClearAll ([](CUfileHandle_t h, int fd ) {
173+ HandlePool< int , CUfileHandle_t> ::Instance ().ClearAll ([](CUfileHandle_t h) {
175174 cuFileHandleDeregister (h);
176- if (fd >= 0 ) {
177- close (fd);
178- }
179175 });
180176
181177 if (stream_ != nullptr ) {
182178 cudaStreamDestroy ((cudaStream_t)stream_);
183179 }
184180 }
185- Status Setup (bool transferUseDirect ) override
181+ Status Setup () override
186182 {
187- if (transferUseDirect) {InitGdsOnce ();}
188183 auto status = Status::OK ();
189184 if ((status = CUDA_API (cudaSetDevice, this ->deviceId )).Failure ()) { return status; }
190- if ((status = IBufferedDevice::Setup (transferUseDirect )).Failure ()) { return status; }
185+ if ((status = IBufferedDevice::Setup ()).Failure ()) { return status; }
191186 if ((status = CUDA_API (cudaStreamCreate, (cudaStream_t*)&this ->stream_ )).Failure ()) {
192187 return status;
193188 }
@@ -209,36 +204,48 @@ public:
209204 {
210205 return CUDA_API (cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost, this ->stream_ );
211206 }
212- Status S2DSync (const std::string& path , void * address, const size_t length, const size_t fileOffset, const size_t devOffset) override
207+ Status S2DSync (int fd , void * address, const size_t length, const size_t fileOffset, const size_t devOffset) override
213208 {
214209 CUfileHandle_t cuFileHandle = nullptr ;
215- auto status = CuFileHandleRecorder ::Instance ().Get (path , cuFileHandle,
216- [&path ](CUfileHandle_t& handle, int & fd ) -> Status {
217- return CreateCuFileHandle (path, O_RDONLY | O_DIRECT, handle, fd );
210+ auto status = HandlePool< int , CUfileHandle_t> ::Instance ().Get (fd , cuFileHandle,
211+ [fd ](CUfileHandle_t& handle) -> Status {
212+ return CreateCuFileHandle (fd, handle);
218213 });
219214 if (status.Failure ()) {
220215 return status;
221216 }
222217 ssize_t bytesRead = cuFileRead (cuFileHandle, address, length, fileOffset, devOffset);
218+ HandlePool<int , CUfileHandle_t>::Instance ().Put (fd, [](CUfileHandle_t h) {
219+ if (h != nullptr ) {
220+ cuFileHandleDeregister (h);
221+ }
222+ });
223+
223224 if (bytesRead < 0 || (size_t )bytesRead != length) {
224- UC_ERROR (" cuFileRead failed for {}: expected {}, got {}" , path , length, bytesRead);
225+ UC_ERROR (" cuFileRead failed for fd {}: expected {}, got {}" , fd , length, bytesRead);
225226 return Status::Error ();
226227 }
227228 return Status::OK ();
228229 }
229- Status D2SSync (const std::string& path , void * address, const size_t length, const size_t fileOffset, const size_t devOffset) override
230+ Status D2SSync (int fd , void * address, const size_t length, const size_t fileOffset, const size_t devOffset) override
230231 {
231232 CUfileHandle_t cuFileHandle = nullptr ;
232- auto status = CuFileHandleRecorder ::Instance ().Get (path , cuFileHandle,
233- [&path ](CUfileHandle_t& handle, int & fd ) -> Status {
234- return CreateCuFileHandle (path, O_WRONLY | O_CREAT | O_DIRECT, handle, fd );
233+ auto status = HandlePool< int , CUfileHandle_t> ::Instance ().Get (fd , cuFileHandle,
234+ [fd ](CUfileHandle_t& handle) -> Status {
235+ return CreateCuFileHandle (fd, handle);
235236 });
236237 if (status.Failure ()) {
237238 return status;
238239 }
239240 ssize_t bytesWrite = cuFileWrite (cuFileHandle, address, length, fileOffset, devOffset);
241+ HandlePool<int , CUfileHandle_t>::Instance ().Put (fd, [](CUfileHandle_t h) {
242+ if (h != nullptr ) {
243+ cuFileHandleDeregister (h);
244+ }
245+ });
246+
240247 if (bytesWrite < 0 || (size_t )bytesWrite != length) {
241- UC_ERROR (" cuFileWrite failed for {}: expected {}, got {}" , path , length, bytesWrite);
248+ UC_ERROR (" cuFileWrite failed for fd {}: expected {}, got {}" , fd , length, bytesWrite);
242249 return Status::Error ();
243250 }
244251 return Status::OK ();
@@ -305,9 +312,12 @@ private:
305312};
306313
307314std::unique_ptr<IDevice> DeviceFactory::Make (const int32_t deviceId, const size_t bufferSize,
308- const size_t bufferNumber)
315+ const size_t bufferNumber, bool transferUseDirect )
309316{
310317 try {
318+ if (transferUseDirect) {
319+ CudaDevice::InitGdsOnce ();
320+ }
311321 return std::make_unique<CudaDevice>(deviceId, bufferSize, bufferNumber);
312322 } catch (const std::exception& e) {
313323 UC_ERROR (" Failed({}) to make cuda device({},{},{})." , e.what (), deviceId, bufferSize,
0 commit comments