Skip to content

Commit cfbac91

Browse files
authored
make fqn_matches_fqn_config public and fix module device loading (#3302)
Summary: This PR does two things: * make `fqn_matches_fqn_config` public for transformers * fixes device assignment to happen after quantization to fix a broken test Test Plan: pytest test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags:
1 parent a257166 commit cfbac91

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
float8_static_activation_float8_weight,
7070
float8_weight_only,
7171
fpx_weight_only,
72+
fqn_matches_fqn_config,
7273
gemlite_uintx_weight_only,
7374
int4_dynamic_activation_int4_weight,
7475
int4_weight_only,
@@ -142,6 +143,7 @@
142143
"float8_static_activation_float8_weight",
143144
"uintx_weight_only",
144145
"fpx_weight_only",
146+
"fqn_matches_fqn_config",
145147
"gemlite_uintx_weight_only",
146148
"swap_conv2d_1x1_to_linear",
147149
"Int4DynamicActivationInt4WeightConfig",

torchao/quantization/quant_api.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -480,15 +480,17 @@ def quantize_(
480480

481481
for module_fqn, module in model.named_modules():
482482
if (
483-
_fqn_matches_fqn_config(module_fqn, config)
483+
fqn_matches_fqn_config(module_fqn, config)
484484
or _module_param_matches_fqn_config(module, module_fqn, config)
485485
or ("_default" in config.fqn_to_config and _is_linear(module))
486486
):
487487
module_name = (
488488
module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn
489489
)
490490
# this replaces inplace, so no need to reassign
491-
_fqn_to_config_handler(module, module_name, config, device)
491+
_fqn_to_config_handler(module, module_name, config)
492+
if device is not None:
493+
module.to(device=device)
492494
return
493495
if isinstance(config, AOBaseConfig):
494496
filter_fn = _is_linear if filter_fn is None else filter_fn
@@ -2467,7 +2469,6 @@ def _fqn_to_config_handler(
24672469
module: torch.nn.Module,
24682470
fqn: str,
24692471
config: FqnToConfig,
2470-
device: Optional[torch.device] = None,
24712472
):
24722473
"""This function expects a module that either is specified in FqnToConfig or has a parameter that is specified in FqnToConfig.
24732474
@@ -2476,17 +2477,13 @@ def _fqn_to_config_handler(
24762477
fqn (str): The fully qualified name of the module containing the parameters.
24772478
config (FqnToConfig): Configuration object containing regex patterns / fqn mapped
24782479
to quantization configurations.
2479-
device (Optional[torch.device]): The device to move the module to as part of quantization
24802480
24812481
Returns:
24822482
torch.nn.Module: The modified module with quantized parameters.
24832483
24842484
Raises:
24852485
NotImplementedError: If the quantization configuration is not yet supported for parameter quantization.
24862486
"""
2487-
if device is not None:
2488-
module = module.to(device)
2489-
24902487
parameter_config_found = False
24912488
top_level_params = []
24922489
for i, (parameter_name, param) in enumerate(list(module.named_parameters())):
@@ -2560,7 +2557,7 @@ def _fqn_to_config_handler(
25602557
return module
25612558

25622559

2563-
def _fqn_matches_fqn_config(
2560+
def fqn_matches_fqn_config(
25642561
fqn: str,
25652562
config: FqnToConfig,
25662563
):
@@ -2605,7 +2602,7 @@ def _module_param_matches_fqn_config(
26052602
for name, param in module.named_parameters():
26062603
if name in dir(module):
26072604
parameter_fqn = f"{fqn}.{name}" if len(fqn) > 0 else name
2608-
if _fqn_matches_fqn_config(parameter_fqn, config):
2605+
if fqn_matches_fqn_config(parameter_fqn, config):
26092606
return True
26102607

26112608
return False

0 commit comments

Comments
 (0)