@@ -109,6 +109,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
109109 self ._need_load_reqs : dict [str , Union [list [int ], list [Task ]]] = {}
110110 self ._load_failed_reqs : set [str ] = set ()
111111 self ._load_req_to_blocks : dict [str , set [int ]] = {}
112+ self .num_head = vllm_config .model_config .get_num_kv_heads (
113+ vllm_config .parallel_config
114+ )
115+ self .head_size = vllm_config .model_config .get_head_size ()
112116 if role == KVConnectorRole .WORKER :
113117 self ._initialize_dataoffset (vllm_config )
114118 if (
@@ -131,6 +135,20 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
131135 config ["role" ] = (
132136 "scheduler" if role == KVConnectorRole .SCHEDULER else "worker"
133137 )
138+ config_base = self .block_size * self .element_size * self .head_size
139+ config ["kv_block_size" ] = (
140+ config_base
141+ * self .num_layers
142+ * (1 if self .is_mla else self .num_head * self .total_tp_size * 2 )
143+ )
144+ config ["transferIoSize" ] = config_base * (
145+ 1 if self .is_mla else self .num_head
146+ )
147+ logger .info (
148+ "kv_block_size = %d, transferIoSize = %d," ,
149+ config ["kv_block_size" ],
150+ config ["transferIoSize" ],
151+ )
134152 logger .info ("init UCConnectorImpl, connector: %s" , name )
135153 self .connector = UcmConnectorFactory .create_connector (name , config )
136154 else :
0 commit comments