@@ -677,12 +677,13 @@ class TpKVTopology:
677677 mapping between local and remote TP workers.
678678 """
679679
680- tp_size : int
681680 tp_rank : int
682681 remote_tp_size : dict [EngineId , int ]
683682 is_mla : bool
684683 total_num_kv_heads : int
685684 attn_backend : type [AttentionBackend ]
685+ engine_id : EngineId
686+ remote_block_size : dict [EngineId , int ]
686687
687688 def __post_init__ (self ):
688689 # Figure out whether the first dimension of the cache is K/V
@@ -710,8 +711,13 @@ def split_k_and_v(self) -> bool:
710711 self .is_mla or self ._use_pallas or self .is_kv_layout_blocks_first
711712 )
712713
713- block_size : int
714- remote_block_size : dict [EngineId , int ]
714+ @property
715+ def tp_size (self ) -> int :
716+ return self .remote_tp_size [self .engine_id ]
717+
718+ @property
719+ def block_size (self ) -> int :
720+ return self .remote_block_size [self .engine_id ]
715721
716722 def tp_ratio (
717723 self ,
@@ -957,13 +963,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
957963 self .xfer_stats = NixlKVConnectorStats ()
958964
959965 self .kv_topo = self .TpKVTopology (
960- tp_size = self .world_size ,
961966 tp_rank = self .tp_rank ,
967+ engine_id = self .engine_id ,
962968 remote_tp_size = self ._tp_size , # shared state
969+ remote_block_size = self ._block_size , # shared state
963970 is_mla = self .use_mla ,
964971 total_num_kv_heads = self .model_config .get_total_num_kv_heads (),
965- block_size = self .block_size ,
966- remote_block_size = self ._block_size ,
967972 attn_backend = backend ,
968973 )
969974 self ._use_pallas = self .kv_topo ._use_pallas
@@ -1185,6 +1190,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
11851190 self .block_size // kernel_block_size
11861191 )
11871192 self .block_size = kernel_block_size
1193+ self ._block_size [self .engine_id ] = kernel_block_size
11881194
11891195 seen_base_addresses .append (base_addr )
11901196 curr_tensor_size_bytes = cache .numel () * cache .element_size ()
0 commit comments