Skip to content

Commit b67983c

Browse files
authored
[WebNN] Support RotaryEmbedding op (microsoft#23283)
WebNN doesn't provide a dedicated op for RotaryEmbedding. Instead, we implement it by using a combination of WebNN ops. The decomposed graph is referenced from DML EP at: onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp
1 parent c07afd3 commit b67983c

File tree

5 files changed

+323
-0
lines changed

5 files changed

+323
-0
lines changed

js/web/docs/webnn-operators.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
8686
| Relu | ai.onnx(7-12, 13, 14+) | relu ||| |
8787
| Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape ||| Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported |
8888
| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d ||| Only supports 4-D input, antialias == 0, exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant |
89+
| RotaryEmbedding | com.microsoft(1+) | add, concat, gather, mul, reshape, split ||| |
8990
| ScatterElements | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterElements ||| Only supports 'reduction' == 'none' |
9091
| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND ||| Only supports 'reduction' == 'none' |
9192
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice ||| |

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
194194
const WebnnDeviceType device_type,
195195
const emscripten::val& wnn_limits,
196196
const logging::Logger& logger);
197+
// TODO(@Honry): Some ONNX ops are supported by decomposed WebNN ops,
198+
// we need to check the support of the decomposed ops.
197199
static const InlinedHashMap<std::string, std::string> op_map = {
198200
{"Abs", "abs"},
199201
{"Add", "add"},
@@ -273,6 +275,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
273275
{"Relu", "relu"},
274276
{"Reshape", "reshape"},
275277
{"Resize", "resample2d"},
278+
{"RotaryEmbedding", "gather"},
276279
{"ScatterElements", "scatterElements"},
277280
{"ScatterND", "scatterND"},
278281
{"Shape", "slice"},
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Copyright (c) Intel Corporation. All rights reserved.
3+
// Licensed under the MIT License.
4+
5+
#include "core/providers/common.h"
6+
#include "core/providers/shared/utils/utils.h"
7+
#include "core/providers/webnn/builders/helper.h"
8+
#include "core/providers/webnn/builders/model_builder.h"
9+
#include "core/providers/webnn/builders/op_builder_factory.h"
10+
11+
#include "base_op_builder.h"
12+
13+
// WebNN doesn't provide a dedicated op for RotaryEmbedding. Instead, we implement it by using a
14+
// combination of WebNN ops. The decomposed graph is referenced from DML EP at:
15+
// onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp
16+
/*
17+
Input CosCache PositionIds SinCache
18+
| | | |
19+
| | +--------+-----------+ |
20+
Split | | | |
21+
| | Gather Gather
22+
+-------+ | | |
23+
| | | |
24+
| Identity----------+ | |
25+
| | | | |
26+
| | | | |
27+
| --Split-- | | |
28+
| \ / | +-----------------+ |
29+
| \ / | | |
30+
| \ / Mul |
31+
| \ / | |
32+
| X | |
33+
| / \ | |
34+
| / \ | |
35+
| Join | |
36+
| | | |
37+
| | +---------------------------------------------------------+
38+
| | | |
39+
| Mul |
40+
| | |
41+
| +-----+ +------+
42+
| | |
43+
| Add
44+
| |
45+
+-------------+ |
46+
| |
47+
Join
48+
*/
49+
namespace onnxruntime {
50+
namespace webnn {
51+
52+
class RotaryEmbeddingOpBuilder : public BaseOpBuilder {
53+
// Add operator related.
54+
private:
55+
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
56+
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
57+
58+
// Operator support related.
59+
private:
60+
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
61+
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
62+
};
63+
64+
Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
65+
const logging::Logger& logger) const {
66+
const auto& input_defs = node.InputDefs();
67+
int32_t input_data_type;
68+
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_data_type, logger), "Cannot get input type");
69+
std::vector<int64_t> input_shape;
70+
std::vector<int64_t> position_ids_shape;
71+
std::vector<int64_t> cos_cache_shape;
72+
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
73+
ORT_RETURN_IF_NOT(GetShape(*input_defs[1], position_ids_shape, logger), "Cannot get position_ids shape");
74+
ORT_RETURN_IF_NOT(GetShape(*input_defs[2], cos_cache_shape, logger), "Cannot get cos_cache shape");
75+
const bool input_is_4d = input_shape.size() == 4;
76+
// When position_ids is a 1D tensor, it represents the start offset for each sequence.
77+
const bool position_ids_is_offset = position_ids_shape.size() == 1;
78+
79+
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
80+
emscripten::val position_ids = model_builder.GetOperand(input_defs[1]->Name());
81+
emscripten::val cos_cache = model_builder.GetOperand(input_defs[2]->Name());
82+
emscripten::val sin_cache = model_builder.GetOperand(input_defs[3]->Name());
83+
84+
const auto node_name = node.Name();
85+
emscripten::val wnn_builder = model_builder.GetBuilder();
86+
87+
NodeAttrHelper helper(node);
88+
const bool interleaved = gsl::narrow_cast<bool>(helper.Get("interleaved", 0));
89+
uint32_t num_heads = helper.Get("num_heads", 0);
90+
uint32_t rotary_embedding_dim = helper.Get("rotary_embedding_dim", 0);
91+
92+
// The input is either with 3D tensor shape (batch_size, sequence_length, hidden_size) or
93+
// 4D tensor shape (batch_size, num_heads, sequence_length, head_size)
94+
const uint32_t batch_size = static_cast<uint32_t>(input_shape[0]);
95+
const uint32_t sequence_length = input_is_4d ? static_cast<uint32_t>(input_shape[2])
96+
: static_cast<uint32_t>(input_shape[1]);
97+
const uint32_t hidden_size = input_is_4d ? static_cast<uint32_t>(input_shape[1] * input_shape[3])
98+
: static_cast<uint32_t>(input_shape[2]);
99+
const uint32_t head_size = num_heads == 0 ? static_cast<uint32_t>(cos_cache_shape[1]) * 2
100+
: hidden_size / num_heads;
101+
if (num_heads == 0) {
102+
num_heads = hidden_size / head_size;
103+
}
104+
if (rotary_embedding_dim == 0) {
105+
rotary_embedding_dim = head_size;
106+
}
107+
108+
// First ensure the input has shape (batch_size, num_heads, sequence_length, head_size).
109+
if (!input_is_4d) {
110+
const std::vector<uint32_t> new_shape{batch_size, num_heads, sequence_length, head_size};
111+
emscripten::val reshape_input_options = emscripten::val::object();
112+
reshape_input_options.set("label", node_name + "_reshape_input");
113+
input = wnn_builder.call<emscripten::val>(
114+
"reshape", input, emscripten::val::array(new_shape), reshape_input_options);
115+
}
116+
117+
// Split the input to perform the rotary embedding only on a subregion of the tensor if needed.
118+
// The split inputs will be joined back together at the end.
119+
emscripten::val partial_input0 = input;
120+
emscripten::val partial_input1 = emscripten::val::undefined();
121+
if (head_size != rotary_embedding_dim) {
122+
const std::vector<uint32_t> splits{rotary_embedding_dim, head_size - rotary_embedding_dim};
123+
emscripten::val split_input_options = emscripten::val::object();
124+
split_input_options.set("label", node_name + "_split_input");
125+
split_input_options.set("axis", 3);
126+
emscripten::val split = wnn_builder.call<emscripten::val>(
127+
"split", input, emscripten::val::array(splits), split_input_options);
128+
partial_input0 = split[0];
129+
partial_input1 = split[1];
130+
}
131+
132+
// Split the partial input0 data into 2 equal parts.
133+
// Firstly reshape the partial input0.
134+
const std::vector<uint32_t> new_partial_input0_shape =
135+
interleaved ? std::vector<uint32_t>({batch_size, sequence_length, num_heads, rotary_embedding_dim / 2, 2})
136+
: std::vector<uint32_t>({batch_size, sequence_length, num_heads, 2, rotary_embedding_dim / 2});
137+
emscripten::val reshape_partial_input0_options = emscripten::val::object();
138+
reshape_partial_input0_options.set("label", node_name + "_reshape_partial_input0");
139+
partial_input0 = wnn_builder.call<emscripten::val>(
140+
"reshape", partial_input0, emscripten::val::array(new_partial_input0_shape), reshape_partial_input0_options);
141+
// Split partial input0.
142+
const int split_axis = interleaved ? 4 : 3;
143+
emscripten::val split_partial_input0_options = emscripten::val::object();
144+
split_partial_input0_options.set("label", node_name + "_split_partial_input0");
145+
split_partial_input0_options.set("axis", split_axis);
146+
emscripten::val split_partial_input0 = wnn_builder.call<emscripten::val>(
147+
"split", partial_input0, 2, split_partial_input0_options);
148+
149+
// Swap the two halves and join them together.
150+
emscripten::val concat_partial_input0_options = emscripten::val::object();
151+
concat_partial_input0_options.set("label", node_name + "_concat_partial_input0");
152+
emscripten::val concated_partial_input0 = wnn_builder.call<emscripten::val>(
153+
"concat", split_partial_input0.call<emscripten::val>("reverse"), split_axis, concat_partial_input0_options);
154+
155+
if (position_ids_is_offset) {
156+
// We generate a sequence from 0 to sequence_length and add the offset to it.
157+
const std::vector<uint32_t> position_ids_range_shape = {1, sequence_length};
158+
emscripten::val position_ids_range_buffer = emscripten::val::global("BigInt64Array").new_(sequence_length);
159+
for (uint32_t i = 0; i < sequence_length; i++) {
160+
position_ids_range_buffer.set(i, emscripten::val::global("BigInt")(i));
161+
}
162+
emscripten::val position_ids_range_desc = emscripten::val::object();
163+
position_ids_range_desc.set("shape", emscripten::val::array(position_ids_range_shape));
164+
position_ids_range_desc.set("dimensions", emscripten::val::array(position_ids_range_shape));
165+
position_ids_range_desc.set("dataType", emscripten::val("int64"));
166+
emscripten::val position_ids_range = wnn_builder.call<emscripten::val>(
167+
"constant", position_ids_range_desc, position_ids_range_buffer);
168+
// Add the offset to the sequence.
169+
emscripten::val position_ids_add_range_options = emscripten::val::object();
170+
position_ids_add_range_options.set("label", node_name + "_position_ids_add_range");
171+
position_ids = wnn_builder.call<emscripten::val>(
172+
"add", position_ids, position_ids_range, position_ids_add_range_options);
173+
}
174+
175+
// Gather the cosine/sine values based on the position_ids.
176+
emscripten::val gather_cos_sin_options = emscripten::val::object();
177+
gather_cos_sin_options.set("label", node_name + "_gather_cos_sin");
178+
gather_cos_sin_options.set("axis", 0);
179+
emscripten::val gather_cos = wnn_builder.call<emscripten::val>(
180+
"gather", cos_cache, position_ids, gather_cos_sin_options);
181+
emscripten::val gather_sin = wnn_builder.call<emscripten::val>(
182+
"gather", sin_cache, position_ids, gather_cos_sin_options);
183+
184+
// After gathering cosine/sine, reshape and broadcast them to match the number of heads of the input data.
185+
const std::vector<uint32_t> reshaped_cos_sin_shape =
186+
interleaved ? std::vector<uint32_t>({batch_size, sequence_length, 1, rotary_embedding_dim / 2, 1})
187+
: std::vector<uint32_t>({batch_size, sequence_length, 1, 1, rotary_embedding_dim / 2});
188+
emscripten::val reshape_gather_cos_sin_options = emscripten::val::object();
189+
reshape_gather_cos_sin_options.set("label", node_name + "_reshape_gather_cos_sin");
190+
gather_cos = wnn_builder.call<emscripten::val>(
191+
"reshape", gather_cos, emscripten::val::array(reshaped_cos_sin_shape), reshape_gather_cos_sin_options);
192+
gather_sin = wnn_builder.call<emscripten::val>(
193+
"reshape", gather_sin, emscripten::val::array(reshaped_cos_sin_shape), reshape_gather_cos_sin_options);
194+
195+
// Multiply the non-rotated data with the cosine and the rotated data with the sine.
196+
emscripten::val mul_cos_options = emscripten::val::object();
197+
mul_cos_options.set("label", node_name + "_mul_cos");
198+
emscripten::val mul_cos = wnn_builder.call<emscripten::val>(
199+
"mul", partial_input0, gather_cos, mul_cos_options);
200+
emscripten::val mul_sin_options = emscripten::val::object();
201+
mul_sin_options.set("label", node_name + "_mul_sin");
202+
emscripten::val mul_sin = wnn_builder.call<emscripten::val>(
203+
"mul", concated_partial_input0, gather_sin, mul_sin_options);
204+
205+
// Create a vector that contains the sign values {-1, 1}.
206+
emscripten::val sign_buffer = emscripten::val::undefined();
207+
const std::vector<uint32_t> sign_shape = interleaved ? std::vector<uint32_t>({1, 1, 1, 2})
208+
: std::vector<uint32_t>({1, 1, 2, 1});
209+
emscripten::val sign_constant_desc = emscripten::val::object();
210+
sign_constant_desc.set("shape", emscripten::val::array(sign_shape));
211+
sign_constant_desc.set("dimensions", emscripten::val::array(sign_shape));
212+
ORT_RETURN_IF_NOT(SetWebnnDataType(sign_constant_desc, input_data_type), "Unsupported data type");
213+
if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
214+
sign_buffer = emscripten::val::global("Float32Array").new_(2);
215+
sign_buffer.set(0, -1.0f);
216+
sign_buffer.set(1, 1.0f);
217+
} else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
218+
sign_buffer = emscripten::val::global("Uint16Array").new_(2);
219+
sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f));
220+
sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f));
221+
} else {
222+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported input data type: ", input_data_type);
223+
}
224+
emscripten::val sign_constant = wnn_builder.call<emscripten::val>("constant", sign_constant_desc, sign_buffer);
225+
226+
// Multiply the broadcasted sign values with the rotated input.
227+
emscripten::val mul_sign_options = emscripten::val::object();
228+
mul_sign_options.set("label", node_name + "_mul_sign");
229+
mul_sin = wnn_builder.call<emscripten::val>("mul", mul_sin, sign_constant, mul_sign_options);
230+
231+
// Reshape mul_cos and mul_sin to (batch_size, sequence_length, num_heads, rotary_embedding_dim).
232+
const std::vector<uint32_t> reshaped_mul_cos_sin_shape =
233+
{batch_size, sequence_length, num_heads, rotary_embedding_dim};
234+
emscripten::val reshape_mul_cos_sin_options = emscripten::val::object();
235+
reshape_mul_cos_sin_options.set("label", node_name + "_reshape_mul_cos_sign");
236+
mul_cos = wnn_builder.call<emscripten::val>(
237+
"reshape", mul_cos, emscripten::val::array(reshaped_mul_cos_sin_shape), reshape_mul_cos_sin_options);
238+
mul_sin = wnn_builder.call<emscripten::val>(
239+
"reshape", mul_sin, emscripten::val::array(reshaped_mul_cos_sin_shape), reshape_mul_cos_sin_options);
240+
241+
// Add the multiplied cos and sin values together.
242+
emscripten::val add_mul_cos_sin_options = emscripten::val::object();
243+
add_mul_cos_sin_options.set("label", node_name + "_add_mul_cos_sin");
244+
emscripten::val output = wnn_builder.call<emscripten::val>(
245+
"add", mul_cos, mul_sin, add_mul_cos_sin_options);
246+
247+
// Join the added values with the rest of the input.
248+
if (head_size != rotary_embedding_dim) {
249+
emscripten::val concat_back_input_options = emscripten::val::object();
250+
concat_back_input_options.set("label", node_name + "_concat_back_input");
251+
emscripten::val concat_inputs = emscripten::val::array();
252+
concat_inputs.call<void>("push", output);
253+
concat_inputs.call<void>("push", partial_input1);
254+
output = wnn_builder.call<emscripten::val>("concat", concat_inputs, 3, concat_back_input_options);
255+
}
256+
257+
// Reshape the output to the original shape. The output shape is the same as the input shape.
258+
const std::vector<uint32_t> output_shape = GetVecUint32FromVecInt64(input_shape);
259+
emscripten::val reshape_output_options = emscripten::val::object();
260+
reshape_output_options.set("label", node_name + "_reshape_output");
261+
output = wnn_builder.call<emscripten::val>(
262+
"reshape", output, emscripten::val::array(output_shape), reshape_output_options);
263+
264+
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
265+
return Status::OK();
266+
}
267+
268+
// Operator support related.
269+
bool RotaryEmbeddingOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
270+
const WebnnDeviceType /* device_type */,
271+
const logging::Logger& logger) const {
272+
const auto& input_defs = node.InputDefs();
273+
std::vector<int64_t> input_shape;
274+
std::vector<int64_t> cos_cache_shape;
275+
if (!GetShape(*input_defs[0], input_shape, logger)) return false;
276+
if (!GetShape(*input_defs[2], cos_cache_shape, logger)) return false;
277+
const auto input_size = input_shape.size();
278+
if (input_size != 3 && input_size != 4) {
279+
LOGS(logger, VERBOSE) << "RotaryEmbedding only supports 3D or 4D input shape, input is " << input_size << "D shape";
280+
return false;
281+
}
282+
283+
NodeAttrHelper helper(node);
284+
const int is_packed_batching = helper.Get("is_packed_batching", 0);
285+
const int num_heads = helper.Get("num_heads", 0);
286+
const int rotary_embedding_dim = helper.Get("rotary_embedding_dim", 0);
287+
288+
const auto sequence_length = input_size == 4 ? input_shape[2] : input_shape[1];
289+
if (is_packed_batching == 0 && sequence_length > cos_cache_shape[0]) {
290+
LOGS(logger, VERBOSE) << "RotaryEmbedding: updating cos_cache and sin_cache is not currently supported";
291+
return false;
292+
}
293+
294+
if (input_size == 4 && num_heads != 0 && num_heads != input_shape[1]) {
295+
LOGS(logger, VERBOSE) << "RotaryEmbedding: when input has 4 dimensions, num_heads must be 0 or have the same value "
296+
"as the second dimension of the input";
297+
return false;
298+
}
299+
300+
if (rotary_embedding_dim > 0 && num_heads == 0) {
301+
LOGS(logger, VERBOSE) << "RotaryEmbedding: num_heads must be provided if rotary_embedding_dim is specified";
302+
return false;
303+
}
304+
305+
return true;
306+
}
307+
308+
void CreateRotaryEmbeddingOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
309+
op_registrations.builders.push_back(std::make_unique<RotaryEmbeddingOpBuilder>());
310+
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
311+
}
312+
313+
} // namespace webnn
314+
} // namespace onnxruntime

onnxruntime/core/providers/webnn/builders/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
196196
CreateResizeOpBuilder("Resize", op_registrations);
197197
}
198198

199+
{ // RotaryEmbedding
200+
CreateRotaryEmbeddingOpBuilder("RotaryEmbedding", op_registrations);
201+
}
202+
199203
{ // ScatterElements
200204
CreateScatterElementsOpBuilder("ScatterElements", op_registrations);
201205
}

0 commit comments

Comments
 (0)