Skip to content

Commit 08e18b2

Browse files
authored
[webgpu] support GatherND operator (microsoft#25632)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 5d77b73 commit 08e18b2

File tree

3 files changed

+184
-0
lines changed

3 files changed

+184
-0
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <vector>
5+
6+
#include "core/providers/webgpu/tensor/gather_nd.h"
7+
#include "core/providers/webgpu/shader_helper.h"
8+
#include "core/providers/webgpu/webgpu_supported_types.h"
9+
10+
namespace onnxruntime {
11+
namespace webgpu {
12+
13+
Status GatherNDProgram::GenerateShaderCode(ShaderHelper& shader) const {
14+
const auto& data = shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
15+
const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
16+
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
17+
18+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size")
19+
<< " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"
20+
<< " var data_indices: data_indices_t;\n"
21+
<< " var indices_indices: input_indices_indices_t;\n";
22+
23+
uint32_t data_dim = 0;
24+
for (uint32_t i = data_dim; i < batch_dims_; i++) {
25+
shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", i)) << "\n"
26+
<< " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", i)) << "\n";
27+
}
28+
data_dim += batch_dims_;
29+
30+
for (uint32_t i = data_dim; i < static_cast<uint32_t>(indices.Rank() - 1); i++) {
31+
shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", i)) << "\n";
32+
}
33+
34+
shader.MainFunctionBody() << " var indice_value = i32(0);\n";
35+
for (uint32_t i = 0; i < indices_innerest_dim_; i++) {
36+
shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", indices.Rank() - 1, std::to_string(i)) << "\n"
37+
<< " indice_value = " << indices.GetByIndices("indices_indices") << ";\n"
38+
<< " if (indice_value < 0) {\n"
39+
<< " indice_value += i32(" << data.IndicesGet("uniforms.data_shape", data_dim + i) << ");\n"
40+
<< " }\n"
41+
<< " " << data.IndicesSet("data_indices", data_dim + i, "u32(indice_value)") << "\n";
42+
}
43+
data_dim += indices_innerest_dim_;
44+
45+
for (uint32_t i = 0; i < static_cast<uint32_t>(data.Rank() - data_dim); i++) {
46+
shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", data_dim, output.IndicesGet("output_indices", indices.Rank() - 1 + i)) << "\n";
47+
}
48+
49+
shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", data.GetByIndices("data_indices"));
50+
51+
return Status::OK();
52+
}
53+
54+
Status CheckBatchDimensionsMatch(size_t num_batch_dimensions, const TensorShape& input_shape,
55+
const TensorShape& indices_shape) {
56+
ORT_RETURN_IF_NOT(
57+
num_batch_dimensions <= input_shape.NumDimensions() && num_batch_dimensions <= indices_shape.NumDimensions(),
58+
"Number of batch dimensions exceeds tensor rank. ", "Batch dimension count: ", num_batch_dimensions,
59+
", input tensor rank: ", input_shape.NumDimensions(), ", indices tensor rank: ", indices_shape.NumDimensions());
60+
61+
for (size_t batch_dimension_idx = 0; batch_dimension_idx < num_batch_dimensions; ++batch_dimension_idx) {
62+
ORT_RETURN_IF_NOT(
63+
input_shape[batch_dimension_idx] == indices_shape[batch_dimension_idx],
64+
"Batch dimensions differ at index ", batch_dimension_idx, ": ",
65+
input_shape[batch_dimension_idx], " != ", indices_shape[batch_dimension_idx]);
66+
}
67+
68+
return Status::OK();
69+
}
70+
71+
Status GatherND::ComputeInternal(ComputeContext& context) const {
72+
const auto* input_tensor = context.Input(0);
73+
const TensorShape& input_shape = input_tensor->Shape();
74+
const auto* indices_tensor = context.Input(1);
75+
const TensorShape& indices_shape = indices_tensor->Shape();
76+
77+
if (indices_shape.NumDimensions() == 0) {
78+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
79+
"indices tensor must has rank larger than 0");
80+
}
81+
82+
auto indices_innerest_dim = indices_shape[indices_shape.NumDimensions() - 1];
83+
auto last_indices_dimension = batch_dims_ + indices_innerest_dim;
84+
if (last_indices_dimension > static_cast<int64_t>(input_shape.NumDimensions())) {
85+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
86+
"last dimension of indices must not be larger than rank of input tensor");
87+
}
88+
89+
ORT_RETURN_IF_ERROR(CheckBatchDimensionsMatch(static_cast<size_t>(batch_dims_),
90+
input_shape, indices_shape));
91+
92+
// Output shape
93+
std::vector<int64_t> shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1);
94+
shape.insert(shape.end(), input_shape.GetDims().begin() + static_cast<size_t>(last_indices_dimension), input_shape.GetDims().end());
95+
auto output_tensor = context.Output(0, TensorShape(shape));
96+
uint32_t data_size = onnxruntime::narrow<uint32_t>(output_tensor->Shape().Size());
97+
if (data_size == 0) {
98+
return Status::OK();
99+
}
100+
101+
GatherNDProgram program{static_cast<uint32_t>(batch_dims_), static_cast<uint32_t>(indices_innerest_dim)};
102+
program
103+
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank},
104+
{indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
105+
.AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank})
106+
.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
107+
.CacheHint(std::to_string(batch_dims_), std::to_string(indices_innerest_dim))
108+
.AddUniformVariables({{data_size}});
109+
return context.RunProgram(program);
110+
}
111+
112+
#define WEBGPU_GATHERND_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \
113+
ONNX_OPERATOR_KERNEL_EX( \
114+
OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \
115+
KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("indices", DataTypeImpl::GetTensorType<int64_t>()), \
116+
KERNEL_CLASS);
117+
118+
#define WEBGPU_GATHERND_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \
119+
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
120+
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \
121+
KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("indices", DataTypeImpl::GetTensorType<int64_t>()), \
122+
KERNEL_CLASS);
123+
124+
WEBGPU_GATHERND_VERSIONED_KERNEL(GatherND, 11, 11, GatherND, WebGpuSupportedNumberTypes())
125+
WEBGPU_GATHERND_VERSIONED_KERNEL(GatherND, 12, 12, GatherND, WebGpuSupportedNumberTypes())
126+
WEBGPU_GATHERND_KERNEL(GatherND, 13, GatherND, WebGpuSupportedNumberTypes())
127+
128+
} // namespace webgpu
129+
} // namespace onnxruntime
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/program.h"
7+
#include "core/providers/webgpu/webgpu_kernel.h"
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
12+
class GatherNDProgram final : public Program<GatherNDProgram> {
13+
public:
14+
GatherNDProgram(const uint32_t batch_dims, const uint32_t indices_innerest_dim) : Program{"GatherND"},
15+
batch_dims_{batch_dims},
16+
indices_innerest_dim_{indices_innerest_dim} {}
17+
18+
Status GenerateShaderCode(ShaderHelper& sh) const override;
19+
20+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32});
21+
22+
private:
23+
uint32_t batch_dims_;
24+
uint32_t indices_innerest_dim_;
25+
};
26+
27+
class GatherNDBase : public WebGpuKernel {
28+
public:
29+
explicit GatherNDBase(const OpKernelInfo& info) : WebGpuKernel(info) {
30+
info.GetAttrOrDefault("batch_dims", &batch_dims_, static_cast<int64_t>(0));
31+
ORT_ENFORCE(batch_dims_ >= 0);
32+
}
33+
34+
protected:
35+
int64_t batch_dims_;
36+
};
37+
38+
class GatherND final : public GatherNDBase {
39+
public:
40+
GatherND(const OpKernelInfo& info) : GatherNDBase(info) {}
41+
42+
protected:
43+
Status ComputeInternal(ComputeContext& context) const override;
44+
};
45+
46+
} // namespace webgpu
47+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13,
349349
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
350350
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements);
351351

352+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, GatherND);
353+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, GatherND);
354+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherND);
355+
352356
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 9, Slice);
353357
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Slice);
354358
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Slice);
@@ -676,6 +680,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
676680
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
677681
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements)>,
678682

683+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, GatherND)>,
684+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, GatherND)>,
685+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherND)>,
686+
679687
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
680688
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
681689
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Resize)>,

0 commit comments

Comments
 (0)