Skip to content

Commit fa3c2f6

Browse files
vjanfazaquic-rishinr
authored andcommitted
improving handeling CCL lists
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent 42b4b7f commit fa3c2f6

File tree

2 files changed

+26
-63
lines changed

2 files changed

+26
-63
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 7 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -879,13 +879,7 @@ def __init__(
879879
self.model = model
880880
self.config = model.config
881881

882-
self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
883-
self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
884-
ctx_len = kwargs.pop("ctx_len", None)
885-
if self.comp_ctx_lengths_prefill:
886-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
887-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len
888-
)
882+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs)
889883

890884
self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
891885
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
@@ -933,14 +927,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
933927

934928
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
935929

936-
comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
937-
comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
938-
ctx_len = kwargs.pop("ctx_len", None)
939-
940-
if comp_ctx_lengths_prefill:
941-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations(
942-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len
943-
)
930+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs)
944931

945932
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
946933
return cls(
@@ -1498,14 +1485,7 @@ def __init__(
14981485
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
14991486
super().__init__(model, **kwargs)
15001487

1501-
self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
1502-
self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
1503-
ctx_len = kwargs.pop("ctx_len", None)
1504-
1505-
if self.comp_ctx_lengths_prefill:
1506-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1507-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len
1508-
)
1488+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs)
15091489

15101490
# to handle internvl models
15111491
if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"):
@@ -1554,14 +1534,7 @@ def from_pretrained(
15541534

15551535
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
15561536

1557-
comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
1558-
comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
1559-
ctx_len = kwargs.pop("ctx_len", None)
1560-
1561-
if comp_ctx_lengths_prefill:
1562-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations(
1563-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len
1564-
)
1537+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs)
15651538

15661539
from transformers import AutoConfig
15671540

@@ -2115,14 +2088,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
21152088

21162089
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
21172090

2118-
comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
2119-
comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
2120-
ctx_len = kwargs.pop("ctx_len", None)
2121-
2122-
if comp_ctx_lengths_prefill:
2123-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations(
2124-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len
2125-
)
2091+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs)
21262092

21272093
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
21282094
return cls(
@@ -2232,15 +2198,7 @@ def __init__(
22322198
self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs)
22332199
self.is_tlm = transformed
22342200

2235-
self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
2236-
self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
2237-
ctx_len = kwargs.pop("ctx_len", None)
2238-
prefill_seq_len = kwargs.pop("prefill_seq_len", 128)
2239-
2240-
if self.comp_ctx_lengths_prefill and prefill_seq_len > 1:
2241-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
2242-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len
2243-
)
2201+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs)
22442202

22452203
self.hash_params["qeff_auto_class"] = self.__class__.__name__
22462204

@@ -2336,15 +2294,7 @@ def from_pretrained(
23362294

23372295
kv_offload = kwargs.pop("kv_offload", None)
23382296

2339-
comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
2340-
comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
2341-
ctx_len = kwargs.pop("ctx_len", None)
2342-
prefill_seq_len = kwargs.pop("prefill_seq_len", 128)
2343-
2344-
if comp_ctx_lengths_prefill and prefill_seq_len > 1:
2345-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations(
2346-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len
2347-
)
2297+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs)
23482298

23492299
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
23502300
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

QEfficient/utils/check_ccl_specializations.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,28 @@
88
from typing import List, Optional
99

1010

11+
# def process_ccl_specializations(
12+
# ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None
13+
# ):
1114
def process_ccl_specializations(
12-
ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None
15+
kwargs
1316
):
17+
ccl_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
18+
ccl_decode = kwargs.pop("comp_ctx_lengths_decode", None)
19+
ctx_len = kwargs.pop("ctx_len", None)
20+
prefill_seq_len = kwargs.pop("prefill_seq_len", 128)
21+
1422
if ctx_len is None:
1523
raise TypeError("`ctx_len` is required when loading the model.")
16-
if ccl_prefill is None:
17-
ccl_prefill = [ctx_len]
18-
if ccl_decode is None:
19-
ccl_decode = [ctx_len]
24+
25+
if ccl_prefill is None or ccl_decode is None:
26+
return None, None, ctx_len, prefill_seq_len
27+
28+
if prefill_seq_len == 1:
29+
#both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
30+
ccl_union_all = sorted(set(ccl_prefill + ccl_decode))
31+
ccl_union_all = [min(x, ctx_len) for x in ccl_union_all]
32+
return ccl_union_all, ccl_union_all, ctx_len, prefill_seq_len
2033

2134
# Step 1: Cap values to ctx_len
2235
ccl_prefill = [min(x, ctx_len) for x in ccl_prefill]
@@ -40,4 +53,4 @@ def process_ccl_specializations(
4053
updated_prefill.sort()
4154
ccl_decode.sort()
4255

43-
return updated_prefill, ccl_decode
56+
return updated_prefill, ccl_decode, ctx_len, prefill_seq_len

0 commit comments

Comments
 (0)