Skip to content

Commit 7b31a10

Browse files
authored
Add patch for WebGPU on Android to handle fp16 in uniforms (microsoft#25349)
### Motivation and Context Android devices (like S24) doesn't seem to allow fp16 in uniforms so the WebGPU EP has to manually handle passing an fp32 in the uniform and converting to fp16 before using.
1 parent d5d3b28 commit 7b31a10

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

cmake/external/onnxruntime_external_deps.cmake

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,11 @@ if (onnxruntime_USE_WEBGPU)
747747
#
748748
# - (private) Fulfill the BinSkim requirements
749749
# Some build warnings are not allowed to be disabled in project level.
750-
${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_binskim.patch)
750+
${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_binskim.patch &&
751+
752+
# Android devices doesn't seem to allow fp16 in uniforms so the WebGPU EP has to manually handle passing an fp32
753+
# in the uniform and converting to fp16 before using.
754+
${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/uniform_and_storage_buffer_16_bit_access.patch)
751755

752756
onnxruntime_fetchcontent_declare(
753757
dawn
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
diff --git a/src/dawn/native/vulkan/DeviceVk.cpp b/src/dawn/native/vulkan/DeviceVk.cpp
2+
index c01d64e40f..0f1f4beae4 100644
3+
--- a/src/dawn/native/vulkan/DeviceVk.cpp
4+
+++ b/src/dawn/native/vulkan/DeviceVk.cpp
5+
@@ -464,13 +464,15 @@ ResultOrError<VulkanDeviceKnobs> Device::CreateDevice(VkPhysicalDevice vkPhysica
6+
DAWN_ASSERT(usedKnobs.HasExt(DeviceExt::ShaderFloat16Int8) &&
7+
mDeviceInfo.shaderFloat16Int8Features.shaderFloat16 == VK_TRUE &&
8+
usedKnobs.HasExt(DeviceExt::_16BitStorage) &&
9+
- mDeviceInfo._16BitStorageFeatures.storageBuffer16BitAccess == VK_TRUE &&
10+
+ mDeviceInfo._16BitStorageFeatures.storageBuffer16BitAccess == VK_TRUE /*&&
11+
mDeviceInfo._16BitStorageFeatures.uniformAndStorageBuffer16BitAccess ==
12+
- VK_TRUE);
13+
+ VK_TRUE*/);
14+
15+
usedKnobs.shaderFloat16Int8Features.shaderFloat16 = VK_TRUE;
16+
usedKnobs._16BitStorageFeatures.storageBuffer16BitAccess = VK_TRUE;
17+
- usedKnobs._16BitStorageFeatures.uniformAndStorageBuffer16BitAccess = VK_TRUE;
18+
+ if (mDeviceInfo._16BitStorageFeatures.uniformAndStorageBuffer16BitAccess == VK_TRUE) {
19+
+ usedKnobs._16BitStorageFeatures.uniformAndStorageBuffer16BitAccess = VK_TRUE;
20+
+ }
21+
if (mDeviceInfo._16BitStorageFeatures.storageInputOutput16 == VK_TRUE) {
22+
usedKnobs._16BitStorageFeatures.storageInputOutput16 = VK_TRUE;
23+
}
24+
diff --git a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
25+
index a324c101ed..8d64da750f 100644
26+
--- a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
27+
+++ b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp
28+
@@ -269,8 +269,9 @@ void PhysicalDevice::InitializeSupportedFeaturesImpl() {
29+
if (mDeviceInfo.HasExt(DeviceExt::ShaderFloat16Int8) &&
30+
mDeviceInfo.HasExt(DeviceExt::_16BitStorage) &&
31+
mDeviceInfo.shaderFloat16Int8Features.shaderFloat16 == VK_TRUE &&
32+
- mDeviceInfo._16BitStorageFeatures.storageBuffer16BitAccess == VK_TRUE &&
33+
- mDeviceInfo._16BitStorageFeatures.uniformAndStorageBuffer16BitAccess == VK_TRUE) {
34+
+ mDeviceInfo._16BitStorageFeatures.storageBuffer16BitAccess == VK_TRUE /*&&
35+
+ WebGPU EP needs to ensure we don't put fp16 values in uniforms when this patch is applied.
36+
+ mDeviceInfo._16BitStorageFeatures.uniformAndStorageBuffer16BitAccess == VK_TRUE*/) {
37+
// ONNX Runtime Patch: enable shaderF16 on all devices.
38+
EnableFeature(Feature::ShaderF16);
39+
shaderF16Enabled = true;

tools/ci_build/github/android/default_full_aar_build_settings.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
"--build_shared_lib",
1717
"--use_nnapi",
1818
"--use_xnnpack",
19+
"--use_webgpu",
20+
"--cmake_extra_defines=CMAKE_CXX_SCAN_FOR_MODULES=OFF",
1921
"--skip_tests"
2022
]
2123
}

0 commit comments

Comments
 (0)