Skip to content

Commit f58f7eb

Browse files
authored
[webgpu] Expand Unsqueeze version to 23 (microsoft#25858)
### Description The phi4 mini in Edge is using ai.onnx v21. Without this change, it results a `MemcpyToHost` inserted and slows the generation speed.
1 parent 0de1c01 commit f58f7eb

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,31 @@ namespace webgpu {
1111
ONNX_OPERATOR_KERNEL_EX(
1212
Unsqueeze,
1313
kOnnxDomain,
14-
13,
14+
23,
15+
kWebGpuExecutionProvider,
16+
(*KernelDefBuilder::Create())
17+
.TypeConstraint("T", WebGpuSupportedNumberTypes())
18+
.TypeConstraint("axes", DataTypeImpl::GetTensorType<int64_t>())
19+
.Alias(0, 0)
20+
.InputMemoryType(OrtMemTypeCPU, 1),
21+
Unsqueeze);
22+
23+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
24+
Unsqueeze,
25+
kOnnxDomain,
26+
21, 22,
27+
kWebGpuExecutionProvider,
28+
(*KernelDefBuilder::Create())
29+
.TypeConstraint("T", WebGpuSupportedNumberTypes())
30+
.TypeConstraint("axes", DataTypeImpl::GetTensorType<int64_t>())
31+
.Alias(0, 0)
32+
.InputMemoryType(OrtMemTypeCPU, 1),
33+
Unsqueeze);
34+
35+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
36+
Unsqueeze,
37+
kOnnxDomain,
38+
13, 20,
1539
kWebGpuExecutionProvider,
1640
(*KernelDefBuilder::Create())
1741
.TypeConstraint("T", WebGpuSupportedNumberTypes())

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13,
249249

250250
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze);
251251
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze);
252-
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Unsqueeze);
252+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Unsqueeze);
253+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Unsqueeze);
254+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Unsqueeze);
253255

254256
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 15, Where);
255257
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, Where);
@@ -548,7 +550,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
548550

549551
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze)>,
550552
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze)>,
551-
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Unsqueeze)>,
553+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Unsqueeze)>,
554+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Unsqueeze)>,
555+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Unsqueeze)>,
552556

553557
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMin)>,
554558
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMin)>,

0 commit comments

Comments
 (0)