@@ -373,26 +373,57 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
373373 continue ;
374374 }
375375
376- bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16;
377-
378- size_t element_size = ProgramUniformVariableDataTypeSize[static_cast <int >(uniform.data_type )];
376+ // Calculate the size and alignment of the uniform variable.
377+ //
379378 // https://www.w3.org/TR/WGSL/#alignof
380- size_t base_alignment = is_f16
381- ? (length > 4 ? 16 : length > 2 ? 8
382- : length * element_size)
383- : (length > 2 ? 16 : length * element_size);
384- size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16 ;
385-
386- current_offset = (current_offset + base_alignment - 1 ) / base_alignment * base_alignment;
379+ //
380+ // For f16:
381+ // - length > 8 : array<vec4<u32>, N> (align 16) (size 16 * N, N = ceil(length / 8))
382+ // - length == 7 or 8: vec4<u32> (align 16) (size 16)
383+ // - length == 5 or 6: vec3<u32> (align 16) (size 12)
384+ // - length == 3 or 4: vec2<u32> (align 8) (size 8)
385+ // - length == 1 or 2: u32 (align 4) (size 4)
386+ //
387+ // For other types (i32, u32, f32):
388+ // - length > 4 : array<vec4<T>, N> (align 16) (size 16 * N, N = ceil(length / 4))
389+ // - length == 4 : vec4<T> (align 16) (size 16)
390+ // - length == 3 : vec3<T> (align 16) (size 12)
391+ // - length == 2 : vec2<T> (align 8) (size 8)
392+ // - length == 1 : T (align 4) (size 4)
393+ //
394+
395+ const bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16;
396+
397+ size_t variable_alignment = 4 ; // default alignment for scalar types
398+ size_t variable_size = 4 ; // default size for scalar types
399+
400+ if (is_f16) {
401+ if (length > 6 ) {
402+ variable_alignment = 16 ;
403+ variable_size = 16 * ((length + 7 ) / 8 );
404+ } else if (length > 4 ) {
405+ variable_alignment = 16 ;
406+ variable_size = 12 ;
407+ } else if (length > 2 ) {
408+ variable_alignment = 8 ;
409+ variable_size = 8 ;
410+ }
411+ } else {
412+ if (length > 3 ) {
413+ variable_alignment = 16 ;
414+ variable_size = 16 * ((length + 3 ) / 4 );
415+ } else if (length > 2 ) {
416+ variable_alignment = 16 ;
417+ variable_size = 12 ;
418+ } else if (length > 1 ) {
419+ variable_alignment = 8 ;
420+ variable_size = 8 ;
421+ }
422+ }
423+ current_offset = (current_offset + variable_alignment - 1 ) / variable_alignment * variable_alignment;
387424 uniform_and_offsets.emplace_back (uniform, current_offset);
388425
389- // For non-float16 type, when length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where
390- // N = ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N * SizeOf(vec4<i32|u32|f32>).
391- // For float16 type, when length > 4, the uniform variable is of type array<mat2x4<f16>,N>, where
392- // N = ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte length is N * SizeOf(mat2x4<f16>).
393- size_t element_per_struct = is_f16 ? 8 : 4 ;
394- current_offset +=
395- length > 4 ? (length + element_per_struct - 1 ) / element_per_struct * struct_size : length * element_size;
426+ current_offset += variable_size;
396427 }
397428
398429 // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
0 commit comments