Skip to content

Commit 25bd87a

Browse files
committed
fix zero points initialize
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 24f6104 commit 25bd87a

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,11 @@ def __init__(self, weights, weight_scale=None, weight_zero_point=None):
123123
# Attach weight_scale and weight_zero_point as parameters
124124
if weight_scale is not None:
125125
self.linear.weight_scale = nn.Parameter(
126-
torch.tensor(weight_scale), requires_grad=False
126+
weight_scale.detach().clone(), requires_grad=False
127127
)
128128
if weight_zero_point is not None:
129129
self.linear.weight_zero_point = nn.Parameter(
130-
torch.tensor(weight_zero_point), requires_grad=False
130+
weight_zero_point.detach().clone(), requires_grad=False
131131
)
132132

133133
def forward(self, x):
@@ -176,23 +176,21 @@ def create_quantization_config(bits=8, type="int", strategy="tensor"):
176176
],
177177
)
178178
def test_composability(sparsity_config, quantization_config):
179-
model_compressor = ModelCompressor(
180-
sparsity_config=sparsity_config, quantization_config=quantization_config
181-
)
179+
model_compressor = ModelCompressor(sparsity_config, quantization_config)
182180
model: DummyLinearModel = _get_fake_oneshot_sparse_quantized_model(
183-
sparsity_config=sparsity_config, quantization_config=quantization_config
181+
quantization_config,
182+
sparsity_config,
184183
)
185184
model = model.to(torch.float32)
186185

187186
# does both sparse and quantization compression
187+
original_state_dict = {k: v.clone() for k, v in model.state_dict().items()}
188188
model_compressor.compress_model(model)
189-
compressed_state_dict = {key: value.clone() for key, value in model.state_dict()}
190-
191189
model_compressor.decompress_model(model)
192-
decompressed_state_dict = {key: value.clone() for key, value in model.state_dict()}
190+
decompressed_state_dict = {k: v.clone() for k, v in model.state_dict().items()}
193191

194192
# check that the decompressed model is the same as the original model
195-
_check_state_dicts(compressed_state_dict, decompressed_state_dict)
193+
_check_state_dicts(original_state_dict, decompressed_state_dict)
196194

197195

198196
class TwoLayerModel(nn.Module):
@@ -252,6 +250,9 @@ def _get_fake_oneshot_sparse_quantized_model(quantization_config, sparsity_confi
252250
args=quantization_args,
253251
)
254252

253+
if quantization_args.symmetric:
254+
zero_point = None # do not include in model
255+
255256
fake_oneshot_model = DummyLinearModel(quantized_weights, scale, zero_point)
256257
fake_oneshot_model.linear.quantization_scheme = quantization_config.config_groups[
257258
"group_0"
@@ -306,7 +307,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
306307
)
307308
# Only stores dtype because meta model does not store values
308309
reference_compressor.compress_model(cpu_model)
309-
expected = {k: v.dtype for k, v in cpu_model.state_dict()}
310+
expected = {k: v.dtype for k, v in cpu_model.state_dict().items()}
310311

311312
# Load model on meta device
312313
meta_model = AutoModelForCausalLM.from_pretrained(

0 commit comments

Comments
 (0)