Skip to content

Commit 98f9a7e

Browse files
[BE] Remove Float8Linear from quant_api.py (#3085)
* remove Float8Linear from quant_api.py * move Float8Linear -> nn.Linear conversion into Float8WeigtOnlyConfig handler * do conversion in fp8/fpx related handlers
1 parent bb65dbc commit 98f9a7e

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

torchao/quantization/quant_api.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,6 @@ def _replace_with_custom_fn_if_matches_filter(
203203
Returns:
204204
None
205205
"""
206-
if isinstance(model, Float8Linear):
207-
with torch.device("meta"):
208-
new_module = nn.Linear(model.in_features, model.out_features)
209-
new_module.weight = model.weight
210-
new_module.bias = model.bias
211-
model = new_module
212206
if filter_fn(model, cur_fqn[:-1]):
213207
if device is not None:
214208
model.to(device=device) # move to device before quantization
@@ -249,12 +243,6 @@ def _replace_with_custom_fn_if_matches_filter_with_name(
249243
Returns:
250244
None
251245
"""
252-
if isinstance(model, Float8Linear):
253-
with torch.device("meta"):
254-
new_module = nn.Linear(model.in_features, model.out_features)
255-
new_module.weight = model.weight
256-
new_module.bias = model.bias
257-
model = new_module
258246
if filter_fn(model, cur_fqn[:-1]):
259247
if device is not None:
260248
model.to(device=device) # move to device before quantization
@@ -1687,6 +1675,10 @@ def _float8_weight_only_transform(
16871675
"applying int8 weight only quant requires module to have weight attribute"
16881676
+ " but {module} does not have one"
16891677
)
1678+
1679+
if isinstance(module, Float8Linear):
1680+
module = _unwrap_float8_linear(module)
1681+
16901682
new_weight = _float8_weight_only_quant_tensor(module.weight, config)
16911683

16921684
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
@@ -1896,6 +1888,9 @@ def _float8_dynamic_activation_float8_weight_transform(
18961888
"applying float8 dynamic activation quant requires module to have weight attribute"
18971889
+ f"but {module} does not have one"
18981890
)
1891+
if isinstance(module, Float8Linear):
1892+
module = _unwrap_float8_linear(module)
1893+
18991894
quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor(
19001895
module.weight, config
19011896
)
@@ -1931,6 +1926,9 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
19311926
):
19321927
assert is_sm_at_least_90(), "Float8 quantization is only supported on CUDA>=9.0"
19331928

1929+
if isinstance(module, Float8Linear):
1930+
module = _unwrap_float8_linear(module)
1931+
19341932
weight = module.weight
19351933
weight_dtype = config.weight_dtype
19361934
activation_dtype = config.activation_dtype
@@ -1995,6 +1993,9 @@ def _float8_static_activation_float8_weight_transform(
19951993
"Float8 static activation quantization is only supported on CUDA 8.9 and above"
19961994
)
19971995

1996+
if isinstance(module, Float8Linear):
1997+
module = _unwrap_float8_linear(module)
1998+
19981999
scale = config.scale
19992000
activation_dtype = config.activation_dtype
20002001
weight_dtype = config.weight_dtype
@@ -2364,6 +2365,9 @@ def _fpx_weight_only_transform(
23642365
if config.set_inductor_config:
23652366
torchao.quantization.utils.recommended_inductor_config_setter()
23662367

2368+
if isinstance(module, Float8Linear):
2369+
module = _unwrap_float8_linear(module)
2370+
23672371
from torchao.dtypes import to_affine_quantized_fpx
23682372
from torchao.dtypes.floatx import FloatxTensorCoreLayout
23692373

@@ -2443,6 +2447,21 @@ def _module_fqn_to_config_handler(
24432447
return module
24442448

24452449

2450+
def _unwrap_float8_linear(module: Float8Linear) -> nn.Linear:
2451+
"""
2452+
Unwrap a torchao Float8Linear by returning a nn.Linear with the same weights and bias.
2453+
2454+
Torchao inference quantization techniques are generally only applicable to nn.Linear
2455+
layers, so this helper is useful for unwrapping models trained with torchao float8 training,
2456+
which replaces nn.Linear layers with Float8Linear layers.
2457+
"""
2458+
with torch.device("meta"):
2459+
new_module = nn.Linear(module.in_features, module.out_features)
2460+
new_module.weight = module.weight
2461+
new_module.bias = module.bias
2462+
return new_module
2463+
2464+
24462465
torch.serialization.add_safe_globals(
24472466
[
24482467
_int8_asymm_per_token_quant,

0 commit comments

Comments
 (0)