Skip to content

Commit 50170c6

Browse files
authored
[Optimizer] DQ + MatMul to MatMulNBits support: kernel changes (microsoft#21342)
Description: ### Description This is a partial change ported from fajin/qdqmatmulnbitstoolchain. That branch has issues resolving the web CI. MatMulNBits is a heavily optimized matmul operation. Currently a MatMul can be converted to MatMulNBits to speed up the model inference. However, MatMulNBits is an ORT only op. To make the graph compatible with ONNX ops and utilize MatMulNBits at the same time, we introduce Q/DQ support for MatMulNBits. To convert MatMul ops in a model to MatMulNBits: 1. use matmul_4bits_quantizer.py to convert MatMul to DQ + MatMul using QDQ mode. 2. In ORT session, DQ + MatMul is fused to MatMulNBits #### Note MatMulNBits assume B weight is uint4. When no zp is provided, zp defaults to 8, which is different from DQ. DQ defaults zp to 0 when no zp provided. And DQ supports int4. Therefore some conversions are introduced during DQ + MatMul --> MatMulNBits step. #### Perf Using QDQ format will increase the model initialization time and memory consumption. With current implement, model init time increased from ~4s to ~9s, and memory consumption increased from ~2.8GB to ~4.8GB. The memory increase is due to 1. in optimizer, after transpose the B weight, a in-memory tensor proto is created using protobuf's arena. 2. in finalize step, when saving initializer and prepacking, ORT arena is used to create buffers for initializers. The memory allocated by arenas cannot be fully deallocated. If disable ORT arena memory allocation, the memory consumptions of both QDQ format and original format are ~2.2GB. The time increase is mainly due to multiple memory copy, but can be further optimized. ### Motivation and Context Please see description for details.
1 parent c03e6ff commit 50170c6

File tree

5 files changed

+197
-145
lines changed

5 files changed

+197
-145
lines changed

onnxruntime/core/mlas/inc/mlas_q4.h

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -360,12 +360,12 @@ MlasDequantizeBlockwise(
360360
);
361361

362362
/**
363-
* @brief Blockwise 2 bits or 4 bits quantization. After quantization, the weights and zero points
364-
* are packed row-wise. In terms of the qbits type, dst and src have the same shape, and
365-
* scales and zero_points have the same shape.
366-
* columns must be multiple of 8 / qbits.
363+
* @brief Blockwise 4 bits quantization. After quantization, the weights and zero points
364+
* are packed row-wise. If zero_points is null, quantized type is int4 with default
365+
* zero point 0, to align with DQ schema. Otherwise, quantized type is uint4.
366+
* In int4/uint4, dst have the same shape as src, and zero_points have the same shape as scales.
367367
* @tparam Tin
368-
* @tparam qbits number of bits used for quantization, 2 or 4
368+
* @tparam qbits number of bits used for quantization, only 4 is supported
369369
* @param src points to the floating point matrix, to be quantized, row major shape [rows, columns]
370370
* @param scales points to the scales matrix, row major
371371
* @param zero_points points to the zero_points matrix, row major
@@ -376,9 +376,10 @@ MlasDequantizeBlockwise(
376376
* @param columns
377377
* @param quant_block_size number of elements in a quantize block
378378
* @param thread_pool
379+
* @return the quantized type is signed.
379380
*/
380381
template <typename Tin, int qbits>
381-
void
382+
bool
382383
MlasQDQQuantizeBlockwise(
383384
const Tin* src,
384385
Tin* scales,
@@ -395,8 +396,17 @@ MlasQDQQuantizeBlockwise(
395396
* @brief Transpose blockwise quantized tensors. The src tensors are row major. src weights and zero
396397
* points are packed row-wise. The dst tensors are column major. dst weights and zero points
397398
* are packed column-wise.
399+
* dst_weights and dst_zero_points are in uint4.
400+
* If src_weights is int4 and has src_zero_points, src_weights and src_zero_points are
401+
* converted to uint4 by adding 8.
402+
* If src_weights is int4 and no src_zero_points, src_weights is converted to uint4 by adding 8.
403+
* src_zero_points is 0 and dst_zero_points is 8.
404+
* If src_weights is uint4 and has src_zero_points, just transpose.
405+
* If src_weights is uint4 and no src_zero_points, caller must allocate dst_zero_points with
406+
* 0 values. Otherwise exception is thrown.
398407
* @tparam Tin
399-
* @tparam qbits number of bits used for quantization, 2 or 4
408+
* @tparam qbits number of bits used for quantization, only 4 is supported
409+
* @tparam signed_quant true when quantized type is signed, false when quantized type is unsigned
400410
* @param src_weights points to the quantized matrix, row major, shape [rows, columns] in qbits type.
401411
* In uint8_t type, shape is [rows, columns * qbits / 8].
402412
* @param src_scales points to the scales matrix, row major
@@ -410,7 +420,7 @@ MlasQDQQuantizeBlockwise(
410420
* @param quant_block_size number of elements in a quantize block
411421
* @param thread_pool
412422
*/
413-
template <typename Tin, int qbits>
423+
template <typename Tin, int qbits, bool signed_quant>
414424
void
415425
MlasQDQTransposeBlockwiseQuantized(
416426
const uint8_t* src_weights,

0 commit comments

Comments
 (0)