Skip to content

Commit 0407d34

Browse files
committed
Adding Compute-Context-Length(CCL)
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent 1b22bf3 commit 0407d34

File tree

18 files changed

+372
-273
lines changed

18 files changed

+372
-273
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 44 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ def cloud_ai_100_exec_kv(
318318
prompts_txt_file_path: Optional[str] = None,
319319
device_id: Optional[List[int]] = None,
320320
generation_len: Optional[int] = None,
321-
comp_ctx_lengths: Optional[List[int]] = None,
322-
prefill_ccl_len: Optional[int] = 1,
321+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
322+
comp_ctx_lengths_decode: Optional[List[int]] = None,
323323
enable_debug_logs: bool = False,
324324
stream: bool = True,
325325
write_io_dir: Optional[str] = None,
@@ -384,8 +384,8 @@ def cloud_ai_100_exec_kv(
384384
qpc_path=qpc_path,
385385
device_id=device_id,
386386
ctx_len=ctx_len,
387-
comp_ctx_lengths=comp_ctx_lengths,
388-
prefill_ccl_len=prefill_ccl_len,
387+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
388+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
389389
enable_debug_logs=enable_debug_logs,
390390
write_io_dir=write_io_dir,
391391
full_batch_size=full_batch_size,
@@ -428,8 +428,8 @@ def __init__(
428428
qpc_path: str,
429429
full_batch_size: Optional[int] = None,
430430
ctx_len: Optional[int] = None,
431-
comp_ctx_lengths: Optional[List[int]] = None,
432-
prefill_ccl_len: Optional[int] = 1,
431+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
432+
comp_ctx_lengths_decode: Optional[List[int]] = None,
433433
device_id: Optional[List[int]] = None,
434434
enable_debug_logs: bool = False,
435435
write_io_dir: Optional[str] = None,
@@ -439,8 +439,8 @@ def __init__(
439439
sampling_params: Optional[Dict[str, Any]] = None,
440440
) -> None:
441441
self._ctx_len = ctx_len
442-
self.comp_ctx_lengths = comp_ctx_lengths
443-
self.prefill_ccl_len = prefill_ccl_len
442+
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
443+
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
444444
self._write_io_dir = write_io_dir
445445
self.is_tlm = is_tlm
446446
self.return_pdfs = return_pdfs
@@ -799,22 +799,15 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
799799
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
800800
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
801801

802-
if self.comp_ctx_lengths is not None:
803-
self.list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths]
802+
if self.comp_ctx_lengths_prefill is not None:
803+
self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
804804
prefill_ccl_id = 0
805-
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[prefill_ccl_id]
805+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
806806

807807
for i in range(num_chunks):
808-
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths[prefill_ccl_id]:
809-
prefill_ccl_id += 1
810-
if prefill_ccl_id >= self.prefill_ccl_len:
811-
prefill_ccl_id = (
812-
(self.prefill_ccl_len - 1)
813-
if self.prefill_ccl_len != 0
814-
else min(prefill_ccl_id, len(self.comp_ctx_lengths) - 1)
815-
)
816-
817-
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[prefill_ccl_id]
808+
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
809+
prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
810+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
818811

819812
chunk_inputs = inputs.copy()
820813
chunk_inputs["input_ids"] = inputs["input_ids"][
@@ -835,12 +828,13 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
835828
)
836829

837830
def initialize_ccl(self, decode_inputs):
838-
max_ccl_id = len(self.comp_ctx_lengths) - 1
831+
self.list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
832+
max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
839833
max_position_id = np.max(decode_inputs["position_ids"])
840-
ccl_id_initial = self.prefill_ccl_len
834+
ccl_id_initial = 0
841835
ccl_id = ccl_id_initial
842-
for i in range(ccl_id_initial, len(self.comp_ctx_lengths)):
843-
if max_position_id < self.comp_ctx_lengths[i]:
836+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
837+
if max_position_id < self.comp_ctx_lengths_decode[i]:
844838
ccl_id = i
845839
break
846840

@@ -877,9 +871,9 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
877871
# Prepare decode inputs inputs.
878872
decode_inputs = self.prepare_decode_inputs()
879873

880-
if self.comp_ctx_lengths is not None:
874+
if self.comp_ctx_lengths_decode is not None:
881875
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
882-
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
876+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
883877

884878
while prompt_queue or current_decode_ongoing.any():
885879
outputs = self._session.run(decode_inputs)
@@ -918,19 +912,19 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
918912
batch_id_map[decode_batch_id]
919913
]
920914

921-
if self.comp_ctx_lengths is not None:
915+
if self.comp_ctx_lengths_decode is not None:
922916
###Recalculate ccl_id based on position ids###
923917
# Determine the maximum value of position_ids across all batch elements
924918
max_position_id = np.max(decode_inputs["position_ids"])
925919

926-
# Update ccl_id and comp_ctx_lengths based on the maximum position id
920+
# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
927921
ccl_id_initial = self.prefill_ccl_len
928922
ccl_id = ccl_id_initial
929-
for i in range(ccl_id_initial, len(self.comp_ctx_lengths)):
930-
if max_position_id < self.comp_ctx_lengths[i]:
923+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
924+
if max_position_id < self.comp_ctx_lengths_decode[i]:
931925
ccl_id = i
932926
break
933-
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
927+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
934928

935929
else:
936930
current_decode_ongoing[decode_batch_id] = False
@@ -944,11 +938,14 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
944938
if self.include_sampler:
945939
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
946940

947-
if self.comp_ctx_lengths is not None:
948-
# Update ccl_id and comp_ctx_lengths based on the maximum position id
949-
if decode_inputs["position_ids"][decode_batch_id, -1] >= self.comp_ctx_lengths[ccl_id] - 1:
941+
if self.comp_ctx_lengths_decode is not None:
942+
# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
943+
if (
944+
decode_inputs["position_ids"][decode_batch_id, -1]
945+
>= self.comp_ctx_lengths_decode[ccl_id] - 1
946+
):
950947
ccl_id = min(ccl_id + 1, max_ccl_id)
951-
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
948+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
952949

953950
generated_id_current_index[decode_batch_id] += 1
954951

@@ -975,16 +972,16 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
975972
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
976973
num_token = 0
977974

978-
if self.comp_ctx_lengths is not None:
975+
if self.comp_ctx_lengths_decode is not None:
979976
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
980-
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
977+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
981978

982979
cache_index = np.max(decode_inputs["position_ids"])
983980
for num_token in range(1, generation_len):
984-
if self.comp_ctx_lengths is not None:
985-
if cache_index >= self.comp_ctx_lengths[ccl_id] - 1:
981+
if self.comp_ctx_lengths_decode is not None:
982+
if cache_index >= self.comp_ctx_lengths_decode[ccl_id] - 1:
986983
ccl_id = min(ccl_id + 1, max_ccl_id)
987-
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
984+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
988985

989986
if streamer:
990987
streamer.put(decode_inputs["input_ids"][0])
@@ -1047,8 +1044,8 @@ def __init__(
10471044
qpc_path: str,
10481045
full_batch_size: Optional[int] = None,
10491046
ctx_len: Optional[int] = None,
1050-
comp_ctx_lengths: Optional[List[int]] = None,
1051-
prefill_ccl_len: Optional[int] = 1,
1047+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
1048+
comp_ctx_lengths_decode: Optional[List[int]] = None,
10521049
device_id: Optional[List[int]] = None,
10531050
enable_debug_logs: bool = False,
10541051
write_io_dir: Optional[str] = None,
@@ -1062,8 +1059,8 @@ def __init__(
10621059
qpc_path=qpc_path,
10631060
full_batch_size=full_batch_size,
10641061
ctx_len=ctx_len,
1065-
comp_ctx_lengths=comp_ctx_lengths,
1066-
prefill_ccl_len=prefill_ccl_len,
1062+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
1063+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
10671064
device_id=device_id,
10681065
enable_debug_logs=enable_debug_logs,
10691066
write_io_dir=write_io_dir,
@@ -1075,8 +1072,8 @@ def __init__(
10751072
self._full_batch_size = self._qaic_model.full_batch_size
10761073
self._tokenizer = self._qaic_model.tokenizer
10771074
self._ctx_len = ctx_len
1078-
self.comp_ctx_lengths = comp_ctx_lengths
1079-
self.prefill_ccl_len = prefill_ccl_len
1075+
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1076+
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
10801077
self._perf_metrics = None
10811078
self._prompt_queue = None
10821079
self._text_streamer = None

QEfficient/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -672,8 +672,8 @@ def get_specializations(
672672
prefill_seq_len: int,
673673
ctx_len: int,
674674
img_size: int,
675-
comp_ctx_lengths: List[int] = None,
676-
prefill_ccl_len: int = None,
675+
comp_ctx_lengths_prefill: List[int] = None,
676+
comp_ctx_lengths_decode: List[int] = None,
677677
kv_offload: bool = False,
678678
**compiler_options,
679679
):
@@ -694,31 +694,29 @@ def get_specializations(
694694
"ctx_len": ctx_len,
695695
}
696696
]
697-
if comp_ctx_lengths is not None:
697+
if comp_ctx_lengths_prefill is not None:
698698
lang = []
699699

700-
# prefill_ccl_len elements of comp_ctx_lengths will be used for prefilling
701-
for i in range(0, prefill_ccl_len):
700+
for i in range(0, len(comp_ctx_lengths_prefill)):
702701
lang.append(
703702
{
704703
"batch_size": batch_size,
705704
"seq_len": prefill_seq_len,
706705
"ctx_len": ctx_len,
707-
"comp_ctx_lengths": comp_ctx_lengths[i],
706+
"comp_ctx_lengths": comp_ctx_lengths_prefill[i],
708707
"sliding_window": self.language_model.config.sliding_window,
709708
"img_size": img_size,
710709
"mm_tokens_per_image": mm_tokens_per_image,
711710
}
712711
)
713712

714-
# Remaining elements use comp_ctx_lengths[1:] in a loop
715-
for i in range(prefill_ccl_len, len(comp_ctx_lengths)):
713+
for i in range(0, len(comp_ctx_lengths_decode)):
716714
lang.append(
717715
{
718716
"batch_size": batch_size,
719717
"seq_len": "1",
720718
"ctx_len": ctx_len,
721-
"comp_ctx_lengths": comp_ctx_lengths[i],
719+
"comp_ctx_lengths": comp_ctx_lengths_decode[i],
722720
"sliding_window": self.language_model.config.sliding_window,
723721
"img_size": img_size,
724722
"mm_tokens_per_image": mm_tokens_per_image,

QEfficient/transformers/models/internvl/modeling_internvl.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def get_specializations(
6969
prefill_seq_len: int,
7070
ctx_len: int,
7171
img_size: int,
72-
comp_ctx_lengths: List[int],
73-
prefill_ccl_len: int = None,
72+
comp_ctx_lengths_prefill: List[int] = None,
73+
comp_ctx_lengths_decode: List[int] = None,
7474
kv_offload: bool = False,
7575
**compiler_options,
7676
):
@@ -100,31 +100,29 @@ def get_specializations(
100100
"img_size": img_size,
101101
}
102102
]
103-
if comp_ctx_lengths is not None:
103+
if comp_ctx_lengths_prefill is not None:
104104
lang = []
105105

106-
# prefill_ccl_len elements of comp_ctx_lengths will be used for prefilling
107-
for i in range(0, prefill_ccl_len):
106+
for i in range(0, len(comp_ctx_lengths_prefill)):
108107
lang.append(
109108
{
110109
"batch_size": batch_size,
111110
"seq_len": prefill_seq_len,
112111
"ctx_len": ctx_len,
113-
"comp_ctx_lengths": comp_ctx_lengths[i],
112+
"comp_ctx_lengths": comp_ctx_lengths_prefill[i],
114113
"num_patches": num_patches,
115114
"img_size": img_size,
116115
"vision_size": vision_size,
117116
}
118117
)
119118

120-
# Remaining elements use comp_ctx_lengths[1:] in a loop
121-
for i in range(prefill_ccl_len, len(comp_ctx_lengths)):
119+
for i in range(0, len(comp_ctx_lengths_decode)):
122120
lang.append(
123121
{
124122
"batch_size": batch_size,
125123
"seq_len": "1",
126124
"ctx_len": ctx_len,
127-
"comp_ctx_lengths": comp_ctx_lengths[i],
125+
"comp_ctx_lengths": comp_ctx_lengths_decode[i],
128126
"num_patches": num_patches,
129127
"img_size": img_size,
130128
"vision_size": vision_size,

QEfficient/transformers/models/llama4/modeling_llama4.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -908,8 +908,8 @@ def get_specializations(
908908
prefill_seq_len: int,
909909
ctx_len: int,
910910
img_size: int,
911-
comp_ctx_lengths: List[int] = None,
912-
prefill_ccl_len: int = None,
911+
comp_ctx_lengths_prefill: List[int] = None,
912+
comp_ctx_lengths_decode: List[int] = None,
913913
kv_offload: bool = False,
914914
**compiler_options,
915915
):
@@ -959,17 +959,16 @@ def get_specializations(
959959
"img_size": img_size,
960960
}
961961
]
962-
if comp_ctx_lengths is not None:
962+
if comp_ctx_lengths_prefill is not None:
963963
lang = []
964964

965-
# prefill_ccl_len elements of comp_ctx_lengths will be used for prefilling
966-
for i in range(0, prefill_ccl_len):
965+
for i in range(0, len(comp_ctx_lengths_prefill)):
967966
lang.append(
968967
{
969968
"batch_size": batch_size,
970969
"seq_len": prefill_seq_len,
971970
"ctx_len": ctx_len,
972-
"comp_ctx_lengths": comp_ctx_lengths[i],
971+
"comp_ctx_lengths": comp_ctx_lengths_prefill[i],
973972
"max_num_tiles": max_num_tiles,
974973
"img_size": img_size,
975974
"vision_size": vision_size,
@@ -978,14 +977,13 @@ def get_specializations(
978977
}
979978
)
980979

981-
# Remaining elements use comp_ctx_lengths[1:] in a loop
982-
for i in range(prefill_ccl_len, len(comp_ctx_lengths)):
980+
for i in range(0, len(comp_ctx_lengths_decode)):
983981
lang.append(
984982
{
985983
"batch_size": batch_size,
986984
"seq_len": "1",
987985
"ctx_len": ctx_len,
988-
"comp_ctx_lengths": comp_ctx_lengths[i],
986+
"comp_ctx_lengths": comp_ctx_lengths_decode[i],
989987
"max_num_tiles": max_num_tiles,
990988
"img_size": img_size,
991989
"vision_size": vision_size,

QEfficient/transformers/models/llava/modeling_llava.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def get_specializations(
162162
prefill_seq_len: int,
163163
ctx_len: int,
164164
img_size: int,
165-
comp_ctx_lengths: List[int] = None,
166-
prefill_ccl_len: int = None,
165+
comp_ctx_lengths_prefill: List[int] = None,
166+
comp_ctx_lengths_decode: List[int] = None,
167167
kv_offload: bool = False,
168168
**compiler_options,
169169
):
@@ -186,31 +186,29 @@ def get_specializations(
186186
}
187187
]
188188

189-
if comp_ctx_lengths is not None:
189+
if comp_ctx_lengths_prefill is not None:
190190
lang = []
191191

192-
# prefill_ccl_len elements of comp_ctx_lengths will be used for prefilling
193-
for i in range(0, prefill_ccl_len):
192+
for i in range(0, len(comp_ctx_lengths_prefill)):
194193
lang.append(
195194
{
196195
"batch_size": batch_size,
197196
"seq_len": prefill_seq_len,
198197
"ctx_len": ctx_len,
199-
"comp_ctx_lengths": comp_ctx_lengths[i],
198+
"comp_ctx_lengths": comp_ctx_lengths_prefill[i],
200199
"max_num_images": max_num_images,
201200
"img_size": img_size,
202201
"vision_size": vision_size,
203202
}
204203
)
205204

206-
# Remaining elements use comp_ctx_lengths[1:] in a loop
207-
for i in range(prefill_ccl_len, len(comp_ctx_lengths)):
205+
for i in range(0, len(comp_ctx_lengths_decode)):
208206
lang.append(
209207
{
210208
"batch_size": batch_size,
211209
"seq_len": "1",
212210
"ctx_len": ctx_len,
213-
"comp_ctx_lengths": comp_ctx_lengths[i],
211+
"comp_ctx_lengths": comp_ctx_lengths_decode[i],
214212
"max_num_images": max_num_images,
215213
"img_size": img_size,
216214
"vision_size": vision_size,

0 commit comments

Comments
 (0)