Skip to content

Commit dee863c

Browse files
NickLuccheDarkLight1337
authored andcommitted
[Bugfix][NIXL] Fix block_size_ratio when logical !=physical blocks (#28925)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Signed-off-by: jiang1.li <jiang1.li@intel.com>
1 parent 436a594 commit dee863c

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -677,12 +677,13 @@ class TpKVTopology:
677677
mapping between local and remote TP workers.
678678
"""
679679

680-
tp_size: int
681680
tp_rank: int
682681
remote_tp_size: dict[EngineId, int]
683682
is_mla: bool
684683
total_num_kv_heads: int
685684
attn_backend: type[AttentionBackend]
685+
engine_id: EngineId
686+
remote_block_size: dict[EngineId, int]
686687

687688
def __post_init__(self):
688689
# Figure out whether the first dimension of the cache is K/V
@@ -710,8 +711,13 @@ def split_k_and_v(self) -> bool:
710711
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
711712
)
712713

713-
block_size: int
714-
remote_block_size: dict[EngineId, int]
714+
@property
715+
def tp_size(self) -> int:
716+
return self.remote_tp_size[self.engine_id]
717+
718+
@property
719+
def block_size(self) -> int:
720+
return self.remote_block_size[self.engine_id]
715721

716722
def tp_ratio(
717723
self,
@@ -957,13 +963,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
957963
self.xfer_stats = NixlKVConnectorStats()
958964

959965
self.kv_topo = self.TpKVTopology(
960-
tp_size=self.world_size,
961966
tp_rank=self.tp_rank,
967+
engine_id=self.engine_id,
962968
remote_tp_size=self._tp_size, # shared state
969+
remote_block_size=self._block_size, # shared state
963970
is_mla=self.use_mla,
964971
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
965-
block_size=self.block_size,
966-
remote_block_size=self._block_size,
967972
attn_backend=backend,
968973
)
969974
self._use_pallas = self.kv_topo._use_pallas
@@ -1185,6 +1190,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
11851190
self.block_size // kernel_block_size
11861191
)
11871192
self.block_size = kernel_block_size
1193+
self._block_size[self.engine_id] = kernel_block_size
11881194

11891195
seen_base_addresses.append(base_addr)
11901196
curr_tensor_size_bytes = cache.numel() * cache.element_size()

0 commit comments

Comments
 (0)