Skip to content

Commit 8c82207

Browse files
committed
Adding Compute-Context-Length(CCL)
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent 0182d95 commit 8c82207

32 files changed

+2099
-211
lines changed

QEfficient/customop/ctx_scatter_gather.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor
115115

116116

117117
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
118-
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
119-
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
118+
def CtxGather(
119+
data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
120+
) -> onnxscript.FLOAT:
121+
# Create a shape tensor based on comp_ctx_len
122+
shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0)
123+
124+
# Directly use the shape tensor without validation
125+
ctx_indices = ops.Expand(ctx_indices, shape_tensor)
120126
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
121127
return ops.GatherND(data, ctx_indices, batch_dims=2)
122128

@@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function):
127133
"""
128134

129135
@staticmethod
130-
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
136+
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
131137
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
132138
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
133139
return data[batch_indices, head_indices, ctx_indices]
@@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs):
137143
pass
138144

139145
@staticmethod
140-
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
141-
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
146+
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
147+
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)

QEfficient/customop/ctx_scatter_gather_cb.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,20 @@ def symbolic(
9797

9898
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
9999
def CtxGatherCB(
100-
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
100+
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
101101
) -> onnxscript.FLOAT:
102102
batch_size = ops.Gather(ops.Shape(batch_index), [0])
103103
num_heads = ops.Gather(ops.Shape(data), [1])
104-
ctx_len = ops.Gather(ops.Shape(data), [2])
104+
# using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well.
105+
ctx_len = ops.Reshape(comp_ctx_len, [1])
105106

106107
# Expanded shape to create indices
107108
zero = ops.Constant(value_ints=[0])
108109
one = ops.Constant(value_ints=[1])
109-
exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
110+
# exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
111+
exp_shape = ops.Concat(
112+
ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0
113+
)
110114

111115
# Create indices
112116
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
@@ -119,7 +123,7 @@ def CtxGatherCB(
119123

120124
class CtxGatherFuncCB(torch.autograd.Function):
121125
@staticmethod
122-
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
126+
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
123127
batch_indices = batch_index.view(-1, 1, 1)
124128
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
125129
return data[batch_indices, head_indices, ctx_indices]
@@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs):
129133
pass
130134

131135
@staticmethod
132-
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
133-
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
136+
def symbolic(
137+
g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int
138+
) -> torch.Value:
139+
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)
134140

135141

136142
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))

QEfficient/generation/text_generation_inference.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
318318
prompts_txt_file_path: Optional[str] = None,
319319
device_id: Optional[List[int]] = None,
320320
generation_len: Optional[int] = None,
321+
comp_ctx_lengths: Optional[List[int]] = None,
322+
prefill_ccl_len: Optional[int] = 1,
321323
enable_debug_logs: bool = False,
322324
stream: bool = True,
323325
write_io_dir: Optional[str] = None,
@@ -382,6 +384,8 @@ def cloud_ai_100_exec_kv(
382384
qpc_path=qpc_path,
383385
device_id=device_id,
384386
ctx_len=ctx_len,
387+
comp_ctx_lengths=comp_ctx_lengths,
388+
prefill_ccl_len=prefill_ccl_len,
385389
enable_debug_logs=enable_debug_logs,
386390
write_io_dir=write_io_dir,
387391
full_batch_size=full_batch_size,
@@ -424,6 +428,8 @@ def __init__(
424428
qpc_path: str,
425429
full_batch_size: Optional[int] = None,
426430
ctx_len: Optional[int] = None,
431+
comp_ctx_lengths: Optional[List[int]] = None,
432+
prefill_ccl_len: Optional[int] = 1,
427433
device_id: Optional[List[int]] = None,
428434
enable_debug_logs: bool = False,
429435
write_io_dir: Optional[str] = None,
@@ -433,6 +439,8 @@ def __init__(
433439
sampling_params: Optional[Dict[str, Any]] = None,
434440
) -> None:
435441
self._ctx_len = ctx_len
442+
self.comp_ctx_lengths = comp_ctx_lengths
443+
self.prefill_ccl_len = prefill_ccl_len
436444
self._write_io_dir = write_io_dir
437445
self.is_tlm = is_tlm
438446
self.return_pdfs = return_pdfs
@@ -791,7 +799,23 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
791799
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
792800
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
793801

802+
if self.comp_ctx_lengths is not None:
803+
self.list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths]
804+
prefill_ccl_id = 0
805+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[prefill_ccl_id]
806+
794807
for i in range(num_chunks):
808+
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths[prefill_ccl_id]:
809+
prefill_ccl_id += 1
810+
if prefill_ccl_id >= self.prefill_ccl_len:
811+
prefill_ccl_id = (
812+
(self.prefill_ccl_len - 1)
813+
if self.prefill_ccl_len != 0
814+
else min(prefill_ccl_id, len(self.comp_ctx_lengths) - 1)
815+
)
816+
817+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[prefill_ccl_id]
818+
795819
chunk_inputs = inputs.copy()
796820
chunk_inputs["input_ids"] = inputs["input_ids"][
797821
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
@@ -810,6 +834,18 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
810834
generation_len,
811835
)
812836

837+
def initialize_ccl(self, decode_inputs):
838+
max_ccl_id = len(self.comp_ctx_lengths) - 1
839+
max_position_id = np.max(decode_inputs["position_ids"])
840+
ccl_id_initial = self.prefill_ccl_len
841+
ccl_id = ccl_id_initial
842+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths)):
843+
if max_position_id < self.comp_ctx_lengths[i]:
844+
ccl_id = i
845+
break
846+
847+
return ccl_id, max_ccl_id
848+
813849
def run_continuous_batching_decode(self, prompt_queue, generation_len):
814850
"""
815851
Runs continuous batching decode for the given prompt queue and generation length.
@@ -841,6 +877,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
841877
# Prepare decode inputs inputs.
842878
decode_inputs = self.prepare_decode_inputs()
843879

880+
if self.comp_ctx_lengths is not None:
881+
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
882+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
883+
844884
while prompt_queue or current_decode_ongoing.any():
845885
outputs = self._session.run(decode_inputs)
846886

@@ -878,6 +918,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
878918
batch_id_map[decode_batch_id]
879919
]
880920

921+
if self.comp_ctx_lengths is not None:
922+
###Recalculate ccl_id based on position ids###
923+
# Determine the maximum value of position_ids across all batch elements
924+
max_position_id = np.max(decode_inputs["position_ids"])
925+
926+
# Update ccl_id and comp_ctx_lengths based on the maximum position id
927+
ccl_id_initial = self.prefill_ccl_len
928+
ccl_id = ccl_id_initial
929+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths)):
930+
if max_position_id < self.comp_ctx_lengths[i]:
931+
ccl_id = i
932+
break
933+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
934+
881935
else:
882936
current_decode_ongoing[decode_batch_id] = False
883937
else:
@@ -890,6 +944,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
890944
if self.include_sampler:
891945
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
892946

947+
if self.comp_ctx_lengths is not None:
948+
# Update ccl_id and comp_ctx_lengths based on the maximum position id
949+
if decode_inputs["position_ids"][decode_batch_id, -1] >= self.comp_ctx_lengths[ccl_id] - 1:
950+
ccl_id = min(ccl_id + 1, max_ccl_id)
951+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
952+
893953
generated_id_current_index[decode_batch_id] += 1
894954

895955
return decode_pause_time
@@ -914,7 +974,18 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
914974
self._session.set_buffers({"logits": logits_out_placeholder})
915975
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
916976
num_token = 0
977+
978+
if self.comp_ctx_lengths is not None:
979+
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
980+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
981+
982+
cache_index = np.max(decode_inputs["position_ids"])
917983
for num_token in range(1, generation_len):
984+
if self.comp_ctx_lengths is not None:
985+
if cache_index >= self.comp_ctx_lengths[ccl_id] - 1:
986+
ccl_id = min(ccl_id + 1, max_ccl_id)
987+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
988+
918989
if streamer:
919990
streamer.put(decode_inputs["input_ids"][0])
920991
outputs = self._session.run(decode_inputs)
@@ -926,6 +997,7 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
926997
# Prepare inputs for next iteration
927998
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
928999
decode_inputs["position_ids"][:, -1] += 1
1000+
cache_index += 1
9291001
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
9301002
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
9311003
if self.include_sampler:
@@ -975,6 +1047,8 @@ def __init__(
9751047
qpc_path: str,
9761048
full_batch_size: Optional[int] = None,
9771049
ctx_len: Optional[int] = None,
1050+
comp_ctx_lengths: Optional[List[int]] = None,
1051+
prefill_ccl_len: Optional[int] = 1,
9781052
device_id: Optional[List[int]] = None,
9791053
enable_debug_logs: bool = False,
9801054
write_io_dir: Optional[str] = None,
@@ -988,6 +1062,8 @@ def __init__(
9881062
qpc_path=qpc_path,
9891063
full_batch_size=full_batch_size,
9901064
ctx_len=ctx_len,
1065+
comp_ctx_lengths=comp_ctx_lengths,
1066+
prefill_ccl_len=prefill_ccl_len,
9911067
device_id=device_id,
9921068
enable_debug_logs=enable_debug_logs,
9931069
write_io_dir=write_io_dir,
@@ -999,6 +1075,8 @@ def __init__(
9991075
self._full_batch_size = self._qaic_model.full_batch_size
10001076
self._tokenizer = self._qaic_model.tokenizer
10011077
self._ctx_len = ctx_len
1078+
self.comp_ctx_lengths = comp_ctx_lengths
1079+
self.prefill_ccl_len = prefill_ccl_len
10021080
self._perf_metrics = None
10031081
self._prompt_queue = None
10041082
self._text_streamer = None

0 commit comments

Comments
 (0)