@@ -44,7 +44,7 @@ def quantize_model_(
4444 if qlinear_config == "8w" :
4545 assert (
4646 qembedding_group_size == 0
47- ), "8-bit embedding quantization only supports per-channel at the moment, please use qembedding_group_size = 0."
47+ ), "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0."
4848 if qembedding_group_size == 0 :
4949 embedding_weight_granularity = PerAxis (0 )
5050 else :
@@ -71,42 +71,99 @@ def quantize_model_(
7171 )
7272
7373 if qlinear_config :
74+
75+ def build_linear_config (quant_config_key : str , granularity : str , packing_format : Optional [str ] = None ):
76+ if quant_config_key == "8da4w" :
77+ return Int8DynamicActivationIntxWeightConfig (
78+ weight_dtype = torch .int4 ,
79+ weight_granularity = granularity ,
80+ )
81+ if quant_config_key == "4w" :
82+ # Determine if we need to use Int4WeightOnlyConfig with int4_packing_format
83+ if packing_format :
84+ return Int4WeightOnlyConfig (
85+ group_size = qlinear_group_size ,
86+ int4_packing_format = packing_format ,
87+ int4_choose_qparams_algorithm = "hqq" ,
88+ )
89+ else :
90+ return IntxWeightOnlyConfig (
91+ weight_dtype = torch .int4 ,
92+ granularity = granularity ,
93+ )
94+ if quant_config_key == "8w" :
95+ return IntxWeightOnlyConfig (
96+ weight_dtype = torch .int8 ,
97+ granularity = granularity ,
98+ )
99+ if quant_config_key == "8da8w" :
100+ return Int8DynamicActivationIntxWeightConfig (
101+ weight_dtype = torch .int8 ,
102+ weight_granularity = PerAxis (0 ),
103+ )
104+ raise ValueError (f"Unsupported linear quantization config '{ quant_config_key } '." )
105+
106+ qlinear_configs = [cfg .strip () for cfg in qlinear_config .split ("," )]
107+ if any (cfg == "" for cfg in qlinear_configs ):
108+ raise ValueError ("Linear quantization config entries must be non-empty." )
109+ if len (qlinear_configs ) > 2 :
110+ raise ValueError ("Expected at most one fallback linear quantization config, got more than one comma." )
111+
112+ primary_linear_config_key = qlinear_configs [0 ]
113+ fallback_linear_config_key = qlinear_configs [1 ] if len (qlinear_configs ) == 2 else None
114+
74115 if qlinear_group_size == 0 :
75116 linear_weight_granularity = PerAxis (0 )
117+ if fallback_linear_config_key is not None :
118+ logging .warning (
119+ "qlinear_group_size is 0, fallback linear config will not be used as all layers will be quantized with per-axis granularity."
120+ )
121+ fallback_linear_config_key = None
76122 else :
77- assert qlinear_group_size % 2 == 0 , "Linear quantization group size must be a multiple of 2."
123+ assert (
124+ qlinear_group_size % 2 == 0
125+ ), f"Linear quantization group size must be a multiple of 2, got { qlinear_group_size } ."
78126 linear_weight_granularity = PerGroup (qlinear_group_size )
79127
80128 logging .info ("Quantizing linear layers." )
129+ primary_linear_config = build_linear_config (
130+ primary_linear_config_key , linear_weight_granularity , qlinear_packing_format
131+ )
81132
82- # Determine if we need to use Int4WeightOnlyConfig with int4_packing_format
83- if qlinear_config == "4w" and qlinear_packing_format :
84- linear_config = Int4WeightOnlyConfig (
85- group_size = qlinear_group_size ,
86- int4_packing_format = qlinear_packing_format ,
87- int4_choose_qparams_algorithm = "hqq" ,
88- )
89- else :
90- linear_config = {
91- "8da4w" : Int8DynamicActivationIntxWeightConfig (
92- weight_dtype = torch .int4 ,
93- weight_granularity = linear_weight_granularity ,
94- ),
95- "4w" : IntxWeightOnlyConfig (
96- weight_dtype = torch .int4 ,
97- granularity = linear_weight_granularity ,
98- ),
99- "8w" : IntxWeightOnlyConfig (
100- weight_dtype = torch .int8 ,
101- granularity = linear_weight_granularity ,
102- ),
103- }[qlinear_config ]
133+ # First, quantize layers that are compatible with group quantization
134+ def per_group_filter (module , fqn ):
135+ if isinstance (module , torch .nn .Linear ):
136+ # Check if hidden dimension is divisible by group size
137+ # For Linear layers, weight shape is [out_features, in_features]
138+ # Group quantization typically applies to the in_features dimension (dim=1)
139+ return qlinear_group_size == 0 or (module .weight .shape [1 ] % qlinear_group_size == 0 )
140+ return False
104141
105142 quantize_ (
106143 eager_model ,
107- linear_config ,
144+ primary_linear_config ,
145+ filter_fn = per_group_filter ,
108146 )
109147
148+ # Then, quantize incompatible layers using the fallback per-axis config
149+ if fallback_linear_config_key is not None :
150+ fallback_linear_config = build_linear_config (fallback_linear_config_key , PerAxis (0 ))
151+
152+ def per_token_filter (module , fqn ):
153+ if isinstance (module , torch .nn .Linear ):
154+ return module .weight .shape [1 ] % qlinear_group_size != 0
155+ return False
156+
157+ logging .info (
158+ f"Applying fallback linear config '{ fallback_linear_config_key } ' (per-axis)"
159+ f" to layers incompatible with group size { qlinear_group_size } ."
160+ )
161+ quantize_ (
162+ eager_model ,
163+ fallback_linear_config ,
164+ filter_fn = per_token_filter ,
165+ )
166+
110167 # TODO: remove after ExecuTorch dep on Torch >= 2.10.0.
111168 if parse (torch_version ) < parse ("2.10.0.dev20251104" ):
112169 unwrap_tensor_subclass (eager_model )
0 commit comments