Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ucm/store/nfsstore/cc/api/nfsstore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class NFSStoreImpl : public NFSStore {
status =
this->transMgr_.Setup(config.transferDeviceId, config.transferStreamNumber,
config.transferIoSize, config.transferBufferNumber,
this->spaceMgr_.GetSpaceLayout(), config.transferTimeoutMs);
this->spaceMgr_.GetSpaceLayout(), config.transferTimeoutMs, config.transferIoDirect);
if (status.Failure()) {
UC_ERROR("Failed({}) to setup TsfTaskManager.", status);
return status.Underlying();
Expand Down Expand Up @@ -124,6 +124,7 @@ class NFSStoreImpl : public NFSStore {
UC_INFO("Set UC::storageCapacity to {}.", config.storageCapacity);
UC_INFO("Set UC::RecycleEnable to {}.", config.recycleEnable);
UC_INFO("Set UC::RecycleThreshold to {}.", config.recycleThresholdRatio);
UC_INFO("Set UC::IoDirect to {}.", config.transferIoDirect);
}

private:
Expand Down
4 changes: 3 additions & 1 deletion ucm/store/nfsstore/cc/api/nfsstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@ class NFSStore : public CCStore {
size_t storageCapacity;
bool recycleEnable;
float recycleThresholdRatio;
bool transferIoDirect;

Config(const std::vector<std::string>& storageBackends, const size_t kvcacheBlockSize,
const bool transferEnable)
: storageBackends{storageBackends}, kvcacheBlockSize{kvcacheBlockSize},
transferEnable{transferEnable}, transferDeviceId{-1}, transferStreamNumber{32},
transferIoSize{262144}, transferBufferNumber{512}, transferTimeoutMs{30000},
tempDumpDirEnable{false}, hotnessEnable{true}, hotnessInterval{60},
storageCapacity{0}, recycleEnable{true}, recycleThresholdRatio{0.7f}
storageCapacity{0}, recycleEnable{true}, recycleThresholdRatio{0.7f},
transferIoDirect{false}
{
}
};
Expand Down
7 changes: 4 additions & 3 deletions ucm/store/nfsstore/cc/domain/trans/posix_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ bool IsAligned(const T value)
}

Status PosixQueue::Setup(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber,
TaskSet* failureSet, const SpaceLayout* layout, const size_t timeoutMs)
TaskSet* failureSet, const SpaceLayout* layout, const size_t timeoutMs, bool useDirect)
{
this->deviceId_ = deviceId;
this->bufferSize_ = bufferSize;
this->bufferNumber_ = bufferNumber;
this->failureSet_ = failureSet;
this->layout_ = layout;
this->useDirect_ = useDirect;
auto success =
this->backend_.SetWorkerInitFn([this](auto& device) { return this->Init(device); })
.SetWorkerFn([this](auto& shard, const auto& device) { this->Work(shard, device); })
Expand Down Expand Up @@ -106,7 +107,7 @@ Status PosixQueue::D2S(Task::Shard& shard, const Device& device)
auto status = device->D2HSync((std::byte*)hub, (std::byte*)shard.address, shard.length);
if (status.Failure()) { return status; }
auto path = this->layout_->DataFilePath(shard.block, true);
return File::Write(path, shard.offset, shard.length, (uintptr_t)hub);
return File::Write(path, shard.offset, shard.length, (uintptr_t)hub, useDirect_);
}

Status PosixQueue::S2D(Task::Shard& shard, const Device& device)
Expand All @@ -118,7 +119,7 @@ Status PosixQueue::S2D(Task::Shard& shard, const Device& device)
}
auto hub = shard.buffer.get();
auto path = this->layout_->DataFilePath(shard.block, false);
auto status = File::Read(path, shard.offset, shard.length, (uintptr_t)hub);
auto status = File::Read(path, shard.offset, shard.length, (uintptr_t)hub, useDirect_);
if (status.Failure()) { return status; }
return device->H2DAsync((std::byte*)shard.address, (std::byte*)hub, shard.length);
}
Expand Down
3 changes: 2 additions & 1 deletion ucm/store/nfsstore/cc/domain/trans/posix_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ class PosixQueue : public TaskQueue {
size_t bufferNumber_{0};
TaskSet* failureSet_{nullptr};
const SpaceLayout* layout_{nullptr};
bool useDirect_{false};
ThreadPool<Task::Shard, Device> backend_{};

public:
Status Setup(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber,
TaskSet* failureSet, const SpaceLayout* layout, const size_t timeoutMs);
TaskSet* failureSet, const SpaceLayout* layout, const size_t timeoutMs, bool useDirect = false);
void Push(std::list<Task::Shard>& shards) noexcept override;

private:
Expand Down
4 changes: 2 additions & 2 deletions ucm/store/nfsstore/cc/domain/trans/trans_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ namespace UC {
class TransManager : public TaskManager {
public:
Status Setup(const int32_t deviceId, const size_t streamNumber, const size_t ioSize,
const size_t bufferNumber, const SpaceLayout* layout, const size_t timeoutMs)
const size_t bufferNumber, const SpaceLayout* layout, const size_t timeoutMs, bool useDirect = false)
{
this->timeoutMs_ = timeoutMs;
auto status = Status::OK();
for (size_t i = 0; i < streamNumber; i++) {
auto q = std::make_shared<PosixQueue>();
status =
q->Setup(deviceId, ioSize, bufferNumber, &this->failureSet_, layout, timeoutMs);
q->Setup(deviceId, ioSize, bufferNumber, &this->failureSet_, layout, timeoutMs, useDirect);
if (status.Failure()) { break; }
this->queues_.emplace_back(std::move(q));
}
Expand Down
1 change: 1 addition & 0 deletions ucm/store/nfsstore/cpy/nfsstore.py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ PYBIND11_MODULE(ucmnfsstore, module)
config.def_readwrite("transferDeviceId", &UC::NFSStorePy::Config::transferDeviceId);
config.def_readwrite("transferStreamNumber", &UC::NFSStorePy::Config::transferStreamNumber);
config.def_readwrite("transferIoSize", &UC::NFSStorePy::Config::transferIoSize);
config.def_readwrite("transferIoDirect", &UC::NFSStorePy::Config::transferIoDirect);
config.def_readwrite("transferBufferNumber", &UC::NFSStorePy::Config::transferBufferNumber);
config.def_readwrite("transferTimeoutMs", &UC::NFSStorePy::Config::transferTimeoutMs);
config.def_readwrite("tempDumpDirEnable", &UC::NFSStorePy::Config::tempDumpDirEnable);
Expand Down
1 change: 1 addition & 0 deletions ucm/store/nfsstore/nfsstore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, config: Dict):
if transfer_enable:
param.transferDeviceId = config["device"]
param.transferIoSize = config["io_size"]
param.transferIoDirect = config.get("transferIoDirect", False)

# NOTE: compatible with legacy nfsstore lib
if hasattr(param, "storageCapacity"):
Expand Down
208 changes: 129 additions & 79 deletions ucm/store/test/e2e/nfsstore_embed_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# SOFTWARE.
#
import csv
import math
import os
import secrets
import time
Expand All @@ -35,7 +36,12 @@


def setup(
storage_backends, block_size, device_id, io_size, transferStreamNumber
storage_backends,
block_size,
device_id,
io_size,
transferStreamNumber,
transferIoDirect,
) -> UcmKVStoreBase:
config = {
"storage_backends": storage_backends,
Expand All @@ -44,19 +50,41 @@ def setup(
"device": device_id,
"io_size": io_size,
"transferStreamNumber": transferStreamNumber,
"transferIoDirect": transferIoDirect,
}
return UcmNfsStore(config)


def make_aligned_tensor(shape, dtype, device, alignment=4096):
numl = math.prod(shape)
dtype_size = torch.tensor(1, dtype=dtype).element_size()
total_byters = numl * dtype_size

padded_bytes = total_byters + alignment
storage = torch.ByteTensor(padded_bytes).to(device)

ptr = storage.data_ptr()
offset = ptr % alignment
if offset != 0:
aligned_ptr = ptr + (alignment - offset)
else:
aligned_ptr = ptr

aligned_storage = storage[(aligned_ptr - ptr) :].view(dtype)
tensor = aligned_storage[:numl].view(shape)
tensor.storage_ref = storage
return tensor


def make_buffers(
block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head, kv
):
hashes = [secrets.token_hex(16) for _ in range(block_number)]
kv_caches = {}
for i in range(block_layer):
kv_caches[i] = torch.rand(
kv_caches[i] = make_aligned_tensor(
[kv, block_number, block_len, num_head, head_dim],
dtype=torch.bfloat16,
dtype=torch.float16,
device=f"cuda:{device_id}",
)
return hashes, kv_caches
Expand All @@ -69,6 +97,14 @@ def store_all_hashes(hashes: List[str]):
f.write(h + "\n")


def load_hashes_from_file() -> List[str]:
file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt")
if not os.path.exists(file_path):
return []
with open(file_path, "r", encoding="utf-8") as f:
return [line.strip() for line in f.readlines()]


def embed(
store: UcmKVStoreBase,
hashes: List[str],
Expand Down Expand Up @@ -177,6 +213,8 @@ def run(
block_elem_size: int,
kv: int,
mla: bool,
transferIoDirect: bool,
operation_mode: str = "both", # "write_only", "read_only", or "both"
) -> Tuple[float, float, float, float, float, float]:
"""
Run a single test with given parameters and return performance metrics.
Expand All @@ -196,87 +234,99 @@ def run(
w_size_sum, r_size_sum = 0.0, 0.0

store = setup(
storage_backends, block_size, device_id, io_size, transferStreamNumber
storage_backends,
block_size,
device_id,
io_size,
transferStreamNumber,
transferIoDirect,
)

for r in range(repeat):
print(f"\n--- Round {r+1} ---")

hashes, kvcaches = make_buffers(
real_blocks,
device_id,
batch_size,
head_size,
block_len,
block_layer,
num_head,
kv,
)

results = store.create(hashes[:batch_size])
assert sum(results) == 0, "Create operation failed"

w_size, w_time, w_bw = embed(
store,
hashes[:batch_size],
kvcaches,
mla,
)
store.commit(hashes[:batch_size], True)

store_all_hashes(hashes[:batch_size])

r_size, r_time, r_bw = fetch(
store,
hashes[:batch_size],
kvcaches,
mla,
)

w_bw_list.append(w_bw)
r_bw_list.append(r_bw)
w_time_list.append(w_time)
r_time_list.append(r_time)
w_size_sum += w_size
r_size_sum += r_size

# Clean up resources
del kvcaches, hashes
torch.cuda.empty_cache()
if operation_mode in ["write_only", "both"]:
hashes, kvcaches = make_buffers(
real_blocks,
device_id,
batch_size,
head_size,
block_len,
block_layer,
num_head,
kv,
)

results = store.create(hashes[:batch_size])
assert sum(results) == 0, "Create operation failed"

w_size, w_time, w_bw = embed(
store,
hashes[:batch_size],
kvcaches,
mla,
)
store.commit(hashes[:batch_size], True)

if r == 0:
store_all_hashes(hashes[:batch_size])

w_bw_list.append(w_bw)
w_time_list.append(w_time)
w_size_sum += w_size

if operation_mode == "write_only":
del kvcaches, hashes
torch.cuda.empty_cache()

if operation_mode in ["read_only", "both"]:
if operation_mode == "read_only":
saved_hashes = load_hashes_from_file()
if not saved_hashes:
raise RuntimeError("No saved hashes found for read operation")

_, kvcaches = make_buffers(
real_blocks,
device_id,
batch_size,
head_size,
block_len,
block_layer,
num_head,
kv,
)

r_size, r_time, r_bw = fetch(
store,
saved_hashes[:batch_size],
kvcaches,
mla,
)
else:
r_size, r_time, r_bw = fetch(
store,
hashes[:batch_size],
kvcaches,
mla,
)

r_bw_list.append(r_bw)
r_time_list.append(r_time)
r_size_sum += r_size

if operation_mode == "read_only":
del kvcaches
torch.cuda.empty_cache()
else:
del kvcaches, hashes
torch.cuda.empty_cache()

del store
avg_w_bw = sum(w_bw_list) / repeat
avg_r_bw = sum(r_bw_list) / repeat
avg_w_time = sum(w_time_list) / repeat
avg_r_time = sum(r_time_list) / repeat
avg_w_size = w_size_sum / (1024**3) / repeat
avg_r_size = r_size_sum / (1024**3) / repeat
avg_w_bw = sum(w_bw_list) / len(w_bw_list) if w_bw_list else 0.0
avg_r_bw = sum(r_bw_list) / len(r_bw_list) if r_bw_list else 0.0
avg_w_time = sum(w_time_list) / len(w_time_list) if w_time_list else 0.0
avg_r_time = sum(r_time_list) / len(r_time_list) if r_time_list else 0.0
avg_w_size = w_size_sum / (1024**3) / len(w_time_list) if w_time_list else 0.0
avg_r_size = r_size_sum / (1024**3) / len(r_time_list) if r_time_list else 0.0

return avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size


if __name__ == "__main__":
os.environ["UC_LOGGER_LEVEL"] = "debug"

try:
result = run(
storage_backends="/home/nfs/zht_data",
device_id=1,
repeat=1,
num_head=1,
block_len=128,
transferStreamNumber=32,
num_tokens=4096,
block_layer=61,
head_size=576,
block_elem_size=2,
kv=1,
mla=True,
)

avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size = result

except Exception as e:
print(f"Error: {e}")
import traceback

traceback.print_exc()
Loading