1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
3+
4+ #include " core/common/inlined_containers.h"
5+ #include " core/providers/webgpu/tensor/slice.h"
6+ #include " core/providers/cpu/tensor/utils.h"
7+ #include " core/providers/webgpu/shader_helper.h"
8+ #include " core/providers/webgpu/webgpu_supported_types.h"
9+
10+ namespace onnxruntime {
11+ namespace webgpu {
12+
13+ ONNX_OPERATOR_VERSIONED_KERNEL_EX (
14+ Slice,
15+ kOnnxDomain ,
16+ 1 , 9 ,
17+ kWebGpuExecutionProvider ,
18+ (*KernelDefBuilder::Create ())
19+ .TypeConstraint(" T" , WebGpuSupportedFloatTypes()),
20+ Slice);
21+
22+ ONNX_OPERATOR_VERSIONED_KERNEL_EX (
23+ Slice,
24+ kOnnxDomain ,
25+ 10 , 10 ,
26+ kWebGpuExecutionProvider ,
27+ (*KernelDefBuilder::Create ())
28+ .TypeConstraint(" T" , WebGpuSupportedFloatTypes())
29+ .InputMemoryType(OrtMemTypeCPU, 1 )
30+ .InputMemoryType(OrtMemTypeCPU, 2 )
31+ .InputMemoryType(OrtMemTypeCPU, 3 )
32+ .InputMemoryType(OrtMemTypeCPU, 4 ),
33+ Slice);
34+
35+ ONNX_OPERATOR_VERSIONED_KERNEL_EX (
36+ Slice,
37+ kOnnxDomain ,
38+ 11 , 12 ,
39+ kWebGpuExecutionProvider ,
40+ (*KernelDefBuilder::Create ())
41+ .TypeConstraint(" T" , WebGpuSupportedFloatTypes())
42+ .InputMemoryType(OrtMemTypeCPU, 1 )
43+ .InputMemoryType(OrtMemTypeCPU, 2 )
44+ .InputMemoryType(OrtMemTypeCPU, 3 )
45+ .InputMemoryType(OrtMemTypeCPU, 4 ),
46+ Slice);
47+
48+ ONNX_OPERATOR_KERNEL_EX (
49+ Slice,
50+ kOnnxDomain ,
51+ 13 ,
52+ kWebGpuExecutionProvider ,
53+ (*KernelDefBuilder::Create ())
54+ .TypeConstraint(" T" , WebGpuSupportedFloatTypes())
55+ .InputMemoryType(OrtMemTypeCPU, 1 )
56+ .InputMemoryType(OrtMemTypeCPU, 2 )
57+ .InputMemoryType(OrtMemTypeCPU, 3 )
58+ .InputMemoryType(OrtMemTypeCPU, 4 ),
59+ Slice);
60+
61+ Status SliceProgram::GenerateShaderCode (ShaderHelper& shader) const {
62+ const ShaderVariableHelper& input = shader.AddInput (" input" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
63+ const ShaderVariableHelper& output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
64+
65+ shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
66+ << " let output_indices = " << output.OffsetToIndices (" global_idx" ) << " ;\n "
67+ << " var input_indices: input_indices_t;\n "
68+ << " var carry = 0u;\n " ;
69+
70+ for (int i = input.Rank () - 1 ; i >= 0 ; i--) {
71+ std::string input_shape_i = absl::StrCat (" input_shape_" , i);
72+ std::string steps_i = absl::StrCat (" steps_" , i);
73+ std::string starts_i = absl::StrCat (" starts_" , i);
74+ std::string output_index_i = absl::StrCat (" output_index_" , i);
75+ std::string input_index_i = absl::StrCat (" input_index_" , i);
76+
77+ shader.MainFunctionBody () << " let " << input_shape_i << " = " << input.IndicesGet (" uniforms.input_shape" , i) << " ;\n "
78+ << " let " << steps_i << " = " << input.IndicesGet (" uniforms.steps" , i) << " ;\n "
79+ << " let " << starts_i << " = " << input.IndicesGet (" uniforms.starts" , i) << " ;\n "
80+ << " var " << output_index_i << " = " << output.IndicesGet (" output_indices" , i) << " ;\n "
81+ << " var " << input_index_i << " = " << output_index_i << " * " << steps_i << " + " << starts_i << " + carry;\n "
82+ << " carry = " << input_index_i << " / " << input_shape_i << " ;\n "
83+ << input_index_i << " = " << input_index_i << " % " << input_shape_i << " ;\n "
84+ << " if (" << input.IndicesGet (" uniforms.signs" , i) << " < 0) {\n "
85+ << " " << input_index_i << " = " << input_shape_i << " - " << input_index_i << " - 1u + " << starts_i << " ;\n "
86+ << " }\n "
87+ << input.IndicesSet (" input_indices" , i, input_index_i) << " ;\n " ;
88+ }
89+
90+ shader.MainFunctionBody () << output.SetByOffset (" global_idx" , input.GetByIndices (" input_indices" ));
91+
92+ return Status::OK ();
93+ }
94+
95+ Status Slice::ComputeInternal (ComputeContext& context) const {
96+ // READ INPUTS
97+ const Tensor* input_tensor = context.Input (0 );
98+ const TensorShape& input_shape = input_tensor->Shape ();
99+ int64_t input_rank = static_cast <int64_t >(input_shape.NumDimensions ());
100+
101+ auto starts_raw = attr_starts_.empty () ? context.Input (1 )->DataAsSpan <int64_t >() : gsl::make_span (attr_starts_);
102+ auto ends_raw = attr_ends_.empty () ? context.Input (2 )->DataAsSpan <int64_t >() : gsl::make_span (attr_ends_);
103+
104+ ORT_ENFORCE (starts_raw.size () == ends_raw.size (), " starts and ends must have the same size" );
105+
106+ int input_count = context.InputCount ();
107+
108+ const Tensor* axes_tensor = nullptr ;
109+ const Tensor* steps_tensor = nullptr ;
110+
111+ if (input_count >= 4 ) {
112+ // axes provided as input
113+ axes_tensor = context.Input (3 );
114+ }
115+
116+ if (input_count == 5 ) {
117+ // steps provided as input
118+ steps_tensor = context.Input (4 );
119+ }
120+
121+ // Inject defaults if axes or steps not provided
122+ std::vector<int64_t > axes_default;
123+ if (axes_tensor == nullptr ) {
124+ // if axes not provided, set to [0, ..., len(starts)-1]
125+ for (size_t i = 0 ; i < starts_raw.size (); i++) {
126+ axes_default.push_back (i);
127+ }
128+ }
129+ auto axes_raw = attr_axes_.empty () ? (axes_tensor == nullptr ? gsl::make_span (axes_default) : axes_tensor->DataAsSpan <int64_t >()) : gsl::make_span (attr_axes_);
130+
131+ std::vector<int64_t > steps_default;
132+ if (steps_tensor == nullptr ) {
133+ // if steps not provided, set to [1, ..., 1] of len(starts)
134+ for (size_t i = 0 ; i < starts_raw.size (); i++) {
135+ steps_default.push_back (1 );
136+ }
137+ }
138+ auto steps_raw = steps_tensor == nullptr ? gsl::make_span (steps_default) : steps_tensor->DataAsSpan <int64_t >();
139+
140+ // PROCESS INPUTS
141+ std::vector<uint32_t > axes;
142+ for (unsigned int i = 0 ; i < axes_raw.size (); i++) {
143+ int64_t val = axes_raw[i];
144+ if (val < 0 ) {
145+ val += input_rank;
146+ }
147+ axes.push_back (static_cast <int32_t >(val));
148+ }
149+
150+ std::vector<uint32_t > starts;
151+ for (unsigned int i = 0 ; i < starts_raw.size (); i++) {
152+ int64_t val = starts_raw[i];
153+ if (val < 0 ) {
154+ val += input_shape[axes[i]];
155+ }
156+
157+ if (steps_raw[i] < 0 ) {
158+ val = std::max (static_cast <int64_t >(0 ), std::min (val, static_cast <int64_t >(input_shape[axes[i]] - 1 )));
159+ } else {
160+ val = std::max (static_cast <int64_t >(0 ), std::min (val, static_cast <int64_t >(input_shape[axes[i]])));
161+ }
162+ starts.push_back (static_cast <uint32_t >(val));
163+ }
164+
165+ std::vector<uint32_t > ends;
166+ for (unsigned int i = 0 ; i < ends_raw.size (); i++) {
167+ int64_t val = ends_raw[i];
168+ if (val < 0 ) {
169+ val += input_shape[axes[i]];
170+ }
171+ if (steps_raw[i] < 0 ) {
172+ val = std::max (static_cast <int64_t >(0 ), std::min (val, static_cast <int64_t >(input_shape[axes[i]] - 1 )));
173+ } else {
174+ val = std::max (static_cast <int64_t >(0 ), std::min (val, static_cast <int64_t >(input_shape[axes[i]])));
175+ }
176+ ends.push_back (static_cast <uint32_t >(val));
177+ }
178+
179+ // temporary steps vector to handle negative steps
180+ std::vector<int32_t > steps_tmp;
181+ for (unsigned int i = 0 ; i < steps_raw.size (); i++) {
182+ if (steps_raw[i] >= std::numeric_limits<int32_t >::max ()) {
183+ steps_tmp.push_back (std::numeric_limits<int32_t >::max ());
184+ } else {
185+ steps_tmp.push_back (static_cast <int32_t >(steps_raw[i]));
186+ }
187+ }
188+
189+ // Insert missing dimensions
190+ if (static_cast <int64_t >(axes.size ()) != input_rank) {
191+ for (uint32_t i = 0 ; i < input_rank; i++) {
192+ int idx = -1 ;
193+ for (unsigned int j = 0 ; j < axes_raw.size (); j++) {
194+ if (axes_raw[j] == i) {
195+ idx = j;
196+ break ;
197+ }
198+ }
199+ if (idx == -1 ) {
200+ axes.insert (axes.begin () + i, i);
201+ starts.insert (starts.begin () + i, 0 );
202+ ends.insert (ends.begin () + i, static_cast <uint32_t >(input_shape[i]));
203+ steps_tmp.insert (steps_tmp.begin () + i, 1 );
204+ }
205+ }
206+ }
207+
208+ // retain the sign of the steps
209+ std::vector<int32_t > signs;
210+ for (unsigned int i = 0 ; i < steps_tmp.size (); i++) {
211+ signs.push_back (steps_tmp[i] < 0 ? -1 : (steps_tmp[i] > 0 ? 1 : 0 ));
212+ }
213+
214+ // Convert negative steps to positive steps and reverse starts and ends
215+ for (unsigned int i = 0 ; i < steps_tmp.size (); i++) {
216+ if (steps_tmp[i] < 0 ) {
217+ float numSteps = static_cast <float >((static_cast <float >(ends[i]) - static_cast <float >(starts[i])) / static_cast <float >(steps_tmp[i]));
218+ float newEnd = static_cast <float >(starts[i]);
219+ float newStart = newEnd + numSteps * static_cast <float >(steps_tmp[i]);
220+
221+ starts[i] = static_cast <uint32_t >(newStart);
222+ ends[i] = static_cast <uint32_t >(newEnd);
223+ steps_tmp[i] = static_cast <int32_t >(-steps_tmp[i]);
224+ }
225+ }
226+
227+ // final steps vector of type unsigned int
228+ std::vector<uint32_t > steps;
229+ for (unsigned int i = 0 ; i < steps_tmp.size (); i++) {
230+ steps.push_back (static_cast <uint32_t >(steps_tmp[i]));
231+ }
232+
233+ // Reorder inputs in order of axis
234+ std::vector<int32_t > signs_reordered;
235+ std::vector<uint32_t > steps_reordered, starts_reordered;
236+ for (unsigned int i = 0 ; i < axes.size (); i++) {
237+ signs_reordered.push_back (0 );
238+ steps_reordered.push_back (0 );
239+ starts_reordered.push_back (0 );
240+ }
241+ for (unsigned int i = 0 ; i < axes.size (); i++) {
242+ int32_t dim = axes[i];
243+ signs_reordered[dim] = signs[i];
244+ steps_reordered[dim] = steps[i];
245+ starts_reordered[dim] = starts[i];
246+ }
247+
248+ // calculate output dims
249+ std::vector<int64_t > output_dims;
250+ for (unsigned int i = 0 ; i < axes.size (); i++) {
251+ int32_t dim = axes[i];
252+ float tmp = ceil ((static_cast <float >(ends[dim]) - static_cast <float >(starts[dim])) / static_cast <float >(steps[dim]));
253+ if (tmp < 0 )
254+ output_dims.push_back (0 );
255+ else
256+ output_dims.push_back (static_cast <int64_t >(tmp));
257+ }
258+
259+ TensorShape output_shape (output_dims);
260+
261+ auto * output_tensor = context.Output (0 , output_shape);
262+ uint32_t output_size = static_cast <uint32_t >(output_shape.Size ());
263+
264+ if (output_size == 0 ) {
265+ return Status::OK ();
266+ }
267+
268+ SliceProgram program{};
269+ program
270+ .AddInputs ({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
271+ .AddOutputs ({output_tensor})
272+ .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
273+ .AddUniformVariables ({{output_size}, {starts_reordered}, {steps_reordered}, {signs_reordered}});
274+ return context.RunProgram (program);
275+ }
276+
277+ } // namespace webgpu
278+ } // namespace onnxruntime
0 commit comments