Skip to content

Commit 0b15200

Browse files
authored
DequantizeLinear should support non-zero zero_point when input type is int32 (microsoft#25646)
### Description This PR makes DequantizeLinear support non-zero zero_point when input data type is int32. ### Motivation and Context For WebNN use case, we have some scenarios that input data type is int32 and the zero_point is not zero for DequantizeLinear.
1 parent d4e31dc commit 0b15200

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

onnxruntime/core/providers/cpu/quantization/quantize_linear.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,14 +520,12 @@ Status DequantizeLinear<T>::Compute(OpKernelContext* ctx) const {
520520
const T* zero_point = x_zero_point ? x_zero_point->Data<T>() : nullptr;
521521

522522
#if !defined(DISABLE_FLOAT8_TYPES)
523-
if constexpr (boost::mp11::mp_contains<boost::mp11::mp_append<element_type_lists::AllFloat8,
524-
TypeList<int32_t>>,
525-
T>::value) {
523+
if constexpr (boost::mp11::mp_contains<element_type_lists::AllFloat8, T>::value) {
526524
ORT_ENFORCE(zero_point == nullptr ||
527525
std::all_of(zero_point,
528526
zero_point + x_zero_point->Shape().Size(),
529527
[](T zp) { return zp == T{0}; }),
530-
"DequantizeLinear with type int32 or float8 should have no zero point or all zero points should be 0");
528+
"DequantizeLinear with type float8 should have no zero point or all zero points should be 0");
531529
}
532530
#endif
533531

onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ TEST(DequantizeLinearOpTest, Uint16) {
137137
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
138138
}
139139

140-
// scalar zero & scale with int8
140+
// scalar zero & scale with int32
141141
TEST(DequantizeLinearOpTest, Int32) {
142142
OpTester test("DequantizeLinear", 10);
143143
std::vector<int64_t> dims{4};
@@ -147,6 +147,17 @@ TEST(DequantizeLinearOpTest, Int32) {
147147
test.Run();
148148
}
149149

150+
// non-zero zero point with int32
151+
TEST(DequantizeLinearOpTest, Int32_Non_Zero_Zero_Point) {
152+
OpTester test("DequantizeLinear", 10);
153+
std::vector<int64_t> dims{4};
154+
test.AddInput<int32_t>("x", dims, {-30, -3, 100, 127});
155+
test.AddInput<float>("x_scale", {}, {2.0f}, true);
156+
test.AddInput<int32_t>("x_zero_point", {}, {1}, true);
157+
test.AddOutput<float>("y", dims, {-62.f, -8.f, 198.f, 252.f});
158+
test.Run();
159+
}
160+
150161
TEST(DequantizeLinearOpTest_BroadcastTensor, Int32) {
151162
OpTester test("DequantizeLinear", 13);
152163
test.AddInput<int32_t>("x", {4}, {-30, -3, 100, 127});

0 commit comments

Comments
 (0)