@@ -203,12 +203,6 @@ def _replace_with_custom_fn_if_matches_filter(
203203 Returns:
204204 None
205205 """
206- if isinstance (model , Float8Linear ):
207- with torch .device ("meta" ):
208- new_module = nn .Linear (model .in_features , model .out_features )
209- new_module .weight = model .weight
210- new_module .bias = model .bias
211- model = new_module
212206 if filter_fn (model , cur_fqn [:- 1 ]):
213207 if device is not None :
214208 model .to (device = device ) # move to device before quantization
@@ -249,12 +243,6 @@ def _replace_with_custom_fn_if_matches_filter_with_name(
249243 Returns:
250244 None
251245 """
252- if isinstance (model , Float8Linear ):
253- with torch .device ("meta" ):
254- new_module = nn .Linear (model .in_features , model .out_features )
255- new_module .weight = model .weight
256- new_module .bias = model .bias
257- model = new_module
258246 if filter_fn (model , cur_fqn [:- 1 ]):
259247 if device is not None :
260248 model .to (device = device ) # move to device before quantization
@@ -1687,6 +1675,10 @@ def _float8_weight_only_transform(
16871675 "applying int8 weight only quant requires module to have weight attribute"
16881676 + " but {module} does not have one"
16891677 )
1678+
1679+ if isinstance (module , Float8Linear ):
1680+ module = _unwrap_float8_linear (module )
1681+
16901682 new_weight = _float8_weight_only_quant_tensor (module .weight , config )
16911683
16921684 module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
@@ -1896,6 +1888,9 @@ def _float8_dynamic_activation_float8_weight_transform(
18961888 "applying float8 dynamic activation quant requires module to have weight attribute"
18971889 + f"but { module } does not have one"
18981890 )
1891+ if isinstance (module , Float8Linear ):
1892+ module = _unwrap_float8_linear (module )
1893+
18991894 quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor (
19001895 module .weight , config
19011896 )
@@ -1931,6 +1926,9 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
19311926):
19321927 assert is_sm_at_least_90 (), "Float8 quantization is only supported on CUDA>=9.0"
19331928
1929+ if isinstance (module , Float8Linear ):
1930+ module = _unwrap_float8_linear (module )
1931+
19341932 weight = module .weight
19351933 weight_dtype = config .weight_dtype
19361934 activation_dtype = config .activation_dtype
@@ -1995,6 +1993,9 @@ def _float8_static_activation_float8_weight_transform(
19951993 "Float8 static activation quantization is only supported on CUDA 8.9 and above"
19961994 )
19971995
1996+ if isinstance (module , Float8Linear ):
1997+ module = _unwrap_float8_linear (module )
1998+
19981999 scale = config .scale
19992000 activation_dtype = config .activation_dtype
20002001 weight_dtype = config .weight_dtype
@@ -2364,6 +2365,9 @@ def _fpx_weight_only_transform(
23642365 if config .set_inductor_config :
23652366 torchao .quantization .utils .recommended_inductor_config_setter ()
23662367
2368+ if isinstance (module , Float8Linear ):
2369+ module = _unwrap_float8_linear (module )
2370+
23672371 from torchao .dtypes import to_affine_quantized_fpx
23682372 from torchao .dtypes .floatx import FloatxTensorCoreLayout
23692373
@@ -2443,6 +2447,21 @@ def _module_fqn_to_config_handler(
24432447 return module
24442448
24452449
2450+ def _unwrap_float8_linear (module : Float8Linear ) -> nn .Linear :
2451+ """
2452+ Unwrap a torchao Float8Linear by returning a nn.Linear with the same weights and bias.
2453+
2454+ Torchao inference quantization techniques are generally only applicable to nn.Linear
2455+ layers, so this helper is useful for unwrapping models trained with torchao float8 training,
2456+ which replaces nn.Linear layers with Float8Linear layers.
2457+ """
2458+ with torch .device ("meta" ):
2459+ new_module = nn .Linear (module .in_features , module .out_features )
2460+ new_module .weight = module .weight
2461+ new_module .bias = module .bias
2462+ return new_module
2463+
2464+
24462465torch .serialization .add_safe_globals (
24472466 [
24482467 _int8_asymm_per_token_quant ,
0 commit comments