Skip to content

Commit c29737d

Browse files
authored
[webgpu] use u32 to represent f16 in uniform (microsoft#25391)
### Description For f16 uniform variables, use u32 to bit-wise represent them. ### Motivation and Context Some devices supports f16 in shader/storage buffer, but not in uniform buffers. Dawn will set the f16_support to false for them. However, we don't necessarily have to use f16 in uniform. This change together with microsoft#25349 will enable using f16 models on some Android devices.
1 parent 131cf40 commit c29737d

File tree

3 files changed

+95
-35
lines changed

3 files changed

+95
-35
lines changed

onnxruntime/core/providers/webgpu/shader_helper.cc

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -491,16 +491,29 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& sha
491491
ss << ",";
492492
}
493493

494-
auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : "";
495-
ss << "\n " << alignment << name << ": ";
494+
// The actual variable type for the uniform variable depends on the data type (T) and length (N).
495+
//
496+
// For T in [i32, u32, f32]:
497+
// - If N == 1, the type is simply i32, u32, or f32.
498+
// - If 2 < N <= 4, the type is vecN<i32>, vecN<u32>, or vecN<f32> where N is the length.
499+
// - If N > 4, the type is array<vec4<T>, ceil(N / 4)>.
500+
//
501+
// For T is f16:
502+
// - If N == 1 or N == 2, the type is u32.
503+
// - If 2 < N <= 8, the type is vecX<u32> where X is ceil(N / 2).
504+
// - If N > 8, the type is array<vec4<u32>, X> where X is ceil(N / 8).
505+
//
506+
// Note: Using f16 type in uniforms is not generally supported on all devices. We use a u32 variable to represent
507+
// 2 f16 values.
508+
509+
if (data_type == ProgramUniformVariableDataType::Float16) {
510+
data_type = ProgramUniformVariableDataType::Uint32; // f16 is represented as u32
511+
length = (length + 1) / 2; // each u32 can hold 2 f16 values
512+
}
513+
ss << "\n " << name << ": ";
496514
if (length > 4) {
497-
if (data_type == ProgramUniformVariableDataType::Float16) {
498-
size_t array_size = (length + 7) / 8;
499-
ss << "array<mat2x4<" << data_type << ">, " << array_size << ">";
500-
} else {
501-
size_t array_size = (length + 3) / 4;
502-
ss << "array<vec4<" << data_type << ">, " << array_size << ">";
503-
}
515+
size_t array_size = (length + 3) / 4;
516+
ss << "array<vec4<" << data_type << ">, " << array_size << ">";
504517
} else if (length > 1) {
505518
ss << "vec" << length << "<" << data_type << ">";
506519
} else {

onnxruntime/core/providers/webgpu/shader_variable.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,34 @@ template <typename TIdx,
1717
typename TRank,
1818
typename = std::enable_if_t<std::is_same_v<TRank, int> || std::is_same_v<TRank, size_t>>>
1919
std::string GetElementAt(std::string_view var, const TIdx& idx, TRank rank, bool is_f16 = false) {
20-
// "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20.
21-
if (var.rfind("uniforms.", 0) == 0) {
22-
if (rank > 4) {
23-
if constexpr (std::is_integral_v<TIdx>) {
24-
if (is_f16) {
25-
return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]");
20+
if (var.starts_with("uniforms.")) {
21+
if (is_f16) {
22+
if (rank > 8) {
23+
// array<vec4<u32>, N>
24+
if constexpr (std::is_integral_v<TIdx>) {
25+
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, "[", idx / 8, "][", (idx % 8) / 2, "])[", (idx % 8) % 2, "]");
2626
} else {
27-
return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]");
27+
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, "[(", idx, ") / 8][((", idx, ") % 8) / 2])[((", idx, ") % 8) % 2]");
28+
}
29+
} else if (rank > 2) {
30+
// vecN<u32>
31+
if constexpr (std::is_integral_v<TIdx>) {
32+
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, "[", idx / 2, "])[", idx % 2, "]");
33+
} else {
34+
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, "[(", idx, ") / 2])[(", idx, ") % 2]");
2835
}
2936
} else {
30-
if (is_f16) {
31-
return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]");
37+
// u32
38+
if constexpr (std::is_integral_v<TIdx>) {
39+
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, ")[", idx % 2, "]");
40+
} else {
41+
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, ")[(", idx, ") % 2]");
42+
}
43+
}
44+
} else {
45+
if (rank > 4) {
46+
if constexpr (std::is_integral_v<TIdx>) {
47+
return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]");
3248
} else {
3349
return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]");
3450
}

onnxruntime/core/providers/webgpu/webgpu_context.cc

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -373,26 +373,57 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
373373
continue;
374374
}
375375

376-
bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16;
377-
378-
size_t element_size = ProgramUniformVariableDataTypeSize[static_cast<int>(uniform.data_type)];
376+
// Calculate the size and alignment of the uniform variable.
377+
//
379378
// https://www.w3.org/TR/WGSL/#alignof
380-
size_t base_alignment = is_f16
381-
? (length > 4 ? 16 : length > 2 ? 8
382-
: length * element_size)
383-
: (length > 2 ? 16 : length * element_size);
384-
size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16;
385-
386-
current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment;
379+
//
380+
// For f16:
381+
// - length > 8 : array<vec4<u32>, N> (align 16) (size 16 * N, N = ceil(length / 8))
382+
// - length == 7 or 8: vec4<u32> (align 16) (size 16)
383+
// - length == 5 or 6: vec3<u32> (align 16) (size 12)
384+
// - length == 3 or 4: vec2<u32> (align 8) (size 8)
385+
// - length == 1 or 2: u32 (align 4) (size 4)
386+
//
387+
// For other types (i32, u32, f32):
388+
// - length > 4 : array<vec4<T>, N> (align 16) (size 16 * N, N = ceil(length / 4))
389+
// - length == 4 : vec4<T> (align 16) (size 16)
390+
// - length == 3 : vec3<T> (align 16) (size 12)
391+
// - length == 2 : vec2<T> (align 8) (size 8)
392+
// - length == 1 : T (align 4) (size 4)
393+
//
394+
395+
const bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16;
396+
397+
size_t variable_alignment = 4; // default alignment for scalar types
398+
size_t variable_size = 4; // default size for scalar types
399+
400+
if (is_f16) {
401+
if (length > 6) {
402+
variable_alignment = 16;
403+
variable_size = 16 * ((length + 7) / 8);
404+
} else if (length > 4) {
405+
variable_alignment = 16;
406+
variable_size = 12;
407+
} else if (length > 2) {
408+
variable_alignment = 8;
409+
variable_size = 8;
410+
}
411+
} else {
412+
if (length > 3) {
413+
variable_alignment = 16;
414+
variable_size = 16 * ((length + 3) / 4);
415+
} else if (length > 2) {
416+
variable_alignment = 16;
417+
variable_size = 12;
418+
} else if (length > 1) {
419+
variable_alignment = 8;
420+
variable_size = 8;
421+
}
422+
}
423+
current_offset = (current_offset + variable_alignment - 1) / variable_alignment * variable_alignment;
387424
uniform_and_offsets.emplace_back(uniform, current_offset);
388425

389-
// For non-float16 type, when length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where
390-
// N = ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N * SizeOf(vec4<i32|u32|f32>).
391-
// For float16 type, when length > 4, the uniform variable is of type array<mat2x4<f16>,N>, where
392-
// N = ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte length is N * SizeOf(mat2x4<f16>).
393-
size_t element_per_struct = is_f16 ? 8 : 4;
394-
current_offset +=
395-
length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size;
426+
current_offset += variable_size;
396427
}
397428

398429
// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set

0 commit comments

Comments
 (0)