Skip to content

Commit 5e44f81

Browse files
committed
adding Context Length Specialization (CCL)
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent 6aaa75a commit 5e44f81

32 files changed

+2062
-210
lines changed

QEfficient/customop/ctx_scatter_gather.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor
115115

116116

117117
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
118-
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
119-
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
118+
def CtxGather(
119+
data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
120+
) -> onnxscript.FLOAT:
121+
# Create a shape tensor based on comp_ctx_len
122+
shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0)
123+
124+
# Directly use the shape tensor without validation
125+
ctx_indices = ops.Expand(ctx_indices, shape_tensor)
120126
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
121127
return ops.GatherND(data, ctx_indices, batch_dims=2)
122128

@@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function):
127133
"""
128134

129135
@staticmethod
130-
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
136+
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
131137
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
132138
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
133139
return data[batch_indices, head_indices, ctx_indices]
@@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs):
137143
pass
138144

139145
@staticmethod
140-
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
141-
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
146+
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
147+
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)

QEfficient/customop/ctx_scatter_gather_cb.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,20 @@ def symbolic(
9797

9898
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
9999
def CtxGatherCB(
100-
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
100+
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
101101
) -> onnxscript.FLOAT:
102102
batch_size = ops.Gather(ops.Shape(batch_index), [0])
103103
num_heads = ops.Gather(ops.Shape(data), [1])
104-
ctx_len = ops.Gather(ops.Shape(data), [2])
104+
# using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well.
105+
ctx_len = ops.Reshape(comp_ctx_len, [1])
105106

106107
# Expanded shape to create indices
107108
zero = ops.Constant(value_ints=[0])
108109
one = ops.Constant(value_ints=[1])
109-
exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
110+
# exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
111+
exp_shape = ops.Concat(
112+
ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0
113+
)
110114

111115
# Create indices
112116
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
@@ -119,7 +123,7 @@ def CtxGatherCB(
119123

120124
class CtxGatherFuncCB(torch.autograd.Function):
121125
@staticmethod
122-
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
126+
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
123127
batch_indices = batch_index.view(-1, 1, 1)
124128
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
125129
return data[batch_indices, head_indices, ctx_indices]
@@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs):
129133
pass
130134

131135
@staticmethod
132-
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
133-
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
136+
def symbolic(
137+
g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int
138+
) -> torch.Value:
139+
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)
134140

135141

136142
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))

QEfficient/generation/text_generation_inference.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
318318
prompts_txt_file_path: Optional[str] = None,
319319
device_id: Optional[List[int]] = None,
320320
generation_len: Optional[int] = None,
321+
comp_ctx_lengths: Optional[List[int]] = None,
322+
prefill_ccl_len: Optional[int] = 1,
321323
enable_debug_logs: bool = False,
322324
stream: bool = True,
323325
write_io_dir: Optional[str] = None,
@@ -382,6 +384,8 @@ def cloud_ai_100_exec_kv(
382384
qpc_path=qpc_path,
383385
device_id=device_id,
384386
ctx_len=ctx_len,
387+
comp_ctx_lengths=comp_ctx_lengths,
388+
prefill_ccl_len=prefill_ccl_len,
385389
enable_debug_logs=enable_debug_logs,
386390
write_io_dir=write_io_dir,
387391
full_batch_size=full_batch_size,
@@ -424,6 +428,8 @@ def __init__(
424428
qpc_path: str,
425429
full_batch_size: Optional[int] = None,
426430
ctx_len: Optional[int] = None,
431+
comp_ctx_lengths: Optional[List[int]] = None,
432+
prefill_ccl_len: Optional[int] = 1,
427433
device_id: Optional[List[int]] = None,
428434
enable_debug_logs: bool = False,
429435
write_io_dir: Optional[str] = None,
@@ -433,6 +439,8 @@ def __init__(
433439
sampling_params: Optional[Dict[str, Any]] = None,
434440
) -> None:
435441
self._ctx_len = ctx_len
442+
self.comp_ctx_lengths = comp_ctx_lengths
443+
self.prefill_ccl_len = prefill_ccl_len
436444
self._write_io_dir = write_io_dir
437445
self.is_tlm = is_tlm
438446
self.return_pdfs = return_pdfs
@@ -791,7 +799,18 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
791799
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
792800
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
793801

802+
if self.comp_ctx_lengths is not None:
803+
self.list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths]
804+
prefill_ccl_id = 0
805+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[prefill_ccl_id]
806+
794807
for i in range(num_chunks):
808+
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths[prefill_ccl_id]:
809+
prefill_ccl_id += 1
810+
if prefill_ccl_id >= self.prefill_ccl_len:
811+
prefill_ccl_id = self.prefill_ccl_len - 1
812+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[prefill_ccl_id]
813+
795814
chunk_inputs = inputs.copy()
796815
chunk_inputs["input_ids"] = inputs["input_ids"][
797816
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
@@ -810,6 +829,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
810829
generation_len,
811830
)
812831

832+
def initialize_ccl(self, decode_inputs):
833+
max_ccl_id = len(self.comp_ctx_lengths) - 1
834+
max_position_id = np.max(decode_inputs["position_ids"])
835+
ccl_id_initial = self.prefill_ccl_len
836+
ccl_id = ccl_id_initial
837+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths)):
838+
if max_position_id < self.comp_ctx_lengths[i]:
839+
ccl_id = i
840+
break
841+
842+
print(f"Decode CCL: {self.comp_ctx_lengths[ccl_id]}")
843+
return ccl_id, max_ccl_id
844+
813845
def run_continuous_batching_decode(self, prompt_queue, generation_len):
814846
"""
815847
Runs continuous batching decode for the given prompt queue and generation length.
@@ -841,6 +873,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
841873
# Prepare decode inputs inputs.
842874
decode_inputs = self.prepare_decode_inputs()
843875

876+
if self.comp_ctx_lengths is not None:
877+
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
878+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
879+
844880
while prompt_queue or current_decode_ongoing.any():
845881
outputs = self._session.run(decode_inputs)
846882

@@ -878,6 +914,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
878914
batch_id_map[decode_batch_id]
879915
]
880916

917+
if self.comp_ctx_lengths is not None:
918+
###Recalculate ccl_id based on position ids###
919+
# Determine the maximum value of position_ids across all batch elements
920+
max_position_id = np.max(decode_inputs["position_ids"])
921+
922+
# Update ccl_id and comp_ctx_lengths based on the maximum position id
923+
ccl_id_initial = self.prefill_ccl_len
924+
ccl_id = ccl_id_initial
925+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths)):
926+
if max_position_id < self.comp_ctx_lengths[i]:
927+
ccl_id = i
928+
break
929+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
930+
881931
else:
882932
current_decode_ongoing[decode_batch_id] = False
883933
else:
@@ -890,6 +940,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
890940
if self.include_sampler:
891941
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
892942

943+
if self.comp_ctx_lengths is not None:
944+
# Update ccl_id and comp_ctx_lengths based on the maximum position id
945+
if decode_inputs["position_ids"][decode_batch_id, -1] >= self.comp_ctx_lengths[ccl_id] - 1:
946+
ccl_id = min(ccl_id + 1, max_ccl_id)
947+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
948+
893949
generated_id_current_index[decode_batch_id] += 1
894950

895951
return decode_pause_time
@@ -914,7 +970,18 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
914970
self._session.set_buffers({"logits": logits_out_placeholder})
915971
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
916972
num_token = 0
973+
974+
if self.comp_ctx_lengths is not None:
975+
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
976+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
977+
978+
cache_index = np.max(decode_inputs["position_ids"])
917979
for num_token in range(1, generation_len):
980+
if self.comp_ctx_lengths is not None:
981+
if cache_index >= self.comp_ctx_lengths[ccl_id] - 1:
982+
ccl_id = min(ccl_id + 1, max_ccl_id)
983+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths[ccl_id]
984+
918985
if streamer:
919986
streamer.put(decode_inputs["input_ids"][0])
920987
outputs = self._session.run(decode_inputs)
@@ -926,6 +993,7 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
926993
# Prepare inputs for next iteration
927994
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
928995
decode_inputs["position_ids"][:, -1] += 1
996+
cache_index += 1
929997
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
930998
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
931999
if self.include_sampler:
@@ -975,6 +1043,8 @@ def __init__(
9751043
qpc_path: str,
9761044
full_batch_size: Optional[int] = None,
9771045
ctx_len: Optional[int] = None,
1046+
comp_ctx_lengths: Optional[List[int]] = None,
1047+
prefill_ccl_len: Optional[int] = 1,
9781048
device_id: Optional[List[int]] = None,
9791049
enable_debug_logs: bool = False,
9801050
write_io_dir: Optional[str] = None,
@@ -988,6 +1058,8 @@ def __init__(
9881058
qpc_path=qpc_path,
9891059
full_batch_size=full_batch_size,
9901060
ctx_len=ctx_len,
1061+
comp_ctx_lengths=comp_ctx_lengths,
1062+
prefill_ccl_len=prefill_ccl_len,
9911063
device_id=device_id,
9921064
enable_debug_logs=enable_debug_logs,
9931065
write_io_dir=write_io_dir,
@@ -999,6 +1071,8 @@ def __init__(
9991071
self._full_batch_size = self._qaic_model.full_batch_size
10001072
self._tokenizer = self._qaic_model.tokenizer
10011073
self._ctx_len = ctx_len
1074+
self.comp_ctx_lengths = comp_ctx_lengths
1075+
self.prefill_ccl_len = prefill_ccl_len
10021076
self._perf_metrics = None
10031077
self._prompt_queue = None
10041078
self._text_streamer = None

0 commit comments

Comments
 (0)