From a5d192446099796e33bfa88bccaee0b11aab5bca Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Fri, 21 Mar 2025 09:59:36 +0800 Subject: [PATCH 1/8] issue/170: quantize_gptq --- include/infinicore.h | 1 + include/infiniop.h | 1 + include/infiniop/ops/quantize_gptq.h | 40 + .../quantize_gptq/cpu/quantize_gptq_cpu.cc | 612 ++++++ .../ops/quantize_gptq/cpu/quantize_gptq_cpu.h | 8 + .../ops/quantize_gptq/cuda/gptq_marlin.cu | 1845 +++++++++++++++++ .../ops/quantize_gptq/cuda/gptq_marlin.cuh | 32 + .../quantize_gptq/cuda/quantize_gptq_cuda.cu | 94 + .../quantize_gptq/cuda/quantize_gptq_cuda.cuh | 8 + src/infiniop/ops/quantize_gptq/operator.cc | 126 ++ .../ops/quantize_gptq/quantize_gptq.h | 156 ++ test/infiniop/quantize_gptq.py | 554 +++++ xmake/cpu.lua | 7 + 13 files changed, 3484 insertions(+) create mode 100644 include/infiniop/ops/quantize_gptq.h create mode 100644 src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc create mode 100644 src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.h create mode 100644 src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cu create mode 100644 src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cuh create mode 100644 src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu create mode 100644 src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cuh create mode 100644 src/infiniop/ops/quantize_gptq/operator.cc create mode 100644 src/infiniop/ops/quantize_gptq/quantize_gptq.h create mode 100644 test/infiniop/quantize_gptq.py diff --git a/include/infinicore.h b/include/infinicore.h index a74af91d2..3260bd78c 100644 --- a/include/infinicore.h +++ b/include/infinicore.h @@ -70,6 +70,7 @@ typedef enum { INFINI_DTYPE_C64 = 17, INFINI_DTYPE_C128 = 18, INFINI_DTYPE_BF16 = 19, + INFINI_DTYPE_I4 = 20, } infiniDtype_t; #endif // __INFINICORE_API_H__ diff --git a/include/infiniop.h b/include/infiniop.h index d51b8d92e..fc57b269a 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -9,6 +9,7 @@ #include "infiniop/ops/conv.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/mul.h" +#include "infiniop/ops/quantize_gptq.h" #include "infiniop/ops/random_sample.h" #include "infiniop/ops/rearrange.h" #include "infiniop/ops/relu.h" diff --git a/include/infiniop/ops/quantize_gptq.h b/include/infiniop/ops/quantize_gptq.h new file mode 100644 index 000000000..ef4d5921b --- /dev/null +++ b/include/infiniop/ops/quantize_gptq.h @@ -0,0 +1,40 @@ +#ifndef __INFINIOP_QUANTIZE_GPTQ_API_H__ +#define __INFINIOP_QUANTIZE_GPTQ_API_H__ + +#include "../operator_descriptor.h" + +typedef InfiniopDescriptor *infiniopQuantizeGPTQDescriptor_t; + +__C __export infiniStatus_t infiniopCreateQuantizeGPTQDescriptor(infiniopHandle_t handle, + infiniopQuantizeGPTQDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t packed_weights_desc, + infiniopTensorDescriptor_t b_scale_desc, + infiniopTensorDescriptor_t zero_desc); + +__C __export infiniStatus_t infiniopGetQuantizeGPTQWorkspaceSize(infiniopQuantizeGPTQDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopQuantizeGPTQ(infiniopQuantizeGPTQDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *packed_weights, + void *b_scale, + void *zero, + const void *a, + const void *b, + void *stream); + +__C __export infiniStatus_t infiniopQuantizeLinearGPTQ(infiniopQuantizeGPTQDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + void *packed_weights, + void *b_scale, + void *zero, + void *stream); + +__C __export infiniStatus_t infiniopDestroyQuantizeGPTQDescriptor(infiniopQuantizeGPTQDescriptor_t desc); + +#endif diff --git a/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc b/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc new file mode 100644 index 000000000..e23ef1177 --- /dev/null +++ b/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc @@ -0,0 +1,612 @@ +#include "quantize_gptq_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../../handle.h" +#include +// 防止 __C 与 AVX512 冲突 +// 判断是否为 ARM 架构(如 ARM64/AArch64) +#if defined(__aarch64__) || defined(__arm__) +#include +// 判断是否为 x86/x86_64 架构 +#elif defined(__x86_64__) || defined(_M_X64) || defined(i386) || defined(__i386__) || defined(__i386) || defined(_M_IX86) +#pragma push_macro("__C") +#undef __C +#include +#pragma pop_macro("__C") +#else +#error "Unsupported architecture: Neither ARM nor x86 detected." +#endif +#include +#ifdef NDEBUG +#define SAFE_ASSERT(x) ((void)(x)) +#else +#define SAFE_ASSERT(x) assert(x) +#endif + +namespace op::quantize_gptq::cpu { +Descriptor::~Descriptor() {} + +infiniStatus_t Descriptor::create(infiniopHandle_t handle_, Descriptor **desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t packed_weights_desc, + infiniopTensorDescriptor_t b_scale_desc, + infiniopTensorDescriptor_t zero_desc) { + auto handle = reinterpret_cast(handle_); + auto result = MatmulGptqInfo::createMatmulGptqInfo(c_desc, a_desc, packed_weights_desc, b_scale_desc, zero_desc); + CHECK_RESULT(result); + MatmulGptqInfo info = result.take(); + size_t min_workspace_size + = (info.k * info.k + info.n * info.block_size) * sizeof(float) + (2 * info.n * info.k) * infiniSizeOf(info.atype); + + *desc_ptr = new Descriptor(info, nullptr, min_workspace_size, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +float quantize(float x, float s, float z, float maxq) { + float q = std::roundf(x / s + z); + q = std::max(0.0f, std::min(maxq, q)); + return s * (q - z); +} + +template +void find_params(T *x, T *b_scale, T *zero, int N, int K, + int bits = 4, bool sym = false, bool mse = false, + float norm = 2.4f, int grid = 100, float maxshrink = 0.8f) { + float maxq = static_cast(std::pow(2, bits) - 1); +#pragma omp parallel for + for (int n = 0; n < N; n++) { + float x_min = FLT_MAX; + float x_max = -FLT_MAX; + for (int k = 0; k < K; k++) { + if (utils::cast(x[n * K + k]) < x_min) { + x_min = utils::cast(x[n * K + k]); + } + if (utils::cast(x[n * K + k]) > x_max) { + x_max = utils::cast(x[n * K + k]); + } + } + if (sym) { + x_max = std::fmax(std::abs(x_min), x_max); + if (x_min < 0) { + x_min = -x_max; + } + } + if (x_min == 0 && x_max == 0) { + x_min = -1; + x_max = 1; + } + if constexpr (std::is_same::value) { + b_scale[n] = utils::cast((x_max - x_min) / maxq); + if (sym) { + zero[n] = utils::cast((maxq + 1.0f) * 0.5f); + } else { + zero[n] = utils::cast(-x_min * maxq / (x_max - x_min)); + } + } else if constexpr (std::is_same::value) { + b_scale[n] = (x_max - x_min) / maxq; + if (sym) { + zero[n] = (maxq + 1.0f) * 0.5f; + } else { + zero[n] = -x_min / b_scale[n]; + } + } + if (mse) { + float best = FLT_MAX; + for (int i = 0; i < int(maxshrink * grid); i++) { + float p = 1 - static_cast(i) / static_cast(grid); + float x_min_1 = p * x_min; + float x_max_1 = p * x_max; + float scale_1 = (x_max_1 - x_min_1) / maxq; + float zero_1 = (sym ? utils::cast(zero[n]) : std::roundf(-x_min_1 / scale_1)); + float err = 0.0f; + for (int k = 0; k < K; k++) { + float q = quantize(utils::cast(x[n * K + k]), scale_1, zero_1, maxq); + q -= utils::cast(x[n * K + k]); + q = std::abs(q); + q = static_cast(std::pow(q, norm)); + err += q; + } + if (err < best) { + best = err; + if constexpr (std::is_same::value) { + b_scale[n] = utils::cast(scale_1); + zero[n] = utils::cast(zero_1); + } else if constexpr (std::is_same::value) { + b_scale[n] = scale_1; + zero[n] = zero_1; + } + } + } + } + } +} + +inline float dot_product(const float *a, const float *b, int len) { +#if defined(__aarch64__) || defined(__arm__) + float32x4_t sum0 = vdupq_n_f32(0.0f); + float32x4_t sum1 = vdupq_n_f32(0.0f); + int i = 0; + + for (; i + 8 <= len; i += 8) { + float32x4_t va0 = vld1q_f32(a + i); + float32x4_t vb0 = vld1q_f32(b + i); + float32x4_t va1 = vld1q_f32(a + i + 4); + float32x4_t vb1 = vld1q_f32(b + i + 4); + sum0 = vfmaq_f32(sum0, va0, vb0); + sum1 = vfmaq_f32(sum1, va1, vb1); + } + + float32x4_t sum = vaddq_f32(sum0, sum1); + float total = vaddvq_f32(sum); // Requires ARMv8.1-A and above + + for (; i < len; ++i) { + total += a[i] * b[i]; + } + + return total; +#elif defined(__x86_64__) || defined(_M_X64) || defined(i386) || defined(__i386__) || defined(__i386) || defined(_M_IX86) + __m256 sum = _mm256_setzero_ps(); + int i = 0; + for (; i + 8 <= len; i += 8) { + __m256 va = _mm256_loadu_ps(a + i); + __m256 vb = _mm256_loadu_ps(b + i); + sum = _mm256_fmadd_ps(va, vb, sum); + } + float result[8]; + _mm256_storeu_ps(result, sum); + float total = result[0] + result[1] + result[2] + result[3] + result[4] + result[5] + result[6] + result[7]; + + for (; i < len; ++i) { + total += a[i] * b[i]; + } + return total; +#else +#error "Unsupported architecture" +#endif +} +template +void add_batch(const T *inp, float *Hess, float nsamples, int M, int K) { // Hess, nsamples默认是0 + const int tmp = M; + const float ns_new = nsamples + tmp; + const float w_old = nsamples / ns_new; + const float w_new = 2.0f / ns_new; + + // 1. Scale existing Hessian +#pragma omp parallel for + for (int i = 0; i < K * K; ++i) { + Hess[i] *= w_old; + } + + // 2. Cast input to float buffer + std::vector buffer(K * M); +#pragma omp parallel for + for (int i = 0; i < K * M; ++i) { + if constexpr (std::is_same::value) { + buffer[i] = utils::cast(inp[i]); + } else { + buffer[i] = inp[i]; + } + } + + // 3. Compute upper triangle of Hessian (without collapse) +#pragma omp parallel for schedule(dynamic) + for (int i = 0; i < K; ++i) { + for (int j = i; j < K; ++j) { + float s = dot_product(buffer.data() + i * M, buffer.data() + j * M, M); + Hess[i * K + j] += s * w_new; + } + } + + // 4. Mirror to lower triangle (no collapse) +#pragma omp parallel for + for (int i = 0; i < K; ++i) { + for (int j = i + 1; j < K; ++j) { + Hess[j * K + i] = Hess[i * K + j]; + } + } +} + +// Cholesky 分解 (in-place),只支持 lower (第一步) 或 upper (第三步) + +// Cholesky decomposition (lower or upper) +bool cholesky_decompose(float *A, int n, bool upper) { + if (upper) { + for (int i = 0; i < n; ++i) { + for (int j = 0; j <= i; ++j) { + float sum = A[i * n + j]; + if (j > 0) { + sum -= dot_product(&A[i * n], &A[j * n], j); + } + if (i == j) { + if (sum <= 0.0f) { + return false; + } + A[i * n + j] = std::sqrt(sum); + } else { + A[i * n + j] = sum / A[j * n + j]; + } + } +#pragma omp parallel for + for (int j = i + 1; j < n; ++j) { + A[i * n + j] = 0.0f; + } + } + } else { + for (int i = 0; i < n; ++i) { + for (int j = 0; j <= i; ++j) { + float sum = A[i * n + j]; + if (j > 0) { + sum -= dot_product(&A[i * n], &A[j * n], j); + } + if (i == j) { + if (sum <= 0.0f) { + return false; + } + A[i * n + j] = std::sqrt(sum); + } else { + A[i * n + j] = sum / A[j * n + j]; + } + } +#pragma omp parallel for + for (int j = i + 1; j < n; ++j) { + A[i * n + j] = 0.0f; + } + } + } + return true; +} + +// Compute A^{-1} from Cholesky decomposition (A = L L^T) +// A: lower-triangular Cholesky factor (n x n) +// invA: output inverse matrix (n x n), symmetric +// temp_row: temporary buffer of size n * n +void invert_symmetric_from_cholesky(const float *A, int n, float *invA, float *temp_row) { +#pragma omp parallel for + for (int col = 0; col < n; ++col) { + float *row_buf = temp_row + col * n; + + // Forward solve: L y = e_col + for (int i = 0; i < n; ++i) { + float sum = (i == col) ? 1.0f : 0.0f; + if (i > 0) { + sum -= dot_product(&A[i * n], row_buf, i); + } + row_buf[i] = sum / A[i * n + i]; + } + + // Backward solve: L^T x = y + for (int i = n - 1; i >= 0; --i) { + float sum = row_buf[i]; + for (int j = i + 1; j < n; ++j) { + sum -= A[j * n + i] * invA[j * n + col]; + } + invA[i * n + col] = sum / A[i * n + i]; + } + + // Exploit symmetry: copy upper triangle to lower + for (int row = 0; row < col; ++row) { + invA[col * n + row] = invA[row * n + col]; + } + } +} + +// Clear lower triangle for upper triangular result +void clear_lower_triangle(float *A, int n) { +#if defined(__aarch64__) || defined(__arm__) + float32x4_t zero = vdupq_n_f32(0.0f); +#pragma omp parallel for + for (int i = 0; i < n; ++i) { + int j = 0; + int row_start = i * n; + // 每次清 4 个 float + for (; j + 3 < i; j += 4) { + vst1q_f32(&A[row_start + j], zero); + } + // 处理剩余 1~3 个 + for (; j < i; ++j) { + A[row_start + j] = 0.0f; + } + } +#elif defined(__x86_64__) || defined(_M_X64) || defined(i386) || defined(__i386__) || defined(__i386) || defined(_M_IX86) + __m256 zero = _mm256_setzero_ps(); +#pragma omp parallel for + for (int i = 0; i < n; ++i) { + int j = 0; + for (; j + 7 < i; j += 8) { + _mm256_storeu_ps(&A[i * n + j], zero); + } + for (; j < i; ++j) { + A[i * n + j] = 0.0f; + } + } +#else +#error "Unsupported architecture" +#endif +} + +void cholesky_inverse_then_upper_cholesky(float *Hess, int K) { + cholesky_decompose(Hess, K, false); + + char *compute_workspace = (char *)malloc(2 * sizeof(float) * K * K); + float *temp = (float *)compute_workspace; // 内存要求和32字节对齐 + float *invA = temp + K * K; + invert_symmetric_from_cholesky(Hess, K, invA, temp); + memcpy(Hess, invA, sizeof(float) * K * K); + + free(compute_workspace); + + cholesky_decompose(Hess, K, true); + clear_lower_triangle(Hess, K); +} + +template +void fasterquant(T *weight, T *Q, float *Err, T *b_scale, T *zero, float *Hess, + int M, int K, int N, + int block_size = 128, float percdamp = 0.01, int group_size = -1, + int bits = 4, bool sym = false, bool mse = false, + float norm = 2.4, int grid = 100, float maxshrink = 0.8) { + float maxq = static_cast(std::pow(2, bits) - 1); + int num_groups = (group_size == -1 ? 1 : K / group_size); + + if (group_size == -1) { + find_params(weight, b_scale, zero, N, K, bits, sym, mse, norm, grid, maxshrink); + } + float damp = 0.0f; + +#pragma omp parallel for reduction(+ : damp) + for (int dead = 0; dead < K; ++dead) { + bool condition = false; + if (Hess[dead * K + dead] == 0.0f) { + Hess[dead * K + dead] = 1.0f; + condition = true; + } + damp += Hess[dead * K + dead]; + + if (condition) { + for (int n = 0; n < N; ++n) { + if constexpr (std::is_same::value) { + weight[n * K + dead] = utils::cast(0.0f); + } else if constexpr (std::is_same::value) { + weight[n * K + dead] = 0.0f; + } + } + } + } + + damp = percdamp * damp / K; +#pragma omp parallel for + for (int dead = 0; dead < K; dead++) { + Hess[dead * K + dead] += damp; + } + cholesky_inverse_then_upper_cholesky(Hess, K); + + for (int index = 0; index < K / block_size; index++) { + for (int i = 0; i < block_size; i++) { + float d = Hess[(index * block_size + i) * K + index * block_size + i]; + + if (group_size != -1) { + if ((index * block_size + i) % group_size == 0) { + int ind = (index * block_size + i) / group_size; + for (int n = 0; n < N; n++) { + find_params(&weight[n * K + index * block_size + i], &b_scale[n * num_groups + ind], &zero[n * num_groups + ind], 1, group_size, bits, sym, mse, norm, grid, maxshrink); + } + } + } + int ind = (group_size != -1 ? (index * block_size + i) / group_size : 0); + for (int n = 0; n < N; n++) { + float q = quantize(utils::cast(weight[n * K + index * block_size + i]), utils::cast(b_scale[n * num_groups + ind]), utils::cast(zero[n * num_groups + ind]), maxq); + if constexpr (std::is_same::value) { + Q[n * K + index * block_size + i] = utils::cast(q); + } else if constexpr (std::is_same::value) { + Q[n * K + index * block_size + i] = q; + } + + float w = utils::cast(weight[n * K + index * block_size + i]); + float err = (w - q) / d; + + if (group_size == -1) { + for (int j = i; j < block_size; j++) { + if constexpr (std::is_same::value) { + weight[n * K + index * block_size + j] = utils::cast(utils::cast(weight[n * K + index * block_size + j]) - err * Hess[(index * block_size + i) * K + j]); + } else if constexpr (std::is_same::value) { + weight[n * K + index * block_size + j] -= err * Hess[(index * block_size + i) * K + j]; + } + } + } + + Err[n * block_size + i] = err; + } + } + int i_2 = std::min((index + 1) * block_size, K); + for (int n = 0; n < N; n++) { + for (int j = i_2; j < K; j++) { + float s = 0.0f; + for (int b = 0; b < block_size; b++) { + s += Err[n * block_size + b] * Hess[(index * block_size + b) * K + j]; + } + if constexpr (std::is_same::value) { + weight[n * K + j] = utils::cast(utils::cast(weight[n * K + j]) - s); + } else if constexpr (std::is_same::value) { + weight[n * K + j] -= s; + } + } + } + } +} + +void PackQuantizedWeight(fp16_t *Q, fp16_t *b_scale, fp16_t *zero, + int32_t *packed_weight, int K, int N, int group_size, int bits = 4) { + int maxq = int(std::pow(2, bits) - 1); + int num_groups = (group_size == -1) ? 1 : K / group_size; + int blocks_per_group = (group_size == -1) ? K / 8 : group_size / 8; + +#pragma omp parallel for + for (int index = 0; index < N * num_groups * blocks_per_group; ++index) { + int n = index / (num_groups * blocks_per_group); + int rem = index % (num_groups * blocks_per_group); + int g = rem / blocks_per_group; + int b = rem % blocks_per_group; + + float scale = utils::cast(b_scale[n * num_groups + g]); + float zero_f = utils::cast(zero[n * num_groups + g]); + + int row_base = (group_size == -1) ? b * 8 : g * group_size + b * 8; + int row_block_idx = row_base / 8; + + int32_t packed = 0; + for (int i = 0; i < 8; ++i) { + int k = row_base + i; + float val = utils::cast(Q[n * K + k]); // Q: [N, K] + int q = static_cast(std::roundf(val / scale + zero_f)); + q = std::max(0, std::min(maxq, q)); // clamp to [0, maxq] + packed |= (q & 0xF) << (i * 4); + } + + packed_weight[n * (K / 8) + row_block_idx] = packed; + } +} + +void MatmulPackedWeight(fp16_t *C, const fp16_t *A, int32_t *packed_weight, + fp16_t *b_scale, fp16_t *zero, + int M, int K, int N, int group_size) { + int num_groups = (group_size == -1) ? 1 : K / group_size; + int blocks_per_group = (group_size == -1) ? K / 8 : group_size / 8; +#pragma omp parallel for + for (int index = 0; index < N * M; index++) { + int m = index % M; + int n = index / M; + float acc = 0.0f; + + for (int g = 0; g < num_groups; ++g) { + float scale = utils::cast(b_scale[n * num_groups + g]); + float zero_f = utils::cast(zero[n * num_groups + g]); + + for (int b = 0; b < blocks_per_group; ++b) { + int row_base = (group_size == -1) ? b * 8 : g * group_size + b * 8; + int row_block_idx = row_base / 8; + int32_t packed = packed_weight[n * (K / 8) + row_block_idx]; + + for (int i = 0; i < 8; ++i) { + int k = row_base + i; + int q = (packed >> (i * 4)) & 0xF; + float w = (q - zero_f) * scale; + + float a_val = utils::cast(A[k * M + m]); // A: [K, M] + acc += w * a_val; + } + } + } + + C[index] = utils::cast(acc); + } +} + +void quantWeights(void *workspace, int32_t *packed_weights, + fp16_t *b_scale, + fp16_t *zero, + const fp16_t *A, + const fp16_t *B, + int M, int K, int N, + int group_size, int block_size = 128) { + + float percdamp = 0.01f; + + int bits = 4; + bool sym = false; + bool mse = false; + float norm = 2.4f; + int grid = 100; + float maxshrink = 0.8f; + float nsamples = 0.0f; + + char *tmp = (char *)workspace + (K * K + N * block_size) * sizeof(float); + float *Hess = (float *)workspace; //[K, K] + float *Err = Hess + K * K; //[N, block_size=128] + fp16_t *Q = (fp16_t *)tmp; //[N, K] + fp16_t *weight = Q + N * K; //[N, K] + + memset(Hess, 0, sizeof(float) * K * K); + + memcpy(weight, B, N * K * sizeof(fp16_t)); + + add_batch(A, Hess, nsamples, M, K); + + fasterquant(weight, Q, Err, b_scale, zero, Hess, + M, K, N, + block_size, percdamp, group_size, + bits, sym, mse, + norm, grid, maxshrink); + + PackQuantizedWeight(Q, b_scale, zero, packed_weights, K, N, group_size, bits); +} + +void caculate(void *workspace, fp16_t *C, const fp16_t *A, + int32_t *packed_weights, fp16_t *b_scale, fp16_t *zero, + int M, int K, int N, int group_size) { + + MatmulPackedWeight(C, A, packed_weights, b_scale, zero, M, K, N, group_size); +} + +infiniStatus_t Descriptor::quant( + void *workspace, + size_t workspace_size, + void *packed_weights, + void *b_scale, + void *zero, + const void *a, + const void *b, + void *stream) const { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + int m = int(_info.m); + int n = int(_info.n); + int k = int(_info.k); + int group_size = int(_info.group_size); + int block_size = int(_info.block_size); + bool is_weight_transposed = _info.is_weight_transposed; + if (_info.atype == INFINI_DTYPE_F16 && !is_weight_transposed) { + quantWeights(workspace, (int32_t *)packed_weights, + (fp16_t *)b_scale, + (fp16_t *)zero, + (fp16_t *)a, (fp16_t *)b, m, k, n, group_size, block_size); + + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *c, + const void *a, + void *packed_weights, + void *b_scale, + void *zero, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + int m = int(_info.m); + int n = int(_info.n); + int k = int(_info.k); + int group_size = int(_info.group_size); + bool is_weight_transposed = _info.is_weight_transposed; + if (_info.atype == INFINI_DTYPE_F16 && !is_weight_transposed) { + caculate(workspace, (fp16_t *)c, (fp16_t *)a, (int32_t *)packed_weights, (fp16_t *)b_scale, (fp16_t *)zero, + m, k, n, group_size); + + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::quantize_gptq::cpu diff --git a/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.h b/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.h new file mode 100644 index 000000000..0bdf64359 --- /dev/null +++ b/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.h @@ -0,0 +1,8 @@ +#ifndef __QUANTIZE_GPTQ_CPU_H__ +#define __QUANTIZE_GPTQ_CPU_H__ + +#include "../quantize_gptq.h" + +DESCRIPTOR(cpu) + +#endif // __QUANTIZE_GPTQ_CPU_H__ diff --git a/src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cu b/src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cu new file mode 100644 index 000000000..24dd2f4ce --- /dev/null +++ b/src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cu @@ -0,0 +1,1845 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ +/* + * Adapted from https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin + */ +#include "gptq_marlin.cuh" +#include +#include +#include +#include +#include +#include + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace gptq_marlin { + +template +struct Vec { + T elems[n]; + __device__ T &operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +__host__ __device__ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__) +// No support for async +#else + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +template +class ScalarType {}; + +template <> +class ScalarType { +public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> +class ScalarType { +public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } +#endif +}; + +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__) + +__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) {} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) {} + +} // namespace gptq_marlin + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + typename ScalarType::FragC &frag_c) { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + float *c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm4(typename ScalarType::FragA &frag_a, + const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +template +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_4bit(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +template +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + typename ScalarType::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_8bit(int q) { + typename ScalarType::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t *fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t *bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4(typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s_1, + typename ScalarType::FragS &frag_s_2, + typename ScalarType::FragS &frag_s_3, + typename ScalarType::FragS &frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float *c, + typename ScalarType::FragS &s) { + scalar_t *s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int *lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do { + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + } while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int *lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const *a_row_half = reinterpret_cast(a_int4_ptr + offset); + half *out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + + constexpr int pack_factor = 32 / num_bits; + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) { + slice_iters = 0; + } + if (slice_iters == 0) { + return; + } + if (slice_row + slice_iters > k_tiles) { + slice_iters = k_tiles - slice_row; + } + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) { + slice_count++; + } + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) { + slice_idx = slice_count - 1; + } else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) { + slice_idx--; + } + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } else { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + } + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) { + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + } + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4 *sh_a = sh; + int4 *sh_b = sh_a + (stages * a_sh_stage); + int4 *sh_g_idx = sh_b + (stages * b_sh_stage); + int4 *sh_s = sh_g_idx + (stages * g_idx_stage); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const *cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4 *sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + + } else { + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float *c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float *c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up +// the compiler and lead to slowdowns, hence we also use async-copies even +// though these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] + += Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&c)[j] = Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS &s) { + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { + res = __hmul2(res, s[0]); + } + + ((scalar_t2 *)sh)[idx] = res; + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + +#pragma unroll + for (int pipe = 0; pipe < stages;) { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (num_bits == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (num_bits == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) { // only the last block in a slice actually writes the result + write_result(); + } + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + } + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + B_ptr[i] -= b_gl_stride; + } + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin<<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ + prob_k, locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, + +}; + +int get_scales_cache_size(thread_config_t const &th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + // if (max_m_blocks == 0) { + // TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + // } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = float((a_size + b_size) * pipe_stages); + + // TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const &th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Determine cache for scales + int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + + return true; +} + +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage + } + + return exec_config_t{0, {-1, -1, -1}}; +} + +#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +template +void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, + void *g_idx, void *perm, void *a_tmp, int prob_m, + int prob_n, int prob_k, void *workspace, int num_bits, + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k, + int thread_n, int sms, int max_par) { + // TORCH_CHECK(num_bits == 4 || num_bits == 8, + // "num_bits must be 4 or 8. Got = ", num_bits); + // TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + // ", ", prob_n, ", ", prob_k, "]"); + + int tot_m = prob_m; + int tot_m_blocks = div_ceil(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + // TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + exec_cfg = exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; + } else { + // Auto config + exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); + } + + // TORCH_CHECK(exec_cfg.max_m_blocks > 0 && is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, max_shared_mem), + // "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + // ", thread_k = ", exec_cfg.tb_cfg.thread_k, + // ", thread_n = ", exec_cfg.tb_cfg.thread_n, + // ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + // prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + // ", group_size = ", group_size, + // ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + // ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + // TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + // " is not divisible by thread_n = ", thread_n); + // TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + // " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + // TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + // TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + // " is not divisible by group_blocks = ", group_blocks); + } else { + // TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + // TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + // " is not divisible by group_blocks = ", group_blocks); + } + } + + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + int4 *C_ptr = (int4 *)C; + const int4 *s_ptr = (const int4 *)s; + const int *g_idx_ptr = (const int *)g_idx; + const int *perm_ptr = (const int *)perm; + int4 *a_tmp_ptr = (int4 *)a_tmp; + + int *locks = (int *)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, blocks); + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by having + // a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + // Main loop + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > exec_cfg.max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); + if (par > max_par) { + par = max_par; + } + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; + } + + // Define kernel configurations + // #define undefined_error TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + ", has_act_order = " + str(has_act_order) + ", num_groups = " + str(num_groups) + ", group_size = " + str(group_size) + ", thread_m_blocks = " + str(thread_m_blocks) + ", thread_n_blocks = " + str(thread_n_blocks) + ", thread_k_blocks = " + str(thread_k_blocks)); + + if (num_bits == 4 && num_threads == 256) { + if (false) { + } + CALL_IF(4, 32, 2, 256) + CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 8, 256) + // else { + // undefined_error + // } + } else if (num_bits == 4 && num_threads == 128) { + if (false) { + } + CALL_IF(4, 8, 4, 128) + CALL_IF(4, 4, 8, 128) + // else { + // undefined_error + // } + } else if (num_bits == 8 && num_threads == 256) { + if (false) { + } + CALL_IF(8, 32, 2, 256) + CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) + // else { + // undefined_error + // } + } else if (num_bits == 8 && num_threads == 128) { + if (false) { + } + CALL_IF(8, 8, 4, 128) + CALL_IF(8, 4, 8, 128) + // else { + // undefined_error + // } + } + // else { + // undefined_error + // } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +void gptq_marlin_mm_fp16(void *c, const void *a, const void *b, const void *scale, + int m, int n, int k, + void *workspace, int num_bits, + int num_groups, int group_size, + int device_id, cudaStream_t stream) { + marlin_mm_f16i4(a, b, c, (void *)scale, nullptr, nullptr, + workspace, m, n, k, (char *)workspace + m * k * 2, + num_bits, false, true, num_groups, group_size, device_id, stream, -1, -1, -1, int(max_par)); +} + +void gptq_marlin_mm_bf16(void *c, const void *a, const void *b, const void *scale, + int m, int n, int k, + void *workspace, int num_bits, + int num_groups, int group_size, + int device_id, cudaStream_t stream) { + marlin_mm_f16i4(a, b, c, (void *)scale, nullptr, nullptr, + workspace, m, n, k, (char *)workspace + m * k * 2, + num_bits, false, true, num_groups, group_size, device_id, stream, -1, -1, -1, int(max_par)); +} +} // namespace gptq_marlin + +#endif diff --git a/src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cuh b/src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cuh new file mode 100644 index 000000000..42fd2ef9f --- /dev/null +++ b/src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cuh @@ -0,0 +1,32 @@ +#ifndef GPTQ_MARLIN_CUH +#define GPTQ_MARLIN_CUH +#include + +namespace gptq_marlin { +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +void gptq_marlin_mm_bf16(void *c, const void *a, const void *b, const void *scale, + int m, int n, int k, + void *workspace, int num_bits, + int num_groups, int group_size, + int device_id, cudaStream_t stream); +void gptq_marlin_mm_fp16(void *c, const void *a, const void *b, const void *scale, + int m, int n, int k, + void *workspace, int num_bits, + int num_groups, int group_size, + int device_id, cudaStream_t stream); + +} // namespace gptq_marlin + +#endif diff --git a/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu b/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu new file mode 100644 index 000000000..1550a96c4 --- /dev/null +++ b/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu @@ -0,0 +1,94 @@ +#include "../../../devices/cuda/cuda_common.cuh" +#include "gptq_marlin.cuh" +#include "quantize_gptq_cuda.cuh" +#include +#ifdef NDEBUG +#define SAFE_ASSERT(x) ((void)(x)) +#else +#define SAFE_ASSERT(x) assert(x) +#endif +namespace op::quantize_gptq::cuda { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t packed_weights_desc, + infiniopTensorDescriptor_t b_scale_desc, + infiniopTensorDescriptor_t zero_desc) { + auto result = MatmulGptqInfo::createMatmulGptqInfo(c_desc, a_desc, packed_weights_desc, b_scale_desc, zero_desc); + CHECK_RESULT(result); + MatmulGptqInfo info = result.take(); + int max_par = gptq_marlin::max_par; + size_t min_workspace_size = info.n / gptq_marlin::min_thread_n * max_par * sizeof(int) + info.m * info.k * infiniSizeOf(info.atype); + + *desc_ptr = new Descriptor(info, new Opaque{reinterpret_cast(handle)->internal()}, min_workspace_size, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::quant( + void *workspace, + size_t workspace_size, + void *packed_weights, + void *b_scale, + void *zero, + const void *a, + const void *b, + void *stream) const { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *c, + const void *a, + void *packed_weights, + void *b_scale, + void *zero, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + int m = int(_info.m); + int n = int(_info.n); + int k = int(_info.k); + int bits = 4; + int group_size = int(_info.group_size); + int num_groups = int(_info.num_groups); + bool is_weight_transposed = _info.is_weight_transposed; + if (_info.atype == INFINI_DTYPE_F16 && !is_weight_transposed) { + gptq_marlin::gptq_marlin_mm_fp16(c, a, packed_weights, b_scale, + m, n, k, + workspace, bits, + num_groups, group_size, + this->device_id, (cudaStream_t)stream); + + } else if (_info.atype == INFINI_DTYPE_BF16 && !is_weight_transposed) { + gptq_marlin::gptq_marlin_mm_bf16(c, a, packed_weights, b_scale, + m, n, k, + workspace, bits, + num_groups, group_size, + this->device_id, (cudaStream_t)stream); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::quantize_gptq::cuda diff --git a/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cuh b/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cuh new file mode 100644 index 000000000..4de0fc109 --- /dev/null +++ b/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cuh @@ -0,0 +1,8 @@ +#ifndef __QUANTIZE_GPTQ_CUDA_H__ +#define __QUANTIZE_GPTQ_CUDA_H__ + +#include "../quantize_gptq.h" + +DESCRIPTOR(cuda) + +#endif // __QUANTIZE_GPTQ_CUDA_H__ diff --git a/src/infiniop/ops/quantize_gptq/operator.cc b/src/infiniop/ops/quantize_gptq/operator.cc new file mode 100644 index 000000000..6393b46da --- /dev/null +++ b/src/infiniop/ops/quantize_gptq/operator.cc @@ -0,0 +1,126 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/quantize_gptq.h" + +#ifdef ENABLE_CPU_API +#include "cpu/quantize_gptq_cpu.h" +#endif +#ifdef ENABLE_CUDA_API +#include "cuda/quantize_gptq_cuda.cuh" +#endif + +__C infiniStatus_t infiniopCreateQuantizeGPTQDescriptor(infiniopHandle_t handle, + infiniopQuantizeGPTQDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t packed_weights_desc, + infiniopTensorDescriptor_t b_scale_desc, + infiniopTensorDescriptor_t zero_desc) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::quantize_gptq::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + c_desc, \ + a_desc, \ + packed_weights_desc, \ + b_scale_desc, \ + zero_desc); + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_CUDA_API + CREATE(INFINI_DEVICE_NVIDIA, cuda) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__C infiniStatus_t infiniopGetQuantizeGPTQWorkspaceSize(infiniopQuantizeGPTQDescriptor_t desc, size_t *size) { + switch (desc->device_type) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->minWorkspaceSize(); \ + return INFINI_STATUS_SUCCESS; +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_CUDA_API + GET(INFINI_DEVICE_NVIDIA, cuda) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__C infiniStatus_t infiniopQuantizeGPTQ(infiniopQuantizeGPTQDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *packed_weights, + void *b_scale, + void *zero, + const void *a, + const void *b, + void *stream) { +#define QUANT(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->quant( \ + workspace, workspace_size, packed_weights, b_scale, zero, a, b, stream); + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + QUANT(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_CUDA_API + QUANT(INFINI_DEVICE_NVIDIA, cuda) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__C infiniStatus_t infiniopQuantizeLinearGPTQ(infiniopQuantizeGPTQDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + void *packed_weights, + void *b_scale, + void *zero, + void *stream) { +#define CACULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, c, a, packed_weights, b_scale, zero, stream); + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CACULATE(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_CUDA_API + CACULATE(INFINI_DEVICE_NVIDIA, cuda) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__C infiniStatus_t infiniopDestroyQuantizeGPTQDescriptor(infiniopQuantizeGPTQDescriptor_t desc) { +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DESTROY(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_CUDA_API + DESTROY(INFINI_DEVICE_NVIDIA, cuda) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} diff --git a/src/infiniop/ops/quantize_gptq/quantize_gptq.h b/src/infiniop/ops/quantize_gptq/quantize_gptq.h new file mode 100644 index 000000000..9a226629a --- /dev/null +++ b/src/infiniop/ops/quantize_gptq/quantize_gptq.h @@ -0,0 +1,156 @@ +#ifndef __QUANTIZE_GPTQ_H__ +#define __QUANTIZE_GPTQ_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::quantize_gptq::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + MatmulGptqInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor(MatmulGptqInfo info, Opaque *opaque, \ + size_t workspace_size, \ + infiniDevice_t device_type, int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), _info(info), _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t minWorkspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t c_desc, \ + infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t packed_weights_desc, \ + infiniopTensorDescriptor_t b_scale_desc, \ + infiniopTensorDescriptor_t zero_desc); \ + \ + infiniStatus_t quant( \ + void *workspace, size_t workspace_size, \ + void *packed_weights, void *b_scale, void *zero, \ + const void *a, const void *b, void *stream) const; \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *c, const void *a, \ + void *packed_weights, void *b_scale, \ + void *zero, void *stream) const; \ + }; \ + } + +class MatmulGptqInfo { +private: + MatmulGptqInfo() = default; + +public: + infiniDtype_t atype, packed_weights_type; + size_t m, k, n, num_groups, block_size; + ptrdiff_t group_size; + bool is_weight_transposed; + + static utils::Result createMatmulGptqInfo( + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t packed_weights_desc, + infiniopTensorDescriptor_t b_scale_desc, + infiniopTensorDescriptor_t zero_desc) { + + CHECK_OR_RETURN( + c_desc != nullptr && a_desc != nullptr && packed_weights_desc != nullptr && b_scale_desc != nullptr && zero_desc != nullptr, + INFINI_STATUS_NULL_POINTER); + + const infiniDtype_t atype = a_desc->dtype(); + const infiniDtype_t packed_weights_type = packed_weights_desc->dtype(); + CHECK_OR_RETURN(atype == c_desc->dtype() && atype == b_scale_desc->dtype() && atype == zero_desc->dtype(), + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(atype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); + CHECK_DTYPE(packed_weights_type, INFINI_DTYPE_I32); + + CHECK_OR_RETURN(c_desc->ndim() == 2 + && a_desc->ndim() == 2 + && packed_weights_desc->ndim() == 2 + && b_scale_desc->ndim() == 2 + && zero_desc->ndim() == 2, + INFINI_STATUS_BAD_TENSOR_SHAPE); + bool is_weight_transposed = false; + size_t m = 1; + size_t k = 1; + size_t n = 1; + size_t num_groups = 1; + CHECK_OR_RETURN(c_desc->dim(0) == a_desc->dim(0) + || c_desc->dim(1) == a_desc->dim(1), + INFINI_STATUS_BAD_TENSOR_SHAPE); + + if (c_desc->dim(0) == a_desc->dim(0)) { + if (c_desc->dim(1) == a_desc->dim(1)) { + if (packed_weights_desc->dim(0) * 8 == packed_weights_desc->dim(1)) { + is_weight_transposed = true; + m = c_desc->dim(0); + n = c_desc->dim(1); + k = a_desc->dim(1); + num_groups = b_scale_desc->dim(0); + } else if (packed_weights_desc->dim(0) == packed_weights_desc->dim(1) * 8) { + is_weight_transposed = false; + m = c_desc->dim(1); + n = c_desc->dim(0); + k = a_desc->dim(0); + num_groups = b_scale_desc->dim(1); + } + } else { + is_weight_transposed = true; + m = c_desc->dim(0); + n = c_desc->dim(1); + k = a_desc->dim(1); + num_groups = b_scale_desc->dim(0); + } + + } else { // c_desc->dim(0) != a_desc->dim(0) + if (c_desc->dim(1) == a_desc->dim(1)) { + is_weight_transposed = false; + m = c_desc->dim(1); + n = c_desc->dim(0); + k = a_desc->dim(0); + num_groups = b_scale_desc->dim(1); + } + } + + size_t block_size = 128; + ptrdiff_t group_size = num_groups > 1 ? static_cast(k) / static_cast(num_groups) : -1; + const size_t k_8 = k / 8; + if (is_weight_transposed) { + CHECK_OR_RETURN(m == a_desc->dim(0) + && num_groups == zero_desc->dim(0) + && n == b_scale_desc->dim(1) && n == zero_desc->dim(1) + && n == packed_weights_desc->dim(1) && k_8 == packed_weights_desc->dim(0), + INFINI_STATUS_BAD_TENSOR_SHAPE); + } else { + CHECK_OR_RETURN(m == a_desc->dim(1) + && num_groups == zero_desc->dim(1) + && n == b_scale_desc->dim(0) && n == zero_desc->dim(0) + && n == packed_weights_desc->dim(0) && k_8 == packed_weights_desc->dim(1), + INFINI_STATUS_BAD_TENSOR_SHAPE); + } + + return utils::Result(MatmulGptqInfo{ + atype, + packed_weights_type, + m, + k, + n, + num_groups, + block_size, + group_size, + is_weight_transposed, + }); + } +}; + +#endif // __QUANTIZE_GPTQ_H__ diff --git a/test/infiniop/quantize_gptq.py b/test/infiniop/quantize_gptq.py new file mode 100644 index 000000000..5d855773c --- /dev/null +++ b/test/infiniop/quantize_gptq.py @@ -0,0 +1,554 @@ +import torch +import torch.nn as nn +import math +import ctypes +from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float +from libinfiniop import ( + infiniopHandle_t, + infiniopTensorDescriptor_t, + open_lib, + to_tensor, + get_test_devices, + check_error, + rearrange_if_needed, + create_workspace, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, +) + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +is_weight_transposed = False + +_TEST_CASES = [] + +MODELS = { + "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], + # "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], + # "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], + # "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], +} + +# Loop through models and layers to generate the new _TEST_CASES +for _, layers in MODELS.items(): + for layer in layers: + for batch in [1, 16]: + _TEST_CASES.append(((batch, layer[0], layer[1], is_weight_transposed))) + +# Data types used for testing +_TENSOR_DTYPES = [torch.float16] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + torch.float16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +# ============================================================================== +# Definitions +# ============================================================================== +class QuantizeGPTQDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopQuantizeGPTQDescriptor_t = POINTER(QuantizeGPTQDescriptor) + + +def quantize(x, scale, zero, maxq): + if scale.shape[1] == 1: + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + else: + group_size = x.shape[1] // scale.shape[1] + y = torch.zeros_like(x) + for j in range(scale.shape[1]): + q = torch.clamp( + torch.round( + x[:, j * group_size : (j + 1) * group_size] / scale[:, j : j + 1] + ) + + zero[:, j : j + 1], + 0, + maxq, + ) + y[:, j * group_size : (j + 1) * group_size] = scale[:, j : j + 1] * ( + q - zero[:, j : j + 1] + ) + return y + + +class Quantizer(nn.Module): + + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer("maxq", torch.tensor(0)) + self.register_buffer("scale", torch.zeros(shape)) + self.register_buffer("zero", torch.zeros(shape)) + + def configure( + self, + bits=4, + perchannel=False, + sym=True, + mse=False, + norm=2.4, + grid=100, + maxshrink=0.8, + ): + self.maxq = torch.tensor(2**bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float("inf"), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + # self.scale = self.scale.unsqueeze(1) + # self.zero = self.zero.unsqueeze(1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return quantize(x, self.scale, self.zero, self.maxq) + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + + +class GPTQ: + + def __init__(self, weight): + self.weight = weight + self.dev = self.weight.device + + self.rows = self.weight.shape[0] + self.columns = self.weight.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + + def add_batch(self, inp, out): + + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) + + def fasterquant(self, blocksize=128, percdamp=0.01, group_size=-1): + W = self.weight.clone() + + W = W.float() + + # tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H.to("cpu")).to( + H.device + ) # 对于CUDA来说,这个地方直接在CUDA上做cholesky分解可能会失败 + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + num_groups = self.columns // group_size + if group_size == -1: + scale = self.quantizer.scale.clone() + zero = self.quantizer.zero.clone() + else: + scale = torch.zeros(self.rows, num_groups) + zero = torch.zeros(self.rows, num_groups) + for index in range(self.columns // blocksize): + i1 = index * blocksize + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if group_size != -1: + if (i1 + i) % group_size == 0: + self.quantizer.find_params( + W[:, (i1 + i) : (i1 + i + group_size)], weight=True + ) + ind = index * blocksize // group_size + i // group_size + + scale[:, ind : ind + 1] = self.quantizer.scale + zero[:, ind : ind + 1] = self.quantizer.zero + + q = quantize( + w.unsqueeze(1), + self.quantizer.scale, + self.quantizer.zero, + self.quantizer.maxq, + ).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + # print('error', torch.sum(Losses).item()) + + self.weight = Q.reshape(self.weight.shape).to(self.weight.dtype) + self.scale = scale.to(self.weight.dtype) + self.zero = zero.to(self.weight.dtype) + + +def get_scale_zero(b, a, c, group_size): + weight = b.clone() + inp = a.clone() + out = c.clone() + gptq = GPTQ(weight) + gptq.quantizer = Quantizer() + gptq.quantizer.configure(perchannel=True, sym=False, mse=False) + gptq.add_batch(inp, out) + gptq.fasterquant(group_size=group_size) + + return ( + gptq.weight.to(weight.device), + gptq.scale.to(weight.device), + gptq.zero.to(weight.device), + ) + + +def pack(weight, scale, zero): + intweight = torch.round((weight + zero) / scale).to(torch.int32) + qweight = torch.zeros( + [weight.shape[0], weight.shape[1] // 8], dtype=torch.int32, device=weight.device + ) + for i in range(intweight.shape[1]): + qweight[:, i // 8] |= intweight[:, i] << (4 * (i % 8)) + return qweight + + +# PyTorch implementation for matrix multiplication +def quantize_gptq(a, b): # 昇腾芯片的CPU不支持转置计算 + ans = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(b.dtype) + return ans + + +# The argument list should be (lib, handle, torch_device, , dtype) +# The should keep the same order as the one specified in _TEST_CASES +def test( + lib, + handle, + torch_device, + M, + K, + N, + is_weight_transposed, + dtype=torch.float16, + sync=None, +): + print( + f"Testing QuantizeGPTQ on {torch_device}" f" M:{M}, K:{K}, N:{N}, dtype:{dtype}" + ) + torch.manual_seed(12) + # Initialize tensors + a = 1e0 * torch.randn([K, M], dtype=dtype).to(torch_device) + layer = nn.Linear(K, N) + b = 1e-3 * layer.weight.data.to(dtype).to(torch_device) + c = torch.zeros([N, M], dtype=dtype).to(torch_device) + + group_size = -1 + num_groups = 1 + if group_size == -1: + num_groups = 1 + else: + num_groups = K // group_size + if is_weight_transposed: + ans = quantize_gptq(a.t(), b.t()) + else: + ans = quantize_gptq(b, a) + packed_weights = torch.zeros([N, K // 8], dtype=torch.int32).to(torch_device) + s = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) + z = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) + if torch_device == "cuda": + b_ref, s, z = get_scale_zero(b, a.t(), c, group_size) + z = torch.zeros_like(s) + packed_weights = pack(b_ref, s, z) + # print(s) + + a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = ( + to_tensor(a, lib), + to_tensor(b, lib), + to_tensor(c, lib), + to_tensor(s, lib), + to_tensor(z, lib), + to_tensor(packed_weights, lib), + ) + + descriptor = infiniopQuantizeGPTQDescriptor_t() + check_error( + lib.infiniopCreateQuantizeGPTQDescriptor( + handle, + ctypes.byref(descriptor), + c_tensor.descriptor, + a_tensor.descriptor, + packed_weights_tensor.descriptor, + s_tensor.descriptor, + z_tensor.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [ + a_tensor, + b_tensor, + c_tensor, + s_tensor, + z_tensor, + packed_weights_tensor, + ]: + tensor.destroyDesc(lib) + + # Get workspace size and create workspace + workspace_size = c_uint64(0) + check_error( + lib.infiniopGetQuantizeGPTQWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = create_workspace(workspace_size.value, a.device) + + # Execute infiniop quantize_gptq operator + check_error( + lib.infiniopQuantizeGPTQ( + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + packed_weights_tensor.data, + s_tensor.data, + z_tensor.data, + a_tensor.data, + b_tensor.data, + None, + ) + ) + + def lib_quantize_gptq(): + check_error( + lib.infiniopQuantizeLinearGPTQ( + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + c_tensor.data, + a_tensor.data, + packed_weights_tensor.data, + s_tensor.data, + z_tensor.data, + None, + ) + ) + + lib_quantize_gptq() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + # tmpa = ans.flatten() + # tmpc = c.flatten() + # for i in range(tmpa.shape[0]): + # if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]): + # print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i])) + # break + + if DEBUG: + debug(c, ans, atol=atol, rtol=rtol) + assert torch.allclose(c, ans, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + if(is_weight_transposed): + profile_operation("PyTorch", lambda: quantize_gptq(a.t(), b.t()), torch_device, NUM_PRERUN, NUM_ITERATIONS) + else: + profile_operation("PyTorch", lambda: quantize_gptq(b, a), torch_device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_quantize_gptq(), torch_device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(lib.infiniopDestroyQuantizeGPTQDescriptor(descriptor)) + + +# ============================================================================== +# Main Execution +# ============================================================================== +if __name__ == "__main__": + args = get_args() + lib = open_lib() + + lib.infiniopCreateQuantizeGPTQDescriptor.restype = c_int32 + lib.infiniopCreateQuantizeGPTQDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopQuantizeGPTQDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetQuantizeGPTQWorkspaceSize.restype = c_int32 + lib.infiniopGetQuantizeGPTQWorkspaceSize.argtypes = [ + infiniopQuantizeGPTQDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopQuantizeGPTQ.restype = c_int32 + lib.infiniopQuantizeGPTQ.argtypes = [ + infiniopQuantizeGPTQDescriptor_t, + c_void_p, + c_uint64, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopQuantizeLinearGPTQ.restype = c_int32 + lib.infiniopQuantizeLinearGPTQ.argtypes = [ + infiniopQuantizeGPTQDescriptor_t, + c_void_p, + c_uint64, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyQuantizeGPTQDescriptor.restype = c_int32 + lib.infiniopDestroyQuantizeGPTQDescriptor.argtypes = [ + infiniopQuantizeGPTQDescriptor_t, + ] + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/xmake/cpu.lua b/xmake/cpu.lua index 22dc8f8e7..d1ee10dbf 100644 --- a/xmake/cpu.lua +++ b/xmake/cpu.lua @@ -11,6 +11,13 @@ target("infiniop-cpu") add_cxflags("/openmp") end else + if is_arch("x86_64", "i386") then + -- x86 架构(启用 AVX2/FMA) + add_cxxflags("-mavx2", "-mfma", "-O3") + elseif is_arch("arm64", "arm.*") then + -- ARM 架构(启用 NEON) + add_cxxflags("-O3", "-mcpu=generic+simd") -- ARMv8+NEON + end add_cxflags("-fPIC", "-Wno-unknown-pragmas") if has_config("omp") then add_cxflags("-fopenmp") From 096233a108bc0612a1d32ff6ca921ed03590a1fb Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Thu, 29 May 2025 14:06:40 +0800 Subject: [PATCH 2/8] issue/170: add signed quant --- .../quantize_gptq/cpu/quantize_gptq_cpu.cc | 67 +++++++++++++------ test/infiniop/quantize_gptq.py | 37 ++++++---- 2 files changed, 71 insertions(+), 33 deletions(-) diff --git a/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc b/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc index e23ef1177..6f021df10 100644 --- a/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc +++ b/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc @@ -42,17 +42,25 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle_, Descriptor **desc_pt return INFINI_STATUS_SUCCESS; } -float quantize(float x, float s, float z, float maxq) { +float quantize(float x, float s, float z, float minq, float maxq) { float q = std::roundf(x / s + z); - q = std::max(0.0f, std::min(maxq, q)); + q = std::max(minq, std::min(maxq, q)); return s * (q - z); } template void find_params(T *x, T *b_scale, T *zero, int N, int K, int bits = 4, bool sym = false, bool mse = false, - float norm = 2.4f, int grid = 100, float maxshrink = 0.8f) { - float maxq = static_cast(std::pow(2, bits) - 1); + float norm = 2.4f, int grid = 100, float maxshrink = 0.8f, bool sign_ed = false) { + float maxq; + float minq; + if (sign_ed) { // 如果有符号量化 + maxq = static_cast(std::pow(2, bits - 1) - 1); + minq = -static_cast(std::pow(2, bits - 1)); + } else { + maxq = static_cast(std::pow(2, bits) - 1); + minq = 0.0f; + } #pragma omp parallel for for (int n = 0; n < N; n++) { float x_min = FLT_MAX; @@ -76,16 +84,16 @@ void find_params(T *x, T *b_scale, T *zero, int N, int K, x_max = 1; } if constexpr (std::is_same::value) { - b_scale[n] = utils::cast((x_max - x_min) / maxq); + b_scale[n] = utils::cast((x_max - x_min) / (maxq - minq)); if (sym) { - zero[n] = utils::cast((maxq + 1.0f) * 0.5f); + zero[n] = utils::cast((maxq + minq + 1.0f) * 0.5f); } else { - zero[n] = utils::cast(-x_min * maxq / (x_max - x_min)); + zero[n] = utils::cast(-x_min * (maxq - minq) / (x_max - x_min)); } } else if constexpr (std::is_same::value) { - b_scale[n] = (x_max - x_min) / maxq; + b_scale[n] = (x_max - x_min) / (maxq - minq); if (sym) { - zero[n] = (maxq + 1.0f) * 0.5f; + zero[n] = (maxq + minq + 1.0f) * 0.5f; } else { zero[n] = -x_min / b_scale[n]; } @@ -96,11 +104,11 @@ void find_params(T *x, T *b_scale, T *zero, int N, int K, float p = 1 - static_cast(i) / static_cast(grid); float x_min_1 = p * x_min; float x_max_1 = p * x_max; - float scale_1 = (x_max_1 - x_min_1) / maxq; + float scale_1 = (x_max_1 - x_min_1) / (maxq - minq); float zero_1 = (sym ? utils::cast(zero[n]) : std::roundf(-x_min_1 / scale_1)); float err = 0.0f; for (int k = 0; k < K; k++) { - float q = quantize(utils::cast(x[n * K + k]), scale_1, zero_1, maxq); + float q = quantize(utils::cast(x[n * K + k]), scale_1, zero_1, minq, maxq); q -= utils::cast(x[n * K + k]); q = std::abs(q); q = static_cast(std::pow(q, norm)); @@ -344,12 +352,20 @@ void fasterquant(T *weight, T *Q, float *Err, T *b_scale, T *zero, float *Hess, int M, int K, int N, int block_size = 128, float percdamp = 0.01, int group_size = -1, int bits = 4, bool sym = false, bool mse = false, - float norm = 2.4, int grid = 100, float maxshrink = 0.8) { - float maxq = static_cast(std::pow(2, bits) - 1); + float norm = 2.4, int grid = 100, float maxshrink = 0.8, bool sign_ed = false) { + float maxq; + float minq; + if (sign_ed) { // 如果有符号量化 + maxq = static_cast(std::pow(2, bits - 1) - 1); + minq = -static_cast(std::pow(2, bits - 1)); + } else { + maxq = static_cast(std::pow(2, bits) - 1); + minq = 0.0f; + } int num_groups = (group_size == -1 ? 1 : K / group_size); if (group_size == -1) { - find_params(weight, b_scale, zero, N, K, bits, sym, mse, norm, grid, maxshrink); + find_params(weight, b_scale, zero, N, K, bits, sym, mse, norm, grid, maxshrink, sign_ed); } float damp = 0.0f; @@ -388,13 +404,13 @@ void fasterquant(T *weight, T *Q, float *Err, T *b_scale, T *zero, float *Hess, if ((index * block_size + i) % group_size == 0) { int ind = (index * block_size + i) / group_size; for (int n = 0; n < N; n++) { - find_params(&weight[n * K + index * block_size + i], &b_scale[n * num_groups + ind], &zero[n * num_groups + ind], 1, group_size, bits, sym, mse, norm, grid, maxshrink); + find_params(&weight[n * K + index * block_size + i], &b_scale[n * num_groups + ind], &zero[n * num_groups + ind], 1, group_size, bits, sym, mse, norm, grid, maxshrink, sign_ed); } } } int ind = (group_size != -1 ? (index * block_size + i) / group_size : 0); for (int n = 0; n < N; n++) { - float q = quantize(utils::cast(weight[n * K + index * block_size + i]), utils::cast(b_scale[n * num_groups + ind]), utils::cast(zero[n * num_groups + ind]), maxq); + float q = quantize(utils::cast(weight[n * K + index * block_size + i]), utils::cast(b_scale[n * num_groups + ind]), utils::cast(zero[n * num_groups + ind]), minq, maxq); if constexpr (std::is_same::value) { Q[n * K + index * block_size + i] = utils::cast(q); } else if constexpr (std::is_same::value) { @@ -435,8 +451,16 @@ void fasterquant(T *weight, T *Q, float *Err, T *b_scale, T *zero, float *Hess, } void PackQuantizedWeight(fp16_t *Q, fp16_t *b_scale, fp16_t *zero, - int32_t *packed_weight, int K, int N, int group_size, int bits = 4) { - int maxq = int(std::pow(2, bits) - 1); + int32_t *packed_weight, int K, int N, int group_size, int bits = 4, bool sign_ed = false) { + int maxq; + int minq; + if (sign_ed) { // 如果有符号量化 + maxq = int(std::pow(2, bits - 1) - 1); + minq = -int(std::pow(2, bits - 1)); + } else { + maxq = int(std::pow(2, bits) - 1); + minq = 0; + } int num_groups = (group_size == -1) ? 1 : K / group_size; int blocks_per_group = (group_size == -1) ? K / 8 : group_size / 8; @@ -458,7 +482,7 @@ void PackQuantizedWeight(fp16_t *Q, fp16_t *b_scale, fp16_t *zero, int k = row_base + i; float val = utils::cast(Q[n * K + k]); // Q: [N, K] int q = static_cast(std::roundf(val / scale + zero_f)); - q = std::max(0, std::min(maxq, q)); // clamp to [0, maxq] + q = std::max(minq, std::min(maxq, q)); // clamp to [minq, maxq] packed |= (q & 0xF) << (i * 4); } @@ -518,6 +542,7 @@ void quantWeights(void *workspace, int32_t *packed_weights, int grid = 100; float maxshrink = 0.8f; float nsamples = 0.0f; + bool sign_ed = false; char *tmp = (char *)workspace + (K * K + N * block_size) * sizeof(float); float *Hess = (float *)workspace; //[K, K] @@ -535,9 +560,9 @@ void quantWeights(void *workspace, int32_t *packed_weights, M, K, N, block_size, percdamp, group_size, bits, sym, mse, - norm, grid, maxshrink); + norm, grid, maxshrink, sign_ed); - PackQuantizedWeight(Q, b_scale, zero, packed_weights, K, N, group_size, bits); + PackQuantizedWeight(Q, b_scale, zero, packed_weights, K, N, group_size, bits, sign_ed); } void caculate(void *workspace, fp16_t *C, const fp16_t *A, diff --git a/test/infiniop/quantize_gptq.py b/test/infiniop/quantize_gptq.py index 5d855773c..e490f948e 100644 --- a/test/infiniop/quantize_gptq.py +++ b/test/infiniop/quantize_gptq.py @@ -64,9 +64,9 @@ class QuantizeGPTQDescriptor(Structure): infiniopQuantizeGPTQDescriptor_t = POINTER(QuantizeGPTQDescriptor) -def quantize(x, scale, zero, maxq): +def quantize(x, scale, zero, minq, maxq): if scale.shape[1] == 1: - q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + q = torch.clamp(torch.round(x / scale) + zero, minq, maxq) return scale * (q - zero) else: group_size = x.shape[1] // scale.shape[1] @@ -77,7 +77,7 @@ def quantize(x, scale, zero, maxq): x[:, j * group_size : (j + 1) * group_size] / scale[:, j : j + 1] ) + zero[:, j : j + 1], - 0, + minq, maxq, ) y[:, j * group_size : (j + 1) * group_size] = scale[:, j : j + 1] * ( @@ -91,6 +91,7 @@ class Quantizer(nn.Module): def __init__(self, shape=1): super(Quantizer, self).__init__() self.register_buffer("maxq", torch.tensor(0)) + self.register_buffer("minq", torch.tensor(0)) self.register_buffer("scale", torch.zeros(shape)) self.register_buffer("zero", torch.zeros(shape)) @@ -103,8 +104,14 @@ def configure( norm=2.4, grid=100, maxshrink=0.8, + sign_ed=False, ): - self.maxq = torch.tensor(2**bits - 1) + if sign_ed: # 有符号量化,范围是[-8,7] + self.maxq = torch.tensor(2 ** (bits - 1) - 1) + self.minq = -torch.tensor(2 ** (bits - 1)) + else: # 无符号量化,范围是[0,15] + self.maxq = torch.tensor(2**bits - 1) + self.minq = -torch.tensor(0) self.perchannel = perchannel self.sym = sym self.mse = mse @@ -115,6 +122,7 @@ def configure( def find_params(self, x, weight=False): dev = x.device self.maxq = self.maxq.to(dev) + self.minq = self.minq.to(dev) shape = x.shape if self.perchannel: @@ -139,9 +147,9 @@ def find_params(self, x, weight=False): xmin[tmp] = -1 xmax[tmp] = +1 - self.scale = (xmax - xmin) / self.maxq + self.scale = (xmax - xmin) / (self.maxq - self.minq) if self.sym: - self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + self.zero = torch.full_like(self.scale, (self.maxq + self.minq + 1) / 2) else: self.zero = torch.round(-xmin / self.scale) @@ -151,9 +159,11 @@ def find_params(self, x, weight=False): p = 1 - i / self.grid xmin1 = p * xmin xmax1 = p * xmax - scale1 = (xmax1 - xmin1) / self.maxq + scale1 = (xmax1 - xmin1) / (self.maxq - self.minq) zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q = quantize( + x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.minq, self.maxq + ) q -= x q.abs_() q.pow_(self.norm) @@ -190,7 +200,7 @@ def find_params(self, x, weight=False): def quantize(self, x): if self.ready(): - return quantize(x, self.scale, self.zero, self.maxq) + return quantize(x, self.scale, self.zero, self.minq, self.maxq) return x def enabled(self): @@ -292,6 +302,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, group_size=-1): w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, + self.quantizer.minq, self.quantizer.maxq, ).flatten() Q1[:, i] = q @@ -313,13 +324,13 @@ def fasterquant(self, blocksize=128, percdamp=0.01, group_size=-1): self.zero = zero.to(self.weight.dtype) -def get_scale_zero(b, a, c, group_size): +def get_scale_zero(b, a, c, group_size, sign_ed): weight = b.clone() inp = a.clone() out = c.clone() gptq = GPTQ(weight) gptq.quantizer = Quantizer() - gptq.quantizer.configure(perchannel=True, sym=False, mse=False) + gptq.quantizer.configure(perchannel=True, sym=False, mse=False, signed=sign_ed) gptq.add_batch(inp, out) gptq.fasterquant(group_size=group_size) @@ -383,7 +394,9 @@ def test( s = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) z = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) if torch_device == "cuda": - b_ref, s, z = get_scale_zero(b, a.t(), c, group_size) + b_ref, s, z = get_scale_zero( + b, a.t(), c, group_size, signed=False + ) # 无符号量化 z = torch.zeros_like(s) packed_weights = pack(b_ref, s, z) # print(s) From 8c2b7ec3021fd591124cb57a5ab5792a2f8da1a5 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Fri, 30 May 2025 11:40:02 +0800 Subject: [PATCH 3/8] issue/170: modified pack py --- test/infiniop/quantize_gptq.py | 77 +++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 29 deletions(-) diff --git a/test/infiniop/quantize_gptq.py b/test/infiniop/quantize_gptq.py index e490f948e..63515c9d3 100644 --- a/test/infiniop/quantize_gptq.py +++ b/test/infiniop/quantize_gptq.py @@ -317,20 +317,22 @@ def fasterquant(self, blocksize=128, percdamp=0.01, group_size=-1): W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - # print('error', torch.sum(Losses).item()) + print("error", torch.sum(Losses).item()) self.weight = Q.reshape(self.weight.shape).to(self.weight.dtype) self.scale = scale.to(self.weight.dtype) self.zero = zero.to(self.weight.dtype) -def get_scale_zero(b, a, c, group_size, sign_ed): +def get_scale_zero(b, a, c, group_size, bits, sign_ed): weight = b.clone() inp = a.clone() out = c.clone() gptq = GPTQ(weight) gptq.quantizer = Quantizer() - gptq.quantizer.configure(perchannel=True, sym=False, mse=False, signed=sign_ed) + gptq.quantizer.configure( + bits=bits, perchannel=True, sym=False, mse=False, sign_ed=sign_ed + ) gptq.add_batch(inp, out) gptq.fasterquant(group_size=group_size) @@ -341,8 +343,10 @@ def get_scale_zero(b, a, c, group_size, sign_ed): ) -def pack(weight, scale, zero): - intweight = torch.round((weight + zero) / scale).to(torch.int32) +def pack(weight, scale, zero, minq, maxq): + intweight = torch.clamp(torch.round(weight / scale + zero), minq, maxq).to( + torch.int32 + ) qweight = torch.zeros( [weight.shape[0], weight.shape[1] // 8], dtype=torch.int32, device=weight.device ) @@ -377,7 +381,7 @@ def test( # Initialize tensors a = 1e0 * torch.randn([K, M], dtype=dtype).to(torch_device) layer = nn.Linear(K, N) - b = 1e-3 * layer.weight.data.to(dtype).to(torch_device) + b = 1e0 * layer.weight.data.to(dtype).to(torch_device) c = torch.zeros([N, M], dtype=dtype).to(torch_device) group_size = -1 @@ -393,13 +397,28 @@ def test( packed_weights = torch.zeros([N, K // 8], dtype=torch.int32).to(torch_device) s = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) z = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) + sign_ed = False + bits = 4 + maxq = 2**bits - 1 + minq = 0 + if sign_ed: # 有符号量化,范围是[-8,7] + maxq = 2 ** (bits - 1) - 1 + minq = -(2 ** (bits - 1)) + sym = False + if torch_device == "cuda": b_ref, s, z = get_scale_zero( - b, a.t(), c, group_size, signed=False + b, a.t(), c, group_size, bits, sign_ed=sign_ed ) # 无符号量化 - z = torch.zeros_like(s) - packed_weights = pack(b_ref, s, z) - # print(s) + + packed_weights = pack(b_ref, s, z, minq, maxq) + + if torch_device == "cpu": + b_ref, s, z = get_scale_zero( + b, a.t(), c, group_size, bits, sign_ed=sign_ed + ) # 无符号量化 + + packed_weights = pack(b_ref, s, z, minq, maxq) a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = ( to_tensor(a, lib), @@ -444,19 +463,19 @@ def test( workspace = create_workspace(workspace_size.value, a.device) # Execute infiniop quantize_gptq operator - check_error( - lib.infiniopQuantizeGPTQ( - descriptor, - workspace.data_ptr() if workspace is not None else None, - workspace_size.value, - packed_weights_tensor.data, - s_tensor.data, - z_tensor.data, - a_tensor.data, - b_tensor.data, - None, - ) - ) + # check_error( + # lib.infiniopQuantizeGPTQ( + # descriptor, + # workspace.data_ptr() if workspace is not None else None, + # workspace_size.value, + # packed_weights_tensor.data, + # s_tensor.data, + # z_tensor.data, + # a_tensor.data, + # b_tensor.data, + # None, + # ) + # ) def lib_quantize_gptq(): check_error( @@ -476,12 +495,12 @@ def lib_quantize_gptq(): lib_quantize_gptq() atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) - # tmpa = ans.flatten() - # tmpc = c.flatten() - # for i in range(tmpa.shape[0]): - # if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]): - # print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i])) - # break + tmpa = ans.flatten() + tmpc = c.flatten() + for i in range(tmpa.shape[0]): + if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]): + print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i])) + break if DEBUG: debug(c, ans, atol=atol, rtol=rtol) From c909d0eaaea9a5175e8b7c2cb25c1a033a3389f9 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Fri, 30 May 2025 13:51:09 +0800 Subject: [PATCH 4/8] issue/170: debug marlin --- .../quantize_gptq/cuda/quantize_gptq_cuda.cu | 4 +- test/infiniop/quantize_gptq.py | 97 +++++++++++-------- 2 files changed, 57 insertions(+), 44 deletions(-) diff --git a/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu b/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu index 1550a96c4..50320b455 100644 --- a/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu +++ b/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu @@ -71,14 +71,14 @@ infiniStatus_t Descriptor::calculate( int group_size = int(_info.group_size); int num_groups = int(_info.num_groups); bool is_weight_transposed = _info.is_weight_transposed; - if (_info.atype == INFINI_DTYPE_F16 && !is_weight_transposed) { + if (_info.atype == INFINI_DTYPE_F16 && is_weight_transposed) { gptq_marlin::gptq_marlin_mm_fp16(c, a, packed_weights, b_scale, m, n, k, workspace, bits, num_groups, group_size, this->device_id, (cudaStream_t)stream); - } else if (_info.atype == INFINI_DTYPE_BF16 && !is_weight_transposed) { + } else if (_info.atype == INFINI_DTYPE_BF16 && is_weight_transposed) { gptq_marlin::gptq_marlin_mm_bf16(c, a, packed_weights, b_scale, m, n, k, workspace, bits, diff --git a/test/infiniop/quantize_gptq.py b/test/infiniop/quantize_gptq.py index 63515c9d3..7daeafd8b 100644 --- a/test/infiniop/quantize_gptq.py +++ b/test/infiniop/quantize_gptq.py @@ -23,7 +23,6 @@ # Configuration (Internal Use Only) # ============================================================================== # These are not meant to be imported from other modules -is_weight_transposed = False _TEST_CASES = [] @@ -38,7 +37,7 @@ for _, layers in MODELS.items(): for layer in layers: for batch in [1, 16]: - _TEST_CASES.append(((batch, layer[0], layer[1], is_weight_transposed))) + _TEST_CASES.append(((batch, layer[0], layer[1]))) # Data types used for testing _TENSOR_DTYPES = [torch.float16] @@ -324,14 +323,14 @@ def fasterquant(self, blocksize=128, percdamp=0.01, group_size=-1): self.zero = zero.to(self.weight.dtype) -def get_scale_zero(b, a, c, group_size, bits, sign_ed): +def get_scale_zero(b, a, c, group_size, bits, sym, sign_ed): weight = b.clone() inp = a.clone() out = c.clone() gptq = GPTQ(weight) gptq.quantizer = Quantizer() gptq.quantizer.configure( - bits=bits, perchannel=True, sym=False, mse=False, sign_ed=sign_ed + bits=bits, perchannel=True, sym=sym, mse=False, sign_ed=sign_ed ) gptq.add_batch(inp, out) gptq.fasterquant(group_size=group_size) @@ -370,7 +369,6 @@ def test( M, K, N, - is_weight_transposed, dtype=torch.float16, sync=None, ): @@ -381,8 +379,13 @@ def test( # Initialize tensors a = 1e0 * torch.randn([K, M], dtype=dtype).to(torch_device) layer = nn.Linear(K, N) - b = 1e0 * layer.weight.data.to(dtype).to(torch_device) + b = 1e-3 * layer.weight.data.to(dtype).to(torch_device) c = torch.zeros([N, M], dtype=dtype).to(torch_device) + is_weight_transposed = False + sign_ed = False + sym = False + if torch_device != "cpu": + is_weight_transposed = True group_size = -1 num_groups = 1 @@ -397,37 +400,45 @@ def test( packed_weights = torch.zeros([N, K // 8], dtype=torch.int32).to(torch_device) s = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) z = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) - sign_ed = False + bits = 4 maxq = 2**bits - 1 minq = 0 if sign_ed: # 有符号量化,范围是[-8,7] maxq = 2 ** (bits - 1) - 1 minq = -(2 ** (bits - 1)) - sym = False if torch_device == "cuda": b_ref, s, z = get_scale_zero( - b, a.t(), c, group_size, bits, sign_ed=sign_ed + b, a.t(), c, group_size, bits, sym, sign_ed=sign_ed ) # 无符号量化 packed_weights = pack(b_ref, s, z, minq, maxq) - if torch_device == "cpu": - b_ref, s, z = get_scale_zero( - b, a.t(), c, group_size, bits, sign_ed=sign_ed - ) # 无符号量化 - - packed_weights = pack(b_ref, s, z, minq, maxq) + # if torch_device == "cpu": + # b_ref, s, z = get_scale_zero( + # b, a.t(), c, group_size, bits, sym, sign_ed=sign_ed + # ) # 无符号量化 - a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = ( - to_tensor(a, lib), - to_tensor(b, lib), - to_tensor(c, lib), - to_tensor(s, lib), - to_tensor(z, lib), - to_tensor(packed_weights, lib), - ) + # packed_weights = pack(b_ref, s, z, minq, maxq) + if is_weight_transposed: + a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = ( + to_tensor(a.t(), lib), + to_tensor(b.t(), lib), + to_tensor(c.t(), lib), + to_tensor(s.t(), lib), + to_tensor(z.t(), lib), + to_tensor(packed_weights.t(), lib), + ) + else: + a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = ( + to_tensor(a, lib), + to_tensor(b, lib), + to_tensor(c, lib), + to_tensor(s, lib), + to_tensor(z, lib), + to_tensor(packed_weights, lib), + ) descriptor = infiniopQuantizeGPTQDescriptor_t() check_error( @@ -463,19 +474,19 @@ def test( workspace = create_workspace(workspace_size.value, a.device) # Execute infiniop quantize_gptq operator - # check_error( - # lib.infiniopQuantizeGPTQ( - # descriptor, - # workspace.data_ptr() if workspace is not None else None, - # workspace_size.value, - # packed_weights_tensor.data, - # s_tensor.data, - # z_tensor.data, - # a_tensor.data, - # b_tensor.data, - # None, - # ) - # ) + check_error( + lib.infiniopQuantizeGPTQ( + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + packed_weights_tensor.data, + s_tensor.data, + z_tensor.data, + a_tensor.data, + b_tensor.data, + None, + ) + ) def lib_quantize_gptq(): check_error( @@ -495,13 +506,15 @@ def lib_quantize_gptq(): lib_quantize_gptq() atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) - tmpa = ans.flatten() - tmpc = c.flatten() - for i in range(tmpa.shape[0]): - if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]): - print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i])) - break + # tmpa = ans.flatten() + # tmpc = c.flatten() + # for i in range(tmpa.shape[0]): + # if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]): + # print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i])) + # break + if is_weight_transposed: + c = c.t() if DEBUG: debug(c, ans, atol=atol, rtol=rtol) assert torch.allclose(c, ans, atol=atol, rtol=rtol) From 830daeb1a468e63f28619a8a118c1eb87d2f23e1 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Mon, 16 Jun 2025 14:09:31 +0800 Subject: [PATCH 5/8] issue/170: success marlin --- test/infiniop/quantize_gptq.py | 215 +++++++++++++++++++++++++++------ 1 file changed, 181 insertions(+), 34 deletions(-) diff --git a/test/infiniop/quantize_gptq.py b/test/infiniop/quantize_gptq.py index 7daeafd8b..16c47bd52 100644 --- a/test/infiniop/quantize_gptq.py +++ b/test/infiniop/quantize_gptq.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import numpy as np import math import ctypes from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float @@ -354,9 +355,170 @@ def pack(weight, scale, zero, minq, maxq): return qweight +def _get_perms(): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm) + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + perm = perm.reshape((-1, 8))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +_perm, _scale_perm, _scale_perm_single = _get_perms() + + +class MarlinLayer(nn.Module): + """PyTorch compatible Marlin layer; 4-bit (symmetric grouped) linear layer without bias.""" + + def __init__(self, infeatures, outfeatures, groupsize=-1): + """Create an empty Marlin layer. + @infeatures: number of input features (must be divisible by 128) + @outfeatures: number of output features (must be divisible by 256) + @groupsize: quantization groupsize (must be -1 or 128) + """ + super().__init__() + if groupsize not in [-1, 128]: + raise ValueError("Only groupsize -1 and 128 are supported.") + if infeatures % 128 != 0 or outfeatures % 256 != 0: + raise ValueError( + "`infeatures` must be divisible by 128 and `outfeatures` by 256." + ) + if groupsize == -1: + groupsize = infeatures + if infeatures % groupsize != 0: + raise ValueError("`infeatures` must be divisible by `groupsize`.") + self.k = infeatures + self.n = outfeatures + self.groupsize = groupsize + self.register_buffer( + "B", torch.empty((self.k // 16, self.n * 16 // 8), dtype=torch.int) + ) + self.register_buffer( + "s", torch.empty((self.k // groupsize, self.n), dtype=torch.half) + ) + + def forward(self, A): + C = torch.empty( + A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device + ) + marlin_matmul( + A.view((-1, A.shape[-1])), + self.B, + C.view((-1, C.shape[-1])), + self.s, + ) + return C + + def pack(self, linear, scales): + """Pack a fake-quantized linear layer into this actual Marlin representation. + @linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`) + @scales: corresponding quantization scales of shape `(infeatures, groups)` + """ + if linear.weight.dtype != torch.half: + raise ValueError("Only `torch.half` weights are supported.") + tile = 16 + maxq = 2**4 - 1 + s = scales.t() + w = linear.weight.data.t() + if self.groupsize != self.k: + w = w.reshape((-1, self.groupsize, self.n)) + w = w.permute(1, 0, 2) + w = w.reshape((self.groupsize, -1)) + s = s.reshape((1, -1)) + w = torch.round(w / s).int() + w += (maxq + 1) // 2 + w = torch.clamp(w, 0, maxq) + if self.groupsize != self.k: + w = w.reshape((self.groupsize, -1, self.n)) + w = w.permute(1, 0, 2) + w = w.reshape((self.k, self.n)).contiguous() + s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] + else: + s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] + s = s.reshape((-1, self.n)).contiguous() + w = w.reshape((self.k // tile, tile, self.n // tile, tile)) + w = w.permute((0, 2, 1, 3)) + w = w.reshape((self.k // tile, self.n * tile)) + res = w + res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape) + q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32) + res = res.cpu().numpy().astype(np.uint32) + for i in range(8): + q |= res[:, i::8] << 4 * i + q = torch.from_numpy(q.astype(np.int32)).to(w.device) + self.B[:, :] = q.to(self.B.device) + self.s[:, :] = s.to(self.s.device) + + +def gen_quant4(m, n, groupsize=-1): + DEV = torch.device("cuda:0") + tile = 16 + maxq = 2**4 - 1 + w = torch.randn((m, n), dtype=torch.half, device=DEV) + if groupsize != -1: + w = w.reshape((-1, groupsize, n)) + w = w.permute(1, 0, 2) + w = w.reshape((groupsize, -1)) + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / maxq + w = torch.round(w / s).int() + w += (maxq + 1) // 2 + w = torch.clamp(w, 0, maxq) + ref = (w - (maxq + 1) // 2).half() * s + if groupsize != -1: + + def reshape(w): + w = w.reshape((groupsize, -1, n)) + w = w.permute(1, 0, 2) + w = w.reshape((m, n)).contiguous() + return w + + ref = reshape(ref) + w = reshape(w) + s = s.reshape((-1, n)).contiguous() + linear = nn.Linear(m, n) + linear.weight.data = ref.t() + # Workaround to test some special cases that are forbidden by the API + layer = MarlinLayer(256, 256, groupsize=groupsize) + if groupsize == -1: + groupsize = m + layer.k = m + layer.n = n + layer.groupsize = groupsize + layer.B = torch.empty((m // 16, n * 16 // 8), dtype=torch.int, device=DEV) + layer.s = torch.empty((m // groupsize, n), dtype=torch.half, device=DEV) + layer.pack(linear, s.t()) + q = layer.B.reshape(m // 8, n) + s = layer.s + return ref, q, s + + # PyTorch implementation for matrix multiplication -def quantize_gptq(a, b): # 昇腾芯片的CPU不支持转置计算 - ans = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(b.dtype) +def quantize_gptq(a, b, is_weight_transposed): # 昇腾芯片的CPU不支持转置计算 + if is_weight_transposed: + ans = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(b.dtype) + else: + ans = torch.matmul(b.to(torch.float32), a.to(torch.float32)).to(b.dtype) return ans @@ -379,7 +541,7 @@ def test( # Initialize tensors a = 1e0 * torch.randn([K, M], dtype=dtype).to(torch_device) layer = nn.Linear(K, N) - b = 1e-3 * layer.weight.data.to(dtype).to(torch_device) + b = 1e0 * layer.weight.data.to(dtype).to(torch_device) c = torch.zeros([N, M], dtype=dtype).to(torch_device) is_weight_transposed = False sign_ed = False @@ -393,10 +555,6 @@ def test( num_groups = 1 else: num_groups = K // group_size - if is_weight_transposed: - ans = quantize_gptq(a.t(), b.t()) - else: - ans = quantize_gptq(b, a) packed_weights = torch.zeros([N, K // 8], dtype=torch.int32).to(torch_device) s = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) z = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) @@ -409,11 +567,12 @@ def test( minq = -(2 ** (bits - 1)) if torch_device == "cuda": - b_ref, s, z = get_scale_zero( - b, a.t(), c, group_size, bits, sym, sign_ed=sign_ed - ) # 无符号量化 - - packed_weights = pack(b_ref, s, z, minq, maxq) + b, packed_weights, s = gen_quant4(K, N, groupsize=group_size) + a = 1e0 * torch.randn([M, K], dtype=dtype).to( + torch_device + ) # 不知道为什么,不能使用a = a.t(), c = c.t() + c = torch.zeros([M, N], dtype=dtype).to(torch_device) + z = torch.zeros_like(s).to(torch_device) # if torch_device == "cpu": # b_ref, s, z = get_scale_zero( @@ -421,24 +580,15 @@ def test( # ) # 无符号量化 # packed_weights = pack(b_ref, s, z, minq, maxq) - if is_weight_transposed: - a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = ( - to_tensor(a.t(), lib), - to_tensor(b.t(), lib), - to_tensor(c.t(), lib), - to_tensor(s.t(), lib), - to_tensor(z.t(), lib), - to_tensor(packed_weights.t(), lib), - ) - else: - a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = ( - to_tensor(a, lib), - to_tensor(b, lib), - to_tensor(c, lib), - to_tensor(s, lib), - to_tensor(z, lib), - to_tensor(packed_weights, lib), - ) + ans = quantize_gptq(a, b, is_weight_transposed) + a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = ( + to_tensor(a, lib), + to_tensor(b, lib), + to_tensor(c, lib), + to_tensor(s, lib), + to_tensor(z, lib), + to_tensor(packed_weights, lib), + ) descriptor = infiniopQuantizeGPTQDescriptor_t() check_error( @@ -522,10 +672,7 @@ def lib_quantize_gptq(): # Profiling workflow if PROFILE: # fmt: off - if(is_weight_transposed): - profile_operation("PyTorch", lambda: quantize_gptq(a.t(), b.t()), torch_device, NUM_PRERUN, NUM_ITERATIONS) - else: - profile_operation("PyTorch", lambda: quantize_gptq(b, a), torch_device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation("PyTorch", lambda: quantize_gptq(a, b, is_weight_transposed), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_quantize_gptq(), torch_device, NUM_PRERUN, NUM_ITERATIONS) # fmt: on check_error(lib.infiniopDestroyQuantizeGPTQDescriptor(descriptor)) From d829fc1a0af1fa2ed6492b133985c1be1ec5f42a Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Thu, 7 Aug 2025 15:56:31 +0800 Subject: [PATCH 6/8] issue/170: error --- include/infinicore.h | 2 +- .../quantize_gptq/cpu/quantize_gptq_cpu.cc | 8 +- .../{cuda => nvidia}/gptq_marlin.cu | 0 .../{cuda => nvidia}/gptq_marlin.cuh | 0 .../quantize_gptq_nvidia.cu} | 24 +- .../quantize_gptq_nvidia.cuh} | 2 +- src/infiniop/ops/quantize_gptq/operator.cc | 24 +- .../ops/quantize_gptq/quantize_gptq.h | 12 +- test/infiniop/libinfiniop/utils.py | 3 +- test/infiniop/quantize_gptq.py | 318 ++++++++---------- 10 files changed, 186 insertions(+), 207 deletions(-) rename src/infiniop/ops/quantize_gptq/{cuda => nvidia}/gptq_marlin.cu (100%) rename src/infiniop/ops/quantize_gptq/{cuda => nvidia}/gptq_marlin.cuh (100%) rename src/infiniop/ops/quantize_gptq/{cuda/quantize_gptq_cuda.cu => nvidia/quantize_gptq_nvidia.cu} (74%) rename src/infiniop/ops/quantize_gptq/{cuda/quantize_gptq_cuda.cuh => nvidia/quantize_gptq_nvidia.cuh} (87%) diff --git a/include/infinicore.h b/include/infinicore.h index 3260bd78c..e07e33e58 100644 --- a/include/infinicore.h +++ b/include/infinicore.h @@ -70,7 +70,7 @@ typedef enum { INFINI_DTYPE_C64 = 17, INFINI_DTYPE_C128 = 18, INFINI_DTYPE_BF16 = 19, - INFINI_DTYPE_I4 = 20, + } infiniDtype_t; #endif // __INFINICORE_API_H__ diff --git a/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc b/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc index 6f021df10..01fe27fa3 100644 --- a/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc +++ b/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc @@ -32,12 +32,12 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle_, Descriptor **desc_pt infiniopTensorDescriptor_t b_scale_desc, infiniopTensorDescriptor_t zero_desc) { auto handle = reinterpret_cast(handle_); - auto result = MatmulGptqInfo::createMatmulGptqInfo(c_desc, a_desc, packed_weights_desc, b_scale_desc, zero_desc); + auto result = QuantizeGptqInfo::createQuantizeGptqInfo(c_desc, a_desc, packed_weights_desc, b_scale_desc, zero_desc); CHECK_RESULT(result); - MatmulGptqInfo info = result.take(); + QuantizeGptqInfo info = result.take(); size_t min_workspace_size = (info.k * info.k + info.n * info.block_size) * sizeof(float) + (2 * info.n * info.k) * infiniSizeOf(info.atype); - + std::cout << "kernel workspace:" << min_workspace_size << std::endl; *desc_ptr = new Descriptor(info, nullptr, min_workspace_size, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } @@ -624,8 +624,10 @@ infiniStatus_t Descriptor::calculate( int group_size = int(_info.group_size); bool is_weight_transposed = _info.is_weight_transposed; if (_info.atype == INFINI_DTYPE_F16 && !is_weight_transposed) { + std::cout << "cpu start" << std::endl; caculate(workspace, (fp16_t *)c, (fp16_t *)a, (int32_t *)packed_weights, (fp16_t *)b_scale, (fp16_t *)zero, m, k, n, group_size); + std::cout << "cpu end" << std::endl; } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; diff --git a/src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cu b/src/infiniop/ops/quantize_gptq/nvidia/gptq_marlin.cu similarity index 100% rename from src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cu rename to src/infiniop/ops/quantize_gptq/nvidia/gptq_marlin.cu diff --git a/src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cuh b/src/infiniop/ops/quantize_gptq/nvidia/gptq_marlin.cuh similarity index 100% rename from src/infiniop/ops/quantize_gptq/cuda/gptq_marlin.cuh rename to src/infiniop/ops/quantize_gptq/nvidia/gptq_marlin.cuh diff --git a/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu b/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu similarity index 74% rename from src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu rename to src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu index 50320b455..4a5069506 100644 --- a/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cu +++ b/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu @@ -1,16 +1,16 @@ -#include "../../../devices/cuda/cuda_common.cuh" +#include "../../../devices/nvidia/nvidia_common.cuh" #include "gptq_marlin.cuh" -#include "quantize_gptq_cuda.cuh" +#include "quantize_gptq_nvidia.cuh" #include #ifdef NDEBUG #define SAFE_ASSERT(x) ((void)(x)) #else #define SAFE_ASSERT(x) assert(x) #endif -namespace op::quantize_gptq::cuda { +namespace op::quantize_gptq::nvidia { struct Descriptor::Opaque { - std::shared_ptr internal; + std::shared_ptr internal; }; Descriptor::~Descriptor() { @@ -18,19 +18,20 @@ Descriptor::~Descriptor() { } infiniStatus_t Descriptor::create( - infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopHandle_t handle_, Descriptor **desc_ptr, infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t packed_weights_desc, infiniopTensorDescriptor_t b_scale_desc, infiniopTensorDescriptor_t zero_desc) { - auto result = MatmulGptqInfo::createMatmulGptqInfo(c_desc, a_desc, packed_weights_desc, b_scale_desc, zero_desc); + auto handle = reinterpret_cast(handle_); + auto result = QuantizeGptqInfo::createQuantizeGptqInfo(c_desc, a_desc, packed_weights_desc, b_scale_desc, zero_desc); CHECK_RESULT(result); - MatmulGptqInfo info = result.take(); + QuantizeGptqInfo info = result.take(); int max_par = gptq_marlin::max_par; size_t min_workspace_size = info.n / gptq_marlin::min_thread_n * max_par * sizeof(int) + info.m * info.k * infiniSizeOf(info.atype); - - *desc_ptr = new Descriptor(info, new Opaque{reinterpret_cast(handle)->internal()}, min_workspace_size, handle->device, handle->device_id); + std::cout << "kernel workspace:" << min_workspace_size << std::endl; + *desc_ptr = new Descriptor(info, new Opaque{handle->internal()}, min_workspace_size, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } @@ -71,12 +72,15 @@ infiniStatus_t Descriptor::calculate( int group_size = int(_info.group_size); int num_groups = int(_info.num_groups); bool is_weight_transposed = _info.is_weight_transposed; + if (_info.atype == INFINI_DTYPE_F16 && is_weight_transposed) { + std::cout << "gpu start: " << "group_size: " << group_size << "num_groups: " << num_groups << std::endl; gptq_marlin::gptq_marlin_mm_fp16(c, a, packed_weights, b_scale, m, n, k, workspace, bits, num_groups, group_size, this->device_id, (cudaStream_t)stream); + std::cout << "gpu end" << std::endl; } else if (_info.atype == INFINI_DTYPE_BF16 && is_weight_transposed) { gptq_marlin::gptq_marlin_mm_bf16(c, a, packed_weights, b_scale, @@ -91,4 +95,4 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_SUCCESS; } -} // namespace op::quantize_gptq::cuda +} // namespace op::quantize_gptq::nvidia diff --git a/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cuh b/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cuh similarity index 87% rename from src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cuh rename to src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cuh index 4de0fc109..acf138d2f 100644 --- a/src/infiniop/ops/quantize_gptq/cuda/quantize_gptq_cuda.cuh +++ b/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cuh @@ -3,6 +3,6 @@ #include "../quantize_gptq.h" -DESCRIPTOR(cuda) +DESCRIPTOR(nvidia) #endif // __QUANTIZE_GPTQ_CUDA_H__ diff --git a/src/infiniop/ops/quantize_gptq/operator.cc b/src/infiniop/ops/quantize_gptq/operator.cc index 6393b46da..a978498c9 100644 --- a/src/infiniop/ops/quantize_gptq/operator.cc +++ b/src/infiniop/ops/quantize_gptq/operator.cc @@ -5,8 +5,8 @@ #ifdef ENABLE_CPU_API #include "cpu/quantize_gptq_cpu.h" #endif -#ifdef ENABLE_CUDA_API -#include "cuda/quantize_gptq_cuda.cuh" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/quantize_gptq_nvidia.cuh" #endif __C infiniStatus_t infiniopCreateQuantizeGPTQDescriptor(infiniopHandle_t handle, @@ -30,8 +30,8 @@ __C infiniStatus_t infiniopCreateQuantizeGPTQDescriptor(infiniopHandle_t handle, #ifdef ENABLE_CPU_API CREATE(INFINI_DEVICE_CPU, cpu) #endif -#ifdef ENABLE_CUDA_API - CREATE(INFINI_DEVICE_NVIDIA, cuda) +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -47,8 +47,8 @@ __C infiniStatus_t infiniopGetQuantizeGPTQWorkspaceSize(infiniopQuantizeGPTQDesc #ifdef ENABLE_CPU_API GET(INFINI_DEVICE_CPU, cpu) #endif -#ifdef ENABLE_CUDA_API - GET(INFINI_DEVICE_NVIDIA, cuda) +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -73,8 +73,8 @@ __C infiniStatus_t infiniopQuantizeGPTQ(infiniopQuantizeGPTQDescriptor_t desc, #ifdef ENABLE_CPU_API QUANT(INFINI_DEVICE_CPU, cpu) #endif -#ifdef ENABLE_CUDA_API - QUANT(INFINI_DEVICE_NVIDIA, cuda) +#ifdef ENABLE_NVIDIA_API + QUANT(INFINI_DEVICE_NVIDIA, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -99,8 +99,8 @@ __C infiniStatus_t infiniopQuantizeLinearGPTQ(infiniopQuantizeGPTQDescriptor_t d #ifdef ENABLE_CPU_API CACULATE(INFINI_DEVICE_CPU, cpu) #endif -#ifdef ENABLE_CUDA_API - CACULATE(INFINI_DEVICE_NVIDIA, cuda) +#ifdef ENABLE_NVIDIA_API + CACULATE(INFINI_DEVICE_NVIDIA, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -117,8 +117,8 @@ __C infiniStatus_t infiniopDestroyQuantizeGPTQDescriptor(infiniopQuantizeGPTQDes #ifdef ENABLE_CPU_API DESTROY(INFINI_DEVICE_CPU, cpu) #endif -#ifdef ENABLE_CUDA_API - DESTROY(INFINI_DEVICE_NVIDIA, cuda) +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/quantize_gptq/quantize_gptq.h b/src/infiniop/ops/quantize_gptq/quantize_gptq.h index 9a226629a..a16688a3d 100644 --- a/src/infiniop/ops/quantize_gptq/quantize_gptq.h +++ b/src/infiniop/ops/quantize_gptq/quantize_gptq.h @@ -11,10 +11,10 @@ class Descriptor final : public InfiniopDescriptor { \ struct Opaque; \ Opaque *_opaque; \ - MatmulGptqInfo _info; \ + QuantizeGptqInfo _info; \ size_t _workspace_size; \ \ - Descriptor(MatmulGptqInfo info, Opaque *opaque, \ + Descriptor(QuantizeGptqInfo info, Opaque *opaque, \ size_t workspace_size, \ infiniDevice_t device_type, int device_id) \ : InfiniopDescriptor{device_type, device_id}, \ @@ -46,9 +46,9 @@ }; \ } -class MatmulGptqInfo { +class QuantizeGptqInfo { private: - MatmulGptqInfo() = default; + QuantizeGptqInfo() = default; public: infiniDtype_t atype, packed_weights_type; @@ -56,7 +56,7 @@ class MatmulGptqInfo { ptrdiff_t group_size; bool is_weight_transposed; - static utils::Result createMatmulGptqInfo( + static utils::Result createQuantizeGptqInfo( infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t packed_weights_desc, @@ -139,7 +139,7 @@ class MatmulGptqInfo { INFINI_STATUS_BAD_TENSOR_SHAPE); } - return utils::Result(MatmulGptqInfo{ + return utils::Result(QuantizeGptqInfo{ atype, packed_weights_type, m, diff --git a/test/infiniop/libinfiniop/utils.py b/test/infiniop/libinfiniop/utils.py index 5c8e7f80a..a2b3a6228 100644 --- a/test/infiniop/libinfiniop/utils.py +++ b/test/infiniop/libinfiniop/utils.py @@ -81,7 +81,8 @@ def __init__( elif mode == "manual": assert set_tensor is not None assert torch_shape == list(set_tensor.shape) - assert torch_strides == list(set_tensor.stride()) + if torch_strides is not None: + assert torch_strides == list(set_tensor.stride()) self._torch_tensor = set_tensor.to(to_torch_dtype(dt)).to( torch_device_map[device] ) diff --git a/test/infiniop/quantize_gptq.py b/test/infiniop/quantize_gptq.py index 16c47bd52..6c25cf08d 100644 --- a/test/infiniop/quantize_gptq.py +++ b/test/infiniop/quantize_gptq.py @@ -1,23 +1,25 @@ import torch import torch.nn as nn -import numpy as np import math +import numpy as np import ctypes -from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float +from ctypes import c_uint64 from libinfiniop import ( - infiniopHandle_t, - infiniopTensorDescriptor_t, - open_lib, - to_tensor, + LIBINFINIOP, + TestTensor, get_test_devices, check_error, - rearrange_if_needed, - create_workspace, test_operator, get_args, debug, get_tolerance, profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + InfiniDeviceEnum, + infiniopOperatorDescriptor_t, ) # ============================================================================== @@ -25,27 +27,30 @@ # ============================================================================== # These are not meant to be imported from other modules -_TEST_CASES = [] +_TEST_CASES = [(1, 128, 128)] +# _TEST_CASES = [] -MODELS = { - "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], - # "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], - # "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], - # "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], -} +# MODELS = { +# "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], +# # "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], +# # "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], +# # "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], +# } -# Loop through models and layers to generate the new _TEST_CASES -for _, layers in MODELS.items(): - for layer in layers: - for batch in [1, 16]: - _TEST_CASES.append(((batch, layer[0], layer[1]))) +# # Loop through models and layers to generate the new _TEST_CASES +# for _, layers in MODELS.items(): +# for layer in layers: +# for batch in [1, 16]: +# _TEST_CASES.append(((batch, layer[0], layer[1]))) # Data types used for testing -_TENSOR_DTYPES = [torch.float16] - +#_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] +_TENSOR_DTYPES = [InfiniDtype.F16] # Tolerance map for different data types _TOLERANCE_MAP = { - torch.float16: {"atol": 1e-2, "rtol": 1e-2}, + InfiniDtype.F16: {"atol": 0, "rtol": 1e-2}, + # InfiniDtype.F32: {"atol": 0, "rtol": 1e-3}, + # InfiniDtype.BF16: {"atol": 0, "rtol": 5e-2}, } DEBUG = False @@ -53,17 +58,6 @@ NUM_PRERUN = 10 NUM_ITERATIONS = 1000 - -# ============================================================================== -# Definitions -# ============================================================================== -class QuantizeGPTQDescriptor(Structure): - _fields_ = [("device", c_int32)] - - -infiniopQuantizeGPTQDescriptor_t = POINTER(QuantizeGPTQDescriptor) - - def quantize(x, scale, zero, minq, maxq): if scale.shape[1] == 1: q = torch.clamp(torch.round(x / scale) + zero, minq, maxq) @@ -514,39 +508,37 @@ def reshape(w): # PyTorch implementation for matrix multiplication -def quantize_gptq(a, b, is_weight_transposed): # 昇腾芯片的CPU不支持转置计算 +def quantize_gptq(ans, a, b, is_weight_transposed): # 昇腾芯片的CPU不支持转置计算 if is_weight_transposed: ans = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(b.dtype) else: ans = torch.matmul(b.to(torch.float32), a.to(torch.float32)).to(b.dtype) - return ans - - -# The argument list should be (lib, handle, torch_device, , dtype) + +# The argument list should be (lib, handle, device, , dtype) # The should keep the same order as the one specified in _TEST_CASES def test( - lib, handle, - torch_device, + device, M, K, N, - dtype=torch.float16, + dtype=InfiniDtype.F16, sync=None, ): print( - f"Testing QuantizeGPTQ on {torch_device}" f" M:{M}, K:{K}, N:{N}, dtype:{dtype}" + f"Testing QuantizeGPTQ on {InfiniDeviceNames[device]}" f" M:{M}, K:{K}, N:{N}, dtype:{InfiniDtypeNames[dtype]}" ) - torch.manual_seed(12) + # Initialize tensors - a = 1e0 * torch.randn([K, M], dtype=dtype).to(torch_device) - layer = nn.Linear(K, N) - b = 1e0 * layer.weight.data.to(dtype).to(torch_device) - c = torch.zeros([N, M], dtype=dtype).to(torch_device) + a = TestTensor((K, M), None, dtype, device) + b = TestTensor((N, K), None, dtype, device) + c = TestTensor((N, M), None, dtype, device, mode="zeros") + ans = TestTensor((N, M), None, dtype, device, mode="zeros") + is_weight_transposed = False sign_ed = False sym = False - if torch_device != "cpu": + if device != InfiniDeviceEnum.CPU: is_weight_transposed = True group_size = -1 @@ -555,9 +547,11 @@ def test( num_groups = 1 else: num_groups = K // group_size - packed_weights = torch.zeros([N, K // 8], dtype=torch.int32).to(torch_device) - s = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) - z = torch.zeros([N, num_groups], dtype=dtype).to(torch_device) + + packed_weights = TestTensor((N, K // 8), None, InfiniDtype.I32, device, mode="zeros") + print(packed_weights.torch_tensor().dtype) + s = TestTensor((N, num_groups), None, dtype, device, mode="zeros") + z = TestTensor((N, num_groups), None, dtype, device, mode="zeros") bits = 4 maxq = 2**bits - 1 @@ -565,117 +559,144 @@ def test( if sign_ed: # 有符号量化,范围是[-8,7] maxq = 2 ** (bits - 1) - 1 minq = -(2 ** (bits - 1)) + + if device == InfiniDeviceEnum.NVIDIA: + b_data, packed_weights_data, s_data = gen_quant4(K, N, groupsize=group_size) + a = TestTensor((M, K), None, dtype, device) + b = TestTensor((K, N), None, dtype, device, mode="manual", set_tensor=b_data) + c = TestTensor((M, N), None, dtype, device, mode="zeros") + ans = TestTensor((M, N), None, dtype, device, mode="zeros") + packed_weights = TestTensor((K // 8, N), None, InfiniDtype.I32, device, mode="manual", set_tensor=packed_weights_data) + s = TestTensor((num_groups, N), None, dtype, device, mode="manual", set_tensor=s_data) + z = TestTensor((num_groups, N), None, dtype, device, mode="zeros") + + if device == InfiniDeviceEnum.CPU: + b_ref_data, s_data, z_data = get_scale_zero( + b.torch_tensor(), a.torch_tensor().t(), c.torch_tensor(), group_size, bits, sym, sign_ed=sign_ed + ) # 无符号量化 + + packed_weights_data = pack(b_ref_data, s_data, z_data, minq, maxq) + packed_weights = TestTensor((N, K // 8), None, InfiniDtype.I32, device, mode="manual", set_tensor=packed_weights_data) + s = TestTensor((N, num_groups), None, dtype, device, mode="manual", set_tensor=s_data) + z = TestTensor((N, num_groups), None, dtype, device, mode="manual", set_tensor=z_data) + + def torch_quantize_gptq(): + quantize_gptq( + ans.torch_tensor(), + a.torch_tensor(), + b.torch_tensor(), + is_weight_transposed, + ) - if torch_device == "cuda": - b, packed_weights, s = gen_quant4(K, N, groupsize=group_size) - a = 1e0 * torch.randn([M, K], dtype=dtype).to( - torch_device - ) # 不知道为什么,不能使用a = a.t(), c = c.t() - c = torch.zeros([M, N], dtype=dtype).to(torch_device) - z = torch.zeros_like(s).to(torch_device) - - # if torch_device == "cpu": - # b_ref, s, z = get_scale_zero( - # b, a.t(), c, group_size, bits, sym, sign_ed=sign_ed - # ) # 无符号量化 - - # packed_weights = pack(b_ref, s, z, minq, maxq) - ans = quantize_gptq(a, b, is_weight_transposed) - a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = ( - to_tensor(a, lib), - to_tensor(b, lib), - to_tensor(c, lib), - to_tensor(s, lib), - to_tensor(z, lib), - to_tensor(packed_weights, lib), - ) + torch_quantize_gptq() - descriptor = infiniopQuantizeGPTQDescriptor_t() + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() check_error( - lib.infiniopCreateQuantizeGPTQDescriptor( + LIBINFINIOP.infiniopCreateQuantizeGPTQDescriptor( handle, ctypes.byref(descriptor), - c_tensor.descriptor, - a_tensor.descriptor, - packed_weights_tensor.descriptor, - s_tensor.descriptor, - z_tensor.descriptor, + c.descriptor, + a.descriptor, + packed_weights.descriptor, + s.descriptor, + z.descriptor, ) ) # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel for tensor in [ - a_tensor, - b_tensor, - c_tensor, - s_tensor, - z_tensor, - packed_weights_tensor, + a, + b, + c, + s, + z, + packed_weights, ]: - tensor.destroyDesc(lib) + tensor.destroy_desc() # Get workspace size and create workspace workspace_size = c_uint64(0) check_error( - lib.infiniopGetQuantizeGPTQWorkspaceSize( + LIBINFINIOP.infiniopGetQuantizeGPTQWorkspaceSize( descriptor, ctypes.byref(workspace_size) ) ) - workspace = create_workspace(workspace_size.value, a.device) - + workspace = TestWorkspace(workspace_size.value, device) + print("work python", workspace_size.value) # Execute infiniop quantize_gptq operator - check_error( - lib.infiniopQuantizeGPTQ( - descriptor, - workspace.data_ptr() if workspace is not None else None, - workspace_size.value, - packed_weights_tensor.data, - s_tensor.data, - z_tensor.data, - a_tensor.data, - b_tensor.data, - None, - ) - ) - + # check_error( + # LIBINFINIOP.infiniopQuantizeGPTQ( + # descriptor, + # workspace.data(), + # workspace_size.value, + # packed_weights.data(), + # s.data(), + # z.data(), + # a.data(), + # b.data(), + # None, + # ) + # ) + def check_tensor_valid(name, t): + tt = t.torch_tensor() + ptr = tt.data_ptr() + print(f"[{name}] shape: {tt.shape}, dtype: {tt.dtype}, device: {tt.device}, data_ptr: {hex(ptr)}") + + if ptr % 16 != 0: + print(f"⚠️ Warning: {name} is NOT aligned to 16 bytes, kernel may crash!") + + check_tensor_valid("a", a) + check_tensor_valid("c", c) + check_tensor_valid("packed_weights", packed_weights) + check_tensor_valid("s", s) + check_tensor_valid("z", z) + + print(ans.torch_tensor()) + # ad = TestTensor((M, N), None, dtype, device, mode="zeros") + # print(ad.torch_tensor()) def lib_quantize_gptq(): check_error( - lib.infiniopQuantizeLinearGPTQ( + LIBINFINIOP.infiniopQuantizeLinearGPTQ( descriptor, - workspace.data_ptr() if workspace is not None else None, + workspace.data(), workspace_size.value, - c_tensor.data, - a_tensor.data, - packed_weights_tensor.data, - s_tensor.data, - z_tensor.data, + c.data(), + a.data(), + packed_weights.data(), + s.data(), + z.data(), None, ) ) lib_quantize_gptq() - + ad = TestTensor((M, N), None, dtype, device, mode="zeros") + print(ad.torch_tensor()) + print(ans.torch_tensor()) + print(c.torch_tensor()) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) - # tmpa = ans.flatten() - # tmpc = c.flatten() + # tmpa = ans.torch_tensor().flatten() + # tmpc = c.actual_tensor().flatten() # for i in range(tmpa.shape[0]): # if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]): # print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i])) # break - if is_weight_transposed: - c = c.t() + if DEBUG: - debug(c, ans, atol=atol, rtol=rtol) - assert torch.allclose(c, ans, atol=atol, rtol=rtol) + debug(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol) # Profiling workflow if PROFILE: # fmt: off - profile_operation("PyTorch", lambda: quantize_gptq(a, b, is_weight_transposed), torch_device, NUM_PRERUN, NUM_ITERATIONS) - profile_operation(" lib", lambda: lib_quantize_gptq(), torch_device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation("PyTorch", lambda: torch_quantize_gptq(), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_quantize_gptq(), device, NUM_PRERUN, NUM_ITERATIONS) # fmt: on - check_error(lib.infiniopDestroyQuantizeGPTQDescriptor(descriptor)) + check_error(LIBINFINIOP.infiniopDestroyQuantizeGPTQDescriptor(descriptor)) # ============================================================================== @@ -683,55 +704,6 @@ def lib_quantize_gptq(): # ============================================================================== if __name__ == "__main__": args = get_args() - lib = open_lib() - - lib.infiniopCreateQuantizeGPTQDescriptor.restype = c_int32 - lib.infiniopCreateQuantizeGPTQDescriptor.argtypes = [ - infiniopHandle_t, - POINTER(infiniopQuantizeGPTQDescriptor_t), - infiniopTensorDescriptor_t, - infiniopTensorDescriptor_t, - infiniopTensorDescriptor_t, - infiniopTensorDescriptor_t, - infiniopTensorDescriptor_t, - ] - - lib.infiniopGetQuantizeGPTQWorkspaceSize.restype = c_int32 - lib.infiniopGetQuantizeGPTQWorkspaceSize.argtypes = [ - infiniopQuantizeGPTQDescriptor_t, - POINTER(c_size_t), - ] - - lib.infiniopQuantizeGPTQ.restype = c_int32 - lib.infiniopQuantizeGPTQ.argtypes = [ - infiniopQuantizeGPTQDescriptor_t, - c_void_p, - c_uint64, - c_void_p, - c_void_p, - c_void_p, - c_void_p, - c_void_p, - c_void_p, - ] - - lib.infiniopQuantizeLinearGPTQ.restype = c_int32 - lib.infiniopQuantizeLinearGPTQ.argtypes = [ - infiniopQuantizeGPTQDescriptor_t, - c_void_p, - c_uint64, - c_void_p, - c_void_p, - c_void_p, - c_void_p, - c_void_p, - c_void_p, - ] - - lib.infiniopDestroyQuantizeGPTQDescriptor.restype = c_int32 - lib.infiniopDestroyQuantizeGPTQDescriptor.argtypes = [ - infiniopQuantizeGPTQDescriptor_t, - ] # Configure testing options DEBUG = args.debug @@ -741,6 +713,6 @@ def lib_quantize_gptq(): # Execute tests for device in get_test_devices(args): - test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) print("\033[92mTest passed!\033[0m") From 1ed1a259d2856b9091fff4b8a252108f27e9c53a Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Fri, 8 Aug 2025 12:32:55 +0800 Subject: [PATCH 7/8] issue/170: success register --- .../quantize_gptq/cpu/quantize_gptq_cpu.cc | 4 - .../nvidia/quantize_gptq_nvidia.cu | 10 +- src/infiniop/ops/quantize_gptq/operator.cc | 5 + test/infiniop/libinfiniop/op_register.py | 52 +++++++ test/infiniop/quantize_gptq.py | 140 +++++++++--------- 5 files changed, 131 insertions(+), 80 deletions(-) diff --git a/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc b/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc index 01fe27fa3..690418350 100644 --- a/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc +++ b/src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc @@ -37,7 +37,6 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle_, Descriptor **desc_pt QuantizeGptqInfo info = result.take(); size_t min_workspace_size = (info.k * info.k + info.n * info.block_size) * sizeof(float) + (2 * info.n * info.k) * infiniSizeOf(info.atype); - std::cout << "kernel workspace:" << min_workspace_size << std::endl; *desc_ptr = new Descriptor(info, nullptr, min_workspace_size, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } @@ -568,7 +567,6 @@ void quantWeights(void *workspace, int32_t *packed_weights, void caculate(void *workspace, fp16_t *C, const fp16_t *A, int32_t *packed_weights, fp16_t *b_scale, fp16_t *zero, int M, int K, int N, int group_size) { - MatmulPackedWeight(C, A, packed_weights, b_scale, zero, M, K, N, group_size); } @@ -624,10 +622,8 @@ infiniStatus_t Descriptor::calculate( int group_size = int(_info.group_size); bool is_weight_transposed = _info.is_weight_transposed; if (_info.atype == INFINI_DTYPE_F16 && !is_weight_transposed) { - std::cout << "cpu start" << std::endl; caculate(workspace, (fp16_t *)c, (fp16_t *)a, (int32_t *)packed_weights, (fp16_t *)b_scale, (fp16_t *)zero, m, k, n, group_size); - std::cout << "cpu end" << std::endl; } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; diff --git a/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu b/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu index 4a5069506..93a5fd4cd 100644 --- a/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu +++ b/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu @@ -1,12 +1,7 @@ #include "../../../devices/nvidia/nvidia_common.cuh" #include "gptq_marlin.cuh" #include "quantize_gptq_nvidia.cuh" -#include -#ifdef NDEBUG -#define SAFE_ASSERT(x) ((void)(x)) -#else -#define SAFE_ASSERT(x) assert(x) -#endif + namespace op::quantize_gptq::nvidia { struct Descriptor::Opaque { @@ -30,7 +25,6 @@ infiniStatus_t Descriptor::create( QuantizeGptqInfo info = result.take(); int max_par = gptq_marlin::max_par; size_t min_workspace_size = info.n / gptq_marlin::min_thread_n * max_par * sizeof(int) + info.m * info.k * infiniSizeOf(info.atype); - std::cout << "kernel workspace:" << min_workspace_size << std::endl; *desc_ptr = new Descriptor(info, new Opaque{handle->internal()}, min_workspace_size, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } @@ -74,13 +68,11 @@ infiniStatus_t Descriptor::calculate( bool is_weight_transposed = _info.is_weight_transposed; if (_info.atype == INFINI_DTYPE_F16 && is_weight_transposed) { - std::cout << "gpu start: " << "group_size: " << group_size << "num_groups: " << num_groups << std::endl; gptq_marlin::gptq_marlin_mm_fp16(c, a, packed_weights, b_scale, m, n, k, workspace, bits, num_groups, group_size, this->device_id, (cudaStream_t)stream); - std::cout << "gpu end" << std::endl; } else if (_info.atype == INFINI_DTYPE_BF16 && is_weight_transposed) { gptq_marlin::gptq_marlin_mm_bf16(c, a, packed_weights, b_scale, diff --git a/src/infiniop/ops/quantize_gptq/operator.cc b/src/infiniop/ops/quantize_gptq/operator.cc index a978498c9..a6137d879 100644 --- a/src/infiniop/ops/quantize_gptq/operator.cc +++ b/src/infiniop/ops/quantize_gptq/operator.cc @@ -36,6 +36,7 @@ __C infiniStatus_t infiniopCreateQuantizeGPTQDescriptor(infiniopHandle_t handle, default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } +#undef CREATE } __C infiniStatus_t infiniopGetQuantizeGPTQWorkspaceSize(infiniopQuantizeGPTQDescriptor_t desc, size_t *size) { @@ -53,6 +54,7 @@ __C infiniStatus_t infiniopGetQuantizeGPTQWorkspaceSize(infiniopQuantizeGPTQDesc default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } +#undef GET } __C infiniStatus_t infiniopQuantizeGPTQ(infiniopQuantizeGPTQDescriptor_t desc, @@ -79,6 +81,7 @@ __C infiniStatus_t infiniopQuantizeGPTQ(infiniopQuantizeGPTQDescriptor_t desc, default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } +#undef QUANT } __C infiniStatus_t infiniopQuantizeLinearGPTQ(infiniopQuantizeGPTQDescriptor_t desc, @@ -105,6 +108,7 @@ __C infiniStatus_t infiniopQuantizeLinearGPTQ(infiniopQuantizeGPTQDescriptor_t d default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } +#undef CACULATE } __C infiniStatus_t infiniopDestroyQuantizeGPTQDescriptor(infiniopQuantizeGPTQDescriptor_t desc) { @@ -123,4 +127,5 @@ __C infiniStatus_t infiniopDestroyQuantizeGPTQDescriptor(infiniopQuantizeGPTQDes default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } +#undef DESTROY } diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index e92e77105..e232eafcf 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -237,6 +237,58 @@ def mul_(lib): ] + +@OpRegister.operator +def quantize_gptq_(lib): + lib.infiniopCreateQuantizeGPTQDescriptor.restype = c_int32 + lib.infiniopCreateQuantizeGPTQDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetQuantizeGPTQWorkspaceSize.restype = c_int32 + lib.infiniopGetQuantizeGPTQWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopQuantizeLinearGPTQ.restype = c_int32 + lib.infiniopQuantizeLinearGPTQ.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopQuantizeGPTQ.restype = c_int32 + lib.infiniopQuantizeGPTQ.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyQuantizeGPTQDescriptor.restype = c_int32 + lib.infiniopDestroyQuantizeGPTQDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def random_sample_(lib): lib.infiniopCreateRandomSampleDescriptor.restype = c_int32 diff --git a/test/infiniop/quantize_gptq.py b/test/infiniop/quantize_gptq.py index 6c25cf08d..70b39ec8f 100644 --- a/test/infiniop/quantize_gptq.py +++ b/test/infiniop/quantize_gptq.py @@ -27,24 +27,23 @@ # ============================================================================== # These are not meant to be imported from other modules -_TEST_CASES = [(1, 128, 128)] -# _TEST_CASES = [] - -# MODELS = { -# "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], -# # "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], -# # "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], -# # "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], -# } - -# # Loop through models and layers to generate the new _TEST_CASES -# for _, layers in MODELS.items(): -# for layer in layers: -# for batch in [1, 16]: -# _TEST_CASES.append(((batch, layer[0], layer[1]))) +_TEST_CASES = [] + +MODELS = { + "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], + # "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], + # "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], + # "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], +} + +# Loop through models and layers to generate the new _TEST_CASES +for _, layers in MODELS.items(): + for layer in layers: + for batch in [1, 16]: + _TEST_CASES.append(((batch, layer[0], layer[1]))) # Data types used for testing -#_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] +# _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] _TENSOR_DTYPES = [InfiniDtype.F16] # Tolerance map for different data types _TOLERANCE_MAP = { @@ -58,6 +57,7 @@ NUM_PRERUN = 10 NUM_ITERATIONS = 1000 + def quantize(x, scale, zero, minq, maxq): if scale.shape[1] == 1: q = torch.clamp(torch.round(x / scale) + zero, minq, maxq) @@ -508,12 +508,14 @@ def reshape(w): # PyTorch implementation for matrix multiplication -def quantize_gptq(ans, a, b, is_weight_transposed): # 昇腾芯片的CPU不支持转置计算 +def quantize_gptq(a, b, is_weight_transposed): # 昇腾芯片的CPU不支持转置计算 if is_weight_transposed: ans = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(b.dtype) else: ans = torch.matmul(b.to(torch.float32), a.to(torch.float32)).to(b.dtype) - + return ans + + # The argument list should be (lib, handle, device, , dtype) # The should keep the same order as the one specified in _TEST_CASES def test( @@ -526,15 +528,15 @@ def test( sync=None, ): print( - f"Testing QuantizeGPTQ on {InfiniDeviceNames[device]}" f" M:{M}, K:{K}, N:{N}, dtype:{InfiniDtypeNames[dtype]}" + f"Testing QuantizeGPTQ on {InfiniDeviceNames[device]}" + f" M:{M}, K:{K}, N:{N}, dtype:{InfiniDtypeNames[dtype]}" ) # Initialize tensors a = TestTensor((K, M), None, dtype, device) b = TestTensor((N, K), None, dtype, device) c = TestTensor((N, M), None, dtype, device, mode="zeros") - ans = TestTensor((N, M), None, dtype, device, mode="zeros") - + is_weight_transposed = False sign_ed = False sym = False @@ -548,8 +550,10 @@ def test( else: num_groups = K // group_size - packed_weights = TestTensor((N, K // 8), None, InfiniDtype.I32, device, mode="zeros") - print(packed_weights.torch_tensor().dtype) + packed_weights = TestTensor( + (N, K // 8), None, InfiniDtype.I32, device, mode="zeros" + ) + s = TestTensor((N, num_groups), None, dtype, device, mode="zeros") z = TestTensor((N, num_groups), None, dtype, device, mode="zeros") @@ -559,40 +563,62 @@ def test( if sign_ed: # 有符号量化,范围是[-8,7] maxq = 2 ** (bits - 1) - 1 minq = -(2 ** (bits - 1)) - + if device == InfiniDeviceEnum.NVIDIA: b_data, packed_weights_data, s_data = gen_quant4(K, N, groupsize=group_size) a = TestTensor((M, K), None, dtype, device) b = TestTensor((K, N), None, dtype, device, mode="manual", set_tensor=b_data) c = TestTensor((M, N), None, dtype, device, mode="zeros") - ans = TestTensor((M, N), None, dtype, device, mode="zeros") - packed_weights = TestTensor((K // 8, N), None, InfiniDtype.I32, device, mode="manual", set_tensor=packed_weights_data) - s = TestTensor((num_groups, N), None, dtype, device, mode="manual", set_tensor=s_data) + + packed_weights = TestTensor( + (K // 8, N), + None, + InfiniDtype.I32, + device, + mode="manual", + set_tensor=packed_weights_data, + ) + s = TestTensor( + (num_groups, N), None, dtype, device, mode="manual", set_tensor=s_data + ) z = TestTensor((num_groups, N), None, dtype, device, mode="zeros") - + if device == InfiniDeviceEnum.CPU: b_ref_data, s_data, z_data = get_scale_zero( - b.torch_tensor(), a.torch_tensor().t(), c.torch_tensor(), group_size, bits, sym, sign_ed=sign_ed + b.torch_tensor(), + a.torch_tensor().t(), + c.torch_tensor(), + group_size, + bits, + sym, + sign_ed=sign_ed, ) # 无符号量化 packed_weights_data = pack(b_ref_data, s_data, z_data, minq, maxq) - packed_weights = TestTensor((N, K // 8), None, InfiniDtype.I32, device, mode="manual", set_tensor=packed_weights_data) - s = TestTensor((N, num_groups), None, dtype, device, mode="manual", set_tensor=s_data) - z = TestTensor((N, num_groups), None, dtype, device, mode="manual", set_tensor=z_data) - - def torch_quantize_gptq(): - quantize_gptq( - ans.torch_tensor(), - a.torch_tensor(), - b.torch_tensor(), - is_weight_transposed, + packed_weights = TestTensor( + (N, K // 8), + None, + InfiniDtype.I32, + device, + mode="manual", + set_tensor=packed_weights_data, + ) + s = TestTensor( + (N, num_groups), None, dtype, device, mode="manual", set_tensor=s_data + ) + z = TestTensor( + (N, num_groups), None, dtype, device, mode="manual", set_tensor=z_data ) - torch_quantize_gptq() + ans = quantize_gptq( + a.torch_tensor(), + b.torch_tensor(), + is_weight_transposed, + ) if sync is not None: sync() - + descriptor = infiniopOperatorDescriptor_t() check_error( LIBINFINIOP.infiniopCreateQuantizeGPTQDescriptor( @@ -625,7 +651,7 @@ def torch_quantize_gptq(): ) ) workspace = TestWorkspace(workspace_size.value, device) - print("work python", workspace_size.value) + # Execute infiniop quantize_gptq operator # check_error( # LIBINFINIOP.infiniopQuantizeGPTQ( @@ -640,23 +666,7 @@ def torch_quantize_gptq(): # None, # ) # ) - def check_tensor_valid(name, t): - tt = t.torch_tensor() - ptr = tt.data_ptr() - print(f"[{name}] shape: {tt.shape}, dtype: {tt.dtype}, device: {tt.device}, data_ptr: {hex(ptr)}") - - if ptr % 16 != 0: - print(f"⚠️ Warning: {name} is NOT aligned to 16 bytes, kernel may crash!") - - check_tensor_valid("a", a) - check_tensor_valid("c", c) - check_tensor_valid("packed_weights", packed_weights) - check_tensor_valid("s", s) - check_tensor_valid("z", z) - - print(ans.torch_tensor()) - # ad = TestTensor((M, N), None, dtype, device, mode="zeros") - # print(ad.torch_tensor()) + def lib_quantize_gptq(): check_error( LIBINFINIOP.infiniopQuantizeLinearGPTQ( @@ -673,10 +683,7 @@ def lib_quantize_gptq(): ) lib_quantize_gptq() - ad = TestTensor((M, N), None, dtype, device, mode="zeros") - print(ad.torch_tensor()) - print(ans.torch_tensor()) - print(c.torch_tensor()) + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) # tmpa = ans.torch_tensor().flatten() # tmpc = c.actual_tensor().flatten() @@ -685,15 +692,14 @@ def lib_quantize_gptq(): # print(tmpa[i], tmpc[i], abs(tmpa[i] - tmpc[i]), rtol * abs(tmpa[i])) # break - if DEBUG: - debug(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol) - assert torch.allclose(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol) + debug(c.actual_tensor(), ans, atol=atol, rtol=rtol) + assert torch.allclose(c.actual_tensor(), ans, atol=atol, rtol=rtol) # Profiling workflow if PROFILE: # fmt: off - profile_operation("PyTorch", lambda: torch_quantize_gptq(), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation("PyTorch", lambda: quantize_gptq(a.torch_tensor(), b.torch_tensor(), is_weight_transposed), device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_quantize_gptq(), device, NUM_PRERUN, NUM_ITERATIONS) # fmt: on check_error(LIBINFINIOP.infiniopDestroyQuantizeGPTQDescriptor(descriptor)) From 757bbeb0b63383bce7b4535fd9e1bd969b612aef Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Tue, 12 Aug 2025 10:23:26 +0800 Subject: [PATCH 8/8] issue/170: success marlin, workspace=0 --- src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu | 4 ++++ test/infiniop/libinfiniop/op_register.py | 2 +- test/infiniop/quantize_gptq.py | 4 ++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu b/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu index 93a5fd4cd..a2424bbf0 100644 --- a/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu +++ b/src/infiniop/ops/quantize_gptq/nvidia/quantize_gptq_nvidia.cu @@ -66,6 +66,10 @@ infiniStatus_t Descriptor::calculate( int group_size = int(_info.group_size); int num_groups = int(_info.num_groups); bool is_weight_transposed = _info.is_weight_transposed; + cudaError_t err = cudaMemset(workspace, 0, workspace_size); + if (err != cudaSuccess) { + printf("cudaMemset failed: %s\n", cudaGetErrorString(err)); + } if (_info.atype == INFINI_DTYPE_F16 && is_weight_transposed) { gptq_marlin::gptq_marlin_mm_fp16(c, a, packed_weights, b_scale, diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index e232eafcf..414a8f6f8 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -237,7 +237,6 @@ def mul_(lib): ] - @OpRegister.operator def quantize_gptq_(lib): lib.infiniopCreateQuantizeGPTQDescriptor.restype = c_int32 @@ -506,6 +505,7 @@ def swiglu_(lib): infiniopOperatorDescriptor_t, ] + @OpRegister.operator def conv_(lib): lib.infiniopCreateConvDescriptor.restype = c_int32 diff --git a/test/infiniop/quantize_gptq.py b/test/infiniop/quantize_gptq.py index 70b39ec8f..54d52cc86 100644 --- a/test/infiniop/quantize_gptq.py +++ b/test/infiniop/quantize_gptq.py @@ -526,7 +526,7 @@ def test( N, dtype=InfiniDtype.F16, sync=None, -): +): # device=nvidia的时候精度不足 print( f"Testing QuantizeGPTQ on {InfiniDeviceNames[device]}" f" M:{M}, K:{K}, N:{N}, dtype:{InfiniDtypeNames[dtype]}" @@ -685,7 +685,7 @@ def lib_quantize_gptq(): lib_quantize_gptq() atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) - # tmpa = ans.torch_tensor().flatten() + # tmpa = ans.flatten() # tmpc = c.actual_tensor().flatten() # for i in range(tmpa.shape[0]): # if abs(tmpa[i] - tmpc[i]) > atol + rtol * abs(tmpa[i]):