Skip to content

Commit 8c2b7ec

Browse files
committed
issue/170: modified pack py
1 parent 096233a commit 8c2b7ec

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

test/infiniop/quantize_gptq.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -317,20 +317,22 @@ def fasterquant(self, blocksize=128, percdamp=0.01, group_size=-1):
317317

318318
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
319319

320-
# print('error', torch.sum(Losses).item())
320+
print("error", torch.sum(Losses).item())
321321

322322
self.weight = Q.reshape(self.weight.shape).to(self.weight.dtype)
323323
self.scale = scale.to(self.weight.dtype)
324324
self.zero = zero.to(self.weight.dtype)
325325

326326

327-
def get_scale_zero(b, a, c, group_size, sign_ed):
327+
def get_scale_zero(b, a, c, group_size, bits, sign_ed):
328328
weight = b.clone()
329329
inp = a.clone()
330330
out = c.clone()
331331
gptq = GPTQ(weight)
332332
gptq.quantizer = Quantizer()
333-
gptq.quantizer.configure(perchannel=True, sym=False, mse=False, signed=sign_ed)
333+
gptq.quantizer.configure(
334+
bits=bits, perchannel=True, sym=False, mse=False, sign_ed=sign_ed
335+
)
334336
gptq.add_batch(inp, out)
335337
gptq.fasterquant(group_size=group_size)
336338

@@ -341,8 +343,10 @@ def get_scale_zero(b, a, c, group_size, sign_ed):
341343
)
342344

343345

344-
def pack(weight, scale, zero):
345-
intweight = torch.round((weight + zero) / scale).to(torch.int32)
346+
def pack(weight, scale, zero, minq, maxq):
347+
intweight = torch.clamp(torch.round(weight / scale + zero), minq, maxq).to(
348+
torch.int32
349+
)
346350
qweight = torch.zeros(
347351
[weight.shape[0], weight.shape[1] // 8], dtype=torch.int32, device=weight.device
348352
)
@@ -377,7 +381,7 @@ def test(
377381
# Initialize tensors
378382
a = 1e0 * torch.randn([K, M], dtype=dtype).to(torch_device)
379383
layer = nn.Linear(K, N)
380-
b = 1e-3 * layer.weight.data.to(dtype).to(torch_device)
384+
b = 1e0 * layer.weight.data.to(dtype).to(torch_device)
381385
c = torch.zeros([N, M], dtype=dtype).to(torch_device)
382386

383387
group_size = -1
@@ -393,13 +397,28 @@ def test(
393397
packed_weights = torch.zeros([N, K // 8], dtype=torch.int32).to(torch_device)
394398
s = torch.zeros([N, num_groups], dtype=dtype).to(torch_device)
395399
z = torch.zeros([N, num_groups], dtype=dtype).to(torch_device)
400+
sign_ed = False
401+
bits = 4
402+
maxq = 2**bits - 1
403+
minq = 0
404+
if sign_ed: # 有符号量化,范围是[-8,7]
405+
maxq = 2 ** (bits - 1) - 1
406+
minq = -(2 ** (bits - 1))
407+
sym = False
408+
396409
if torch_device == "cuda":
397410
b_ref, s, z = get_scale_zero(
398-
b, a.t(), c, group_size, signed=False
411+
b, a.t(), c, group_size, bits, sign_ed=sign_ed
399412
) # 无符号量化
400-
z = torch.zeros_like(s)
401-
packed_weights = pack(b_ref, s, z)
402-
# print(s)
413+
414+
packed_weights = pack(b_ref, s, z, minq, maxq)
415+
416+
if torch_device == "cpu":
417+
b_ref, s, z = get_scale_zero(
418+
b, a.t(), c, group_size, bits, sign_ed=sign_ed
419+
) # 无符号量化
420+
421+
packed_weights = pack(b_ref, s, z, minq, maxq)
403422

404423
a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = (
405424
to_tensor(a, lib),
@@ -444,19 +463,19 @@ def test(
444463
workspace = create_workspace(workspace_size.value, a.device)
445464

446465
# Execute infiniop quantize_gptq operator
447-
check_error(
448-
lib.infiniopQuantizeGPTQ(
449-
descriptor,
450-
workspace.data_ptr() if workspace is not None else None,
451-
workspace_size.value,
452-
packed_weights_tensor.data,
453-
s_tensor.data,
454-
z_tensor.data,
455-
a_tensor.data,
456-
b_tensor.data,
457-
None,
458-
)
459-
)
466+
# check_error(
467+
# lib.infiniopQuantizeGPTQ(
468+
# descriptor,
469+
# workspace.data_ptr() if workspace is not None else None,
470+
# workspace_size.value,
471+
# packed_weights_tensor.data,
472+
# s_tensor.data,
473+
# z_tensor.data,
474+
# a_tensor.data,
475+
# b_tensor.data,
476+
# None,
477+
# )
478+
# )
460479

461480
def lib_quantize_gptq():
462481
check_error(
@@ -476,12 +495,12 @@ def lib_quantize_gptq():
476495
lib_quantize_gptq()
477496

478497
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
479-
# tmpa = ans.flatten()
480-
# tmpc = c.flatten()
481-
# for i in range(tmpa.shape[0]):
482-
# if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]):
483-
# print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i]))
484-
# break
498+
tmpa = ans.flatten()
499+
tmpc = c.flatten()
500+
for i in range(tmpa.shape[0]):
501+
if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]):
502+
print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i]))
503+
break
485504

486505
if DEBUG:
487506
debug(c, ans, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)