@@ -123,11 +123,11 @@ def __init__(self, weights, weight_scale=None, weight_zero_point=None):
123123 # Attach weight_scale and weight_zero_point as parameters
124124 if weight_scale is not None :
125125 self .linear .weight_scale = nn .Parameter (
126- torch . tensor ( weight_scale ), requires_grad = False
126+ weight_scale . detach (). clone ( ), requires_grad = False
127127 )
128128 if weight_zero_point is not None :
129129 self .linear .weight_zero_point = nn .Parameter (
130- torch . tensor ( weight_zero_point ), requires_grad = False
130+ weight_zero_point . detach (). clone ( ), requires_grad = False
131131 )
132132
133133 def forward (self , x ):
@@ -176,23 +176,21 @@ def create_quantization_config(bits=8, type="int", strategy="tensor"):
176176 ],
177177)
178178def test_composability (sparsity_config , quantization_config ):
179- model_compressor = ModelCompressor (
180- sparsity_config = sparsity_config , quantization_config = quantization_config
181- )
179+ model_compressor = ModelCompressor (sparsity_config , quantization_config )
182180 model : DummyLinearModel = _get_fake_oneshot_sparse_quantized_model (
183- sparsity_config = sparsity_config , quantization_config = quantization_config
181+ quantization_config ,
182+ sparsity_config ,
184183 )
185184 model = model .to (torch .float32 )
186185
187186 # does both sparse and quantization compression
187+ original_state_dict = {k : v .clone () for k , v in model .state_dict ().items ()}
188188 model_compressor .compress_model (model )
189- compressed_state_dict = {key : value .clone () for key , value in model .state_dict ()}
190-
191189 model_compressor .decompress_model (model )
192- decompressed_state_dict = {key : value .clone () for key , value in model .state_dict ()}
190+ decompressed_state_dict = {k : v .clone () for k , v in model .state_dict (). items ()}
193191
194192 # check that the decompressed model is the same as the original model
195- _check_state_dicts (compressed_state_dict , decompressed_state_dict )
193+ _check_state_dicts (original_state_dict , decompressed_state_dict )
196194
197195
198196class TwoLayerModel (nn .Module ):
@@ -252,6 +250,9 @@ def _get_fake_oneshot_sparse_quantized_model(quantization_config, sparsity_confi
252250 args = quantization_args ,
253251 )
254252
253+ if quantization_args .symmetric :
254+ zero_point = None # do not include in model
255+
255256 fake_oneshot_model = DummyLinearModel (quantized_weights , scale , zero_point )
256257 fake_oneshot_model .linear .quantization_scheme = quantization_config .config_groups [
257258 "group_0"
@@ -306,7 +307,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
306307 )
307308 # Only stores dtype because meta model does not store values
308309 reference_compressor .compress_model (cpu_model )
309- expected = {k : v .dtype for k , v in cpu_model .state_dict ()}
310+ expected = {k : v .dtype for k , v in cpu_model .state_dict (). items () }
310311
311312 # Load model on meta device
312313 meta_model = AutoModelForCausalLM .from_pretrained (
0 commit comments