diff --git a/hf_torchao_vllm/quantize_hf_model_with_torchao.py b/hf_torchao_vllm/quantize_hf_model_with_torchao.py index 071c95c..8b8ad79 100644 --- a/hf_torchao_vllm/quantize_hf_model_with_torchao.py +++ b/hf_torchao_vllm/quantize_hf_model_with_torchao.py @@ -83,6 +83,7 @@ def get_quantization_config(args): # TODO tool to find this (I used bisect on this tiny model). activation_value_lb=1.0e-12, ) + if args.experts_only_qwen_1_5_moe_a_2_7b: expert_fqn_to_config = {} # TODO(future PR): this is annoying, I should be able to use a regex here @@ -125,6 +126,32 @@ def get_quantization_config(args): return TorchAoConfig( quant_type=module_fqn_to_config, ) + elif args.ffn_only_llama_4_scout: + # TODO gate this properly + expert_3d_weight_single_config = Float8DynamicActivationFloat8WeightConfig( + # the weights of this model are stored in (B, K, N) layout, and we + # need to quantize rowwise across the K axis, which is `PerRow(1)`. + granularity=[PerRow(), PerRow(1)], + # the 125m model has a lot of activation zeroes for some + # prompts, need to set a lower bound to prevent scales from + # being 0. + # TODO seems like torchao should do this for me. + # TODO tool to find this (I used bisect on this tiny model). + activation_value_lb=1.0e-12, + ) + module_fqn_to_config = ModuleFqnToConfig( + { + r"re:.*\.feed_forward\.experts\.gate_up_proj": expert_3d_weight_single_config, + r"re:.*\.feed_forward\.experts\.down_proj": expert_3d_weight_single_config, + r"re:.*\.shared_expert\.down_proj": single_config, + r"re:.*\.shared_expert\.up_proj": single_config, + r"re:.*\.shared_expert\.gate_proj": single_config, + } + ) + return TorchAoConfig( + quant_type=module_fqn_to_config, + ) + else: return TorchAoConfig(single_config) case "int4_weight_only": @@ -318,6 +345,44 @@ def benchmark_model(model, input_ids, max_new_tokens, name=""): return elapsed +def _inference_with_processor( + model, + processor, + prompts, + args, +) -> None: + messages = [] + for prompt in prompts[:1]: + messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + ], + }, + ) + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="longest", + ).to(model.device) + + outputs = model.generate( + **inputs, + max_new_tokens=args.max_new_tokens, + ) + + responses = processor.batch_decode( + outputs[:, inputs["input_ids"].shape[-1] :] + ) + for response in responses: + print(response) + + def main( model_name: str = "facebook/opt-125m", output_dir: str | None = None, @@ -345,6 +410,7 @@ def main( device_map: str = "cuda", experts_only_qwen_1_5_moe_a_2_7b: bool = False, skip_gate_qwen_1_5_moe_a_2_7b: bool = False, + ffn_only_llama_4_scout: bool = False, save_model_to_disk: bool = True, ): """ @@ -363,6 +429,7 @@ def main( device_map: Device mapping strategy experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model skip_gate_qwen_1_5_moe_a_2_7b: if True, skips gate quantization for Qwen1.5-MoE-A2.7B model + ffn_only_llama_4_scout: if True, FFN only for meta-llama/Llama-4-Scout-17B-16E-Instruct save_model_to_disk: if True, saves quantized model to local disk """ # Test prompts @@ -397,6 +464,7 @@ def main( experts_only_qwen_1_5_moe_a_2_7b=experts_only_qwen_1_5_moe_a_2_7b, save_model_to_disk=save_model_to_disk, skip_gate_qwen_1_5_moe_a_2_7b=skip_gate_qwen_1_5_moe_a_2_7b, + ffn_only_llama_4_scout=ffn_only_llama_4_scout, ) print(f"{args=}") @@ -430,37 +498,26 @@ def main( torch_dtype=torch.bfloat16, quantization_config=quantization_config, ) + print(quantized_model) - messages = [] - for prompt in prompts[:1]: - messages.append( - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - ], - }, - ) - - inputs = processor.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", - padding="longest", - ).to(quantized_model.device) - - outputs = quantized_model.generate( - **inputs, - max_new_tokens=args.max_new_tokens, + print( + "quantized_model.language_model.model.layers[47].feed_forward.experts.down_proj", + type( + quantized_model.language_model.model.layers[ + 47 + ].feed_forward.experts.down_proj + ), ) - - responses = processor.batch_decode( - outputs[:, inputs["input_ids"].shape[-1] :] + print( + "quantized_model.language_model.model.layers[47].feed_forward.experts.gate_up_proj", + type( + quantized_model.language_model.model.layers[ + 47 + ].feed_forward.experts.gate_up_proj + ), ) - for response in responses: - print(response) + + _inference_with_processor(quantized_model, processor, prompts, args) else: # Load tokenizer @@ -495,8 +552,16 @@ def main( if args.save_model_to_disk: # Save quantized model print(f"\nSaving quantized model to: {output_dir}") - quantized_model.save_pretrained(output_dir, safe_serialization=False) - tokenizer.save_pretrained(output_dir) + if args.model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct": + quantized_model.save_pretrained( + output_dir, safe_serialization=False + ) + processor.save_pretrained(output_dir) + else: + quantized_model.save_pretrained( + output_dir, safe_serialization=False + ) + tokenizer.save_pretrained(output_dir) # Push to HuggingFace hub if requested if args.push_to_hub: @@ -509,28 +574,35 @@ def main( if args.save_model_to_disk: # TODO(future): support this for LLaMa 4 Scout - # Load saved model to verify print("\nLoading saved quantized model to verify...") - # TODO: do we really need `weights_only=False` here? - loaded_model = AutoModelForCausalLM.from_pretrained( - output_dir, - device_map=args.device_map, - torch_dtype="auto", - weights_only=False, - ) + if args.model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct": + _inference_with_processor(quantized_model, processor, prompts, args) + print("Verified inference with reloaded model") + + else: + # Load saved model to verify + # TODO: do we really need `weights_only=False` here? + loaded_model = AutoModelForCausalLM.from_pretrained( + output_dir, + device_map=args.device_map, + torch_dtype="auto", + weights_only=False, + ) - # Test loaded model with first prompt - test_prompt = prompts[0] - input_ids = tokenizer(test_prompt, return_tensors="pt").to( - loaded_model.device - ) - output = loaded_model.generate( - **input_ids, max_new_tokens=args.max_new_tokens - ) - generated_text = tokenizer.decode(output[0], skip_special_tokens=True) - print( - f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}" - ) + # Test loaded model with first prompt + test_prompt = prompts[0] + input_ids = tokenizer(test_prompt, return_tensors="pt").to( + loaded_model.device + ) + output = loaded_model.generate( + **input_ids, max_new_tokens=args.max_new_tokens + ) + generated_text = tokenizer.decode( + output[0], skip_special_tokens=True + ) + print( + f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}" + ) # Benchmark if requested if args.benchmark: diff --git a/hf_torchao_vllm/run_quantized_model_in_vllm.py b/hf_torchao_vllm/run_quantized_model_in_vllm.py index 61733b3..43312a3 100644 --- a/hf_torchao_vllm/run_quantized_model_in_vllm.py +++ b/hf_torchao_vllm/run_quantized_model_in_vllm.py @@ -75,6 +75,7 @@ def main( model=model_name, tensor_parallel_size=tp_size, enforce_eager=not compile, + max_model_len=max_tokens, ) # Print diagnostic information @@ -86,6 +87,8 @@ def main( print(f"model_config: {model_config}") print(f"hf_config: {model_config.hf_config}") if print_model: + # TODO: fix this for latest vllm, lines below crash when building from + # source with https://www.internalfb.com/phabricator/paste/view/P2028278010 model = llm.llm_engine.model_executor.driver_worker.model_runner.model print(f"model: {model}")