Skip to content

Commit 0f73dd8

Browse files
committed
issue/170: modified cpu.lua
1 parent 8948293 commit 0f73dd8

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

src/infiniop/ops/matmul_gptq/cpu/matmul_gptq_cpu.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,16 @@ void fasterquant(T *weight, T *Q, T *Err, T *b_scale, T *zero, float *Hess,
240240
float w = utils::cast<float>(weight[n * K + index * blocksize + i]);
241241
float err = (w - q) / d;
242242

243-
for (int j = i; j < blocksize; j++) {
244-
if constexpr (std::is_same<T, fp16_t>::value) {
245-
weight[n * K + index * blocksize + j] = utils::cast<fp16_t>(utils::cast<float>(weight[n * K + index * blocksize + j]) - err * Hess[(index * blocksize + i) * K + j]);
246-
} else if constexpr (std::is_same<T, float>::value) {
247-
weight[n * K + index * blocksize + j] -= err * Hess[(index * blocksize + i) * K + j];
243+
if (group_size == -1) {
244+
for (int j = i; j < blocksize; j++) {
245+
if constexpr (std::is_same<T, fp16_t>::value) {
246+
weight[n * K + index * blocksize + j] = utils::cast<fp16_t>(utils::cast<float>(weight[n * K + index * blocksize + j]) - err * Hess[(index * blocksize + i) * K + j]);
247+
} else if constexpr (std::is_same<T, float>::value) {
248+
weight[n * K + index * blocksize + j] -= err * Hess[(index * blocksize + i) * K + j];
249+
}
248250
}
249251
}
252+
250253
if constexpr (std::is_same<T, fp16_t>::value) {
251254
Err[n * blocksize + i] = utils::cast<fp16_t>(err);
252255
} else if constexpr (std::is_same<T, float>::value) {

test/infiniop/matmul_gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test(
8282
)
8383
torch.manual_seed(12)
8484
# Initialize tensors
85-
a = 1e-3 * torch.randn([M, K], dtype=dtype).to(torch_device)
85+
a = 1e0 * torch.randn([M, K], dtype=dtype).to(torch_device)
8686
layer = nn.Linear(K, N)
8787
b = 1e-3 * layer.weight.data.to(dtype).to(torch_device)
8888
c = torch.zeros([M, N], dtype=dtype).to(torch_device).t()

xmake/cpu.lua

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
add_requires("lapacke", {public = true})
1+
add_requires("lapack", {configs = {shared = true}})
22
target("infiniop-cpu")
33
set_kind("static")
44
add_deps("infini-utils")
@@ -13,9 +13,8 @@ target("infiniop-cpu")
1313
end
1414
else
1515
add_cxflags("-fPIC")
16-
add_includedirs("/usr/include")
17-
add_linkdirs("/usr/lib64")
18-
add_links("lapacke", "lapack", "blas")
16+
add_packages("lapack")
17+
add_links("lapacke", "lapack", "blas", "gfortran")
1918
if has_config("omp") then
2019
add_cxflags("-fopenmp")
2120
add_ldflags("-fopenmp")

0 commit comments

Comments
 (0)