Skip to content

Commit c909d0e

Browse files
committed
issue/170: debug marlin
1 parent 8c2b7ec commit c909d0e

File tree

2 files changed

+57
-44
lines changed

2 files changed

+57
-44
lines changed

src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,14 @@ infiniStatus_t Descriptor::calculate(
7171
int group_size = int(_info.group_size);
7272
int num_groups = int(_info.num_groups);
7373
bool is_weight_transposed = _info.is_weight_transposed;
74-
if (_info.atype == INFINI_DTYPE_F16 && !is_weight_transposed) {
74+
if (_info.atype == INFINI_DTYPE_F16 && is_weight_transposed) {
7575
gptq_marlin::gptq_marlin_mm_fp16(c, a, packed_weights, b_scale,
7676
m, n, k,
7777
workspace, bits,
7878
num_groups, group_size,
7979
this->device_id, (cudaStream_t)stream);
8080

81-
} else if (_info.atype == INFINI_DTYPE_BF16 && !is_weight_transposed) {
81+
} else if (_info.atype == INFINI_DTYPE_BF16 && is_weight_transposed) {
8282
gptq_marlin::gptq_marlin_mm_bf16(c, a, packed_weights, b_scale,
8383
m, n, k,
8484
workspace, bits,

test/infiniop/quantize_gptq.py

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
# Configuration (Internal Use Only)
2424
# ==============================================================================
2525
# These are not meant to be imported from other modules
26-
is_weight_transposed = False
2726

2827
_TEST_CASES = []
2928

@@ -38,7 +37,7 @@
3837
for _, layers in MODELS.items():
3938
for layer in layers:
4039
for batch in [1, 16]:
41-
_TEST_CASES.append(((batch, layer[0], layer[1], is_weight_transposed)))
40+
_TEST_CASES.append(((batch, layer[0], layer[1])))
4241

4342
# Data types used for testing
4443
_TENSOR_DTYPES = [torch.float16]
@@ -324,14 +323,14 @@ def fasterquant(self, blocksize=128, percdamp=0.01, group_size=-1):
324323
self.zero = zero.to(self.weight.dtype)
325324

326325

327-
def get_scale_zero(b, a, c, group_size, bits, sign_ed):
326+
def get_scale_zero(b, a, c, group_size, bits, sym, sign_ed):
328327
weight = b.clone()
329328
inp = a.clone()
330329
out = c.clone()
331330
gptq = GPTQ(weight)
332331
gptq.quantizer = Quantizer()
333332
gptq.quantizer.configure(
334-
bits=bits, perchannel=True, sym=False, mse=False, sign_ed=sign_ed
333+
bits=bits, perchannel=True, sym=sym, mse=False, sign_ed=sign_ed
335334
)
336335
gptq.add_batch(inp, out)
337336
gptq.fasterquant(group_size=group_size)
@@ -370,7 +369,6 @@ def test(
370369
M,
371370
K,
372371
N,
373-
is_weight_transposed,
374372
dtype=torch.float16,
375373
sync=None,
376374
):
@@ -381,8 +379,13 @@ def test(
381379
# Initialize tensors
382380
a = 1e0 * torch.randn([K, M], dtype=dtype).to(torch_device)
383381
layer = nn.Linear(K, N)
384-
b = 1e0 * layer.weight.data.to(dtype).to(torch_device)
382+
b = 1e-3 * layer.weight.data.to(dtype).to(torch_device)
385383
c = torch.zeros([N, M], dtype=dtype).to(torch_device)
384+
is_weight_transposed = False
385+
sign_ed = False
386+
sym = False
387+
if torch_device != "cpu":
388+
is_weight_transposed = True
386389

387390
group_size = -1
388391
num_groups = 1
@@ -397,37 +400,45 @@ def test(
397400
packed_weights = torch.zeros([N, K // 8], dtype=torch.int32).to(torch_device)
398401
s = torch.zeros([N, num_groups], dtype=dtype).to(torch_device)
399402
z = torch.zeros([N, num_groups], dtype=dtype).to(torch_device)
400-
sign_ed = False
403+
401404
bits = 4
402405
maxq = 2**bits - 1
403406
minq = 0
404407
if sign_ed: # 有符号量化,范围是[-8,7]
405408
maxq = 2 ** (bits - 1) - 1
406409
minq = -(2 ** (bits - 1))
407-
sym = False
408410

409411
if torch_device == "cuda":
410412
b_ref, s, z = get_scale_zero(
411-
b, a.t(), c, group_size, bits, sign_ed=sign_ed
413+
b, a.t(), c, group_size, bits, sym, sign_ed=sign_ed
412414
) # 无符号量化
413415

414416
packed_weights = pack(b_ref, s, z, minq, maxq)
415417

416-
if torch_device == "cpu":
417-
b_ref, s, z = get_scale_zero(
418-
b, a.t(), c, group_size, bits, sign_ed=sign_ed
419-
) # 无符号量化
420-
421-
packed_weights = pack(b_ref, s, z, minq, maxq)
418+
# if torch_device == "cpu":
419+
# b_ref, s, z = get_scale_zero(
420+
# b, a.t(), c, group_size, bits, sym, sign_ed=sign_ed
421+
# ) # 无符号量化
422422

423-
a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = (
424-
to_tensor(a, lib),
425-
to_tensor(b, lib),
426-
to_tensor(c, lib),
427-
to_tensor(s, lib),
428-
to_tensor(z, lib),
429-
to_tensor(packed_weights, lib),
430-
)
423+
# packed_weights = pack(b_ref, s, z, minq, maxq)
424+
if is_weight_transposed:
425+
a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = (
426+
to_tensor(a.t(), lib),
427+
to_tensor(b.t(), lib),
428+
to_tensor(c.t(), lib),
429+
to_tensor(s.t(), lib),
430+
to_tensor(z.t(), lib),
431+
to_tensor(packed_weights.t(), lib),
432+
)
433+
else:
434+
a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = (
435+
to_tensor(a, lib),
436+
to_tensor(b, lib),
437+
to_tensor(c, lib),
438+
to_tensor(s, lib),
439+
to_tensor(z, lib),
440+
to_tensor(packed_weights, lib),
441+
)
431442

432443
descriptor = infiniopQuantizeGPTQDescriptor_t()
433444
check_error(
@@ -463,19 +474,19 @@ def test(
463474
workspace = create_workspace(workspace_size.value, a.device)
464475

465476
# Execute infiniop quantize_gptq operator
466-
# check_error(
467-
# lib.infiniopQuantizeGPTQ(
468-
# descriptor,
469-
# workspace.data_ptr() if workspace is not None else None,
470-
# workspace_size.value,
471-
# packed_weights_tensor.data,
472-
# s_tensor.data,
473-
# z_tensor.data,
474-
# a_tensor.data,
475-
# b_tensor.data,
476-
# None,
477-
# )
478-
# )
477+
check_error(
478+
lib.infiniopQuantizeGPTQ(
479+
descriptor,
480+
workspace.data_ptr() if workspace is not None else None,
481+
workspace_size.value,
482+
packed_weights_tensor.data,
483+
s_tensor.data,
484+
z_tensor.data,
485+
a_tensor.data,
486+
b_tensor.data,
487+
None,
488+
)
489+
)
479490

480491
def lib_quantize_gptq():
481492
check_error(
@@ -495,13 +506,15 @@ def lib_quantize_gptq():
495506
lib_quantize_gptq()
496507

497508
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
498-
tmpa = ans.flatten()
499-
tmpc = c.flatten()
500-
for i in range(tmpa.shape[0]):
501-
if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]):
502-
print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i]))
503-
break
509+
# tmpa = ans.flatten()
510+
# tmpc = c.flatten()
511+
# for i in range(tmpa.shape[0]):
512+
# if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]):
513+
# print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i]))
514+
# break
504515

516+
if is_weight_transposed:
517+
c = c.t()
505518
if DEBUG:
506519
debug(c, ans, atol=atol, rtol=rtol)
507520
assert torch.allclose(c, ans, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)