|
16 | 16 | ## - The second comp_ctx_lengths_decode list will be used for decoding. During the decoding process, based on the position_id or cache index it will work with the specific compute-context-length in the list. It will start from a proper compute-context-length in the list based on input prompt length and will gradually increase the compute-context-length if the cache index passes the current compute-context-length. ## |
17 | 17 |
|
18 | 18 | ctx_len = 1024 |
19 | | -comp_ctx_lengths_prefill = [256] |
20 | | -comp_ctx_lengths_decode = [512, ctx_len] |
| 19 | +comp_ctx_lengths_prefill = [256] # None |
| 20 | +comp_ctx_lengths_decode = [ctx_len] # None |
21 | 21 |
|
22 | 22 | # model_name = "google/gemma-7b" |
23 | 23 | # model_name = "google/gemma-2-2b" |
|
27 | 27 | # model_name = "microsoft/phi-1_5" |
28 | 28 | # model_name = "microsoft/Phi-3-mini-4k-instruct" |
29 | 29 | # model_name = "Qwen/Qwen2.5-7B-Instruct" |
30 | | -model_name = "meta-llama/Llama-3.2-1B" |
| 30 | +# model_name = "meta-llama/Llama-3.2-1B" |
31 | 31 | # model_name = "Qwen/Qwen3-1.7B" |
32 | 32 | # model_name = "allenai/OLMo-2-0425-1B" |
33 | | -# model_name = "ibm-granite/granite-3.3-2b-base" |
| 33 | +model_name = "ibm-granite/granite-3.3-2b-base" |
| 34 | +# model_name = "ibm-granite/granite-3.2-8b-instruct" |
34 | 35 | # model_name = "meta-llama/Llama-3.3-70B-Instruct" |
35 | 36 | # model_name = "Salesforce/codegen-350M-mono" |
36 | 37 | # model_name = "tiiuae/falcon-7b-instruct" |
37 | 38 | # model_name = "openai-community/gpt2" |
38 | 39 | # model_name = "EleutherAI/gpt-j-6b" |
39 | | -# model_name = "EleutherAI/gpt-j-6b" |
40 | 40 |
|
41 | 41 | model = QEFFAutoModelForCausalLM.from_pretrained( |
42 | 42 | model_name, |
43 | 43 | continuous_batching=True, |
| 44 | + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, |
| 45 | + comp_ctx_lengths_decode=comp_ctx_lengths_decode, |
| 46 | + ctx_len=ctx_len, |
44 | 47 | ) |
45 | 48 |
|
46 | 49 | # model compilation for either continuous or static batching. For continuous batching full_batch_size is needed. |
47 | 50 | model.compile( |
48 | 51 | prefill_seq_len=128, |
49 | 52 | ctx_len=ctx_len, |
50 | 53 | num_cores=16, |
51 | | - num_devices=1, |
| 54 | + num_devices=4, |
52 | 55 | full_batch_size=1, |
53 | 56 | mxint8_kv_cache=True, |
54 | 57 | mxfp6_matmul=True, |
|
0 commit comments