2626from tokenizer import get_tokenizer
2727import time
2828from torchao .quantization .GPTQ import Int4WeightOnlyGPTQQuantizer
29- from torchao ._models .llama .model import prepare_inputs_for_model
29+ from torchao ._models .llama .model import prepare_inputs_for_model , TransformerBlock
3030from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
3131
3232def run_evaluation (
@@ -122,6 +122,51 @@ def run_evaluation(
122122 else :
123123 if not TORCH_VERSION_AT_LEAST_2_5 :
124124 unwrap_tensor_subclass (model )
125+ if "autoround" in quantization :
126+ from torchao .prototype .autoround .autoround_llm import quantize_model_with_autoround_
127+ from transformers import AutoTokenizer
128+
129+ _tokenizer = AutoTokenizer .from_pretrained (checkpoint_path .parent )
130+ # parse args from quantization string:
131+ # autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>
132+ _quant_args = quantization .split ("-" )
133+ _default_quant_args = [False , 200 , 128 , 8 , 2048 , 128 ]
134+ _model_devie = _quant_args [1 ] if len (_quant_args ) > 1 else device
135+ _quant_args = _quant_args [2 :]
136+ quant_lm_head , iters , groupsize , batch_size , seqlen , nsamples = [
137+ int (x ) for x in _quant_args
138+ ] + _default_quant_args [len (_quant_args ) :]
139+ model = model .to (_model_devie )
140+ print (
141+ (
142+ f"Quantizing model with autoround(iters={ iters } , groupsize={ groupsize } , "
143+ f"quant_lm_head={ quant_lm_head } , batch_size={ batch_size } , seqlen={ seqlen } , nsamples={ nsamples } )"
144+ )
145+ )
146+ with torch .device (_model_devie ):
147+ model .setup_caches (
148+ max_batch_size = batch_size , max_seq_length = seqlen , training = True
149+ )
150+
151+ if quant_lm_head :
152+ is_target_module = (
153+ lambda mod , fqn : isinstance (mod , TransformerBlock )
154+ or "output" in fqn
155+ )
156+ else :
157+ is_target_module = lambda mod , fqn : isinstance (mod , TransformerBlock )
158+ quantize_model_with_autoround_ (
159+ model = model ,
160+ tokenizer = _tokenizer ,
161+ is_target_module = is_target_module ,
162+ bits = 4 ,
163+ seqlen = seqlen ,
164+ bs = batch_size ,
165+ iters = iters ,
166+ nsamples = nsamples ,
167+ )
168+ model .to (device )
169+ model .reset_caches ()
125170
126171 if compile :
127172 model = torch .compile (model , mode = "max-autotune" , fullgraph = True )
@@ -145,11 +190,15 @@ def run_evaluation(
145190 parser .add_argument ('--limit' , type = int , default = None , help = 'Number of eval samples to evaluate' )
146191 parser .add_argument ('--precision' , type = lambda x : getattr (torch , x .split ("." )[- 1 ]), default = torch .bfloat16 , help = 'dtype precision to use' )
147192 parser .add_argument ('--device' , type = str , default = "cuda" , help = 'Device to use for evaluation' )
148- parser .add_argument ('-q' , '--quantization' , type = str ,
193+ parser .add_argument (
194+ "-q" ,
195+ "--quantization" ,
196+ type = str ,
149197 help = (
150- 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-gptq, autoquant, autoquant-int4, ' +
151- 'int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
152- )
198+ "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-gptq, "
199+ "autoquant, autoquant-int4, int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, "
200+ "sparse-marlin, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>"
201+ ),
153202 )
154203 parser .add_argument ('--compile' , action = 'store_true' , help = 'Whether to compile the model.' )
155204 parser .add_argument ('--max_length' , type = int , default = None , help = 'Length of text to process at one time' )
0 commit comments