@@ -329,7 +329,7 @@ from torchao.dtypes import Int4XPULayout
329329from torchao.quantization.quant_primitives import ZeroPointDomain
330330
331331
332- quant_config = Int4WeightOnlyConfig(group_size = 128 , layout = Int4XPULayout(), zero_point_domain = ZeroPointDomain.INT )
332+ quant_config = Int4WeightOnlyConfig(group_size = 128 , layout = Int4XPULayout(), zero_point_domain = ZeroPointDomain.INT , int4_packing_format = " plain_int32 " )
333333quantization_config = TorchAoConfig(quant_type = quant_config)
334334
335335# Load and quantize the model
@@ -342,7 +342,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
342342
343343tokenizer = AutoTokenizer.from_pretrained(" meta-llama/Llama-3.1-8B-Instruct" )
344344input_text = " What are we having for dinner?"
345- input_ids = tokenizer(input_text, return_tensors = " pt" ).to(model .device)
345+ input_ids = tokenizer(input_text, return_tensors = " pt" ).to(quantized_model .device)
346346
347347# auto-compile the quantized model with `cache_implementation="static"` to get speed up
348348output = quantized_model.generate(** input_ids, max_new_tokens = 10 , cache_implementation = " static" )
@@ -395,7 +395,7 @@ from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
395395from torchao.quantization import Int4WeightOnlyConfig
396396from torchao.dtypes import Int4CPULayout
397397
398- quant_config = Int4WeightOnlyConfig(group_size = 128 , layout = Int4CPULayout())
398+ quant_config = Int4WeightOnlyConfig(group_size = 128 , layout = Int4CPULayout(), int4_packing_format = " opaque " )
399399quantization_config = TorchAoConfig(quant_type = quant_config)
400400
401401# Load and quantize the model
0 commit comments