Skip to content

Commit ebddfee

Browse files
authored
update scout example (#2310)
Signed-off-by: Mengni Wang <mengni.wang@intel.com>
1 parent 065e4d0 commit ebddfee

File tree

4 files changed

+106
-7
lines changed

4 files changed

+106
-7
lines changed

examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ docker run -d --gpus all -v ... --shm-size=100g --name llama4 -it nvcr.io/nvidia
1111
docker exec -it llama4 bash
1212
git clone https://github.com/intel/neural-compressor.git
1313
cd neural-compressor/examples/3.x_api/pytorch/multimodal-modeling/quantization/auto_round/llama4
14+
# Use `INC_PT_ONLY=1 pip install git+https://github.com/intel/neural-compressor.git@v3.6rc` for the latest updates before neural-compressor v3.6 release
15+
pip install neural-compressor-pt==3.6
16+
# Use `pip install git+https://github.com/intel/auto-round.git@v0.8.0rc2` for the latest updates before auto-round v0.8.0 release
17+
pip install auto-round==0.8.0
1418
bash setup.sh
1519
```
1620

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
17+
import torch
18+
19+
torch.use_deterministic_algorithms(True, warn_only=True)
20+
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, AutoProcessor
21+
from neural_compressor.torch.quantization import (
22+
AutoRoundConfig,
23+
convert,
24+
prepare,
25+
)
26+
27+
28+
class BasicArgumentParser(argparse.ArgumentParser):
29+
def __init__(self, *args, **kwargs):
30+
super().__init__(*args, **kwargs)
31+
self.add_argument("--model", "--model_name", "--model_name_or_path",
32+
help="model name or path")
33+
34+
self.add_argument('--scheme', default="MXFP4", type=str,
35+
help="quantizaion scheme.")
36+
37+
self.add_argument("--device", "--devices", default="auto", type=str,
38+
help="the device to be used for tuning. The default is set to auto,"
39+
"allowing for automatic detection."
40+
"Currently, device settings support CPU, GPU, and HPU.")
41+
42+
self.add_argument("--export_format", default="llm_compressor", type=str,
43+
help="the format to save the model"
44+
)
45+
46+
self.add_argument("--output_dir", default="./tmp_autoround", type=str,
47+
help="the directory to save quantized model")
48+
49+
self.add_argument("--fp_layers", default="", type=str,
50+
help="layers to maintain original data type")
51+
52+
53+
def setup_parser():
54+
parser = BasicArgumentParser()
55+
56+
parser.add_argument("--iters", "--iter", default=0, type=int,
57+
help=" iters")
58+
59+
args = parser.parse_args()
60+
return args
61+
62+
63+
def tune(args):
64+
model_name = args.model
65+
if model_name[-1] == "/":
66+
model_name = model_name[:-1]
67+
print(f"start to quantize {model_name}")
68+
69+
layer_config = {}
70+
model = Llama4ForConditionalGeneration.from_pretrained(args.model, device_map=None, torch_dtype="auto", trust_remote_code=True)
71+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
72+
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
73+
fp_layers = args.fp_layers.replace(" ", "").split(",")
74+
if len(fp_layers) > 0:
75+
for n, m in model.named_modules():
76+
if not isinstance(m, (torch.nn.Linear)):
77+
continue
78+
for name in fp_layers:
79+
if name in n:
80+
layer_config[n] = {"bits": 16, "act_bits": 16}
81+
break
82+
83+
qconfig = AutoRoundConfig(
84+
tokenizer=tokenizer,
85+
iters=args.iters,
86+
scheme=args.scheme,
87+
layer_config=layer_config,
88+
export_format="llm_compressor",
89+
output_dir=args.output_dir,
90+
processor=processor,
91+
)
92+
model = prepare(model, qconfig)
93+
model = convert(model, qconfig)
94+
95+
if __name__ == '__main__':
96+
args = setup_parser()
97+
tune(args)

examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
auto-round @ git+https://github.com/intel/auto-round@v0.8.0rc
21
lm-eval==0.4.9.1
32
setuptools_scm
43
torchao==0.12.0

examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/run_quant.sh

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,11 @@ function run_tuning {
4747
extra_cmd="--fp_layers lm-head,self_attn,router,vision_model,multi_modal_projector,shared_expert --scheme MXFP4"
4848
fi
4949

50-
python3 -m auto_round \
51-
--model ${input_model} \
52-
--iters ${iters} \
53-
--format "llm_compressor" \
54-
--output_dir ${tuned_checkpoint} \
55-
${extra_cmd}
50+
python3 main.py \
51+
--model ${input_model} \
52+
--iters ${iters} \
53+
--output_dir ${tuned_checkpoint} \
54+
${extra_cmd}
5655
}
5756

5857
main "$@"

0 commit comments

Comments
 (0)