Skip to content

Commit 80b6e93

Browse files
authored
[WebNN] Fallback int64 indices to int32 (microsoft#26308)
ONNX's ScatterND and ScatterElements limit their indices input to int64, but some WebNN backends only support int32 indices. As a workaround for such backends, we can insert a Cast operation to convert the data type.
1 parent df418e2 commit 80b6e93

File tree

4 files changed

+54
-7
lines changed

4 files changed

+54
-7
lines changed

onnxruntime/core/providers/webnn/builders/helper.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,5 +439,16 @@ uint16_t PackFloat32ToUint16AsFloat16(float value) {
439439
return sign_float16 | (exponent_float16 << 10) | mantissa_float16;
440440
}
441441

442+
// Check if it can fallback to int32 if the input of WebNN op doesn't support int64.
443+
bool CanFallbackInt64ToInt32(const emscripten::val& wnn_limits,
444+
const std::string& webnn_op_type,
445+
const std::string& input_name) {
446+
emscripten::val supported_data_types = wnn_limits[webnn_op_type][input_name]["dataTypes"];
447+
448+
return !supported_data_types.isUndefined() &&
449+
!supported_data_types.call<emscripten::val>("includes", emscripten::val("int64")).as<bool>() &&
450+
supported_data_types.call<emscripten::val>("includes", emscripten::val("int32")).as<bool>();
451+
}
452+
442453
} // namespace webnn
443454
} // namespace onnxruntime

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,5 +297,9 @@ bool IsMLTensorSupported();
297297
uint8_t PackInt8ToUint8DoubledNibbles(int8_t value, const int32_t& data_type);
298298
uint16_t PackFloat32ToUint16AsFloat16(float value);
299299

300+
bool CanFallbackInt64ToInt32(const emscripten::val& wnn_limits,
301+
const std::string& webnn_op_type,
302+
const std::string& input_name);
303+
300304
} // namespace webnn
301305
} // namespace onnxruntime

onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ScatterElementsOpBuilder : public BaseOpBuilder {
2424
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
2525
bool HasSupportedInputsImpl(const GraphViewer& graph_viewer, const Node& node,
2626
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
27+
mutable bool can_fallback_int64_to_int32_ = false;
2728
};
2829

2930
// Add operator related.
@@ -35,14 +36,21 @@ Status ScatterElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build
3536
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
3637
emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name());
3738
emscripten::val options = emscripten::val::object();
38-
options.set("label", node.Name());
39+
40+
// ONNX specifies that indices must use int64, but some WebNN backends only support int32.
41+
// As a workaround for such backends, we can insert a Cast operation to convert the data type.
42+
if (can_fallback_int64_to_int32_) {
43+
options.set("label", node.Name() + "_cast_indices_to_int32");
44+
indices = model_builder.GetBuilder().call<emscripten::val>("cast", indices, emscripten::val("int32"), options);
45+
}
3946

4047
std::vector<int64_t> input_shape;
4148
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
4249
const size_t rank = input_shape.size();
4350
NodeAttrHelper helper(node);
4451
const uint32_t axis = static_cast<uint32_t>(HandleNegativeAxis(helper.Get("axis", 0), rank));
4552
options.set("axis", axis);
53+
options.set("label", node.Name());
4654

4755
emscripten::val output =
4856
model_builder.GetBuilder().call<emscripten::val>("scatterElements", data, indices, updates, options);
@@ -86,9 +94,16 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const
8694

8795
const std::string_view op_type = node.OpType();
8896

89-
return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
90-
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) &&
91-
IsInputRankSupportedByOp(node, wnn_limits, logger);
97+
if (!IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) ||
98+
!IsInputRankSupportedByOp(node, wnn_limits, logger)) {
99+
return false;
100+
}
101+
102+
// ONNX specifies that indices must use int64, but some WebNN backends only support int32.
103+
// Allows to use int32 as a workaround for such backends.
104+
can_fallback_int64_to_int32_ = CanFallbackInt64ToInt32(wnn_limits, "scatterElements", "indices");
105+
return can_fallback_int64_to_int32_ ||
106+
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
92107
}
93108

94109
void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ScatterNDOpBuilder : public BaseOpBuilder {
2424
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
2525
bool HasSupportedInputsImpl(const GraphViewer& graph_viewer, const Node& node,
2626
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
27+
mutable bool can_fallback_int64_to_int32_ = false;
2728
};
2829

2930
// Add operator related.
@@ -35,6 +36,14 @@ Status ScatterNDOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
3536
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
3637
emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name());
3738
emscripten::val options = emscripten::val::object();
39+
40+
// ONNX specifies that indices must use int64, but some WebNN backends only support int32.
41+
// As a workaround for such backends, we can insert a Cast operation to convert the data type.
42+
if (can_fallback_int64_to_int32_) {
43+
options.set("label", node.Name() + "_cast_indices_to_int32");
44+
indices = model_builder.GetBuilder().call<emscripten::val>("cast", indices, emscripten::val("int32"), options);
45+
}
46+
3847
options.set("label", node.Name());
3948
emscripten::val output =
4049
model_builder.GetBuilder().call<emscripten::val>("scatterND", data, indices, updates, options);
@@ -76,9 +85,17 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node&
7685
return false;
7786
}
7887
const std::string_view op_type = node.OpType();
79-
return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
80-
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) &&
81-
IsInputRankSupportedByOp(node, wnn_limits, logger);
88+
89+
if (!IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) ||
90+
!IsInputRankSupportedByOp(node, wnn_limits, logger)) {
91+
return false;
92+
}
93+
94+
// ONNX specifies that indices must use int64, but some WebNN backends only support int32.
95+
// Allows to use int32 as a workaround for such backends.
96+
can_fallback_int64_to_int32_ = CanFallbackInt64ToInt32(wnn_limits, "scatterND", "indices");
97+
return can_fallback_int64_to_int32_ ||
98+
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
8299
}
83100

84101
void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

0 commit comments

Comments
 (0)