Skip to content

Commit 71c5182

Browse files
committed
improving handeling CCL lists
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent 5f047b4 commit 71c5182

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
927927

928928
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
929929

930-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs)
930+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(
931+
kwargs
932+
)
931933

932934
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
933935
return cls(
@@ -1534,7 +1536,9 @@ def from_pretrained(
15341536

15351537
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
15361538

1537-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs)
1539+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(
1540+
kwargs
1541+
)
15381542

15391543
from transformers import AutoConfig
15401544

@@ -2088,7 +2092,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
20882092

20892093
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
20902094

2091-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs)
2095+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(
2096+
kwargs
2097+
)
20922098

20932099
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
20942100
return cls(
@@ -2294,7 +2300,9 @@ def from_pretrained(
22942300

22952301
kv_offload = kwargs.pop("kv_offload", None)
22962302

2297-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs)
2303+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(
2304+
kwargs
2305+
)
22982306

22992307
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
23002308
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

QEfficient/utils/check_ccl_specializations.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,11 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import List, Optional
9-
108

119
# def process_ccl_specializations(
1210
# ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None
1311
# ):
14-
def process_ccl_specializations(
15-
kwargs
16-
):
12+
def process_ccl_specializations(kwargs):
1713
ccl_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
1814
ccl_decode = kwargs.pop("comp_ctx_lengths_decode", None)
1915
ctx_len = kwargs.pop("ctx_len", None)
@@ -24,9 +20,9 @@ def process_ccl_specializations(
2420

2521
if ccl_prefill is None or ccl_decode is None:
2622
return None, None, ctx_len, prefill_seq_len
27-
23+
2824
if prefill_seq_len == 1:
29-
#both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
25+
# both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
3026
ccl_union_all = sorted(set(ccl_prefill + ccl_decode))
3127
ccl_union_all = [min(x, ctx_len) for x in ccl_union_all]
3228
return ccl_union_all, ccl_union_all, ctx_len, prefill_seq_len

0 commit comments

Comments
 (0)