@@ -715,7 +715,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
715715 # are non-contiguous (it's not locally guaranteed that they will be)
716716 # Disadvantage is that the encoded NixlAgentMetadata is now larger
717717 # (roughly 8KB vs 5KB).
718- # Conversely for FlashInfer, K and V are transferred in the same tensor
718+ # Conversely for FlashInfer, K and V are registered in the same region
719719 # to better exploit the memory layout (ie num_blocks is the first dim).
720720 split_k_and_v = not (self .use_mla or self ._use_pallas_v1
721721 or self ._use_flashinfer )
@@ -758,12 +758,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
758758 assert tensor_size_bytes % self .num_blocks == 0
759759 self .block_len = tensor_size_bytes // self .num_blocks
760760 self .slot_size_bytes = self .block_len // self .block_size
761+ self .device_kv_caches = kv_caches
762+ self .dst_num_blocks [self .engine_id ] = self .num_blocks
761763 if self ._use_flashinfer :
762764 assert self .slot_size_bytes % 2 == 0
763765 self .slot_size_bytes /= 2
764- self .device_kv_caches = kv_caches
765- self .dst_num_blocks [self .engine_id ] = self .num_blocks
766766
767+ # NOTE (NickLucche) When FlashInfer is used, memory is registered
768+ # with joint KV for each block. This minimizes the overhead in
769+ # registerMem allowing faster descs queries. In order to be able to
770+ # split on kv_heads dim as required by heterogeneous TP, one must
771+ # be able to index K/V separately. Hence the we double the number
772+ # of 'virtual' regions here and halve `block_len` below.
773+ self .num_regions *= 2
774+
775+ kv_block_len = self .get_backend_aware_kv_block_len ()
767776 # Register local/src descr for NIXL xfer.
768777 blocks_data = []
769778 for base_addr in seen_base_addresses :
@@ -776,8 +785,18 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
776785 block_offset = block_id * self .block_len
777786 addr = base_addr + block_offset
778787 # (addr, len, device id)
779- # TODO: does device_id matter to DRAM?
780- blocks_data .append ((addr , self .block_len , self .tp_rank ))
788+ blocks_data .append ((addr , kv_block_len , self .tp_rank ))
789+
790+ if self ._use_flashinfer :
791+ # Separate and interleave K/V regions to maintain the same
792+ # descs ordering. This is needed for selecting contiguous heads
793+ # when split across TP ranks.
794+ for block_id in range (self .num_blocks ):
795+ block_offset = block_id * self .block_len
796+ addr = base_addr + block_offset
797+ # Register addresses for V cache (K registered first).
798+ v_addr = addr + kv_block_len
799+ blocks_data .append ((v_addr , kv_block_len , self .tp_rank ))
781800 logger .debug ("Created %s blocks for src engine %s and rank %s" ,
782801 len (blocks_data ), self .engine_id , self .tp_rank )
783802
@@ -903,7 +922,7 @@ def add_remote_agent(self,
903922 remote_block_size = nixl_agent_meta .block_len // (
904923 self .slot_size_bytes * tp_ratio )
905924 if self ._use_flashinfer :
906- # Account for joint KV in FlashInfer .
925+ # With flashinfer, KV are sent in the same message .
907926 remote_block_size //= 2
908927 if tp_ratio > 1 :
909928 # Heterogeneous TP expects same kv_cache_layout.
@@ -929,10 +948,10 @@ def add_remote_agent(self,
929948 # rank. With heterogeneous TP, prepare the descriptors by splitting the
930949 # P KV cache along kv_head dim, of D worker's kv_head size (D>P).
931950 # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
932- # Only register the remote's descriptors if current rank pulls from it.
933951 self .kv_caches_base_addr [
934952 engine_id ] = nixl_agent_meta .kv_caches_base_addr
935- rank_offset = self .tp_rank % tp_ratio * self .block_len \
953+ kv_block_len = self .get_backend_aware_kv_block_len ()
954+ rank_offset = self .tp_rank % tp_ratio * kv_block_len \
936955 if not (self .use_mla or is_kv_replicated ) else 0
937956 # Register all remote blocks, but only the corresponding kv heads.
938957 for base_addr in nixl_agent_meta .kv_caches_base_addr :
@@ -943,7 +962,16 @@ def add_remote_agent(self,
943962 # self.block_len == remote_block_len//tp_ratio bytes.
944963 addr = base_addr + block_offset + rank_offset
945964 # (addr, len, device id)
946- blocks_data .append ((addr , self .block_len , remote_tp_rank ))
965+ blocks_data .append ((addr , kv_block_len , remote_tp_rank ))
966+
967+ if self ._use_flashinfer :
968+ # With FlashInfer index V separately to allow head splitting.
969+ for block_id in range (nixl_agent_meta .num_blocks ):
970+ block_offset = block_id * nixl_agent_meta .block_len
971+ addr = base_addr + block_offset + rank_offset
972+ v_addr = addr + nixl_agent_meta .block_len // 2
973+ blocks_data .append ((v_addr , kv_block_len , remote_tp_rank ))
974+
947975 logger .debug (
948976 "Created %s blocks for dst engine %s with remote rank %s and "
949977 "local rank %s" , len (blocks_data ), engine_id , remote_tp_rank ,
@@ -1249,6 +1277,22 @@ def _get_block_descs_ids(self,
12491277 descs_ids .append (reg_id * num_blocks + block_id )
12501278 return descs_ids
12511279
1280+ def get_backend_aware_kv_block_len (self ):
1281+ """
1282+ Get the block length for one K/V element (K and V have the same size).
1283+
1284+ For FA and other backends, this is equal to the length of the whole
1285+ block, as K and V are in separate regions.
1286+ For FlashInfer, this is half the length of the whole block, as K and V
1287+ share the same region.
1288+ """
1289+ if self ._use_flashinfer :
1290+ # For indexing only half (either just the K or V part).
1291+ block_len = self .block_len // 2
1292+ else :
1293+ block_len = self .block_len
1294+ return block_len
1295+
12521296
12531297@contextlib .contextmanager
12541298def zmq_ctx (socket_type : Any , addr : str ) -> Iterator [zmq .Socket ]:
0 commit comments