Skip to content

Commit 2714d15

Browse files
authored
roll back dataoffset (#201)
1 parent 264f2cc commit 2714d15

File tree

2 files changed

+31
-41
lines changed

2 files changed

+31
-41
lines changed

test/test_uc_connector.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,12 @@ def setUp(self):
8383
self.total_blocks_num = 40
8484
self.total_tp_size = 2
8585
self.kv_caches = {}
86-
self.k_data_offsets = {}
8786
for i in range(self.num_layers):
8887
layer_name = f"model.layers.{i}.self_attn.attn"
8988
kv_tensor = torch.rand(
9089
(2, self.total_blocks_num, self.block_size, 4, 8), dtype=torch.bfloat16
9190
)
9291
self.kv_caches[layer_name] = kv_tensor
93-
for layer_id in range(self.num_layers):
94-
self.k_data_offsets[layer_id] = {}
95-
for i in range(self.total_tp_size):
96-
self.k_data_offsets[layer_id][i] = 0
9792

9893
def init_uc(
9994
self, mock_connector, metadata=Mock(), use_layerwise=True
@@ -116,8 +111,6 @@ def init_uc(
116111
ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
117112
ucconnector._load_failed_reqs: set[str] = set()
118113
ucconnector._load_req_to_blocks: dict[str, set[int]] = {}
119-
ucconnector.k_data_offsets = self.k_data_offsets
120-
ucconnector.min_block_size = 0
121114
return ucconnector
122115

123116
def test_get_num_new_matched_tokens_hit_all_on_storage(self):

ucm/integration/vllm/uc_connector.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
113113
vllm_config.parallel_config
114114
)
115115
self.head_size = vllm_config.model_config.get_head_size()
116-
if role == KVConnectorRole.WORKER:
117-
self._initialize_dataoffset(vllm_config)
118116
if (
119117
self._vllm_config.kv_transfer_config is not None
120118
and "ucm_connector_name"
@@ -176,35 +174,37 @@ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"
176174
forward_context.virtual_engine
177175
]
178176

179-
def _initialize_dataoffset(self, vllm_config: "VllmConfig"):
180-
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
181-
vllm_config.parallel_config
182-
)
183-
head_size = vllm_config.model_config.get_head_size()
184-
self.min_block_size = (
185-
self.block_size * num_kv_heads * head_size * self.element_size
177+
def DataOffset(self, kv_layer, rank, layer_id, is_v):
178+
# Non-MLA scene: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
179+
# MLA scene: one layer shape is (num_blocks, block_size, head_size)
180+
# Element size
181+
elem_size = kv_layer[0].element_size()
182+
logger.debug(
183+
f"total_tp_size = {self.total_tp_size},\n" f"element size = {elem_size}."
186184
)
185+
# One block size
186+
k_min_data_block_size = (
187+
kv_layer[0][0].numel() if not self.is_mla else kv_layer[0].numel()
188+
) * elem_size
189+
v_min_data_block_size = (
190+
kv_layer[1][0].numel() if not self.is_mla else 0
191+
) * elem_size
192+
# When tp > 1 layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size
187193
layer_size = (
188-
self.min_block_size * 2 * self.total_tp_size
189-
if not self.is_mla
190-
else self.min_block_size
191-
)
192-
# layer_id -> rank -> k_offset
193-
self.k_data_offsets: dict[int, dict[int, int]] = {}
194-
195-
pp_size = vllm_config.parallel_config.pipeline_parallel_size
196-
for layer_id in range(self.num_layers * pp_size):
197-
self.k_data_offsets[layer_id] = {}
198-
for rank in range(self.total_tp_size):
199-
if self.is_mla:
200-
self.k_data_offsets[layer_id][0] = layer_size * layer_id
201-
break
202-
else:
203-
offset = (
204-
layer_size * layer_id
205-
+ (layer_size // self.total_tp_size) * rank
206-
)
207-
self.k_data_offsets[layer_id][rank] = offset
194+
k_min_data_block_size + v_min_data_block_size
195+
) * self.total_tp_size
196+
if is_v:
197+
# Offset of v = Offset of k + k_min_data_block_size
198+
return int(
199+
self.DataOffset(kv_layer, rank, layer_id, False) + k_min_data_block_size
200+
)
201+
if self.is_mla:
202+
return int(layer_size * layer_id)
203+
else:
204+
# Offset of k = layer_size * layer_id + layer_size / tp_size * current rank
205+
return int(
206+
layer_size * layer_id + layer_size / self.total_tp_size * self.rank
207+
)
208208

209209
def get_tensor_and_offset_layerwise(
210210
self, vllm_block_ids: List[int], kv_layer: torch.Tensor, layer_name: str
@@ -216,17 +216,14 @@ def get_tensor_and_offset_layerwise(
216216
layer_id = self._extract_layer_index(layer_name)
217217

218218
for blk_id in vllm_block_ids:
219+
k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False)
219220
if self.is_mla:
220-
k_data_offset = self.k_data_offsets[layer_id][0]
221221
k_tensors.append(kv_layer[blk_id])
222222
else:
223-
k_data_offset = self.k_data_offsets[layer_id][self.rank]
224223
k_tensors.append(kv_layer[0][blk_id])
225224
k_offsets.append(k_data_offset)
226225
if not self.is_mla:
227-
v_data_offset = (
228-
self.k_data_offsets[layer_id][self.rank] + self.min_block_size
229-
)
226+
v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True)
230227
v_tensors.append(kv_layer[1][blk_id])
231228
v_offsets.append(v_data_offset)
232229
return k_tensors + v_tensors, k_offsets + v_offsets

0 commit comments

Comments
 (0)