Skip to content

Commit 63cb532

Browse files
shiyi9801Honry
andauthored
[WebNN] Support steps >= 1 for slice operator (microsoft#22708)
Co-authored-by: Wanming Lin <wanming.lin@intel.com>
1 parent b9b1a03 commit 63cb532

File tree

3 files changed

+22
-17
lines changed

3 files changed

+22
-17
lines changed

js/web/docs/webnn-operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
9393
| Softplus | ai.onnx(7+) | softplus ||| |
9494
| Softsign | ai.onnx(7+) | softsign ||| |
9595
| Sin | ai.onnx(7+) | sin ||| |
96-
| Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice ||| Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant, only supports 'steps' value 1 |
96+
| Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice ||| Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant, only supports 'steps' value >= 1 |
9797
| Softmax | ai.onnx(7-10, 11-12, 13+) | softmax ||| |
9898
| Split | ai.onnx(7-10, 11-12, 13-17, 18+) | split ||| Input 'split' if present should be a constant |
9999
| Sqrt | ai.onnx(7-12, 13+) | sqrt ||| |

js/web/test/suite-test-list.jsonc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2362,14 +2362,14 @@
23622362
// "test_sinh",
23632363
// // "test_size_example",
23642364
// // "test_size",
2365-
// "test_slice_default_axes",
2366-
// "test_slice_default_steps",
2367-
// "test_slice_end_out_of_bounds",
2368-
// "test_slice_neg_steps",
2369-
// "test_slice_neg",
2370-
// "test_slice_negative_axes",
2371-
// "test_slice_start_out_of_bounds",
2372-
// "test_slice",
2365+
"test_slice_default_axes",
2366+
"test_slice_default_steps",
2367+
"test_slice_end_out_of_bounds",
2368+
"test_slice_neg_steps",
2369+
"test_slice_neg",
2370+
"test_slice_negative_axes",
2371+
"test_slice_start_out_of_bounds",
2372+
"test_slice",
23732373
// "test_softmax_axis_0_expanded",
23742374
"test_softmax_axis_0",
23752375
// "test_softmax_axis_1_expanded",

onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
5252
emscripten::val inputs = model_builder.GetOperand(input_defs[0]->Name());
5353
std::vector<int32_t> starts(rank);
5454
std::vector<int32_t> sizes(rank);
55+
std::vector<int32_t> steps(rank);
5556

5657
// Copy the data from the starts/ends/axes/steps initializers.
5758
std::vector<int64_t> input_starts;
@@ -94,8 +95,11 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
9495
std::transform(compute_metadata.ends_.cbegin(), compute_metadata.ends_.cend(), compute_metadata.starts_.cbegin(),
9596
sizes.begin(),
9697
[](int64_t i, int64_t j) { return SafeInt<uint32_t>(i - j); });
98+
std::transform(compute_metadata.steps_.cbegin(), compute_metadata.steps_.cend(), steps.begin(),
99+
[](int64_t i) { return SafeInt<uint32_t>(i); });
97100

98101
emscripten::val options = emscripten::val::object();
102+
options.set("strides", emscripten::val::array(steps));
99103
options.set("label", node.Name());
100104
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("slice", inputs,
101105
emscripten::val::array(starts),
@@ -144,18 +148,19 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
144148
return false;
145149
}
146150
const auto data_type = steps_tensor.data_type();
147-
// WebNN doesn't support steps other than 1.
151+
// WebNN doesn't support steps less than 1.
148152
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
149-
if (!std::all_of(reinterpret_cast<int64_t*>(unpacked_tensor.data()),
150-
reinterpret_cast<int64_t*>(unpacked_tensor.data() + unpacked_tensor.size()),
151-
[](int64_t i) { return i == 1; })) {
153+
if (std::any_of(reinterpret_cast<int64_t*>(unpacked_tensor.data()),
154+
reinterpret_cast<int64_t*>(unpacked_tensor.data() + unpacked_tensor.size()),
155+
[](int64_t i) { return i < 1; })) {
156+
LOGS(logger, VERBOSE) << "WebNN slice doesn't support steps less than 1";
152157
return false;
153158
}
154159
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32) {
155-
if (!std::all_of(reinterpret_cast<int32_t*>(unpacked_tensor.data()),
156-
reinterpret_cast<int32_t*>(unpacked_tensor.data()) +
157-
unpacked_tensor.size() / sizeof(int32_t),
158-
[](int32_t i) { return i == 1; })) {
160+
if (std::any_of(reinterpret_cast<int32_t*>(unpacked_tensor.data()),
161+
reinterpret_cast<int32_t*>(unpacked_tensor.data()) + unpacked_tensor.size() / sizeof(int32_t),
162+
[](int32_t i) { return i < 1; })) {
163+
LOGS(logger, VERBOSE) << "WebNN slice doesn't support steps less than 1";
159164
return false;
160165
}
161166
}

0 commit comments

Comments
 (0)