@@ -113,8 +113,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
113113 vllm_config .parallel_config
114114 )
115115 self .head_size = vllm_config .model_config .get_head_size ()
116- if role == KVConnectorRole .WORKER :
117- self ._initialize_dataoffset (vllm_config )
118116 if (
119117 self ._vllm_config .kv_transfer_config is not None
120118 and "ucm_connector_name"
@@ -176,35 +174,37 @@ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"
176174 forward_context .virtual_engine
177175 ]
178176
179- def _initialize_dataoffset (self , vllm_config : "VllmConfig" ):
180- num_kv_heads = vllm_config . model_config . get_num_kv_heads (
181- vllm_config . parallel_config
182- )
183- head_size = vllm_config . model_config . get_head_size ()
184- self . min_block_size = (
185- self .block_size * num_kv_heads * head_size * self . element_size
177+ def DataOffset (self , kv_layer , rank , layer_id , is_v ):
178+ # Non-MLA scene: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
179+ # MLA scene: one layer shape is (num_blocks, block_size, head_size)
180+ # Element size
181+ elem_size = kv_layer [ 0 ]. element_size ()
182+ logger . debug (
183+ f"total_tp_size = { self .total_tp_size } , \n " f"element size = { elem_size } ."
186184 )
185+ # One block size
186+ k_min_data_block_size = (
187+ kv_layer [0 ][0 ].numel () if not self .is_mla else kv_layer [0 ].numel ()
188+ ) * elem_size
189+ v_min_data_block_size = (
190+ kv_layer [1 ][0 ].numel () if not self .is_mla else 0
191+ ) * elem_size
192+ # When tp > 1 layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size
187193 layer_size = (
188- self .min_block_size * 2 * self .total_tp_size
189- if not self .is_mla
190- else self .min_block_size
191- )
192- # layer_id -> rank -> k_offset
193- self .k_data_offsets : dict [int , dict [int , int ]] = {}
194-
195- pp_size = vllm_config .parallel_config .pipeline_parallel_size
196- for layer_id in range (self .num_layers * pp_size ):
197- self .k_data_offsets [layer_id ] = {}
198- for rank in range (self .total_tp_size ):
199- if self .is_mla :
200- self .k_data_offsets [layer_id ][0 ] = layer_size * layer_id
201- break
202- else :
203- offset = (
204- layer_size * layer_id
205- + (layer_size // self .total_tp_size ) * rank
206- )
207- self .k_data_offsets [layer_id ][rank ] = offset
194+ k_min_data_block_size + v_min_data_block_size
195+ ) * self .total_tp_size
196+ if is_v :
197+ # Offset of v = Offset of k + k_min_data_block_size
198+ return int (
199+ self .DataOffset (kv_layer , rank , layer_id , False ) + k_min_data_block_size
200+ )
201+ if self .is_mla :
202+ return int (layer_size * layer_id )
203+ else :
204+ # Offset of k = layer_size * layer_id + layer_size / tp_size * current rank
205+ return int (
206+ layer_size * layer_id + layer_size / self .total_tp_size * self .rank
207+ )
208208
209209 def get_tensor_and_offset_layerwise (
210210 self , vllm_block_ids : List [int ], kv_layer : torch .Tensor , layer_name : str
@@ -216,17 +216,14 @@ def get_tensor_and_offset_layerwise(
216216 layer_id = self ._extract_layer_index (layer_name )
217217
218218 for blk_id in vllm_block_ids :
219+ k_data_offset = self .DataOffset (kv_layer , self .rank , layer_id , False )
219220 if self .is_mla :
220- k_data_offset = self .k_data_offsets [layer_id ][0 ]
221221 k_tensors .append (kv_layer [blk_id ])
222222 else :
223- k_data_offset = self .k_data_offsets [layer_id ][self .rank ]
224223 k_tensors .append (kv_layer [0 ][blk_id ])
225224 k_offsets .append (k_data_offset )
226225 if not self .is_mla :
227- v_data_offset = (
228- self .k_data_offsets [layer_id ][self .rank ] + self .min_block_size
229- )
226+ v_data_offset = self .DataOffset (kv_layer , self .rank , layer_id , True )
230227 v_tensors .append (kv_layer [1 ][blk_id ])
231228 v_offsets .append (v_data_offset )
232229 return k_tensors + v_tensors , k_offsets + v_offsets
0 commit comments