Skip to content

Commit 1d07e94

Browse files
authored
[WebNN] Support Round op (microsoft#25810)
1 parent 7af42b8 commit 1d07e94

File tree

4 files changed

+9
-33
lines changed

4 files changed

+9
-33
lines changed

js/web/docs/webnn-operators.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s
8080
| PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | |
8181
| QuantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | quantizeLinear | The shape of x_scale should be a subsample of the shape of input |
8282
| Reciprocal | ai.onnx(7-12, 13+) | reciprocal | |
83+
| Round | ai.onnx(11-21, 22+) | roundEven | |
8384
| ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | Input 'axes' if present should be a constant |
8485
| ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | Input 'axes' if present should be a constant |
8586
| ReduceLogSum| ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSum | Input 'axes' if present should be a constant |

onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,42 +27,14 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
2727
const auto& op_type(node.OpType());
2828

2929
emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name());
30-
emscripten::val output = emscripten::val::object();
3130
emscripten::val options = emscripten::val::object();
3231
options.set("label", node.Name());
3332

34-
if (op_type == "Abs") {
35-
output = model_builder.GetBuilder().call<emscripten::val>("abs", input, options);
36-
} else if (op_type == "Ceil") {
37-
output = model_builder.GetBuilder().call<emscripten::val>("ceil", input, options);
38-
} else if (op_type == "Cos") {
39-
output = model_builder.GetBuilder().call<emscripten::val>("cos", input, options);
40-
} else if (op_type == "Erf") {
41-
output = model_builder.GetBuilder().call<emscripten::val>("erf", input, options);
42-
} else if (op_type == "Exp") {
43-
output = model_builder.GetBuilder().call<emscripten::val>("exp", input, options);
44-
} else if (op_type == "Floor") {
45-
output = model_builder.GetBuilder().call<emscripten::val>("floor", input, options);
46-
} else if (op_type == "Identity") {
47-
output = model_builder.GetBuilder().call<emscripten::val>("identity", input, options);
48-
} else if (op_type == "Log") {
49-
output = model_builder.GetBuilder().call<emscripten::val>("log", input, options);
50-
} else if (op_type == "Neg") {
51-
output = model_builder.GetBuilder().call<emscripten::val>("neg", input, options);
52-
} else if (op_type == "Reciprocal") {
53-
output = model_builder.GetBuilder().call<emscripten::val>("reciprocal", input, options);
54-
} else if (op_type == "Sign") {
55-
output = model_builder.GetBuilder().call<emscripten::val>("sign", input, options);
56-
} else if (op_type == "Sin") {
57-
output = model_builder.GetBuilder().call<emscripten::val>("sin", input, options);
58-
} else if (op_type == "Sqrt") {
59-
output = model_builder.GetBuilder().call<emscripten::val>("sqrt", input, options);
60-
} else if (op_type == "Tan") {
61-
output = model_builder.GetBuilder().call<emscripten::val>("tan", input, options);
62-
} else {
63-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
64-
"UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
65-
}
33+
const std::string_view webnn_op_type = GetWebNNOpType(op_type);
34+
ORT_RETURN_IF(webnn_op_type.empty(), "Cannot get WebNN op type");
35+
36+
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>(
37+
std::string(webnn_op_type).c_str(), input, options);
6638

6739
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
6840
return Status::OK();
@@ -84,6 +56,7 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op
8456
"Log",
8557
"Neg",
8658
"Reciprocal",
59+
"Round",
8760
"Sign",
8861
"Sin",
8962
"Sqrt",

onnxruntime/core/providers/webnn/builders/map_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ const std::unordered_map<std::string_view, WebnnOpInfo> op_inputs_map = {
174174
{"Greater", {"greater", {{0, "a"}, {1, "b"}}}},
175175
{"Reciprocal", {"reciprocal", {{0, "input"}}}},
176176
{"ReduceMean", {"reduceMean", {{0, "input"}}}},
177+
{"Round", {"roundEven", {{0, "input"}}}},
177178
{"GlobalMaxPool", {"maxPool2d", {{0, "input"}}}},
178179
{"HardSigmoid", {"hardSigmoid", {{0, "input"}}}},
179180
{"ReduceProd", {"reduceProduct", {{0, "input"}}}},

onnxruntime/core/providers/webnn/builders/op_builder_factory.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
2626
CreateUnaryOpBuilder("Log", op_registrations);
2727
CreateUnaryOpBuilder("Neg", op_registrations);
2828
CreateUnaryOpBuilder("Reciprocal", op_registrations);
29+
CreateUnaryOpBuilder("Round", op_registrations);
2930
CreateUnaryOpBuilder("Sign", op_registrations);
3031
CreateUnaryOpBuilder("Sin", op_registrations);
3132
CreateUnaryOpBuilder("Sqrt", op_registrations);

0 commit comments

Comments
 (0)