@@ -365,6 +365,15 @@ def main(
365365 skip_gate_qwen_1_5_moe_a_2_7b: if True, skips gate quantization for Qwen1.5-MoE-A2.7B model
366366 save_model_to_disk: if True, saves quantized model to local disk
367367 """
368+ # Test prompts
369+ prompts = [
370+ "Why is Pytorch 2.0 the best machine learning compiler?" ,
371+ "Hello, my name is" ,
372+ "The president of the United States is" ,
373+ "The capital of France is" ,
374+ "The future of AI is" ,
375+ ]
376+
368377 # Set seed before creating the model
369378 set_seed (42 )
370379
@@ -406,40 +415,82 @@ def main(
406415 # Get quantization config
407416 quantization_config = get_quantization_config (args )
408417
409- # Load and quantize model
410- print ("Loading and quantizing model..." )
411- quantized_model = AutoModelForCausalLM .from_pretrained (
412- args .model_name ,
413- torch_dtype = "bfloat16" ,
414- device_map = args .device_map ,
415- quantization_config = quantization_config ,
416- )
417- print (quantized_model )
418+ if args .model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct" :
419+ # TODO(future): maybe unify with the else branch, need to figure
420+ # out the right syntax for preparing inputs and running generation
421+ # TODO(future): make this work for multiple prompts
422+ from transformers import AutoProcessor , Llama4ForConditionalGeneration
418423
419- # Load tokenizer
420- tokenizer = AutoTokenizer .from_pretrained (args .model_name )
424+ processor = AutoProcessor .from_pretrained (args .model_name )
425+ quantized_model = Llama4ForConditionalGeneration .from_pretrained (
426+ args .model_name ,
427+ # Note: flex does not work with naive device_map="auto"
428+ # attn_implementation="flex_attention",
429+ device_map = "auto" ,
430+ torch_dtype = torch .bfloat16 ,
431+ quantization_config = quantization_config ,
432+ )
421433
422- # Test prompts
423- prompts = [
424- "Why is Pytorch 2.0 the best machine learning compiler?" ,
425- "Hello, my name is" ,
426- "The president of the United States is" ,
427- "The capital of France is" ,
428- "The future of AI is" ,
429- ]
434+ messages = []
435+ for prompt in prompts [:1 ]:
436+ messages .append (
437+ {
438+ "role" : "user" ,
439+ "content" : [
440+ {"type" : "text" , "text" : prompt },
441+ ],
442+ },
443+ )
430444
431- # Test generation
432- print ("\n Testing quantized model generation..." )
433- input_ids = tokenizer (prompts , return_tensors = "pt" , padding = True ).to (
434- quantized_model .device
435- )
436- outputs = quantized_model .generate (
437- ** input_ids , max_new_tokens = args .max_new_tokens
438- )
445+ inputs = processor .apply_chat_template (
446+ messages ,
447+ add_generation_prompt = True ,
448+ tokenize = True ,
449+ return_dict = True ,
450+ return_tensors = "pt" ,
451+ padding = "longest" ,
452+ ).to (quantized_model .device )
453+
454+ outputs = quantized_model .generate (
455+ ** inputs ,
456+ max_new_tokens = args .max_new_tokens ,
457+ )
458+
459+ responses = processor .batch_decode (
460+ outputs [:, inputs ["input_ids" ].shape [- 1 ] :]
461+ )
462+ for response in responses :
463+ print (response )
464+
465+ else :
466+ # Load tokenizer
467+ tokenizer = AutoTokenizer .from_pretrained (args .model_name )
468+ # breakpoint()
469+
470+ # Load and quantize model
471+ print ("Loading and quantizing model..." )
472+ quantized_model = AutoModelForCausalLM .from_pretrained (
473+ args .model_name ,
474+ torch_dtype = "bfloat16" ,
475+ device_map = args .device_map ,
476+ quantization_config = quantization_config ,
477+ )
478+ print (quantized_model )
479+
480+ # Test generation
481+ print ("\n Testing quantized model generation..." )
482+ input_ids = tokenizer (prompts , return_tensors = "pt" , padding = True ).to (
483+ quantized_model .device
484+ )
485+ outputs = quantized_model .generate (
486+ ** input_ids , max_new_tokens = args .max_new_tokens
487+ )
439488
440- for i , (prompt , output ) in enumerate (zip (prompts , outputs , strict = False )):
441- generated_text = tokenizer .decode (output , skip_special_tokens = True )
442- print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
489+ for i , (prompt , output ) in enumerate (
490+ zip (prompts , outputs , strict = False )
491+ ):
492+ generated_text = tokenizer .decode (output , skip_special_tokens = True )
493+ print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
443494
444495 if args .save_model_to_disk :
445496 # Save quantized model
@@ -457,6 +508,7 @@ def main(
457508 tokenizer .push_to_hub (model_name )
458509
459510 if args .save_model_to_disk :
511+ # TODO(future): support this for LLaMa 4 Scout
460512 # Load saved model to verify
461513 print ("\n Loading saved quantized model to verify..." )
462514 # TODO: do we really need `weights_only=False` here?
0 commit comments