Skip to content

Commit 5047569

Browse files
[Fixed] Skip size calculation during async copy (#1152)
1 parent dbde641 commit 5047569

File tree

2 files changed

+154
-108
lines changed

2 files changed

+154
-108
lines changed

tpu_inference/kernels/ragged_paged_attention/v3/kernel.py

Lines changed: 77 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -440,42 +440,54 @@ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
440440
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
441441
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
442442

443-
# Fetch effective kv from kv cache.
444-
def loop_body(i, offset):
445-
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
446-
_async_copy(
447-
cache_hbm_ref.at[pl.ds(
448-
page_indices_ref[page_indices_offset + i] * page_size,
449-
sz)],
450-
vmem_ref.at[pl.ds(i * page_size, sz)],
451-
sem,
452-
wait,
443+
if not wait:
444+
# Fetch effective kv from kv cache.
445+
def loop_body(i, offset):
446+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
447+
_async_copy(
448+
cache_hbm_ref.at[pl.ds(
449+
page_indices_ref[page_indices_offset + i] * page_size,
450+
sz)],
451+
vmem_ref.at[pl.ds(i * page_size, sz)],
452+
sem,
453+
wait=False,
454+
)
455+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
456+
return offset + sz
457+
458+
offset = lax.fori_loop(
459+
0,
460+
bkv_p_frm_cache,
461+
loop_body,
462+
0, # offset
463+
unroll=False,
453464
)
454-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
455-
return offset + sz
456-
457-
offset = lax.fori_loop(
458-
0,
459-
bkv_p_frm_cache,
460-
loop_body,
461-
0, # offset
462-
unroll=False,
463-
)
464465

465-
# Fetch kv directly from new kv.
466-
@pl.when(bkv_sz_frm_new > 0)
467-
def _fetch_bkv_from_new_kv():
468-
new_kv_len_start = q_end - kv_left_frm_new
469-
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
470-
debug_print("[RPA debug] offset_in_bkv={}", offset)
466+
# Fetch kv directly from new kv.
467+
@pl.when(bkv_sz_frm_new > 0)
468+
def _fetch_bkv_from_new_kv():
469+
new_kv_len_start = q_end - kv_left_frm_new
470+
debug_print("[RPA debug] new_kv_len_start={}",
471+
new_kv_len_start)
472+
debug_print("[RPA debug] offset_in_bkv={}", offset)
473+
_async_copy(
474+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
475+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
476+
sem,
477+
wait,
478+
)
479+
480+
return kv_len_start + offset, bkv_sz_frm_new
481+
else:
482+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
483+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
471484
_async_copy(
472-
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
473-
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
474-
sem,
475-
wait,
485+
src=dst,
486+
dst=dst,
487+
sem=sem,
488+
wait=True,
476489
)
477-
478-
return kv_len_start + offset, bkv_sz_frm_new
490+
return kv_len_start + offset, bkv_sz_frm_new
479491

480492
def _update_kv_cache(seq_idx,
481493
bkv_sem_idx,
@@ -511,30 +523,41 @@ def _update_kv_cache(seq_idx,
511523
debug_print("[RPA debug] p_ignore={}", p_ignore)
512524
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
513525

514-
def loop_body(i, states):
515-
update_sz, ignore = states
516-
sz = jnp.minimum(page_size - ignore, update_sz)
517-
526+
if not wait:
527+
528+
def loop_body(i, states):
529+
update_sz, ignore = states
530+
sz = jnp.minimum(page_size - ignore, update_sz)
531+
532+
_async_copy(
533+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
534+
sz)],
535+
cache_hbm_ref.at[pl.ds(
536+
page_indices_ref[page_indices_offset + i] * page_size +
537+
ignore,
538+
sz,
539+
)],
540+
sem,
541+
wait=False,
542+
)
543+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
544+
return update_sz - sz, 0
545+
546+
lax.fori_loop(
547+
0,
548+
kv_p_end - kv_p_start,
549+
loop_body,
550+
(update_sz, ignore), # total transfer size
551+
unroll=False,
552+
)
553+
else:
554+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
518555
_async_copy(
519-
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
520-
cache_hbm_ref.at[pl.ds(
521-
page_indices_ref[page_indices_offset + i] * page_size +
522-
ignore,
523-
sz,
524-
)],
525-
sem,
526-
wait,
556+
src=dst,
557+
dst=dst,
558+
sem=sem,
559+
wait=True,
527560
)
528-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
529-
return update_sz - sz, 0
530-
531-
lax.fori_loop(
532-
0,
533-
kv_p_end - kv_p_start,
534-
loop_body,
535-
(update_sz, ignore), # total transfer size
536-
unroll=False,
537-
)
538561

539562
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
540563
sem = sems.at[1, bq_sem_idx]

tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py

Lines changed: 77 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -475,42 +475,54 @@ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
475475
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
476476
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
477477

478-
# Fetch effective kv from kv cache.
479-
def loop_body(i, offset):
480-
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
481-
_async_copy(
482-
cache_hbm_ref.at[pl.ds(
483-
page_indices_ref[page_indices_offset + i] * page_size,
484-
sz)],
485-
vmem_ref.at[pl.ds(i * page_size, sz)],
486-
sem,
487-
wait,
478+
if not wait:
479+
# Fetch effective kv from kv cache.
480+
def loop_body(i, offset):
481+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
482+
_async_copy(
483+
cache_hbm_ref.at[pl.ds(
484+
page_indices_ref[page_indices_offset + i] * page_size,
485+
sz)],
486+
vmem_ref.at[pl.ds(i * page_size, sz)],
487+
sem,
488+
wait=False,
489+
)
490+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
491+
return offset + sz
492+
493+
offset = lax.fori_loop(
494+
0,
495+
bkv_p_frm_cache,
496+
loop_body,
497+
0, # offset
498+
unroll=False,
488499
)
489-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
490-
return offset + sz
491-
492-
offset = lax.fori_loop(
493-
0,
494-
bkv_p_frm_cache,
495-
loop_body,
496-
0, # offset
497-
unroll=False,
498-
)
499500

500-
# Fetch kv directly from new kv.
501-
@pl.when(bkv_sz_frm_new > 0)
502-
def _fetch_bkv_from_new_kv():
503-
new_kv_len_start = q_end - kv_left_frm_new
504-
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
505-
debug_print("[RPA debug] offset_in_bkv={}", offset)
501+
# Fetch kv directly from new kv.
502+
@pl.when(bkv_sz_frm_new > 0)
503+
def _fetch_bkv_from_new_kv():
504+
new_kv_len_start = q_end - kv_left_frm_new
505+
debug_print("[RPA debug] new_kv_len_start={}",
506+
new_kv_len_start)
507+
debug_print("[RPA debug] offset_in_bkv={}", offset)
508+
_async_copy(
509+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
510+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
511+
sem,
512+
wait,
513+
)
514+
515+
return kv_len_start + offset, bkv_sz_frm_new
516+
else:
517+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
518+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
506519
_async_copy(
507-
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
508-
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
509-
sem,
510-
wait,
520+
src=dst,
521+
dst=dst,
522+
sem=sem,
523+
wait=True,
511524
)
512-
513-
return kv_len_start + offset, bkv_sz_frm_new
525+
return kv_len_start + offset, bkv_sz_frm_new
514526

515527
def _update_kv_cache(seq_idx,
516528
bkv_sem_idx,
@@ -546,30 +558,41 @@ def _update_kv_cache(seq_idx,
546558
debug_print("[RPA debug] p_ignore={}", p_ignore)
547559
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
548560

549-
def loop_body(i, states):
550-
update_sz, ignore = states
551-
sz = jnp.minimum(page_size - ignore, update_sz)
552-
561+
if not wait:
562+
563+
def loop_body(i, states):
564+
update_sz, ignore = states
565+
sz = jnp.minimum(page_size - ignore, update_sz)
566+
567+
_async_copy(
568+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
569+
sz)],
570+
cache_hbm_ref.at[pl.ds(
571+
page_indices_ref[page_indices_offset + i] * page_size +
572+
ignore,
573+
sz,
574+
)],
575+
sem,
576+
wait=False,
577+
)
578+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
579+
return update_sz - sz, 0
580+
581+
lax.fori_loop(
582+
0,
583+
kv_p_end - kv_p_start,
584+
loop_body,
585+
(update_sz, ignore), # total transfer size
586+
unroll=False,
587+
)
588+
else:
589+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
553590
_async_copy(
554-
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
555-
cache_hbm_ref.at[pl.ds(
556-
page_indices_ref[page_indices_offset + i] * page_size +
557-
ignore,
558-
sz,
559-
)],
560-
sem,
561-
wait,
591+
src=dst,
592+
dst=dst,
593+
sem=sem,
594+
wait=True,
562595
)
563-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
564-
return update_sz - sz, 0
565-
566-
lax.fori_loop(
567-
0,
568-
kv_p_end - kv_p_start,
569-
loop_body,
570-
(update_sz, ignore), # total transfer size
571-
unroll=False,
572-
)
573596

574597
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
575598
sem = sems.at[1, bq_sem_idx]

0 commit comments

Comments
 (0)