Skip to content

Commit d9b68aa

Browse files
authored
[Misc] add store intf with tensor addr ptr (#288)
* add store intf with tensor addr ptr * fix interface doxy
1 parent 1f0f228 commit d9b68aa

File tree

4 files changed

+107
-1
lines changed

4 files changed

+107
-1
lines changed

ucm/store/dramstore/dramstore_connector.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,28 @@ def dump(
165165
logger.debug(f"dump block {block_ids} finished.")
166166
return task
167167

168+
def fetch_data(
169+
self,
170+
block_ids: List[str],
171+
offset: List[int],
172+
dst_addr: List[int],
173+
size: List[int],
174+
) -> Task:
175+
raise NotImplementedError(
176+
"Method(fetch_data) not yet implemented in this version"
177+
)
178+
179+
def dump_data(
180+
self,
181+
block_ids: List[str],
182+
offset: List[int],
183+
src_addr: List[int],
184+
size: List[int],
185+
) -> Task:
186+
raise NotImplementedError(
187+
"Method(dump_data) not yet implemented in this version"
188+
)
189+
168190
def wait(self, task: DramTask) -> int:
169191
"""
170192
wait kv cache kv transfer task finished.

ucm/store/mooncakestore/mooncake_connector.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,28 @@ async def _dump_impl(
259259
raise TypeError("Mooncake Store Put Type Error.") from err
260260
return 0
261261

262+
def fetch_data(
263+
self,
264+
block_ids: List[str],
265+
offset: List[int],
266+
dst_addr: List[int],
267+
size: List[int],
268+
) -> Task:
269+
raise NotImplementedError(
270+
"Method(fetch_data) not yet implemented in this version"
271+
)
272+
273+
def dump_data(
274+
self,
275+
block_ids: List[str],
276+
offset: List[int],
277+
src_addr: List[int],
278+
size: List[int],
279+
) -> Task:
280+
raise NotImplementedError(
281+
"Method(dump_data) not yet implemented in this version"
282+
)
283+
262284
def wait(self, task: Task) -> int:
263285
"""
264286
wait kv cache kv transfer task finished.

ucm/store/nfsstore/nfsstore_connector.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,26 @@ def dump(
8888
)
8989
return NfsTask(task_id=task_id)
9090

91+
def fetch_data(
92+
self,
93+
block_ids: List[str],
94+
offset: List[int],
95+
dst_addr: List[int],
96+
size: List[int],
97+
) -> Task:
98+
task_id = self.store.LoadToDevice(block_ids, offset, dst_addr, size)
99+
return NfsTask(task_id=task_id)
100+
101+
def dump_data(
102+
self,
103+
block_ids: List[str],
104+
offset: List[int],
105+
src_addr: List[int],
106+
size: List[int],
107+
) -> Task:
108+
task_id = self.store.DumpFromDevice(block_ids, offset, src_addr, size)
109+
return NfsTask(task_id=task_id)
110+
91111
def wait(self, task: Task) -> int:
92112
return self.store.Wait(task.task_id)
93113

ucm/store/ucmstore.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def dump(
113113
self, block_ids: List[str], offset: List[int], src_tensor: List[torch.Tensor]
114114
) -> Task:
115115
"""
116-
dump kv cache to device.
116+
dump kv cache from device.
117117
118118
Args:
119119
block_ids (List[str]): vLLM block hash.
@@ -124,6 +124,48 @@ def dump(
124124
"""
125125
pass
126126

127+
@abstractmethod
128+
def fetch_data(
129+
self,
130+
block_ids: List[str],
131+
offset: List[int],
132+
dst_addr: List[int],
133+
size: List[int],
134+
) -> Task:
135+
"""
136+
load kv cache data to device.
137+
138+
Args:
139+
block_ids (List[str]): vLLM block hash.
140+
offset(List[int]): tp > 1 scene
141+
dst_addr: List[int]: device tensor addr ptr.
142+
size: List[int]: device tensor size.
143+
Returns:
144+
task(Task).
145+
"""
146+
pass
147+
148+
@abstractmethod
149+
def dump_data(
150+
self,
151+
block_ids: List[str],
152+
offset: List[int],
153+
src_addr: List[int],
154+
size: List[int],
155+
) -> Task:
156+
"""
157+
dump kv cache data from device.
158+
159+
Args:
160+
block_ids (List[str]): vLLM block hash.
161+
offset(List[int]): tp > 1 scene
162+
src_addr: List[int]: device tensor addr ptr.
163+
size: List[int]: device tensor size.
164+
Returns:
165+
task(Task).
166+
"""
167+
pass
168+
127169
@abstractmethod
128170
def wait(self, task: Task) -> int:
129171
"""

0 commit comments

Comments
 (0)