@@ -879,13 +879,7 @@ def __init__(
879879 self .model = model
880880 self .config = model .config
881881
882- self .comp_ctx_lengths_prefill = kwargs .pop ("comp_ctx_lengths_prefill" , None )
883- self .comp_ctx_lengths_decode = kwargs .pop ("comp_ctx_lengths_decode" , None )
884- ctx_len = kwargs .pop ("ctx_len" , None )
885- if self .comp_ctx_lengths_prefill :
886- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
887- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode , ctx_len
888- )
882+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode , _ , _ = process_ccl_specializations (kwargs )
889883
890884 self .vision_model = QEffVisionEncoderForTextImageToTextModel (model , ** kwargs )
891885 self .lang_model = QEffCausalLMForTextImageToTextModel (model , ** kwargs )
@@ -933,14 +927,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
933927
934928 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
935929
936- comp_ctx_lengths_prefill = kwargs .pop ("comp_ctx_lengths_prefill" , None )
937- comp_ctx_lengths_decode = kwargs .pop ("comp_ctx_lengths_decode" , None )
938- ctx_len = kwargs .pop ("ctx_len" , None )
939-
940- if comp_ctx_lengths_prefill :
941- comp_ctx_lengths_prefill , comp_ctx_lengths_decode = process_ccl_specializations (
942- comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len
943- )
930+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len = process_ccl_specializations (kwargs )
944931
945932 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , ** kwargs )
946933 return cls (
@@ -1498,14 +1485,7 @@ def __init__(
14981485 raise NotImplementedError ("Continuous batching is not supported for image-text-to-text models yet." )
14991486 super ().__init__ (model , ** kwargs )
15001487
1501- self .comp_ctx_lengths_prefill = kwargs .pop ("comp_ctx_lengths_prefill" , None )
1502- self .comp_ctx_lengths_decode = kwargs .pop ("comp_ctx_lengths_decode" , None )
1503- ctx_len = kwargs .pop ("ctx_len" , None )
1504-
1505- if self .comp_ctx_lengths_prefill :
1506- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
1507- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode , ctx_len
1508- )
1488+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode , _ , _ = process_ccl_specializations (kwargs )
15091489
15101490 # to handle internvl models
15111491 if hasattr (self .model .config , "llm_config" ) and hasattr (self .model .config , "vision_config" ):
@@ -1554,14 +1534,7 @@ def from_pretrained(
15541534
15551535 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
15561536
1557- comp_ctx_lengths_prefill = kwargs .pop ("comp_ctx_lengths_prefill" , None )
1558- comp_ctx_lengths_decode = kwargs .pop ("comp_ctx_lengths_decode" , None )
1559- ctx_len = kwargs .pop ("ctx_len" , None )
1560-
1561- if comp_ctx_lengths_prefill :
1562- comp_ctx_lengths_prefill , comp_ctx_lengths_decode = process_ccl_specializations (
1563- comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len
1564- )
1537+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len = process_ccl_specializations (kwargs )
15651538
15661539 from transformers import AutoConfig
15671540
@@ -2115,14 +2088,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
21152088
21162089 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
21172090
2118- comp_ctx_lengths_prefill = kwargs .pop ("comp_ctx_lengths_prefill" , None )
2119- comp_ctx_lengths_decode = kwargs .pop ("comp_ctx_lengths_decode" , None )
2120- ctx_len = kwargs .pop ("ctx_len" , None )
2121-
2122- if comp_ctx_lengths_prefill :
2123- comp_ctx_lengths_prefill , comp_ctx_lengths_decode = process_ccl_specializations (
2124- comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len
2125- )
2091+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len = process_ccl_specializations (kwargs )
21262092
21272093 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , ** kwargs )
21282094 return cls (
@@ -2232,15 +2198,7 @@ def __init__(
22322198 self .model , transformed = SpDTransform .apply (self .model , qaic_config , ** kwargs )
22332199 self .is_tlm = transformed
22342200
2235- self .comp_ctx_lengths_prefill = kwargs .pop ("comp_ctx_lengths_prefill" , None )
2236- self .comp_ctx_lengths_decode = kwargs .pop ("comp_ctx_lengths_decode" , None )
2237- ctx_len = kwargs .pop ("ctx_len" , None )
2238- prefill_seq_len = kwargs .pop ("prefill_seq_len" , 128 )
2239-
2240- if self .comp_ctx_lengths_prefill and prefill_seq_len > 1 :
2241- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
2242- self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode , ctx_len
2243- )
2201+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode , _ , _ = process_ccl_specializations (kwargs )
22442202
22452203 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
22462204
@@ -2336,15 +2294,7 @@ def from_pretrained(
23362294
23372295 kv_offload = kwargs .pop ("kv_offload" , None )
23382296
2339- comp_ctx_lengths_prefill = kwargs .pop ("comp_ctx_lengths_prefill" , None )
2340- comp_ctx_lengths_decode = kwargs .pop ("comp_ctx_lengths_decode" , None )
2341- ctx_len = kwargs .pop ("ctx_len" , None )
2342- prefill_seq_len = kwargs .pop ("prefill_seq_len" , 128 )
2343-
2344- if comp_ctx_lengths_prefill and prefill_seq_len > 1 :
2345- comp_ctx_lengths_prefill , comp_ctx_lengths_decode = process_ccl_specializations (
2346- comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len
2347- )
2297+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len = process_ccl_specializations (kwargs )
23482298
23492299 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
23502300 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , * args , ** kwargs )
0 commit comments