Skip to content

Commit c07afd3

Browse files
authored
slice operator implementation for webgpu native (microsoft#23264)
Increases operator coverage for webgpu native ep
1 parent 5c3c764 commit c07afd3

File tree

5 files changed

+350
-5
lines changed

5 files changed

+350
-5
lines changed
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/common/inlined_containers.h"
5+
#include "core/providers/webgpu/tensor/slice.h"
6+
#include "core/providers/cpu/tensor/utils.h"
7+
#include "core/providers/webgpu/shader_helper.h"
8+
#include "core/providers/webgpu/webgpu_supported_types.h"
9+
10+
namespace onnxruntime {
11+
namespace webgpu {
12+
13+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
14+
Slice,
15+
kOnnxDomain,
16+
1, 9,
17+
kWebGpuExecutionProvider,
18+
(*KernelDefBuilder::Create())
19+
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
20+
Slice);
21+
22+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
23+
Slice,
24+
kOnnxDomain,
25+
10, 10,
26+
kWebGpuExecutionProvider,
27+
(*KernelDefBuilder::Create())
28+
.TypeConstraint("T", WebGpuSupportedFloatTypes())
29+
.InputMemoryType(OrtMemTypeCPU, 1)
30+
.InputMemoryType(OrtMemTypeCPU, 2)
31+
.InputMemoryType(OrtMemTypeCPU, 3)
32+
.InputMemoryType(OrtMemTypeCPU, 4),
33+
Slice);
34+
35+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
36+
Slice,
37+
kOnnxDomain,
38+
11, 12,
39+
kWebGpuExecutionProvider,
40+
(*KernelDefBuilder::Create())
41+
.TypeConstraint("T", WebGpuSupportedFloatTypes())
42+
.InputMemoryType(OrtMemTypeCPU, 1)
43+
.InputMemoryType(OrtMemTypeCPU, 2)
44+
.InputMemoryType(OrtMemTypeCPU, 3)
45+
.InputMemoryType(OrtMemTypeCPU, 4),
46+
Slice);
47+
48+
ONNX_OPERATOR_KERNEL_EX(
49+
Slice,
50+
kOnnxDomain,
51+
13,
52+
kWebGpuExecutionProvider,
53+
(*KernelDefBuilder::Create())
54+
.TypeConstraint("T", WebGpuSupportedFloatTypes())
55+
.InputMemoryType(OrtMemTypeCPU, 1)
56+
.InputMemoryType(OrtMemTypeCPU, 2)
57+
.InputMemoryType(OrtMemTypeCPU, 3)
58+
.InputMemoryType(OrtMemTypeCPU, 4),
59+
Slice);
60+
61+
Status SliceProgram::GenerateShaderCode(ShaderHelper& shader) const {
62+
const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
63+
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
64+
65+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
66+
<< "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"
67+
<< "var input_indices: input_indices_t;\n"
68+
<< "var carry = 0u;\n";
69+
70+
for (int i = input.Rank() - 1; i >= 0; i--) {
71+
std::string input_shape_i = absl::StrCat("input_shape_", i);
72+
std::string steps_i = absl::StrCat("steps_", i);
73+
std::string starts_i = absl::StrCat("starts_", i);
74+
std::string output_index_i = absl::StrCat("output_index_", i);
75+
std::string input_index_i = absl::StrCat("input_index_", i);
76+
77+
shader.MainFunctionBody() << "let " << input_shape_i << " = " << input.IndicesGet("uniforms.input_shape", i) << ";\n"
78+
<< "let " << steps_i << " = " << input.IndicesGet("uniforms.steps", i) << ";\n"
79+
<< "let " << starts_i << " = " << input.IndicesGet("uniforms.starts", i) << ";\n"
80+
<< "var " << output_index_i << " = " << output.IndicesGet("output_indices", i) << ";\n"
81+
<< "var " << input_index_i << " = " << output_index_i << " * " << steps_i << " + " << starts_i << " + carry;\n"
82+
<< "carry = " << input_index_i << " / " << input_shape_i << ";\n"
83+
<< input_index_i << " = " << input_index_i << " % " << input_shape_i << ";\n"
84+
<< "if (" << input.IndicesGet("uniforms.signs", i) << " < 0) {\n"
85+
<< " " << input_index_i << " = " << input_shape_i << " - " << input_index_i << " - 1u + " << starts_i << ";\n"
86+
<< "}\n"
87+
<< input.IndicesSet("input_indices", i, input_index_i) << ";\n";
88+
}
89+
90+
shader.MainFunctionBody() << output.SetByOffset("global_idx", input.GetByIndices("input_indices"));
91+
92+
return Status::OK();
93+
}
94+
95+
Status Slice::ComputeInternal(ComputeContext& context) const {
96+
// READ INPUTS
97+
const Tensor* input_tensor = context.Input(0);
98+
const TensorShape& input_shape = input_tensor->Shape();
99+
int64_t input_rank = static_cast<int64_t>(input_shape.NumDimensions());
100+
101+
auto starts_raw = attr_starts_.empty() ? context.Input(1)->DataAsSpan<int64_t>() : gsl::make_span(attr_starts_);
102+
auto ends_raw = attr_ends_.empty() ? context.Input(2)->DataAsSpan<int64_t>() : gsl::make_span(attr_ends_);
103+
104+
ORT_ENFORCE(starts_raw.size() == ends_raw.size(), "starts and ends must have the same size");
105+
106+
int input_count = context.InputCount();
107+
108+
const Tensor* axes_tensor = nullptr;
109+
const Tensor* steps_tensor = nullptr;
110+
111+
if (input_count >= 4) {
112+
// axes provided as input
113+
axes_tensor = context.Input(3);
114+
}
115+
116+
if (input_count == 5) {
117+
// steps provided as input
118+
steps_tensor = context.Input(4);
119+
}
120+
121+
// Inject defaults if axes or steps not provided
122+
std::vector<int64_t> axes_default;
123+
if (axes_tensor == nullptr) {
124+
// if axes not provided, set to [0, ..., len(starts)-1]
125+
for (size_t i = 0; i < starts_raw.size(); i++) {
126+
axes_default.push_back(i);
127+
}
128+
}
129+
auto axes_raw = attr_axes_.empty() ? (axes_tensor == nullptr ? gsl::make_span(axes_default) : axes_tensor->DataAsSpan<int64_t>()) : gsl::make_span(attr_axes_);
130+
131+
std::vector<int64_t> steps_default;
132+
if (steps_tensor == nullptr) {
133+
// if steps not provided, set to [1, ..., 1] of len(starts)
134+
for (size_t i = 0; i < starts_raw.size(); i++) {
135+
steps_default.push_back(1);
136+
}
137+
}
138+
auto steps_raw = steps_tensor == nullptr ? gsl::make_span(steps_default) : steps_tensor->DataAsSpan<int64_t>();
139+
140+
// PROCESS INPUTS
141+
std::vector<uint32_t> axes;
142+
for (unsigned int i = 0; i < axes_raw.size(); i++) {
143+
int64_t val = axes_raw[i];
144+
if (val < 0) {
145+
val += input_rank;
146+
}
147+
axes.push_back(static_cast<int32_t>(val));
148+
}
149+
150+
std::vector<uint32_t> starts;
151+
for (unsigned int i = 0; i < starts_raw.size(); i++) {
152+
int64_t val = starts_raw[i];
153+
if (val < 0) {
154+
val += input_shape[axes[i]];
155+
}
156+
157+
if (steps_raw[i] < 0) {
158+
val = std::max(static_cast<int64_t>(0), std::min(val, static_cast<int64_t>(input_shape[axes[i]] - 1)));
159+
} else {
160+
val = std::max(static_cast<int64_t>(0), std::min(val, static_cast<int64_t>(input_shape[axes[i]])));
161+
}
162+
starts.push_back(static_cast<uint32_t>(val));
163+
}
164+
165+
std::vector<uint32_t> ends;
166+
for (unsigned int i = 0; i < ends_raw.size(); i++) {
167+
int64_t val = ends_raw[i];
168+
if (val < 0) {
169+
val += input_shape[axes[i]];
170+
}
171+
if (steps_raw[i] < 0) {
172+
val = std::max(static_cast<int64_t>(0), std::min(val, static_cast<int64_t>(input_shape[axes[i]] - 1)));
173+
} else {
174+
val = std::max(static_cast<int64_t>(0), std::min(val, static_cast<int64_t>(input_shape[axes[i]])));
175+
}
176+
ends.push_back(static_cast<uint32_t>(val));
177+
}
178+
179+
// temporary steps vector to handle negative steps
180+
std::vector<int32_t> steps_tmp;
181+
for (unsigned int i = 0; i < steps_raw.size(); i++) {
182+
if (steps_raw[i] >= std::numeric_limits<int32_t>::max()) {
183+
steps_tmp.push_back(std::numeric_limits<int32_t>::max());
184+
} else {
185+
steps_tmp.push_back(static_cast<int32_t>(steps_raw[i]));
186+
}
187+
}
188+
189+
// Insert missing dimensions
190+
if (static_cast<int64_t>(axes.size()) != input_rank) {
191+
for (uint32_t i = 0; i < input_rank; i++) {
192+
int idx = -1;
193+
for (unsigned int j = 0; j < axes_raw.size(); j++) {
194+
if (axes_raw[j] == i) {
195+
idx = j;
196+
break;
197+
}
198+
}
199+
if (idx == -1) {
200+
axes.insert(axes.begin() + i, i);
201+
starts.insert(starts.begin() + i, 0);
202+
ends.insert(ends.begin() + i, static_cast<uint32_t>(input_shape[i]));
203+
steps_tmp.insert(steps_tmp.begin() + i, 1);
204+
}
205+
}
206+
}
207+
208+
// retain the sign of the steps
209+
std::vector<int32_t> signs;
210+
for (unsigned int i = 0; i < steps_tmp.size(); i++) {
211+
signs.push_back(steps_tmp[i] < 0 ? -1 : (steps_tmp[i] > 0 ? 1 : 0));
212+
}
213+
214+
// Convert negative steps to positive steps and reverse starts and ends
215+
for (unsigned int i = 0; i < steps_tmp.size(); i++) {
216+
if (steps_tmp[i] < 0) {
217+
float numSteps = static_cast<float>((static_cast<float>(ends[i]) - static_cast<float>(starts[i])) / static_cast<float>(steps_tmp[i]));
218+
float newEnd = static_cast<float>(starts[i]);
219+
float newStart = newEnd + numSteps * static_cast<float>(steps_tmp[i]);
220+
221+
starts[i] = static_cast<uint32_t>(newStart);
222+
ends[i] = static_cast<uint32_t>(newEnd);
223+
steps_tmp[i] = static_cast<int32_t>(-steps_tmp[i]);
224+
}
225+
}
226+
227+
// final steps vector of type unsigned int
228+
std::vector<uint32_t> steps;
229+
for (unsigned int i = 0; i < steps_tmp.size(); i++) {
230+
steps.push_back(static_cast<uint32_t>(steps_tmp[i]));
231+
}
232+
233+
// Reorder inputs in order of axis
234+
std::vector<int32_t> signs_reordered;
235+
std::vector<uint32_t> steps_reordered, starts_reordered;
236+
for (unsigned int i = 0; i < axes.size(); i++) {
237+
signs_reordered.push_back(0);
238+
steps_reordered.push_back(0);
239+
starts_reordered.push_back(0);
240+
}
241+
for (unsigned int i = 0; i < axes.size(); i++) {
242+
int32_t dim = axes[i];
243+
signs_reordered[dim] = signs[i];
244+
steps_reordered[dim] = steps[i];
245+
starts_reordered[dim] = starts[i];
246+
}
247+
248+
// calculate output dims
249+
std::vector<int64_t> output_dims;
250+
for (unsigned int i = 0; i < axes.size(); i++) {
251+
int32_t dim = axes[i];
252+
float tmp = ceil((static_cast<float>(ends[dim]) - static_cast<float>(starts[dim])) / static_cast<float>(steps[dim]));
253+
if (tmp < 0)
254+
output_dims.push_back(0);
255+
else
256+
output_dims.push_back(static_cast<int64_t>(tmp));
257+
}
258+
259+
TensorShape output_shape(output_dims);
260+
261+
auto* output_tensor = context.Output(0, output_shape);
262+
uint32_t output_size = static_cast<uint32_t>(output_shape.Size());
263+
264+
if (output_size == 0) {
265+
return Status::OK();
266+
}
267+
268+
SliceProgram program{};
269+
program
270+
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
271+
.AddOutputs({output_tensor})
272+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
273+
.AddUniformVariables({{output_size}, {starts_reordered}, {steps_reordered}, {signs_reordered}});
274+
return context.RunProgram(program);
275+
}
276+
277+
} // namespace webgpu
278+
} // namespace onnxruntime
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
#include "core/providers/webgpu/program.h"
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
12+
class SliceProgram final : public Program<SliceProgram> {
13+
public:
14+
SliceProgram() : Program{"Slice"} {}
15+
16+
Status GenerateShaderCode(ShaderHelper& sh) const override;
17+
18+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
19+
{"starts", ProgramUniformVariableDataType::Uint32},
20+
{"steps", ProgramUniformVariableDataType::Uint32},
21+
{"signs", ProgramUniformVariableDataType::Int32});
22+
};
23+
24+
class Slice final : public WebGpuKernel {
25+
public:
26+
Slice(const OpKernelInfo& info) : WebGpuKernel(info) {
27+
// since only opset1-9 provides these as attributes, we can safely ignore the return value
28+
// we handle failure in fetching the attribute in ComputeInternal
29+
(void)info.GetAttrs("starts", attr_starts_);
30+
(void)info.GetAttrs("ends", attr_ends_);
31+
(void)info.GetAttrs("axes", attr_axes_);
32+
}
33+
34+
Status ComputeInternal(ComputeContext& context) const override;
35+
36+
private:
37+
std::vector<int64_t> attr_starts_, attr_ends_, attr_axes_;
38+
};
39+
40+
} // namespace webgpu
41+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -664,10 +664,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
664664
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize)>,
665665
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 19, Resize)>,
666666

667-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 9, Slice)>,
668-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Slice)>,
669-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Slice)>,
670-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Slice)>,
667+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 9, Slice)>,
668+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Slice)>,
669+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Slice)>,
670+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Slice)>,
671671

672672
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 8, Flatten)>,
673673
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,

0 commit comments

Comments
 (0)