2323 Int4WeightOnlyConfig ,
2424 Int8DynamicActivationInt8WeightConfig ,
2525 Int8WeightOnlyConfig ,
26+ PerBlock ,
2627 PerRow ,
2728 PerTensor ,
2829 UIntXWeightOnlyConfig ,
@@ -44,6 +45,7 @@ def run_evaluation(
4445 calibration_limit : Optional [int ] = None ,
4546 calibration_seq_length : Optional [int ] = None ,
4647 pad_calibration_inputs : bool = False ,
48+ print_model : bool = False ,
4749):
4850 """Runs the evaluation of a model using LM Eval."""
4951 print (
@@ -169,6 +171,14 @@ def run_evaluation(
169171 model ,
170172 Float8DynamicActivationFloat8WeightConfig (granularity = granularity ),
171173 )
174+ if quantization == "float8_a1x128_w128x128" :
175+ config = Float8DynamicActivationFloat8WeightConfig (
176+ granularity = (PerBlock ((1 , 128 )), PerBlock ((128 , 128 ))),
177+ activation_value_lb = 1e-12 ,
178+ )
179+ # TODO(future): all workflows in this file should be skipping quantization
180+ # of `lm_head`
181+ quantize_ (model , config )
172182 if "autoround" in quantization :
173183 from transformers import AutoTokenizer
174184
@@ -273,7 +283,16 @@ def run_evaluation(
273283 )
274284
275285 if compile :
276- model = torch .compile (model , mode = "max-autotune" , fullgraph = True )
286+ # TODO(future PR): clean this up
287+ if quantization == "float8_a1x128_w128x128" :
288+ # we don't need max-autotune for float8 blockwise quant
289+ model = torch .compile (model )
290+ else :
291+ model = torch .compile (model , mode = "max-autotune" , fullgraph = True )
292+
293+ if print_model :
294+ print (model )
295+
277296 with torch .no_grad ():
278297 print ("Running evaluation ..." )
279298 # avoid circular imports
@@ -371,6 +390,9 @@ def run_evaluation(
371390 default = False ,
372391 help = "pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower" ,
373392 )
393+ parser .add_argument (
394+ "--print_model" , action = "store_true" , help = "Whether to print the model."
395+ )
374396
375397 args = parser .parse_args ()
376398 run_evaluation (
@@ -387,4 +409,5 @@ def run_evaluation(
387409 args .calibration_limit ,
388410 args .calibration_seq_length ,
389411 args .pad_calibration_inputs ,
412+ args .print_model ,
390413 )
0 commit comments