11import inspect
2- from typing import Dict , List , Optional , Tuple , Union
2+ from itertools import product
3+ from typing import Literal
34
45import torch
56from compressed_tensors .quantization import disable_quantization
@@ -111,31 +112,40 @@ class AWQModifier(Modifier, QuantizationMixin):
111112 device. Defaults to None, so cached args are not offloaded. Consider setting
112113 to torch.device("cpu") if you are encountering OOM errors
113114 :param duo_scaling: whether to use duo scaling, which uses both input activations
114- and weights to determine the scaling factor
115+ and weights to determine the scaling factor. Defaults to True
116+ If True, both activations and weights are used.
117+ If False, only activations are used.
118+ If "both", half the grid search is performed with duo_scaling=False and the
119+ other half is performed with duo_scaling=True.
120+ :param n_grid: when performing the best scales grid search for each mapping,
121+ this specifies how many grid points should be used. To decrease the runtime,
122+ at the possible cost of slightly worse scales, this can be decreased.
123+ Defaults to 20
115124 """
116125
117126 # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
118127 model_config : ConfigDict = ConfigDict (arbitrary_types_allowed = True )
119128
120129 # User-provided vars (in addition to QuantizationMixin args)
121- sequential_targets : Union [str , List [str ], None ] = None
122- mappings : Optional [List [AWQMapping ]] = None
123- offload_device : Optional [torch .device ] = None
124- duo_scaling : bool = True
130+ sequential_targets : str | list [str ] | None = None
131+ mappings : list [AWQMapping ] | None = None
132+ offload_device : torch .device | None = None
133+ duo_scaling : bool | Literal ["both" ] = True
134+ n_grid : int = 20
125135
126136 # Private vars set during validation
127- _num_bits : Optional [ int ] = PrivateAttr (default = None )
128- _symmetric : Optional [ bool ] = PrivateAttr (default = None )
129- _group_size : Optional [ int ] = PrivateAttr (default = None )
137+ _num_bits : int | None = PrivateAttr (default = None )
138+ _symmetric : bool | None = PrivateAttr (default = None )
139+ _group_size : int | None = PrivateAttr (default = None )
130140
131141 # Private vars set during initialization, cleared during finalization
132- _resolved_mappings : List [ResolvedMapping ] = PrivateAttr (default_factory = list )
142+ _resolved_mappings : list [ResolvedMapping ] = PrivateAttr (default_factory = list )
133143 # Cache list of forward input args for each parent module, one dict for each batch
134- _parent_args_cache : Dict [Module , IntermediatesCache ] = PrivateAttr (
144+ _parent_args_cache : dict [Module , IntermediatesCache ] = PrivateAttr (
135145 default_factory = dict
136146 )
137147 # Dict[smooth layer name, (activation means, activation counts)]
138- _smooth_activation_means : Dict [str , Tuple [torch .FloatTensor , int ]] = PrivateAttr (
148+ _smooth_activation_means : dict [str , tuple [torch .FloatTensor , int ]] = PrivateAttr (
139149 default_factory = dict
140150 )
141151
@@ -389,7 +399,7 @@ def _setup_activation_cache_hooks(self) -> None:
389399
390400 def cache_parent_kwargs_hook (
391401 module : torch .nn .Module ,
392- args : Tuple [torch .Tensor , ...],
402+ args : tuple [torch .Tensor , ...],
393403 kwargs ,
394404 ):
395405 values = inspect .signature (module .forward ).bind (* args , ** kwargs )
@@ -398,7 +408,7 @@ def cache_parent_kwargs_hook(
398408 def create_cache_smooth_activations_hook_fn (smooth_name ):
399409 def cache_smooth_activations_hook (
400410 _module : torch .nn .Module ,
401- args : Tuple [torch .Tensor , ...],
411+ args : tuple [torch .Tensor , ...],
402412 _output : torch .Tensor ,
403413 ):
404414 self ._smooth_activation_means [smooth_name ] = _accumulate_mean (
@@ -454,9 +464,11 @@ def _apply_smoothing(self, model: Module) -> None:
454464 balance_layers = mapping .balance_layers
455465 parent_module = mapping .parent
456466
457- with align_modules (
458- [parent_module , smooth_layer , * balance_layers ]
459- ), calibration_forward_context (model ), HooksMixin .disable_hooks ():
467+ with (
468+ align_modules ([parent_module , smooth_layer , * balance_layers ]),
469+ calibration_forward_context (model ),
470+ HooksMixin .disable_hooks (),
471+ ):
460472 # [STEP 1]: Compute per-channel mean of normalised weights
461473 # All layer weights are concatted together
462474 weight = torch .cat ([bl .weight for bl in balance_layers ], dim = 0 )
@@ -559,13 +571,13 @@ def _smooth(module):
559571 v .batch_intermediates .clear ()
560572 self ._assert_all_activations_consumed ()
561573
562- def _run_samples (self , module : Module ) -> List [torch .Tensor ]:
574+ def _run_samples (self , module : Module ) -> list [torch .Tensor ]:
563575 outputs = [
564576 module (** batch_kwargs ) for batch_kwargs in self ._parent_args_cache [module ]
565577 ]
566578 return [
567579 # If Tuple, assume that first argument is the input
568- output [0 ] if isinstance (output , Tuple ) else output
580+ output [0 ] if isinstance (output , tuple ) else output
569581 for output in outputs
570582 ]
571583
@@ -574,8 +586,8 @@ def _compute_best_scale(
574586 x_mean : torch .Tensor ,
575587 w_mean : torch .Tensor ,
576588 parent_module : torch .nn .Module ,
577- linears2scale : List [torch .nn .Linear ],
578- fp16_outputs : List [torch .Tensor ],
589+ linears2scale : list [torch .nn .Linear ],
590+ fp16_outputs : list [torch .Tensor ],
579591 ) -> torch .Tensor :
580592 """
581593 Compute loss and select best scales
@@ -586,7 +598,6 @@ def _compute_best_scale(
586598 W: original weights in FP16 | layer
587599 s: per channel scaling factor | s^-1 * X
588600 """
589- n_grid = 20
590601 history = []
591602 best_ratio = - 1
592603 best_scales = None
@@ -602,12 +613,21 @@ def _compute_best_scale(
602613 x_mean = x_mean .view (- 1 ).to (device )
603614 w_mean = w_mean .view (- 1 ).to (device )
604615
605- for ratio in range (n_grid ):
616+ match self .duo_scaling :
617+ # if self.duo_scaling is "both", perform half the grid search with
618+ # duo_scaling off and half with duo_scaling on
619+ case "both" :
620+ n_grid = int (self .n_grid / 2 )
621+ duo_scalings = [False , True ]
622+ case _:
623+ n_grid = self .n_grid
624+ duo_scalings = [self .duo_scaling ]
625+ for grid_idx , use_duo_scaling in product (range (n_grid ), duo_scalings ):
606626 # create new scales
607- ratio = ratio / n_grid
627+ ratio = grid_idx / n_grid
608628
609629 # NOTE: s^-1 * x is fused here, according to paper
610- if self . duo_scaling :
630+ if use_duo_scaling :
611631 scales = (x_mean .pow (ratio ) / (w_mean .pow (1 - ratio ) + 1e-4 )).clamp (
612632 min = 1e-4
613633 )
@@ -667,8 +687,8 @@ def _compute_best_scale(
667687 @torch .no_grad ()
668688 def _compute_loss (
669689 self ,
670- fp16_outputs : List [torch .Tensor ],
671- int_w_outputs : List [torch .Tensor ],
690+ fp16_outputs : list [torch .Tensor ],
691+ int_w_outputs : list [torch .Tensor ],
672692 device : torch .device ,
673693 ) -> torch .Tensor :
674694 loss = 0.0
@@ -746,8 +766,8 @@ def _pseudo_quantize_tensor(
746766
747767def _accumulate_mean (
748768 inp : torch .Tensor ,
749- prev_mean_and_count : Optional [ Tuple [ torch .FloatTensor , int ]] ,
750- ) -> Tuple [torch .FloatTensor , int ]:
769+ prev_mean_and_count : tuple [ torch .FloatTensor , int ] | None ,
770+ ) -> tuple [torch .FloatTensor , int ]:
751771 sum_added = inp .sum (dim = 0 )
752772 num_added = inp .size (0 )
753773 if prev_mean_and_count is None :
@@ -761,7 +781,7 @@ def _accumulate_mean(
761781 return (prev_sum + sum_added ) / new_count , new_count
762782
763783
764- def get_lowest_common_parent (names : List [str ], module : Module ) -> Tuple [str , Module ]:
784+ def get_lowest_common_parent (names : list [str ], module : Module ) -> tuple [str , Module ]:
765785 """
766786 Given a list of names, returns the lowest-scope common parent.
767787
0 commit comments