@@ -480,15 +480,17 @@ def quantize_(
480480
481481 for module_fqn , module in model .named_modules ():
482482 if (
483- _fqn_matches_fqn_config (module_fqn , config )
483+ fqn_matches_fqn_config (module_fqn , config )
484484 or _module_param_matches_fqn_config (module , module_fqn , config )
485485 or ("_default" in config .fqn_to_config and _is_linear (module ))
486486 ):
487487 module_name = (
488488 module_fqn .rsplit ("." , 1 ) if "." in module_fqn else module_fqn
489489 )
490490 # this replaces inplace, so no need to reassign
491- _fqn_to_config_handler (module , module_name , config , device )
491+ _fqn_to_config_handler (module , module_name , config )
492+ if device is not None :
493+ module .to (device = device )
492494 return
493495 if isinstance (config , AOBaseConfig ):
494496 filter_fn = _is_linear if filter_fn is None else filter_fn
@@ -2467,7 +2469,6 @@ def _fqn_to_config_handler(
24672469 module : torch .nn .Module ,
24682470 fqn : str ,
24692471 config : FqnToConfig ,
2470- device : Optional [torch .device ] = None ,
24712472):
24722473 """This function expects a module that either is specified in FqnToConfig or has a parameter that is specified in FqnToConfig.
24732474
@@ -2476,17 +2477,13 @@ def _fqn_to_config_handler(
24762477 fqn (str): The fully qualified name of the module containing the parameters.
24772478 config (FqnToConfig): Configuration object containing regex patterns / fqn mapped
24782479 to quantization configurations.
2479- device (Optional[torch.device]): The device to move the module to as part of quantization
24802480
24812481 Returns:
24822482 torch.nn.Module: The modified module with quantized parameters.
24832483
24842484 Raises:
24852485 NotImplementedError: If the quantization configuration is not yet supported for parameter quantization.
24862486 """
2487- if device is not None :
2488- module = module .to (device )
2489-
24902487 parameter_config_found = False
24912488 top_level_params = []
24922489 for i , (parameter_name , param ) in enumerate (list (module .named_parameters ())):
@@ -2560,7 +2557,7 @@ def _fqn_to_config_handler(
25602557 return module
25612558
25622559
2563- def _fqn_matches_fqn_config (
2560+ def fqn_matches_fqn_config (
25642561 fqn : str ,
25652562 config : FqnToConfig ,
25662563):
@@ -2605,7 +2602,7 @@ def _module_param_matches_fqn_config(
26052602 for name , param in module .named_parameters ():
26062603 if name in dir (module ):
26072604 parameter_fqn = f"{ fqn } .{ name } " if len (fqn ) > 0 else name
2608- if _fqn_matches_fqn_config (parameter_fqn , config ):
2605+ if fqn_matches_fqn_config (parameter_fqn , config ):
26092606 return True
26102607
26112608 return False
0 commit comments