11import inspect
2- from typing import Dict , List , Optional , Tuple , Union
32
43import torch
54from compressed_tensors .quantization import disable_quantization
@@ -111,31 +110,40 @@ class AWQModifier(Modifier, QuantizationMixin):
111110 device. Defaults to None, so cached args are not offloaded. Consider setting
112111 to torch.device("cpu") if you are encountering OOM errors
113112 :param duo_scaling: whether to use duo scaling, which uses both input activations
114- and weights to determine the scaling factor
113+ and weights to determine the scaling factor. Defaults to None
114+ If False, only activations are used.
115+ If True, both activations and weights are used.
116+ If None, half the grid search is performed with duo_scaling=False and the
117+ other half is performed with duo_scaling=True.
118+ :param n_grid: when performing the best scales grid search for each mapping,
119+ this specifies how many grid points should be used. To decrease the runtime,
120+ at the possible cost of slightly worse scales, this can be decreased.
121+ Defaults to 20
115122 """
116123
117124 # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
118125 model_config : ConfigDict = ConfigDict (arbitrary_types_allowed = True )
119126
120127 # User-provided vars (in addition to QuantizationMixin args)
121- sequential_targets : Union [str , List [str ], None ] = None
122- mappings : Optional [List [AWQMapping ]] = None
123- offload_device : Optional [torch .device ] = None
124- duo_scaling : bool = True
128+ sequential_targets : str | list [str ] | None = None
129+ mappings : list [AWQMapping ] | None = None
130+ offload_device : torch .device | None = None
131+ duo_scaling : bool | None = None
132+ n_grid : int = 20
125133
126134 # Private vars set during validation
127- _num_bits : Optional [ int ] = PrivateAttr (default = None )
128- _symmetric : Optional [ bool ] = PrivateAttr (default = None )
129- _group_size : Optional [ int ] = PrivateAttr (default = None )
135+ _num_bits : int | None = PrivateAttr (default = None )
136+ _symmetric : bool | None = PrivateAttr (default = None )
137+ _group_size : int | None = PrivateAttr (default = None )
130138
131139 # Private vars set during initialization, cleared during finalization
132- _resolved_mappings : List [ResolvedMapping ] = PrivateAttr (default_factory = list )
140+ _resolved_mappings : list [ResolvedMapping ] = PrivateAttr (default_factory = list )
133141 # Cache list of forward input args for each parent module, one dict for each batch
134- _parent_args_cache : Dict [Module , IntermediatesCache ] = PrivateAttr (
142+ _parent_args_cache : dict [Module , IntermediatesCache ] = PrivateAttr (
135143 default_factory = dict
136144 )
137145 # Dict[smooth layer name, (activation means, activation counts)]
138- _smooth_activation_means : Dict [str , Tuple [torch .FloatTensor , int ]] = PrivateAttr (
146+ _smooth_activation_means : dict [str , tuple [torch .FloatTensor , int ]] = PrivateAttr (
139147 default_factory = dict
140148 )
141149
@@ -389,7 +397,7 @@ def _setup_activation_cache_hooks(self) -> None:
389397
390398 def cache_parent_kwargs_hook (
391399 module : torch .nn .Module ,
392- args : Tuple [torch .Tensor , ...],
400+ args : tuple [torch .Tensor , ...],
393401 kwargs ,
394402 ):
395403 values = inspect .signature (module .forward ).bind (* args , ** kwargs )
@@ -398,7 +406,7 @@ def cache_parent_kwargs_hook(
398406 def create_cache_smooth_activations_hook_fn (smooth_name ):
399407 def cache_smooth_activations_hook (
400408 _module : torch .nn .Module ,
401- args : Tuple [torch .Tensor , ...],
409+ args : tuple [torch .Tensor , ...],
402410 _output : torch .Tensor ,
403411 ):
404412 self ._smooth_activation_means [smooth_name ] = _accumulate_mean (
@@ -559,13 +567,13 @@ def _smooth(module):
559567 v .batch_intermediates .clear ()
560568 self ._assert_all_activations_consumed ()
561569
562- def _run_samples (self , module : Module ) -> List [torch .Tensor ]:
570+ def _run_samples (self , module : Module ) -> list [torch .Tensor ]:
563571 outputs = [
564572 module (** batch_kwargs ) for batch_kwargs in self ._parent_args_cache [module ]
565573 ]
566574 return [
567575 # If Tuple, assume that first argument is the input
568- output [0 ] if isinstance (output , Tuple ) else output
576+ output [0 ] if isinstance (output , tuple ) else output
569577 for output in outputs
570578 ]
571579
@@ -574,8 +582,8 @@ def _compute_best_scale(
574582 x_mean : torch .Tensor ,
575583 w_mean : torch .Tensor ,
576584 parent_module : torch .nn .Module ,
577- linears2scale : List [torch .nn .Linear ],
578- fp16_outputs : List [torch .Tensor ],
585+ linears2scale : list [torch .nn .Linear ],
586+ fp16_outputs : list [torch .Tensor ],
579587 ) -> torch .Tensor :
580588 """
581589 Compute loss and select best scales
@@ -586,7 +594,6 @@ def _compute_best_scale(
586594 W: original weights in FP16 | layer
587595 s: per channel scaling factor | s^-1 * X
588596 """
589- n_grid = 20
590597 history = []
591598 best_ratio = - 1
592599 best_scales = None
@@ -602,52 +609,61 @@ def _compute_best_scale(
602609 x_mean = x_mean .view (- 1 ).to (device )
603610 w_mean = w_mean .view (- 1 ).to (device )
604611
612+ if self .duo_scaling is None :
613+ # if self.duo_scaling is unsert, perform half the grid search with
614+ # duo_scaling off and half with duo_scaling on
615+ n_grid = int (self .n_grid / 2 )
616+ duo_scalings = [False , True ]
617+ else :
618+ n_grid = self .n_grid
619+ duo_scalings = [self .duo_scaling ]
605620 for ratio in range (n_grid ):
606- # create new scales
607- ratio = ratio / n_grid
608-
609- # NOTE: s^-1 * x is fused here, according to paper
610- if self .duo_scaling :
611- scales = (x_mean .pow (ratio ) / (w_mean .pow (1 - ratio ) + 1e-4 )).clamp (
612- min = 1e-4
613- )
614- else :
615- scales = x_mean .pow (ratio ).clamp (min = 1e-4 ).view (- 1 )
616- scales = scales / (scales .max () * scales .min ()).sqrt ()
617- _scalesview = scales .view (1 , - 1 ).to (device )
618-
619- # avoid scaling values that overflow
620- scales [torch .isinf (scales )] = 1
621- scales [torch .isnan (scales )] = 1
622-
623- # Q(W * s)
624- for linear in linears2scale :
625- linear .weight .mul_ (_scalesview )
626- update_offload_parameter (
627- linear ,
628- "weight" ,
629- _pseudo_quantize_tensor (
630- w = linear .weight .data ,
631- symmetric = self ._symmetric ,
632- bit_width = self ._num_bits ,
633- group_size = self ._group_size ,
634- )[0 ]
635- / _scalesview ,
636- )
621+ for duo_scaling in duo_scalings :
622+ # create new scales
623+ ratio = ratio / n_grid
624+
625+ # NOTE: s^-1 * x is fused here, according to paper
626+ if duo_scaling :
627+ scales = (x_mean .pow (ratio ) / (w_mean .pow (1 - ratio ) + 1e-4 )).clamp (
628+ min = 1e-4
629+ )
630+ else :
631+ scales = x_mean .pow (ratio ).clamp (min = 1e-4 ).view (- 1 )
632+ scales = scales / (scales .max () * scales .min ()).sqrt ()
633+ _scalesview = scales .view (1 , - 1 ).to (device )
634+
635+ # avoid scaling values that overflow
636+ scales [torch .isinf (scales )] = 1
637+ scales [torch .isnan (scales )] = 1
638+
639+ # Q(W * s)
640+ for linear in linears2scale :
641+ linear .weight .mul_ (_scalesview )
642+ update_offload_parameter (
643+ linear ,
644+ "weight" ,
645+ _pseudo_quantize_tensor (
646+ w = linear .weight .data ,
647+ symmetric = self ._symmetric ,
648+ bit_width = self ._num_bits ,
649+ group_size = self ._group_size ,
650+ )[0 ]
651+ / _scalesview ,
652+ )
637653
638- # W * X
639- int_w_outputs = self ._run_samples (parent_module )
654+ # W * X
655+ int_w_outputs = self ._run_samples (parent_module )
640656
641- # compute mean squared error (L2 norm)
642- loss = self ._compute_loss (fp16_outputs , int_w_outputs , device )
657+ # compute mean squared error (L2 norm)
658+ loss = self ._compute_loss (fp16_outputs , int_w_outputs , device )
643659
644- history .append (loss )
645- if loss < best_error :
646- best_error = loss
647- best_ratio = ratio
648- best_scales = scales .clone ()
660+ history .append (loss )
661+ if loss < best_error :
662+ best_error = loss
663+ best_ratio = ratio
664+ best_scales = scales .clone ()
649665
650- parent_module .load_state_dict (org_sd , strict = False )
666+ parent_module .load_state_dict (org_sd , strict = False )
651667
652668 if best_ratio == - 1 :
653669 logger .debug (history )
@@ -667,8 +683,8 @@ def _compute_best_scale(
667683 @torch .no_grad ()
668684 def _compute_loss (
669685 self ,
670- fp16_outputs : List [torch .Tensor ],
671- int_w_outputs : List [torch .Tensor ],
686+ fp16_outputs : list [torch .Tensor ],
687+ int_w_outputs : list [torch .Tensor ],
672688 device : torch .device ,
673689 ) -> torch .Tensor :
674690 loss = 0.0
@@ -746,8 +762,8 @@ def _pseudo_quantize_tensor(
746762
747763def _accumulate_mean (
748764 inp : torch .Tensor ,
749- prev_mean_and_count : Optional [ Tuple [ torch .FloatTensor , int ]] ,
750- ) -> Tuple [torch .FloatTensor , int ]:
765+ prev_mean_and_count : tuple [ torch .FloatTensor , int ] | None ,
766+ ) -> tuple [torch .FloatTensor , int ]:
751767 sum_added = inp .sum (dim = 0 )
752768 num_added = inp .size (0 )
753769 if prev_mean_and_count is None :
@@ -761,7 +777,7 @@ def _accumulate_mean(
761777 return (prev_sum + sum_added ) / new_count , new_count
762778
763779
764- def get_lowest_common_parent (names : List [str ], module : Module ) -> Tuple [str , Module ]:
780+ def get_lowest_common_parent (names : list [str ], module : Module ) -> tuple [str , Module ]:
765781 """
766782 Given a list of names, returns the lowest-scope common parent.
767783
0 commit comments