Skip to content

Commit d5d3b28

Browse files
carzhHectorSVCgithub-actions[bot]
authored
Enable 2bit CPU matmul fallback (microsoft#25582)
### Description - enable 2bit matmulnbits - falls back to ComputeBUnpacked (dequants to fp32) - Also adapting quantize script to enable 2 bits - adds 2bit unit tests - [blockwise quantize for 2bits already implemented ](https://github.com/microsoft/onnxruntime/blob/b9575476e94daa9c6578aba92d8f04324dd15815/onnxruntime/core/mlas/lib/q4_dq.cpp#L407) ### Motivation and Context - working on enabling bitnet + lowbit LLM's --------- Co-authored-by: Hector Li <hecli@microsoft.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 539d0ed commit d5d3b28

File tree

11 files changed

+542
-13
lines changed

11 files changed

+542
-13
lines changed

onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ class MatMulNBits final : public OpKernel {
112112
has_unquantized_zero_point_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8;
113113
}
114114

115-
ORT_ENFORCE(nbits_ == 4 || nbits_ == 8,
116-
"Only 4b and 8b quantization is supported for MatMulNBits op, additional bits support is planned.");
115+
ORT_ENFORCE(nbits_ == 2 || nbits_ == 4 || nbits_ == 8,
116+
"Only 2b, 4b and 8b quantization is supported for MatMulNBits op, additional bits support is planned.");
117117
const Tensor* tensor_zero_point = nullptr;
118118
has_zp_input_ = info.TryGetConstantInput(InputIndex::zero_points, &tensor_zero_point);
119119
}
@@ -458,7 +458,19 @@ Status MatMulNBits<float>::ComputeBUnpacked(const Tensor* a,
458458
auto tmp_b_data_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_, true);
459459

460460
if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType<float>())) {
461-
if (nbits_ == 4) {
461+
// dequantize b, only 2b, 4b, and 8b quantization is supported for now
462+
if (this->nbits_ == 2) {
463+
MlasDequantizeBlockwise<float, 2>(
464+
tmp_b_data_ptr.get(), // dequantized output
465+
b_data, // quantized input
466+
scales_data, // quantization scales
467+
static_cast<const uint8_t*>(zero_points_data), // quantization zero points
468+
static_cast<int32_t>(block_size_), // quantization block size
469+
column_wise_quant_, // columnwise quantization or row-wise
470+
static_cast<int32_t>(K_), // number of rows in quantized input
471+
static_cast<int32_t>(N_), // number of columns in quantized input
472+
thread_pool);
473+
} else if (this->nbits_ == 4) {
462474
MlasDequantizeBlockwise<float, 4>(
463475
tmp_b_data_ptr.get(), // dequantized output
464476
b_data, // quantized input

onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ Status CheckInputs(const T* /*activation*/,
3131
// group_index : (K) or (k_blocks * block_size), or null
3232
// bias : (N), or null
3333
// Note that scales and zero_points can be 1D for backward compatibility.
34-
if (bits != 4 && bits != 8) {
35-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "bits should be 4 or 8, got ", bits);
34+
if (bits != 2 && bits != 4 && bits != 8) {
35+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "bits should be 2, 4 or 8, got ", bits);
3636
}
3737

3838
if (block_size < 16 || (block_size & (block_size - 1)) != 0) {

onnxruntime/core/mlas/inc/mlas_q4.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ MlasBlockwiseQuantizedShape(
277277
*
278278
* If the qbits or block_size values are unsupported the output sizes will be zero.
279279
*/
280-
template <int qbits>
280+
template<int qbits>
281281
void MLASCALL
282282
MlasBlockwiseQuantizedBufferSizes(
283283
int block_size,

onnxruntime/python/onnxruntime_pybind_quant.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ void QuantizeMatMulBnb4Blockwise(
126126
}
127127

128128
void CreateQuantPybindModule(py::module& m) {
129+
m.def("quantize_matmul_2bits", &QuantizeMatMulNBitsBlockwise<float, 2>);
130+
m.def("quantize_matmul_2bits", &QuantizeMatMulNBitsBlockwise<MLFloat16, 2>);
129131
m.def("quantize_matmul_4bits", &QuantizeMatMulNBitsBlockwise<float, 4>);
130132
m.def("quantize_matmul_4bits", &QuantizeMatMulNBitsBlockwise<MLFloat16, 4>);
131133
m.def("quantize_matmul_8bits", &QuantizeMatMulNBitsBlockwise<float, 8>);

onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
import onnx
1717
from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
1818

19-
from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_matmul_8bits, quantize_qdq_matmul_4bits
19+
from onnxruntime.capi._pybind_state import (
20+
quantize_matmul_2bits,
21+
quantize_matmul_4bits,
22+
quantize_matmul_8bits,
23+
quantize_qdq_matmul_4bits,
24+
)
2025

2126
from .calibrate import CalibrationDataReader
2227
from .neural_compressor import gptq_quantize, rtn_quantize
@@ -818,7 +823,11 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n
818823
packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
819824
zero_point = np.zeros(cols * ((k_blocks + kpack - 1) // kpack), dtype="uint8")
820825
scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype)
821-
if qbits == 8:
826+
if qbits == 2:
827+
quantize_matmul_2bits(
828+
packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
829+
)
830+
elif qbits == 8:
822831
quantize_matmul_8bits(
823832
packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
824833
)
@@ -1206,7 +1215,7 @@ class MatMulNBitsQuantizer:
12061215
MatMul MatMulNBits DeQuantizeLinear -> MatMul
12071216
Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
12081217
1209-
Perform 4/8 bits quantization of constant weights for target nodes.
1218+
Perform 2/4/8 bits quantization of constant weights for target nodes.
12101219
If algo_config.quant_format is QOperator:
12111220
- nodes are replaced by the corresponding QOperator nodes.
12121221
- quantized weights are stored in the contrib ops.
@@ -1224,6 +1233,7 @@ class MatMulNBitsQuantizer:
12241233
def __init__(
12251234
self,
12261235
model: ModelProto | str,
1236+
bits: int = 4, # default to 4bit
12271237
block_size: int = 128,
12281238
is_symmetric: bool = False,
12291239
accuracy_level: int | None = None,
@@ -1239,6 +1249,7 @@ def __init__(
12391249
nodes_to_exclude = []
12401250
self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model)
12411251
self.model_path = model if isinstance(model, str) else None
1252+
self.bits = bits
12421253
self.block_size = block_size
12431254
self.is_symmetric = is_symmetric
12441255
self.accuracy_level = accuracy_level
@@ -1254,13 +1265,13 @@ def __init__(
12541265
quant_format=quant_format,
12551266
op_types_to_quantize=op_types_to_quantize,
12561267
quant_axes=quant_axes,
1257-
bits=4, # default to 4 bits
1268+
bits=bits,
12581269
channel_wised_quantize=channel_wised_quantize,
12591270
)
12601271

12611272
self.algo_config = algo_config
12621273
if hasattr(self.algo_config, "bits"):
1263-
assert self.algo_config.bits in [4, 8], "Only support 4 or 8 bits quantization"
1274+
assert self.algo_config.bits in [2, 4, 8], "Only support 2, 4 or 8 bits quantization"
12641275

12651276
if algo_config.algorithm == "HQQ":
12661277
self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config)
@@ -1609,6 +1620,7 @@ def parse_args():
16091620

16101621
quant = MatMulNBitsQuantizer(
16111622
model=model,
1623+
bits=args.bits,
16121624
accuracy_level=args.accuracy_level,
16131625
nodes_to_exclude=args.nodes_to_exclude,
16141626
nodes_to_include=args.nodes_to_include,

onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,8 @@ def get_args():
670670

671671
blockwise_group = parser.add_argument_group("blockwise (4-bit quantization)")
672672

673+
parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight")
674+
673675
blockwise_group.add_argument(
674676
"--block_size",
675677
required=False,
@@ -988,6 +990,7 @@ def main():
988990
model = onnx.load_model(fp_path, load_external_data=True)
989991
quant = MatMulNBitsQuantizer(
990992
model=model,
993+
bits=args.bits,
991994
block_size=args.block_size,
992995
is_symmetric=True,
993996
accuracy_level=args.int4_accuracy_level,

onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
168168
assert self.precision == Precision.INT4
169169
quant = MatMulNBitsQuantizer(
170170
model=optimizer.model,
171+
bits=4,
171172
block_size=self.block_size,
172173
is_symmetric=True,
173174
accuracy_level=self.accuracy_level,

0 commit comments

Comments
 (0)