|
50 | 50 | parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], type=str, \ |
51 | 51 | help="tasks list for accuracy validation") |
52 | 52 | # ============WeightOnlyQuant configs=============== |
| 53 | +parser.add_argument("--bits", type=int, default=4, choices=[4]) |
53 | 54 | parser.add_argument("--woq", action="store_true") |
54 | 55 | parser.add_argument("--woq_algo", default="RTN", choices=['RTN', 'GPTQ'], |
55 | 56 | help="Weight-only parameter.") |
56 | | -parser.add_argument("--woq_dtype", type=str, default="int4_fullrange", |
| 57 | +parser.add_argument("--weight_dtype", type=str, default="int4_fullrange", |
57 | 58 | choices=["int4_fullrange"]) |
58 | | -parser.add_argument("--woq_group_size", type=int, default=32) |
59 | | -parser.add_argument("--woq_scheme", default="sym") |
| 59 | +parser.add_argument("--group_size", type=int, default=32) |
| 60 | +parser.add_argument("--scheme", default="sym") |
60 | 61 | parser.add_argument("--woq_enable_mse_search", action="store_true") |
61 | 62 | parser.add_argument("--device", default="xpu") |
62 | 63 | parser.add_argument("--compute_dtype", default="fp16") |
| 64 | +# ============GPTQ configs============== |
63 | 65 | parser.add_argument( |
64 | | - "--gptq_percdamp", |
| 66 | + "--desc_act", |
| 67 | + action="store_true", |
| 68 | + help="Whether to apply the activation order GPTQ heuristic.", |
| 69 | +) |
| 70 | +parser.add_argument( |
| 71 | + "--damp_percent", |
65 | 72 | type=float, |
66 | 73 | default=0.01, |
67 | 74 | help="Percent of the average Hessian diagonal to use for dampening.", |
68 | 75 | ) |
69 | 76 | parser.add_argument( |
70 | | - "--gptq_block_size", |
| 77 | + "--blocksize", |
71 | 78 | type=int, |
72 | 79 | default=128, |
73 | 80 | help="Block size. sub weight matrix size to run GPTQ.", |
74 | 81 | ) |
75 | 82 | parser.add_argument( |
76 | | - "--gptq_nsamples", type=int, default=128, help="Number of calibration data samples." |
| 83 | + "--nsamples", type=int, default=128, help="Number of calibration data samples." |
77 | 84 | ) |
78 | 85 | parser.add_argument( |
79 | 86 | "--max_input_length", |
80 | 87 | type=int, |
81 | 88 | default=2048, |
82 | 89 | help="Calibration dataset sequence max length, this should align with your model config", |
83 | 90 | ) |
| 91 | +parser.add_argument( |
| 92 | + "--static_groups", |
| 93 | + action="store_true", |
| 94 | + help="Use determined group to do quantization", |
| 95 | +) |
| 96 | +parser.add_argument("--calib_iters", default=100, type=int, help="Calibration iters.") |
84 | 97 | # ============BitsAndBytes configs============== |
85 | 98 | parser.add_argument("--bitsandbytes", action="store_true") |
86 | 99 | parser.add_argument("--load_in_4bit", type=bool, default=False) |
|
118 | 131 | dataset=args.dataset, |
119 | 132 | bits=args.bits, |
120 | 133 | desc_act=args.desc_act, |
121 | | - damp_percent=args.gptq_percdamp, |
122 | | - sym=True if args.woq_scheme == "sym" else False, |
123 | | - blocksize=args.gptq_block_size, |
124 | | - nsamples=args.gptq_nsamples, |
| 134 | + damp_percent=args.damp_percent, |
| 135 | + sym=True if args.scheme == "sym" else False, |
| 136 | + blocksize=args.blocksize, |
| 137 | + nsamples=args.nsamples, |
125 | 138 | static_groups=args.static_groups, |
126 | | - group_size=args.woq_group_size, |
| 139 | + group_size=args.group_size, |
127 | 140 | max_input_length=args.max_input_length, |
128 | 141 | compute_dtype=args.compute_dtype, |
129 | 142 | scale_dtype=args.compute_dtype, |
130 | | - weight_dtype=args.woq_dtype, |
| 143 | + weight_dtype=args.weight_dtype, |
131 | 144 | calib_iters=args.calib_iters, |
132 | 145 | ) |
133 | 146 | else: |
134 | 147 | quantization_config = RtnConfig( |
135 | | - compute_dtype=args.compute_dtype, weight_dtype=args.woq_dtype, |
136 | | - group_size=args.woq_group_size, scale_dtype=args.compute_dtype |
| 148 | + compute_dtype=args.compute_dtype, weight_dtype=args.weight_dtype, |
| 149 | + group_size=args.group_size, scale_dtype=args.compute_dtype |
137 | 150 | ) #default is A16W4G16 |
138 | 151 |
|
139 | 152 | # get model |
|
260 | 273 | args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \ |
261 | 274 | if user_model is None else user_model |
262 | 275 | if quantization_config is None: |
263 | | - quantization_config = WeightOnlyQuantConfig.from_pretrained(args.model) |
| 276 | + quantization_config = user_model.quantization_config if hasattr(user_model, "quantization_config") else {} |
264 | 277 | if not args.disable_optimize_transformers: |
265 | 278 | print("Optimize with IPEX...") |
266 | 279 | user_model = ipex.optimize_transformers( |
267 | 280 | user_model.eval(), device=args.device, inplace=True, quantization_config=quantization_config, dtype=torch_dtype) |
268 | 281 | else: |
269 | 282 | print("Disabled optimization with IPEX...") |
| 283 | + |
270 | 284 | results = evaluate( |
271 | 285 | model="hf-causal", |
272 | | - model_args='pretrained='+args.model+',tokenizer=' + args.model + \ |
| 286 | + model_args='pretrained=' + "facebook/opt-125m" +',tokenizer=' + args.model + \ |
273 | 287 | ',dtype=float32,trust_remote_code=' + str(args.trust_remote_code), |
274 | 288 | user_model=user_model, |
275 | 289 | batch_size=args.batch_size, |
|
0 commit comments