@@ -559,34 +559,41 @@ def forward(
559559 self .kv_cache_dtype ,
560560 k_scale , v_scale )
561561
562- if attn_type != AttentionType .ENCODER :
563- # Decoder self-attention supports chunked prefill.
564- # Encoder/decoder cross-attention requires no chunked
565- # prefill (100% prefill or 100% decode tokens, no mix)
566- num_prefill_tokens = attn_metadata .num_prefill_tokens
567- num_decode_tokens = attn_metadata .num_decode_tokens
568- else :
562+ if attn_type == AttentionType .ENCODER :
569563 # Encoder attention - chunked prefill is not applicable;
570564 # derive token-count from query shape & and treat them
571565 # as 100% prefill tokens
572566 assert attn_metadata .num_encoder_tokens is not None
573567 num_prefill_tokens = attn_metadata .num_encoder_tokens
568+ num_encoder_tokens = attn_metadata .num_encoder_tokens
574569 num_decode_tokens = 0
575-
576- if attn_type == AttentionType .DECODER :
570+ elif attn_type == AttentionType .DECODER :
571+ # Decoder self-attention supports chunked prefill.
572+ num_prefill_tokens = attn_metadata .num_prefill_tokens
573+ num_encoder_tokens = attn_metadata .num_prefill_tokens
574+ num_decode_tokens = attn_metadata .num_decode_tokens
577575 # Only enforce this shape-constraint for decoder
578576 # self-attention
579577 assert key .shape [0 ] == num_prefill_tokens + num_decode_tokens
580578 assert value .shape [0 ] == num_prefill_tokens + num_decode_tokens
579+ else : # attn_type == AttentionType.ENCODER_DECODER
580+ # Encoder/decoder cross-attention requires no chunked
581+ # prefill (100% prefill or 100% decode tokens, no mix)
582+ num_prefill_tokens = attn_metadata .num_prefill_tokens
583+ if attn_metadata .num_encoder_tokens is not None :
584+ num_encoder_tokens = attn_metadata .num_encoder_tokens
585+ else :
586+ num_encoder_tokens = attn_metadata .num_prefill_tokens
587+ num_decode_tokens = attn_metadata .num_decode_tokens
581588
582589 output = torch .empty_like (query )
583590 # Query for decode. KV is not needed because it is already cached.
584591 decode_query = query [num_prefill_tokens :]
585592 # QKV for prefill.
586593 query = query [:num_prefill_tokens ]
587594 if key is not None and value is not None :
588- key = key [:num_prefill_tokens ]
589- value = value [:num_prefill_tokens ]
595+ key = key [:num_encoder_tokens ]
596+ value = value [:num_encoder_tokens ]
590597
591598 assert query .shape [0 ] == num_prefill_tokens
592599 assert decode_query .shape [0 ] == num_decode_tokens
0 commit comments