2424 NVFP4InferenceConfig ,
2525 NVFP4MMConfig ,
2626)
27- from torchao .quantization import ModuleFqnToConfig
27+ from torchao .quantization import (
28+ ModuleFqnToConfig ,
29+ PerBlock ,
30+ PerRow ,
31+ PerTensor ,
32+ )
2833from torchao .quantization .quant_api import (
2934 CutlassInt4PackedLayout ,
3035 Float8DynamicActivationFloat8WeightConfig ,
3439 Int8DynamicActivationInt4WeightConfig ,
3540 Int8DynamicActivationInt8WeightConfig ,
3641 Int8WeightOnlyConfig ,
37- PerRow ,
38- PerTensor ,
3942)
4043from transformers import AutoModelForCausalLM , AutoTokenizer , TorchAoConfig
4144
@@ -62,6 +65,7 @@ def get_quantization_config(args):
6265 granularity_mapping = {
6366 "per_row" : PerRow (),
6467 "per_tensor" : PerTensor (),
68+ "a1x128_w128x128" : [PerBlock ([1 , 128 ]), PerBlock ([128 , 128 ])],
6569 }
6670
6771 gran = granularity_mapping [args .granularity ]
@@ -71,7 +75,13 @@ def get_quantization_config(args):
7175 return TorchAoConfig ("autoquant" , min_sqnr = args .min_sqnr )
7276 case "fp8" :
7377 single_config = Float8DynamicActivationFloat8WeightConfig (
74- granularity = gran
78+ granularity = gran ,
79+ # the 125m model has a lot of activation zeroes for some
80+ # prompts, need to set a lower bound to prevent scales from
81+ # being 0.
82+ # TODO seems like torchao should do this for me.
83+ # TODO tool to find this (I used bisect on this tiny model).
84+ activation_value_lb = 1.0e-12 ,
7585 )
7686 if args .experts_only_qwen_1_5_moe_a_2_7b :
7787 expert_fqn_to_config = {}
@@ -325,7 +335,9 @@ def main(
325335 "mxfp4" ,
326336 "nvfp4" ,
327337 ] = "fp8" ,
328- granularity : Literal ["per_row" , "per_tensor" ] = "per_row" ,
338+ granularity : Literal [
339+ "per_row" , "per_tensor" , "a1x128_w128x128"
340+ ] = "per_row" ,
329341 min_sqnr : float | None = None ,
330342 max_new_tokens : int = 64 ,
331343 benchmark : bool = False ,
0 commit comments