Skip to content

Commit bdb2dee

Browse files
authored
Merge branch 'quic:main' into CCL-main
2 parents 6305b99 + e592774 commit bdb2dee

File tree

9 files changed

+295
-42
lines changed

9 files changed

+295
-42
lines changed

QEfficient/cloud/export.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@
1111

1212
from QEfficient.base.common import QEFFCommonLoader
1313
from QEfficient.utils import check_and_assign_cache_dir
14+
from QEfficient.utils.custom_yaml import generate_custom_io
1415
from QEfficient.utils.logging_utils import logger
1516

1617
# Specifically for Docker images.
1718
ROOT_DIR = os.path.dirname(os.path.abspath(""))
1819

1920

20-
def get_onnx_model_path(
21+
def get_onnx_path_and_setup_customIO(
2122
model_name: str,
2223
cache_dir: Optional[str] = None,
2324
hf_token: Optional[str] = None,
2425
full_batch_size: Optional[int] = None,
2526
local_model_dir: Optional[str] = None,
27+
mxint8_kv_cache: Optional[int] = False,
2628
):
2729
"""
2830
Exports the PyTorch model to ONNX format if a pre-exported file is not found,
@@ -63,6 +65,9 @@ def get_onnx_model_path(
6365
)
6466
onnx_model_path = qeff_model.export()
6567
logger.info(f"Generated onnx_path: {onnx_model_path}")
68+
69+
# Generating Custom IO for the compile.
70+
generate_custom_io(qeff_model, mxint8_kv_cache=mxint8_kv_cache)
6671
return onnx_model_path
6772

6873

@@ -72,13 +77,14 @@ def main(
7277
hf_token: Optional[str] = None,
7378
local_model_dir: Optional[str] = None,
7479
full_batch_size: Optional[int] = None,
80+
mxint8_kv_cache: Optional[bool] = False,
7581
) -> None:
7682
"""
7783
Main function for the QEfficient ONNX export CLI application.
7884
7985
This function serves as the entry point for exporting a PyTorch model, loaded
8086
via QEFFCommonLoader, to the ONNX format. It prepares the necessary
81-
paths and calls `get_onnx_model_path`.
87+
paths and calls `get_onnx_path_and_setup_customIO`.
8288
8389
Parameters
8490
----------
@@ -106,12 +112,13 @@ def main(
106112
107113
"""
108114
cache_dir = check_and_assign_cache_dir(local_model_dir, cache_dir)
109-
get_onnx_model_path(
115+
get_onnx_path_and_setup_customIO(
110116
model_name=model_name,
111117
cache_dir=cache_dir,
112118
hf_token=hf_token,
113119
full_batch_size=full_batch_size,
114120
local_model_dir=local_model_dir,
121+
mxint8_kv_cache=mxint8_kv_cache,
115122
)
116123

117124

@@ -137,5 +144,11 @@ def main(
137144
default=None,
138145
help="Set full batch size to enable continuous batching mode, default is None",
139146
)
147+
parser.add_argument(
148+
"--mxint8_kv_cache",
149+
"--mxint8-kv-cache",
150+
required=False,
151+
help="Compress Present/Past KV to MXINT8 using CustomIO config, default is False",
152+
)
140153
args = parser.parse_args()
141154
main(**args.__dict__)

QEfficient/cloud/infer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ def main(
248248

249249
image_path = kwargs.pop("image_path", None)
250250
image_url = kwargs.pop("image_url", None)
251+
iteration = kwargs.pop("iteration", 1)
252+
automation = kwargs.pop("automation", False)
251253

252254
config = qeff_model.model.config
253255
architecture = config.architectures[0] if config.architectures else None
@@ -310,6 +312,8 @@ def main(
310312
device_id=device_group,
311313
prompts_txt_file_path=prompts_txt_file_path,
312314
generation_len=generation_len,
315+
iteration=iteration,
316+
automation=automation,
313317
)
314318

315319

QEfficient/compile/compile_helper.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def compile(
270270
This method will be removed soon; use `QEFFAutoModelForCausalLM.compile` instead.
271271
272272
"""
273+
273274
if full_batch_size and batch_size != 1:
274275
raise ValueError("Only either batch_size or full_batch_size should be greater than one")
275276

@@ -284,11 +285,20 @@ def compile(
284285
full_batch_size=full_batch_size,
285286
)
286287

287-
# Select the customIO config based on the mx flag.
288-
custom_io_file_name = "custom_io_int8.yaml" if mxint8 else "custom_io_fp16.yaml"
288+
dtype_suffix = "int8" if mxint8 else "fp16"
289+
source_path = f"./custom_io_{dtype_suffix}.yaml"
290+
destination_path = os.path.join(os.path.dirname(qpc_path), f"custom_io_{dtype_suffix}.yaml")
291+
292+
# Move the custom YAML file to the cache/qeff_model directory
293+
try:
294+
shutil.move(source_path, destination_path)
295+
print(f"Successfully moved '{source_path}' to '{destination_path}'.")
296+
except Exception as e:
297+
print(f"Error while moving file '{source_path}': {e}")
289298

299+
custom_io_file_name = f"custom_io_{dtype_suffix}.yaml"
290300
if custom_io_file_path is None:
291-
custom_io_file_path = os.path.join(os.path.dirname(onnx_path), custom_io_file_name)
301+
custom_io_file_path = os.path.join(os.path.dirname(qpc_path), custom_io_file_name)
292302

293303
if not os.path.isfile(custom_io_file_path):
294304
raise FileNotFoundError(

QEfficient/generation/text_generation_inference.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def cloud_ai_100_exec_kv(
324324
stream: bool = True,
325325
write_io_dir: Optional[str] = None,
326326
automation=False,
327+
iteration: int = 1,
327328
prompt_to_lora_id_mapping: Optional[List[int]] = None,
328329
is_tlm: bool = False,
329330
include_sampler: bool = False,
@@ -348,6 +349,7 @@ def cloud_ai_100_exec_kv(
348349
:stream (bool): If True, enable streamer, which returns tokens one by one as the model generates them. ``Defaults to True``.
349350
:Write_io_dir (str): Path to write the input and output files. ``Defaults to None``.
350351
:automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``.
352+
:iteration (int): Number of iterations to run the inference. ``Defaults to 1``.
351353
:prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter.
352354
:include_sampler (bool, default=False): Enable/Disable sampling of next tokens.
353355
:return_pdfs (bool, default=False): Return probability distributions along with sampled
@@ -394,30 +396,34 @@ def cloud_ai_100_exec_kv(
394396
return_pdfs=return_pdfs,
395397
sampling_params=sampling_params,
396398
)
397-
if full_batch_size is None:
398-
exec_info = [
399-
generate_text.generate(prompt[i : i + batch_size], generation_len, stream, prompt_to_lora_id_mapping)
400-
for i in range(0, len(prompt), batch_size)
401-
]
402-
prefill_time = np.average([info.perf_metrics.prefill_time for info in exec_info])
403-
decode_perf = np.average([info.perf_metrics.decode_perf for info in exec_info])
404-
total_perf = np.average([info.perf_metrics.total_perf for info in exec_info])
405-
total_time = np.average([info.perf_metrics.total_time for info in exec_info])
406-
generated_texts = [info.generated_texts for info in exec_info]
407-
generated_ids = [info.generated_ids for info in exec_info]
408-
409-
exec_info = CloudAI100ExecInfo(
410-
batch_size=batch_size,
411-
generated_texts=generated_texts,
412-
generated_ids=generated_ids,
413-
perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time),
414-
)
415-
else:
416-
exec_info = generate_text.generate(
417-
prompt=prompt, generation_len=generation_len, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping
418-
)
419399

420-
print_latency_stats_kv(prompt, exec_info=exec_info, automation=automation)
400+
for _ in range(0, int(iteration)):
401+
if full_batch_size is None:
402+
exec_info = [
403+
generate_text.generate(prompt[i : i + batch_size], generation_len, stream, prompt_to_lora_id_mapping)
404+
for i in range(0, len(prompt), batch_size)
405+
]
406+
prefill_time = np.average([info.perf_metrics.prefill_time for info in exec_info])
407+
decode_perf = np.average([info.perf_metrics.decode_perf for info in exec_info])
408+
total_perf = np.average([info.perf_metrics.total_perf for info in exec_info])
409+
total_time = np.average([info.perf_metrics.total_time for info in exec_info])
410+
generated_texts = [info.generated_texts for info in exec_info]
411+
generated_ids = [info.generated_ids for info in exec_info]
412+
413+
exec_info = CloudAI100ExecInfo(
414+
batch_size=batch_size,
415+
generated_texts=generated_texts,
416+
generated_ids=generated_ids,
417+
perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time),
418+
)
419+
else:
420+
exec_info = generate_text.generate(
421+
prompt=prompt, generation_len=generation_len, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping
422+
)
423+
424+
print_latency_stats_kv(prompt, exec_info=exec_info, automation=automation)
425+
426+
# TODO: Need to handle the case where exec_info if given for n iterations
421427
return exec_info
422428

423429

@@ -951,7 +957,9 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
951957

952958
return decode_pause_time
953959

954-
def run_decode(self, decode_inputs, generation_len, streamer: Optional[transformers.TextStreamer] = None):
960+
def run_decode(
961+
self, decode_inputs, generation_len, automation, streamer: Optional[transformers.TextStreamer] = None
962+
):
955963
"""
956964
Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length.
957965
@@ -1000,11 +1008,11 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
10001008
if self.include_sampler:
10011009
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
10021010

1003-
if finished_sequences.all():
1011+
if finished_sequences.all() and not automation:
10041012
break
10051013
return num_token
10061014

1007-
def generate_decode_stream(self, decode_inputs, generation_len):
1015+
def generate_decode_stream(self, decode_inputs, generation_len, automation):
10081016
"""
10091017
Generator method for yielding decode tokens. Executes the decoding process for a given set of inputs and a specified generation length.
10101018
@@ -1032,7 +1040,7 @@ def generate_decode_stream(self, decode_inputs, generation_len):
10321040
self.generated_ids[:, num_token] = decode_inputs["input_ids"].squeeze(1)
10331041
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
10341042

1035-
if finished_sequences.all():
1043+
if finished_sequences.all() and not automation:
10361044
break
10371045
yield decode_inputs["input_ids"] # yield the last token
10381046

@@ -1115,6 +1123,7 @@ def _regular_model_execution(
11151123
prompt: List[str],
11161124
generation_len: Optional[int] = None,
11171125
stream: Optional[bool] = True,
1126+
automation: Optional[bool] = False,
11181127
prompt_to_lora_id_mapping: Optional[List[int]] = None,
11191128
):
11201129
"""
@@ -1142,7 +1151,7 @@ def _regular_model_execution(
11421151
decode_inputs = self._qaic_model.prepare_decode_inputs()
11431152

11441153
loop_start = perf_counter() # Start decode loop timer
1145-
num_token = self._qaic_model.run_decode(decode_inputs, generation_len, self._text_streamer)
1154+
num_token = self._qaic_model.run_decode(decode_inputs, generation_len, automation, self._text_streamer)
11461155
end = perf_counter()
11471156
generated_texts = self._tokenizer.batch_decode(self._qaic_model.generated_ids, skip_special_tokens=True)
11481157

@@ -1196,6 +1205,7 @@ def generate_stream_tokens(
11961205
self,
11971206
prompt: List[str],
11981207
generation_len: Optional[int] = None,
1208+
automation: Optional[bool] = False,
11991209
prompt_to_lora_id_mapping: Optional[List[int]] = None,
12001210
):
12011211
"""
@@ -1225,7 +1235,7 @@ def generate_stream_tokens(
12251235

12261236
loop_start = perf_counter() # Start decode loop timer
12271237
num_token = 0
1228-
for token_id in self._qaic_model.generate_decode_stream(decode_inputs, generation_len):
1238+
for token_id in self._qaic_model.generate_decode_stream(decode_inputs, generation_len, automation):
12291239
decoded_tokens = []
12301240
for idx in range(self._qaic_model.batch_size):
12311241
decoded_tokens.append(self._tokenizer.decode(token_id[idx], skip_special_tokens=True))
@@ -1244,6 +1254,7 @@ def generate(
12441254
prompt: List[str],
12451255
generation_len: Optional[int] = None,
12461256
stream: bool = True,
1257+
automation: Optional[bool] = False,
12471258
prompt_to_lora_id_mapping: Optional[List[int]] = None,
12481259
):
12491260
"""
@@ -1267,7 +1278,7 @@ def generate(
12671278
if stream:
12681279
print("\nPrompt : " + prompt[0] + "\nCompletion :", flush=True, end="")
12691280
perf_metrics, generated_texts = self._regular_model_execution(
1270-
prompt, generation_len, stream, prompt_to_lora_id_mapping
1281+
prompt, generation_len, stream, automation, prompt_to_lora_id_mapping
12711282
)
12721283

12731284
if stream:

QEfficient/transformers/models/llama/modeling_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def forward(
183183
)
184184

185185
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
186-
attn_output = self.o_proj(attn_output)
186+
attn_output = self.o_proj(attn_output, **kwargs)
187187
return attn_output, attn_weights, past_key_value
188188

189189

QEfficient/transformers/models/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2914,6 +2914,8 @@ def generate(
29142914
comp_ctx_lengths_decode=self.comp_ctx_lengths_decode,
29152915
device_id=device_id,
29162916
generation_len=generation_len,
2917+
automation=kwargs.pop("automation", False),
2918+
iteration=kwargs.pop("iteration", 1),
29172919
is_tlm=self.is_tlm,
29182920
**kwargs,
29192921
)

0 commit comments

Comments
 (0)