Skip to content

Commit 06442f0

Browse files
mag1c-hyou-seesee-you
authored andcommitted
[bugfix] preserve DRAM buffer lifetime to restore inference accuracy (#322)
* linear buffer for device * check data consistency after embedding
1 parent dc454e0 commit 06442f0

File tree

2 files changed

+71
-15
lines changed

2 files changed

+71
-15
lines changed

ucm/store/device/ibuffered_device.h

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,37 @@
2525
#define UNIFIEDCACHE_IBUFFERED_DEVICE_H
2626

2727
#include "idevice.h"
28-
#include "thread/index_pool.h"
2928

3029
namespace UC {
3130

3231
class IBufferedDevice : public IDevice {
32+
class LinearBuffer {
33+
std::shared_ptr<std::byte> addr_{nullptr};
34+
size_t index_{0};
35+
size_t number_{0};
36+
size_t size_{0};
37+
38+
public:
39+
void Setup(std::shared_ptr<std::byte> addr, const size_t number, const size_t size)
40+
{
41+
this->addr_ = addr;
42+
this->number_ = number;
43+
this->size_ = size;
44+
this->Reset();
45+
}
46+
void Reset() noexcept { this->index_ = 0; }
47+
bool Full() const noexcept { return this->index_ == this->number_; }
48+
bool Available(const size_t size) const noexcept { return this->size_ >= size; }
49+
std::shared_ptr<std::byte> Get() noexcept
50+
{
51+
auto addr = this->addr_.get();
52+
auto buffer = addr + this->size_ * this->index_;
53+
++this->index_;
54+
return std::shared_ptr<std::byte>(buffer, [](auto) {});
55+
}
56+
};
57+
LinearBuffer buffer_;
58+
3359
public:
3460
IBufferedDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
3561
: IDevice{deviceId, bufferSize, bufferNumber}
@@ -39,26 +65,20 @@ class IBufferedDevice : public IDevice {
3965
{
4066
auto totalSize = this->bufferSize * this->bufferNumber;
4167
if (totalSize == 0) { return Status::OK(); }
42-
this->_addr = this->MakeBuffer(totalSize);
43-
if (!this->_addr) { return Status::OutOfMemory(); }
44-
this->_indexPool.Setup(this->bufferNumber);
68+
auto addr = this->MakeBuffer(totalSize);
69+
if (!addr) { return Status::OutOfMemory(); }
70+
this->buffer_.Setup(addr, this->bufferNumber, this->bufferSize);
4571
return Status::OK();
4672
}
4773
virtual std::shared_ptr<std::byte> GetBuffer(const size_t size) override
4874
{
49-
if (!this->_addr || size > this->bufferSize) { return this->MakeBuffer(size); }
50-
auto idx = this->_indexPool.Acquire();
51-
if (idx != IndexPool::npos) {
52-
auto ptr = this->_addr.get() + this->bufferSize * idx;
53-
return std::shared_ptr<std::byte>(ptr,
54-
[this, idx](auto) { this->_indexPool.Release(idx); });
75+
if (this->buffer_.Full()) {
76+
auto status = this->Synchronized();
77+
if (status.Failure()) { return nullptr; }
78+
this->buffer_.Reset();
5579
}
56-
return this->MakeBuffer(size);
80+
return this->buffer_.Available(size) ? this->buffer_.Get() : this->MakeBuffer(size);
5781
}
58-
59-
private:
60-
std::shared_ptr<std::byte> _addr{nullptr};
61-
IndexPool _indexPool;
6282
};
6383

6484
} // namespace UC

ucm/store/test/e2e/nfsstore_embed.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,39 @@ def embed(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Ten
8080
store.commit(hashes, True)
8181

8282

83+
def fetch(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]):
84+
founds = store.lookup(hashes)
85+
for found in founds:
86+
assert found
87+
block_ids = []
88+
offsets = []
89+
layers = []
90+
for hash_id, block in zip(hashes, tensors):
91+
offset = 0
92+
for layer in block:
93+
block_ids.append(hash_id)
94+
offsets.append(offset)
95+
layers.append(layer)
96+
offset += layer.untyped_storage().size()
97+
task = store.load(block_ids, offsets, layers)
98+
assert task.task_id > 0
99+
ret = store.wait(task)
100+
assert ret == 0
101+
102+
103+
def cmp_and_print_diff(a, b, rtol=0.0, atol=0.0):
104+
for r, (row_a, row_b) in enumerate(zip(a, b)):
105+
for c, (ta, tb) in enumerate(zip(row_a, row_b)):
106+
if not torch.allclose(ta, tb, rtol=rtol, atol=atol):
107+
mask = ~torch.isclose(ta, tb, rtol=rtol, atol=atol)
108+
diff_a = ta[mask].cpu()
109+
diff_b = tb[mask].cpu()
110+
print(f"DIFF at [{r}][{c}] total {mask.sum().item()} element(s)")
111+
print(" a val:", diff_a.flatten())
112+
print(" b val:", diff_b.flatten())
113+
assert False
114+
115+
83116
def store_all_hashes(hashes):
84117
kvcache_block_hashes_file = "kvcache_block_hashes.txt"
85118
current_directory = os.path.dirname(__file__)
@@ -108,7 +141,10 @@ def main():
108141
for batch in range(total_batches):
109142
start = batch_size * batch
110143
end = min(start + batch_size, block_number)
144+
tensors2 = [[torch.empty_like(t) for t in row] for row in tensors]
111145
embed(store, hashes[start:end], tensors)
146+
fetch(store, hashes[start:end], tensors2)
147+
cmp_and_print_diff(tensors, tensors2)
112148
store_all_hashes(hashes)
113149

114150

0 commit comments

Comments
 (0)