|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include <vector> |
| 5 | + |
| 6 | +#include "core/providers/webgpu/tensor/gather_nd.h" |
| 7 | +#include "core/providers/webgpu/shader_helper.h" |
| 8 | +#include "core/providers/webgpu/webgpu_supported_types.h" |
| 9 | + |
| 10 | +namespace onnxruntime { |
| 11 | +namespace webgpu { |
| 12 | + |
| 13 | +Status GatherNDProgram::GenerateShaderCode(ShaderHelper& shader) const { |
| 14 | + const auto& data = shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); |
| 15 | + const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); |
| 16 | + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); |
| 17 | + |
| 18 | + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") |
| 19 | + << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" |
| 20 | + << " var data_indices: data_indices_t;\n" |
| 21 | + << " var indices_indices: input_indices_indices_t;\n"; |
| 22 | + |
| 23 | + uint32_t data_dim = 0; |
| 24 | + for (uint32_t i = data_dim; i < batch_dims_; i++) { |
| 25 | + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", i)) << "\n" |
| 26 | + << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", i)) << "\n"; |
| 27 | + } |
| 28 | + data_dim += batch_dims_; |
| 29 | + |
| 30 | + for (uint32_t i = data_dim; i < static_cast<uint32_t>(indices.Rank() - 1); i++) { |
| 31 | + shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", i)) << "\n"; |
| 32 | + } |
| 33 | + |
| 34 | + shader.MainFunctionBody() << " var indice_value = i32(0);\n"; |
| 35 | + for (uint32_t i = 0; i < indices_innerest_dim_; i++) { |
| 36 | + shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", indices.Rank() - 1, std::to_string(i)) << "\n" |
| 37 | + << " indice_value = " << indices.GetByIndices("indices_indices") << ";\n" |
| 38 | + << " if (indice_value < 0) {\n" |
| 39 | + << " indice_value += i32(" << data.IndicesGet("uniforms.data_shape", data_dim + i) << ");\n" |
| 40 | + << " }\n" |
| 41 | + << " " << data.IndicesSet("data_indices", data_dim + i, "u32(indice_value)") << "\n"; |
| 42 | + } |
| 43 | + data_dim += indices_innerest_dim_; |
| 44 | + |
| 45 | + for (uint32_t i = 0; i < static_cast<uint32_t>(data.Rank() - data_dim); i++) { |
| 46 | + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", data_dim, output.IndicesGet("output_indices", indices.Rank() - 1 + i)) << "\n"; |
| 47 | + } |
| 48 | + |
| 49 | + shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", data.GetByIndices("data_indices")); |
| 50 | + |
| 51 | + return Status::OK(); |
| 52 | +} |
| 53 | + |
| 54 | +Status CheckBatchDimensionsMatch(size_t num_batch_dimensions, const TensorShape& input_shape, |
| 55 | + const TensorShape& indices_shape) { |
| 56 | + ORT_RETURN_IF_NOT( |
| 57 | + num_batch_dimensions <= input_shape.NumDimensions() && num_batch_dimensions <= indices_shape.NumDimensions(), |
| 58 | + "Number of batch dimensions exceeds tensor rank. ", "Batch dimension count: ", num_batch_dimensions, |
| 59 | + ", input tensor rank: ", input_shape.NumDimensions(), ", indices tensor rank: ", indices_shape.NumDimensions()); |
| 60 | + |
| 61 | + for (size_t batch_dimension_idx = 0; batch_dimension_idx < num_batch_dimensions; ++batch_dimension_idx) { |
| 62 | + ORT_RETURN_IF_NOT( |
| 63 | + input_shape[batch_dimension_idx] == indices_shape[batch_dimension_idx], |
| 64 | + "Batch dimensions differ at index ", batch_dimension_idx, ": ", |
| 65 | + input_shape[batch_dimension_idx], " != ", indices_shape[batch_dimension_idx]); |
| 66 | + } |
| 67 | + |
| 68 | + return Status::OK(); |
| 69 | +} |
| 70 | + |
| 71 | +Status GatherND::ComputeInternal(ComputeContext& context) const { |
| 72 | + const auto* input_tensor = context.Input(0); |
| 73 | + const TensorShape& input_shape = input_tensor->Shape(); |
| 74 | + const auto* indices_tensor = context.Input(1); |
| 75 | + const TensorShape& indices_shape = indices_tensor->Shape(); |
| 76 | + |
| 77 | + if (indices_shape.NumDimensions() == 0) { |
| 78 | + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, |
| 79 | + "indices tensor must has rank larger than 0"); |
| 80 | + } |
| 81 | + |
| 82 | + auto indices_innerest_dim = indices_shape[indices_shape.NumDimensions() - 1]; |
| 83 | + auto last_indices_dimension = batch_dims_ + indices_innerest_dim; |
| 84 | + if (last_indices_dimension > static_cast<int64_t>(input_shape.NumDimensions())) { |
| 85 | + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, |
| 86 | + "last dimension of indices must not be larger than rank of input tensor"); |
| 87 | + } |
| 88 | + |
| 89 | + ORT_RETURN_IF_ERROR(CheckBatchDimensionsMatch(static_cast<size_t>(batch_dims_), |
| 90 | + input_shape, indices_shape)); |
| 91 | + |
| 92 | + // Output shape |
| 93 | + std::vector<int64_t> shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1); |
| 94 | + shape.insert(shape.end(), input_shape.GetDims().begin() + static_cast<size_t>(last_indices_dimension), input_shape.GetDims().end()); |
| 95 | + auto output_tensor = context.Output(0, TensorShape(shape)); |
| 96 | + uint32_t data_size = onnxruntime::narrow<uint32_t>(output_tensor->Shape().Size()); |
| 97 | + if (data_size == 0) { |
| 98 | + return Status::OK(); |
| 99 | + } |
| 100 | + |
| 101 | + GatherNDProgram program{static_cast<uint32_t>(batch_dims_), static_cast<uint32_t>(indices_innerest_dim)}; |
| 102 | + program |
| 103 | + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}, |
| 104 | + {indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) |
| 105 | + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank}) |
| 106 | + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) |
| 107 | + .CacheHint(std::to_string(batch_dims_), std::to_string(indices_innerest_dim)) |
| 108 | + .AddUniformVariables({{data_size}}); |
| 109 | + return context.RunProgram(program); |
| 110 | +} |
| 111 | + |
| 112 | +#define WEBGPU_GATHERND_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ |
| 113 | + ONNX_OPERATOR_KERNEL_EX( \ |
| 114 | + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ |
| 115 | + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("indices", DataTypeImpl::GetTensorType<int64_t>()), \ |
| 116 | + KERNEL_CLASS); |
| 117 | + |
| 118 | +#define WEBGPU_GATHERND_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ |
| 119 | + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ |
| 120 | + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ |
| 121 | + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("indices", DataTypeImpl::GetTensorType<int64_t>()), \ |
| 122 | + KERNEL_CLASS); |
| 123 | + |
| 124 | +WEBGPU_GATHERND_VERSIONED_KERNEL(GatherND, 11, 11, GatherND, WebGpuSupportedNumberTypes()) |
| 125 | +WEBGPU_GATHERND_VERSIONED_KERNEL(GatherND, 12, 12, GatherND, WebGpuSupportedNumberTypes()) |
| 126 | +WEBGPU_GATHERND_KERNEL(GatherND, 13, GatherND, WebGpuSupportedNumberTypes()) |
| 127 | + |
| 128 | +} // namespace webgpu |
| 129 | +} // namespace onnxruntime |
0 commit comments