Skip to content

Commit 4a4fad1

Browse files
committed
issue/170: modified cpu inv
1 parent a1e7195 commit 4a4fad1

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -256,36 +256,36 @@ bool cholesky_decompose(float *A, int n, bool upper) {
256256
return true;
257257
}
258258

259-
// Compute A^{-1} from Cholesky(L)
260-
void invert_symmetric_from_cholesky(float *L, int n, float *invA, float *temp) {
259+
// Compute A^{-1} from Cholesky decomposition (A = L L^T)
260+
// A: lower-triangular Cholesky factor (n x n)
261+
// invA: output inverse matrix (n x n), symmetric
262+
// temp_row: temporary buffer of size n * n
263+
void invert_symmetric_from_cholesky(const float *A, int n, float *invA, float *temp_row) {
261264
#pragma omp parallel for
262265
for (int col = 0; col < n; ++col) {
263-
float *row_buf = temp + col * n;
266+
float *row_buf = temp_row + col * n;
264267

265-
// Forward substitution: solve L * y = e_col
268+
// Forward solve: L y = e_col
266269
for (int i = 0; i < n; ++i) {
267270
float sum = (i == col) ? 1.0f : 0.0f;
268-
for (int k = 0; k < i; ++k) {
269-
sum -= L[i * n + k] * row_buf[k];
271+
if (i > 0) {
272+
sum -= dot_product(&A[i * n], row_buf, i);
270273
}
271-
row_buf[i] = sum / L[i * n + i];
274+
row_buf[i] = sum / A[i * n + i];
272275
}
273276

274-
// Backward substitution: solve L^T * x = y
277+
// Backward solve: L^T x = y
275278
for (int i = n - 1; i >= 0; --i) {
276279
float sum = row_buf[i];
277-
for (int k = i + 1; k < n; ++k) {
278-
sum -= L[k * n + i] * invA[k * n + col];
280+
for (int j = i + 1; j < n; ++j) {
281+
sum -= A[j * n + i] * invA[j * n + col];
279282
}
280-
invA[i * n + col] = sum / L[i * n + i];
283+
invA[i * n + col] = sum / A[i * n + i];
281284
}
282-
}
283285

284-
// Fill upper triangle using symmetry: invA[i][j] = invA[j][i]
285-
#pragma omp parallel for collapse(2)
286-
for (int i = 0; i < n; ++i) {
287-
for (int j = i + 1; j < n; ++j) {
288-
invA[i * n + j] = invA[j * n + i];
286+
// Exploit symmetry: copy upper triangle to lower
287+
for (int row = 0; row < col; ++row) {
288+
invA[col * n + row] = invA[row * n + col];
289289
}
290290
}
291291
}

0 commit comments

Comments
 (0)