Skip to content

Commit b4bf5f9

Browse files
committed
Adding Compute-Context-Length(CCL)
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent 13271c6 commit b4bf5f9

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2235,8 +2235,9 @@ def __init__(
22352235
self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
22362236
self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
22372237
ctx_len = kwargs.pop("ctx_len", None)
2238-
2239-
if self.comp_ctx_lengths_prefill:
2238+
prefill_seq_len = kwargs.pop("prefill_seq_len", 128)
2239+
2240+
if self.comp_ctx_lengths_prefill and prefill_seq_len > 1:
22402241
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
22412242
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len
22422243
)
@@ -2338,7 +2339,9 @@ def from_pretrained(
23382339
comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
23392340
comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
23402341
ctx_len = kwargs.pop("ctx_len", None)
2341-
if comp_ctx_lengths_prefill:
2342+
prefill_seq_len = kwargs.pop("prefill_seq_len", 128)
2343+
2344+
if comp_ctx_lengths_prefill and prefill_seq_len > 1:
23422345
comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations(
23432346
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len
23442347
)
@@ -2356,6 +2359,7 @@ def from_pretrained(
23562359
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
23572360
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
23582361
ctx_len=ctx_len,
2362+
prefill_seq_len=prefill_seq_len,
23592363
kv_offload=kv_offload,
23602364
pretrained_model_name_or_path=pretrained_model_name_or_path,
23612365
**kwargs,
@@ -2368,6 +2372,7 @@ def from_pretrained(
23682372
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
23692373
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
23702374
ctx_len=ctx_len,
2375+
prefill_seq_len=prefill_seq_len,
23712376
**kwargs,
23722377
)
23732378

@@ -2643,7 +2648,7 @@ def build_decode_specialization(
26432648
A dictionary defining the decode specialization, or None if it would be a duplicate
26442649
of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching).
26452650
"""
2646-
if prefill_seq_len == 1 and not self.continuous_batching and comp_ctx_lengths is None:
2651+
if prefill_seq_len == 1 and not self.continuous_batching:# and comp_ctx_lengths is None
26472652
return None # Avoid duplication with prefill
26482653
spec = {
26492654
"batch_size": full_batch_size if self.continuous_batching else batch_size,

examples/qwen3moe_example/ccl_qwen3moe_inference.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,23 @@
1616
# We will use prompt_len=1 for compilation for both cb and non-cb inference
1717
"""
1818

19-
ctx_len = 8192
20-
21-
comp_ctx_lengths_prefill = [4096]
22-
comp_ctx_lengths_decode = [6144, 8192]
19+
ctx_len = 65536
20+
prefill_seq_len = 1
21+
# In moe models when compiling with prefill_seq_len=1 and non-continuous-batching mode, prefill and decode will share the same specializations.
22+
comp_ctx_lengths_prefill = [4096,8192,16384,32768,ctx_len]
23+
comp_ctx_lengths_decode = [4096,8192,16384,32768,ctx_len]
2324

2425
model = QEFFAutoModelForCausalLM.from_pretrained(
2526
model_name,
2627
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
2728
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
2829
ctx_len=ctx_len,
2930
continuous_batching=False,
31+
prefill_seq_len=prefill_seq_len,
3032
)
33+
# prefill_seq_len=prefill_seq_len,
3134
model.compile(
32-
prefill_seq_len=1,
35+
prefill_seq_len=prefill_seq_len,
3336
ctx_len=ctx_len,
3437
batch_size=1,
3538
num_cores=16,
@@ -38,5 +41,6 @@
3841
mxint8_kv_cache=True,
3942
mos=1,
4043
)
44+
# mos=1,
4145
tokenizer = AutoTokenizer.from_pretrained(model_name)
4246
exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer)

0 commit comments

Comments
 (0)