@@ -2235,8 +2235,9 @@ def __init__(
22352235 self .comp_ctx_lengths_prefill = kwargs .pop ("comp_ctx_lengths_prefill" , None )
22362236 self .comp_ctx_lengths_decode = kwargs .pop ("comp_ctx_lengths_decode" , None )
22372237 ctx_len = kwargs .pop ("ctx_len" , None )
2238-
2239- if self .comp_ctx_lengths_prefill :
2238+ prefill_seq_len = kwargs .pop ("prefill_seq_len" , 128 )
2239+
2240+ if self .comp_ctx_lengths_prefill and prefill_seq_len > 1 :
22402241 self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
22412242 self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode , ctx_len
22422243 )
@@ -2338,7 +2339,9 @@ def from_pretrained(
23382339 comp_ctx_lengths_prefill = kwargs .pop ("comp_ctx_lengths_prefill" , None )
23392340 comp_ctx_lengths_decode = kwargs .pop ("comp_ctx_lengths_decode" , None )
23402341 ctx_len = kwargs .pop ("ctx_len" , None )
2341- if comp_ctx_lengths_prefill :
2342+ prefill_seq_len = kwargs .pop ("prefill_seq_len" , 128 )
2343+
2344+ if comp_ctx_lengths_prefill and prefill_seq_len > 1 :
23422345 comp_ctx_lengths_prefill , comp_ctx_lengths_decode = process_ccl_specializations (
23432346 comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len
23442347 )
@@ -2356,6 +2359,7 @@ def from_pretrained(
23562359 comp_ctx_lengths_prefill = comp_ctx_lengths_prefill ,
23572360 comp_ctx_lengths_decode = comp_ctx_lengths_decode ,
23582361 ctx_len = ctx_len ,
2362+ prefill_seq_len = prefill_seq_len ,
23592363 kv_offload = kv_offload ,
23602364 pretrained_model_name_or_path = pretrained_model_name_or_path ,
23612365 ** kwargs ,
@@ -2368,6 +2372,7 @@ def from_pretrained(
23682372 comp_ctx_lengths_prefill = comp_ctx_lengths_prefill ,
23692373 comp_ctx_lengths_decode = comp_ctx_lengths_decode ,
23702374 ctx_len = ctx_len ,
2375+ prefill_seq_len = prefill_seq_len ,
23712376 ** kwargs ,
23722377 )
23732378
@@ -2643,7 +2648,7 @@ def build_decode_specialization(
26432648 A dictionary defining the decode specialization, or None if it would be a duplicate
26442649 of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching).
26452650 """
2646- if prefill_seq_len == 1 and not self .continuous_batching and comp_ctx_lengths is None :
2651+ if prefill_seq_len == 1 and not self .continuous_batching : # and comp_ctx_lengths is None
26472652 return None # Avoid duplication with prefill
26482653 spec = {
26492654 "batch_size" : full_batch_size if self .continuous_batching else batch_size ,
0 commit comments