@@ -111,6 +111,7 @@ class PallasMetadata:
111111 context_lens : torch .Tensor
112112 query_start_loc : torch .Tensor
113113 num_seqs : torch .Tensor
114+ num_kv_update_slices : torch .Tensor
114115 num_slices_per_kv_cache_update_block : int
115116
116117
@@ -219,7 +220,8 @@ def forward(
219220 slot_mapping = attn_metadata .slot_mapping
220221 write_to_kv_cache (
221222 key , value , kv_cache , slot_mapping ,
222- attn_metadata .num_slices_per_kv_cache_update_block )
223+ attn_metadata .num_slices_per_kv_cache_update_block ,
224+ attn_metadata .num_kv_update_slices )
223225
224226 output = torch .ops .xla .ragged_paged_attention (
225227 query ,
@@ -252,6 +254,7 @@ def write_to_kv_cache(
252254 kv_cache : torch .Tensor ,
253255 slot_mapping : torch .Tensor ,
254256 num_slices_per_kv_cache_update_block : int ,
257+ num_kv_update_slices : torch .Tensor ,
255258) -> None :
256259 """ Write the key and values to the KV cache.
257260
@@ -271,40 +274,47 @@ def write_to_kv_cache(
271274
272275 kv_cache = kv_cache .flatten (0 , 1 )
273276 new_kv_cache = torch .ops .xla .kv_cache_update_op (
274- kv , slot_mapping , kv_cache , page_size ,
277+ kv , slot_mapping , kv_cache , num_kv_update_slices , page_size ,
275278 num_slices_per_kv_cache_update_block )
276279 # NOTE: the in-place copy will be optimized away by XLA compiler.
277280 kv_cache .copy_ (new_kv_cache )
278281
279282
280283@requires_jax
281284def kv_cache_update_op_impl (kv : torch .Tensor , slot_mapping : torch .Tensor ,
282- kv_cache : torch .Tensor , page_size : int ,
285+ kv_cache : torch .Tensor ,
286+ num_kv_update_slices : torch .Tensor , page_size : int ,
283287 num_slices_per_block : int ):
284288 from vllm .attention .ops .pallas_kv_cache_update import kv_cache_update
285- new_kv_cache = xb .call_jax (kv_cache_update , (kv , slot_mapping , kv_cache ), {
286- "page_size" : page_size ,
287- "num_slices_per_block" : num_slices_per_block
288- })
289+ new_kv_cache = xb .call_jax (
290+ kv_cache_update , (kv , slot_mapping , kv_cache , num_kv_update_slices ), {
291+ "page_size" : page_size ,
292+ "num_slices_per_block" : num_slices_per_block
293+ })
289294 return new_kv_cache
290295
291296
292297XLA_LIB .define (
293- "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
294- "int page_size, int num_slices_per_block) -> Tensor" , )
298+ "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
299+ "Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
300+ "-> Tensor" , )
295301
296302
297303@impl (XLA_LIB , "kv_cache_update_op" , "XLA" )
298304def kv_cache_update_op_xla (kv : torch .Tensor , slot_mapping : torch .Tensor ,
299- kv_cache : torch .Tensor , page_size : int ,
305+ kv_cache : torch .Tensor ,
306+ num_kv_update_slices : torch .Tensor , page_size : int ,
300307 num_slices_per_block : int ) -> torch .Tensor :
301308 new_kv_cache = kv_cache_update_op_impl (kv , slot_mapping , kv_cache ,
302- page_size , num_slices_per_block )
309+ num_kv_update_slices , page_size ,
310+ num_slices_per_block )
303311 return new_kv_cache
304312
305313
306314@impl (XLA_LIB , "kv_cache_update_op" , "CompositeExplicitAutograd" )
307315def kv_cache_update_op_non_xla (kv : torch .Tensor , slot_mapping : torch .Tensor ,
308- kv_cache : torch .Tensor , page_size : int ,
316+ kv_cache : torch .Tensor ,
317+ num_kv_update_slices : torch .Tensor ,
318+ page_size : int ,
309319 num_slices_per_block : int ) -> torch .Tensor :
310320 return kv_cache
0 commit comments