File tree Expand file tree Collapse file tree 3 files changed +36
-12
lines changed
modifiers/quantization/gptq
tests/llmcompressor/modifiers/quantization Expand file tree Collapse file tree 3 files changed +36
-12
lines changed Original file line number Diff line number Diff line change @@ -120,17 +120,6 @@ class GPTQModifier(Modifier, QuantizationMixin):
120120 _hessians : Dict [torch .nn .Module , torch .Tensor ] = PrivateAttr (default_factory = dict )
121121 _num_samples : Dict [torch .nn .Module , int ] = PrivateAttr (default_factory = dict )
122122
123- @field_validator ("sequential_update" , mode = "before" )
124- def validate_sequential_update (cls , value : bool ) -> bool :
125- if not value :
126- warnings .warn (
127- "`sequential_update=False` is no longer supported, setting "
128- "sequential_update=True" ,
129- DeprecationWarning ,
130- )
131-
132- return True
133-
134123 def resolve_quantization_config (self ) -> QuantizationConfig :
135124 config = super ().resolve_quantization_config ()
136125
@@ -317,3 +306,14 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
317306 if self .offload_hessians :
318307 if module in self ._hessians : # may have been deleted in context
319308 self ._hessians [module ] = self ._hessians [module ].to (device = "cpu" )
309+
310+ @field_validator ("sequential_update" , mode = "before" )
311+ def validate_sequential_update (cls , value : bool ) -> bool :
312+ if not value :
313+ warnings .warn (
314+ "`sequential_update=False` is no longer supported, setting "
315+ "sequential_update=True" ,
316+ DeprecationWarning ,
317+ )
318+
319+ return True
Original file line number Diff line number Diff line change @@ -45,7 +45,13 @@ def __reduce__(self):
4545
4646 @classmethod
4747 def __get_pydantic_core_schema__ (cls , _source_type , _handler ):
48- return core_schema .no_info_plain_validator_function (cls .validate )
48+ return core_schema .no_info_after_validator_function (
49+ cls .validate ,
50+ schema = core_schema .str_schema (),
51+ serialization = core_schema .plain_serializer_function_ser_schema (
52+ lambda v : str (v )
53+ ),
54+ )
4955
5056 @classmethod
5157 def validate (cls , value : "Sentinel" ) -> "Sentinel" :
Original file line number Diff line number Diff line change @@ -142,3 +142,21 @@ def test_config_resolution(strategies, actorder):
142142 for config_group in modifier .config_groups .values ():
143143 if config_group .weights .strategy == "group" :
144144 assert config_group .weights .actorder == actorder
145+
146+
147+ @pytest .mark .parametrize (
148+ "has_actorder,actorder,exp_actorder" ,
149+ [
150+ (False , "N/A" , "static" ),
151+ (True , None , None ),
152+ (True , "static" , "static" ),
153+ (True , "group" , "group" ),
154+ ],
155+ )
156+ def test_serialize_actorder (has_actorder , actorder , exp_actorder ):
157+ if has_actorder :
158+ modifier = GPTQModifier (targets = ["Linear" ], actorder = actorder )
159+ else :
160+ modifier = GPTQModifier (targets = ["Linear" ])
161+
162+ assert modifier .model_dump ()["actorder" ] == exp_actorder
You can’t perform that action at this time.
0 commit comments