@@ -2236,7 +2236,7 @@ def __init__(
22362236 self .comp_ctx_lengths_decode = kwargs .pop ("comp_ctx_lengths_decode" , None )
22372237 ctx_len = kwargs .pop ("ctx_len" , None )
22382238 prefill_seq_len = kwargs .pop ("prefill_seq_len" , 128 )
2239-
2239+
22402240 if self .comp_ctx_lengths_prefill and prefill_seq_len > 1 :
22412241 self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
22422242 self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode , ctx_len
@@ -2340,7 +2340,7 @@ def from_pretrained(
23402340 comp_ctx_lengths_decode = kwargs .pop ("comp_ctx_lengths_decode" , None )
23412341 ctx_len = kwargs .pop ("ctx_len" , None )
23422342 prefill_seq_len = kwargs .pop ("prefill_seq_len" , 128 )
2343-
2343+
23442344 if comp_ctx_lengths_prefill and prefill_seq_len > 1 :
23452345 comp_ctx_lengths_prefill , comp_ctx_lengths_decode = process_ccl_specializations (
23462346 comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len
@@ -2649,9 +2649,9 @@ def build_decode_specialization(
26492649 of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching).
26502650 """
26512651 if prefill_seq_len == 1 :
2652- if not self .continuous_batching or batch_size == 1 :
2652+ if not self .continuous_batching or batch_size == 1 :
26532653 return None # Avoid duplication with prefill
2654-
2654+
26552655 spec = {
26562656 "batch_size" : full_batch_size if self .continuous_batching else batch_size ,
26572657 "seq_len" : (num_speculative_tokens + 1 ) if self .is_tlm else 1 ,
0 commit comments