Skip to content

Commit c578037

Browse files
authored
refactor: reusable transport abstraction & optimized NSFStore pipeline (#296)
1 parent d9b68aa commit c578037

32 files changed

+693
-623
lines changed

ucm/store/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ add_subdirectory(nfsstore)
99
add_subdirectory(dramstore)
1010
add_subdirectory(localstore)
1111
add_subdirectory(mooncakestore)
12+
add_subdirectory(task)
1213
add_subdirectory(test)

ucm/store/device/ascend/ascend_device.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ class AscendDevice : public IBufferedDevice {
9191
}
9292
return Status::OK();
9393
}
94+
Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) override
95+
{
96+
return ASCEND_API(aclrtMemcpy, dst, count, src, count, ACL_MEMCPY_HOST_TO_DEVICE);
97+
}
98+
Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) override
99+
{
100+
return ASCEND_API(aclrtMemcpy, dst, count, src, count, ACL_MEMCPY_DEVICE_TO_HOST);
101+
}
94102
Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) override
95103
{
96104
return ASCEND_API(aclrtMemcpyAsync, dst, count, src, count, ACL_MEMCPY_HOST_TO_DEVICE,
@@ -111,6 +119,7 @@ class AscendDevice : public IBufferedDevice {
111119
return ASCEND_API(aclrtLaunchCallback, Trampoline, (void*)c, ACL_CALLBACK_NO_BLOCK,
112120
this->stream_);
113121
}
122+
Status Synchronized() override { return ASCEND_API(aclrtSynchronizeStream, this->stream_); }
114123

115124
protected:
116125
std::shared_ptr<std::byte> MakeBuffer(const size_t size) override

ucm/store/device/cuda/cuda_device.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ class CudaDevice : public IBufferedDevice {
7878
}
7979
return status;
8080
}
81+
virtual Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) override
82+
{
83+
return CUDA_API(cudaMemcpy, dst, src, count, cudaMemcpyHostToDevice);
84+
}
85+
virtual Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) override
86+
{
87+
return CUDA_API(cudaMemcpy, dst, src, count, cudaMemcpyDeviceToHost);
88+
}
8189
Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) override
8290
{
8391
return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyHostToDevice,
@@ -100,6 +108,10 @@ class CudaDevice : public IBufferedDevice {
100108
if (status.Failure()) { delete c; }
101109
return status;
102110
}
111+
Status Synchronized() override
112+
{
113+
return CUDA_API(cudaStreamSynchronize, (cudaStream_t)this->stream_);
114+
}
103115

104116
protected:
105117
std::shared_ptr<std::byte> MakeBuffer(const size_t size) override

ucm/store/device/idevice.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ class IDevice {
3939
virtual ~IDevice() = default;
4040
virtual Status Setup() = 0;
4141
virtual std::shared_ptr<std::byte> GetBuffer(const size_t size) = 0;
42+
virtual Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) = 0;
43+
virtual Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) = 0;
4244
virtual Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) = 0;
4345
virtual Status D2HAsync(std::byte* dst, const std::byte* src, const size_t count) = 0;
4446
virtual Status AppendCallback(std::function<void(bool)> cb) = 0;
47+
virtual Status Synchronized() = 0;
4548

4649
protected:
4750
virtual std::shared_ptr<std::byte> MakeBuffer(const size_t size) = 0;

ucm/store/device/simu/simu_device.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
* */
2424
#include "ibuffered_device.h"
2525
#include "logger/logger.h"
26+
#include "thread/latch.h"
2627
#include "thread/thread_pool.h"
2728

2829
namespace UC {
@@ -38,7 +39,19 @@ class SimuDevice : public IBufferedDevice {
3839
{
3940
auto status = IBufferedDevice::Setup();
4041
if (status.Failure()) { return status; }
41-
if (!this->backend_.Setup([](auto& task) { task(); })) { return Status::Error(); }
42+
if (!this->backend_.SetWorkerFn([](auto& task, const auto&) { task(); }).Run()) {
43+
return Status::Error();
44+
}
45+
return Status::OK();
46+
}
47+
Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) override
48+
{
49+
std::copy(src, src + count, dst);
50+
return Status::OK();
51+
}
52+
Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) override
53+
{
54+
std::copy(src, src + count, dst);
4255
return Status::OK();
4356
}
4457
Status H2DAsync(std::byte* dst, const std::byte* src, const size_t count) override
@@ -64,6 +77,13 @@ class SimuDevice : public IBufferedDevice {
6477
this->backend_.Push([=] { cb(true); });
6578
return Status::OK();
6679
}
80+
virtual Status Synchronized()
81+
{
82+
Latch waiter{1};
83+
this->backend_.Push([&] { waiter.Done(nullptr); });
84+
waiter.Wait();
85+
return Status::OK();
86+
}
6787

6888
protected:
6989
std::shared_ptr<std::byte> MakeBuffer(const size_t size) override

ucm/store/dramstore/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
file(GLOB_RECURSE UCMSTORE_DRAM_CC_SOURCE_FILES "./cc/*.cc")
22
add_library(dramstore STATIC ${UCMSTORE_DRAM_CC_SOURCE_FILES})
33
target_include_directories(dramstore PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cc)
4-
target_link_libraries(dramstore PUBLIC storeinfra)
4+
target_link_libraries(dramstore PUBLIC storeinfra storetask)
55

66
file(GLOB_RECURSE UCMSTORE_DRAM_CPY_SOURCE_FILES "./cpy/*.cc")
77
pybind11_add_module(ucmdramstore ${UCMSTORE_DRAM_CPY_SOURCE_FILES})

ucm/store/dramstore/cpy/dramstore.py.cc

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,31 +56,30 @@ class DRAMStorePy : public DRAMStore {
5656
size_t Load(const py::list& blockIds, const py::list& offsets, const py::list& addresses,
5757
const py::list& lengths)
5858
{
59-
return this->SubmitPy(blockIds, offsets, addresses, lengths, CCStore::Task::Type::LOAD,
60-
CCStore::Task::Location::DEVICE, "DRAM::H2D");
59+
return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::LOAD,
60+
Task::Location::DEVICE, "DRAM::H2D");
6161
}
6262
size_t Dump(const py::list& blockIds, const py::list& offsets, const py::list& addresses,
6363
const py::list& lengths)
6464
{
65-
return this->SubmitPy(blockIds, offsets, addresses, lengths, CCStore::Task::Type::DUMP,
66-
CCStore::Task::Location::DEVICE, "DRAM::D2H");
65+
return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::DUMP,
66+
Task::Location::DEVICE, "DRAM::D2H");
6767
}
6868

6969
private:
7070
size_t SubmitPy(const py::list& blockIds, const py::list& offsets, const py::list& addresses,
71-
const py::list& lengths, const CCStore::Task::Type type,
72-
const CCStore::Task::Location location, const std::string& brief)
71+
const py::list& lengths, Task::Type&& type, Task::Location&& location,
72+
std::string&& brief)
7373
{
74-
CCStore::Task task{type, location, brief};
74+
Task task{std::move(type), std::move(location), std::move(brief)};
7575
auto blockId = blockIds.begin();
7676
auto offset = offsets.begin();
7777
auto address = addresses.begin();
7878
auto length = lengths.begin();
7979
while ((blockId != blockIds.end()) && (offset != offsets.end()) &&
8080
(address != addresses.end()) && (length != lengths.end())) {
81-
auto ret = task.Append(blockId->cast<std::string>(), offset->cast<size_t>(),
82-
address->cast<uintptr_t>(), length->cast<size_t>());
83-
if (ret != 0) { return CCStore::invalidTaskId; }
81+
task.Append(blockId->cast<std::string>(), offset->cast<size_t>(),
82+
address->cast<uintptr_t>(), length->cast<size_t>());
8483
blockId++;
8584
offset++;
8685
address++;

ucm/store/infra/file/file.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,25 @@ Status File::Access(const std::string& path, const int32_t mode)
5454
}
5555

5656
Status File::Read(const std::string& path, const size_t offset, const size_t length,
57-
uintptr_t address)
57+
uintptr_t address, const bool directIo)
5858
{
5959
FileImpl file{path};
6060
Status status = Status::OK();
61-
if ((status = file.Open(IFile::OpenFlag::READ_ONLY)).Failure()) { return status; }
61+
auto flags = directIo ? IFile::OpenFlag::READ_ONLY | IFile::OpenFlag::DIRECT
62+
: IFile::OpenFlag::READ_ONLY;
63+
if ((status = file.Open(flags)).Failure()) { return status; }
6264
if ((status = file.Read((void*)address, length, offset)).Failure()) { return status; }
6365
return status;
6466
}
6567

6668
Status File::Write(const std::string& path, const size_t offset, const size_t length,
67-
const uintptr_t address)
69+
const uintptr_t address, const bool directIo)
6870
{
6971
FileImpl file{path};
7072
Status status = Status::OK();
71-
if ((status = file.Open(IFile::OpenFlag::WRITE_ONLY)).Failure()) { return status; }
73+
auto flags = directIo ? IFile::OpenFlag::WRITE_ONLY | IFile::OpenFlag::DIRECT
74+
: IFile::OpenFlag::WRITE_ONLY;
75+
if ((status = file.Open(flags)).Failure()) { return status; }
7276
if ((status = file.Write((const void*)address, length, offset)).Failure()) { return status; }
7377
return status;
7478
}

ucm/store/infra/file/file.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ class File {
3737
static Status Rename(const std::string& path, const std::string& newName);
3838
static Status Access(const std::string& path, const int32_t mode);
3939
static Status Read(const std::string& path, const size_t offset, const size_t length,
40-
uintptr_t address);
40+
uintptr_t address, const bool directIo = false);
4141
static Status Write(const std::string& path, const size_t offset, const size_t length,
42-
const uintptr_t address);
42+
const uintptr_t address, const bool directIo = false);
4343
static void MUnmap(void* addr, size_t size);
4444
static void ShmUnlink(const std::string& path);
4545
};

ucm/store/infra/thread/latch.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include <atomic>
2828
#include <condition_variable>
29+
#include <functional>
2930
#include <mutex>
3031

3132
namespace UC {
@@ -34,8 +35,22 @@ class Latch {
3435
public:
3536
explicit Latch(const size_t expected = 0) : counter_{expected} {}
3637
void Up() { ++this->counter_; }
37-
size_t Done() { return --this->counter_; }
38-
void Notify() { this->cv_.notify_all(); }
38+
void Done(std::function<void(void)> finish) noexcept
39+
{
40+
auto counter = this->counter_.load(std::memory_order_acquire);
41+
while (counter > 0) {
42+
auto desired = counter - 1;
43+
if (this->counter_.compare_exchange_weak(counter, desired, std::memory_order_acq_rel)) {
44+
if (desired == 0) {
45+
if (finish) { finish(); }
46+
std::lock_guard<std::mutex> lg(this->mutex_);
47+
this->cv_.notify_all();
48+
}
49+
return;
50+
}
51+
counter = this->counter_.load(std::memory_order_acquire);
52+
}
53+
}
3954
void Wait()
4055
{
4156
std::unique_lock<std::mutex> lk(this->mutex_);

0 commit comments

Comments
 (0)