Skip to content

Commit c011a0f

Browse files
authored
Merge pull request #61 from vkuzo/20250922_exploration
add support for quantizing Qwen/Qwen1.5-MoE-A2.7B experts
2 parents cb93311 + 29417c3 commit c011a0f

File tree

1 file changed

+49
-22
lines changed

1 file changed

+49
-22
lines changed

torchao_hf_vllm/torchao_hf_script.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,25 @@ def get_quantization_config(args):
6565
case "autoquant":
6666
return TorchAoConfig("autoquant", min_sqnr=args.min_sqnr)
6767
case "fp8":
68-
return TorchAoConfig(
69-
Float8DynamicActivationFloat8WeightConfig(granularity=gran)
70-
)
68+
single_config = Float8DynamicActivationFloat8WeightConfig(granularity=gran)
69+
if args.experts_only_qwen_1_5_moe_a_2_7b:
70+
from torchao.quantization import ModuleFqnToConfig
71+
expert_fqn_to_config = {}
72+
# TODO(future PR): this is annoying, I should be able to use a regex here
73+
for layer_idx in range(24):
74+
for expert_idx in range(60):
75+
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj"] = single_config
76+
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj"] = single_config
77+
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj"] = single_config
78+
module_fqn_to_config = ModuleFqnToConfig({
79+
"_default": None,
80+
**expert_fqn_to_config,
81+
})
82+
return TorchAoConfig(
83+
quant_type=module_fqn_to_config,
84+
)
85+
else:
86+
return TorchAoConfig(single_config)
7187
case "int4_weight_only":
7288
return TorchAoConfig(Int4WeightOnlyConfig(group_size=128))
7389
case "int8_weight_only":
@@ -148,12 +164,14 @@ def main(
148164
benchmark: bool = False,
149165
bench_tokens: int = 100,
150166
device_map: str = "cuda",
167+
experts_only_qwen_1_5_moe_a_2_7b: bool = False,
168+
save_model_to_disk: bool = True,
151169
):
152170
"""
153171
Quantize a model with TorchAO and test its performance.
154172
155173
Args:
156-
model_name: Model to quantize (e.g., meta-llama/Meta-Llama-3-8B, facebook/opt-125m)
174+
model_name: Model to quantize (e.g., meta-llama/Meta-Llama-3-8B, facebook/opt-125m, Qwen/Qwen1.5-MoE-A2.7B)
157175
output_dir: Directory to save the quantized model
158176
push_to_hub: HF Hub repo name to push the model (e.g., 'your-username/model-name')
159177
quant_type: Quantization type to use
@@ -163,6 +181,8 @@ def main(
163181
benchmark: Run benchmarking comparison
164182
bench_tokens: Number of tokens to generate for benchmarking
165183
device_map: Device mapping strategy
184+
experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model
185+
save_model_to_disk: if True, saves quantized model to local disk
166186
"""
167187
# Set seed before creating the model
168188
set_seed(42)
@@ -184,9 +204,13 @@ def main(
184204
benchmark=benchmark,
185205
bench_tokens=bench_tokens,
186206
device_map=device_map,
207+
experts_only_qwen_1_5_moe_a_2_7b=experts_only_qwen_1_5_moe_a_2_7b,
208+
save_model_to_disk=save_model_to_disk,
187209
)
188-
print(f"Using Model name: {args.model_name}")
189-
print(f"Quantization type: {args.quant_type}")
210+
print(f"{args=}")
211+
212+
if args.experts_only_qwen_1_5_moe_a_2_7b:
213+
assert args.quant_type == "fp8", "unsupported"
190214

191215
# Create output directory
192216
output_dir = Path(args.output_dir)
@@ -228,10 +252,11 @@ def main(
228252
generated_text = tokenizer.decode(output, skip_special_tokens=True)
229253
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
230254

231-
# Save quantized model
232-
print(f"\n📁Saving quantized model to: {output_dir}")
233-
quantized_model.save_pretrained(output_dir, safe_serialization=False)
234-
tokenizer.save_pretrained(output_dir)
255+
if args.save_model_to_disk:
256+
# Save quantized model
257+
print(f"\nSaving quantized model to: {output_dir}")
258+
quantized_model.save_pretrained(output_dir, safe_serialization=False)
259+
tokenizer.save_pretrained(output_dir)
235260

236261
# Push to HuggingFace hub if requested
237262
if args.push_to_hub:
@@ -242,22 +267,24 @@ def main(
242267
quantized_model.push_to_hub(model_name, safe_serialization=False)
243268
tokenizer.push_to_hub(model_name)
244269

245-
# Load saved model to verify
246-
print("\nLoading saved quantized model to verify...")
247-
# TODO: do we really need `weights_only=False` here?
248-
loaded_model = AutoModelForCausalLM.from_pretrained(
249-
output_dir, device_map=args.device_map, torch_dtype="auto", weights_only=False,
250-
)
270+
if args.save_model_to_disk:
271+
# Load saved model to verify
272+
print("\nLoading saved quantized model to verify...")
273+
# TODO: do we really need `weights_only=False` here?
274+
loaded_model = AutoModelForCausalLM.from_pretrained(
275+
output_dir, device_map=args.device_map, torch_dtype="auto", weights_only=False,
276+
)
251277

252-
# Test loaded model with first prompt
253-
test_prompt = prompts[0]
254-
input_ids = tokenizer(test_prompt, return_tensors="pt").to(loaded_model.device)
255-
output = loaded_model.generate(**input_ids, max_new_tokens=args.max_new_tokens)
256-
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
257-
print(f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}")
278+
# Test loaded model with first prompt
279+
test_prompt = prompts[0]
280+
input_ids = tokenizer(test_prompt, return_tensors="pt").to(loaded_model.device)
281+
output = loaded_model.generate(**input_ids, max_new_tokens=args.max_new_tokens)
282+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
283+
print(f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}")
258284

259285
# Benchmark if requested
260286
if args.benchmark:
287+
assert args.save_model_to_disk, "unsupported"
261288
print("\nBenchmarking models...")
262289
# Benchmark quantized model
263290
print("Benchmarking quantized model:")

0 commit comments

Comments
 (0)