Skip to content

Commit f00737f

Browse files
committed
Adding Compute-Context-Length(CCL)
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent a379e6e commit f00737f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2685
-257
lines changed

QEfficient/cloud/infer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,18 @@ def main(
340340
"--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation."
341341
)
342342
parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.")
343+
parser.add_argument(
344+
"--comp-ctx-lengths-prefill",
345+
type=lambda comp_ctx_lengths_prefill: [int(x) for x in comp_ctx_lengths_prefill.split(",")],
346+
default=[512],
347+
help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).",
348+
)
349+
parser.add_argument(
350+
"--comp-ctx-lengths-decode",
351+
type=lambda comp_ctx_lengths_decode: [int(x) for x in comp_ctx_lengths_decode.split(",")],
352+
default=[2048],
353+
help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).",
354+
)
343355
parser.add_argument(
344356
"--mxfp6",
345357
"--mxfp6_matmul",

QEfficient/customop/ctx_scatter_gather.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor
115115

116116

117117
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
118-
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
119-
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
118+
def CtxGather(
119+
data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
120+
) -> onnxscript.FLOAT:
121+
# Create a shape tensor based on comp_ctx_len
122+
shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0)
123+
124+
# Directly use the shape tensor without validation
125+
ctx_indices = ops.Expand(ctx_indices, shape_tensor)
120126
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
121127
return ops.GatherND(data, ctx_indices, batch_dims=2)
122128

@@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function):
127133
"""
128134

129135
@staticmethod
130-
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
136+
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
131137
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
132138
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
133139
return data[batch_indices, head_indices, ctx_indices]
@@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs):
137143
pass
138144

139145
@staticmethod
140-
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
141-
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
146+
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
147+
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)

QEfficient/customop/ctx_scatter_gather_cb.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,20 @@ def symbolic(
9797

9898
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
9999
def CtxGatherCB(
100-
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
100+
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
101101
) -> onnxscript.FLOAT:
102102
batch_size = ops.Gather(ops.Shape(batch_index), [0])
103103
num_heads = ops.Gather(ops.Shape(data), [1])
104-
ctx_len = ops.Gather(ops.Shape(data), [2])
104+
# using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well.
105+
ctx_len = ops.Reshape(comp_ctx_len, [1])
105106

106107
# Expanded shape to create indices
107108
zero = ops.Constant(value_ints=[0])
108109
one = ops.Constant(value_ints=[1])
109-
exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
110+
# exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
111+
exp_shape = ops.Concat(
112+
ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0
113+
)
110114

111115
# Create indices
112116
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
@@ -119,7 +123,7 @@ def CtxGatherCB(
119123

120124
class CtxGatherFuncCB(torch.autograd.Function):
121125
@staticmethod
122-
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
126+
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
123127
batch_indices = batch_index.view(-1, 1, 1)
124128
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
125129
return data[batch_indices, head_indices, ctx_indices]
@@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs):
129133
pass
130134

131135
@staticmethod
132-
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
133-
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
136+
def symbolic(
137+
g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int
138+
) -> torch.Value:
139+
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)
134140

135141

136142
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))

QEfficient/generation/text_generation_inference.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
318318
prompts_txt_file_path: Optional[str] = None,
319319
device_id: Optional[List[int]] = None,
320320
generation_len: Optional[int] = None,
321+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
322+
comp_ctx_lengths_decode: Optional[List[int]] = None,
321323
enable_debug_logs: bool = False,
322324
stream: bool = True,
323325
write_io_dir: Optional[str] = None,
@@ -384,6 +386,8 @@ def cloud_ai_100_exec_kv(
384386
qpc_path=qpc_path,
385387
device_id=device_id,
386388
ctx_len=ctx_len,
389+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
390+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
387391
enable_debug_logs=enable_debug_logs,
388392
write_io_dir=write_io_dir,
389393
full_batch_size=full_batch_size,
@@ -430,6 +434,8 @@ def __init__(
430434
qpc_path: str,
431435
full_batch_size: Optional[int] = None,
432436
ctx_len: Optional[int] = None,
437+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
438+
comp_ctx_lengths_decode: Optional[List[int]] = None,
433439
device_id: Optional[List[int]] = None,
434440
enable_debug_logs: bool = False,
435441
write_io_dir: Optional[str] = None,
@@ -439,6 +445,8 @@ def __init__(
439445
sampling_params: Optional[Dict[str, Any]] = None,
440446
) -> None:
441447
self._ctx_len = ctx_len
448+
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
449+
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
442450
self._write_io_dir = write_io_dir
443451
self.is_tlm = is_tlm
444452
self.return_pdfs = return_pdfs
@@ -797,7 +805,17 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
797805
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
798806
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
799807

808+
if self.comp_ctx_lengths_prefill is not None:
809+
self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
810+
prefill_ccl_id = 0
811+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
812+
800813
for i in range(num_chunks):
814+
if self.comp_ctx_lengths_prefill is not None:
815+
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
816+
prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
817+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
818+
801819
chunk_inputs = inputs.copy()
802820
chunk_inputs["input_ids"] = inputs["input_ids"][
803821
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
@@ -816,6 +834,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
816834
generation_len,
817835
)
818836

837+
def initialize_ccl(self, decode_inputs):
838+
self.list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
839+
max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
840+
max_position_id = np.max(decode_inputs["position_ids"])
841+
ccl_id_initial = 0
842+
ccl_id = ccl_id_initial
843+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
844+
if max_position_id < self.comp_ctx_lengths_decode[i]:
845+
ccl_id = i
846+
break
847+
848+
return ccl_id, max_ccl_id
849+
819850
def run_continuous_batching_decode(self, prompt_queue, generation_len):
820851
"""
821852
Runs continuous batching decode for the given prompt queue and generation length.
@@ -847,6 +878,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
847878
# Prepare decode inputs inputs.
848879
decode_inputs = self.prepare_decode_inputs()
849880

881+
if self.comp_ctx_lengths_decode is not None:
882+
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
883+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
884+
850885
while prompt_queue or current_decode_ongoing.any():
851886
outputs = self._session.run(decode_inputs)
852887

@@ -884,6 +919,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
884919
batch_id_map[decode_batch_id]
885920
]
886921

922+
if self.comp_ctx_lengths_decode is not None:
923+
###Recalculate ccl_id based on position ids###
924+
# Determine the maximum value of position_ids across all batch elements
925+
max_position_id = np.max(decode_inputs["position_ids"])
926+
927+
# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
928+
ccl_id_initial = 0
929+
ccl_id = ccl_id_initial
930+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
931+
if max_position_id < self.comp_ctx_lengths_decode[i]:
932+
ccl_id = i
933+
break
934+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
935+
887936
else:
888937
current_decode_ongoing[decode_batch_id] = False
889938
else:
@@ -896,6 +945,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
896945
if self.include_sampler:
897946
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
898947

948+
if self.comp_ctx_lengths_decode is not None:
949+
# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
950+
if (
951+
decode_inputs["position_ids"][decode_batch_id, -1]
952+
>= self.comp_ctx_lengths_decode[ccl_id] - 1
953+
):
954+
ccl_id = min(ccl_id + 1, max_ccl_id)
955+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
956+
899957
generated_id_current_index[decode_batch_id] += 1
900958

901959
return decode_pause_time
@@ -922,7 +980,18 @@ def run_decode(
922980
self._session.set_buffers({"logits": logits_out_placeholder})
923981
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
924982
num_token = 0
983+
984+
if self.comp_ctx_lengths_decode is not None:
985+
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
986+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
987+
988+
cache_index = np.max(decode_inputs["position_ids"])
925989
for num_token in range(1, generation_len):
990+
if self.comp_ctx_lengths_decode is not None:
991+
if cache_index >= self.comp_ctx_lengths_decode[ccl_id] - 1:
992+
ccl_id = min(ccl_id + 1, max_ccl_id)
993+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
994+
926995
if streamer:
927996
streamer.put(decode_inputs["input_ids"][0])
928997
outputs = self._session.run(decode_inputs)
@@ -934,6 +1003,7 @@ def run_decode(
9341003
# Prepare inputs for next iteration
9351004
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
9361005
decode_inputs["position_ids"][:, -1] += 1
1006+
cache_index += 1
9371007
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
9381008
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
9391009
if self.include_sampler:
@@ -983,6 +1053,8 @@ def __init__(
9831053
qpc_path: str,
9841054
full_batch_size: Optional[int] = None,
9851055
ctx_len: Optional[int] = None,
1056+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
1057+
comp_ctx_lengths_decode: Optional[List[int]] = None,
9861058
device_id: Optional[List[int]] = None,
9871059
enable_debug_logs: bool = False,
9881060
write_io_dir: Optional[str] = None,
@@ -996,6 +1068,8 @@ def __init__(
9961068
qpc_path=qpc_path,
9971069
full_batch_size=full_batch_size,
9981070
ctx_len=ctx_len,
1071+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
1072+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
9991073
device_id=device_id,
10001074
enable_debug_logs=enable_debug_logs,
10011075
write_io_dir=write_io_dir,
@@ -1007,6 +1081,8 @@ def __init__(
10071081
self._full_batch_size = self._qaic_model.full_batch_size
10081082
self._tokenizer = self._qaic_model.tokenizer
10091083
self._ctx_len = ctx_len
1084+
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1085+
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
10101086
self._perf_metrics = None
10111087
self._prompt_queue = None
10121088
self._text_streamer = None

0 commit comments

Comments
 (0)