@@ -2247,71 +2247,51 @@ DEFN_ARITH_OPERATIONS(double)
22472247DEFN_ARITH_OPERATIONS (half )
22482248#endif // defined(cl_khr_fp16)
22492249
2250- #define DEFN_WORK_GROUP_REDUCE (func , type_abbr , type , op , identity ) \
2251- type __builtin_IB_WorkGroupReduce_##func##_##type_abbr(type X) \
2252- { \
2253- GET_MEMPOOL_PTR(scratch, type, true, 0) \
2254- uint sg_id = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupId, , )(); \
2255- uint sg_lid = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupLocalInvocationId, , )(); \
2256- uint sg_size = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupMaxSize, , )(); \
2257- /* number of values to reduce */ \
2258- uint values_num = SPIRV_BUILTIN_NO_OP (BuiltInNumSubgroups , , )(); \
2259- \
2260- type sg_x = SPIRV_BUILTIN (Group ##func , _i32_i32_ ##type_abbr , )(Subgroup , GroupOperationReduce , X ); /* 1 step */ \
2261- if (sg_lid == 0 ) { \
2262- scratch [sg_id ] = sg_x ; \
2263- } \
2264- SPIRV_BUILTIN (ControlBarrier , _i32_i32_i32 , )(Workgroup , 0 , AcquireRelease | WorkgroupMemory ); \
2265- \
2266- if (sg_size == 32 ) /* SIMD32 */ \
2267- { \
2268- if (sg_id == 0 ) \
2269- { \
2270- type low_data = sg_lid < values_num ? scratch [sg_lid ] : identity ; \
2271- type high_data = sg_lid + 32 < values_num ? scratch [sg_lid + 32 ] : identity ; \
2272- type reduce = op (low_data , high_data ); \
2273- sg_x = SPIRV_BUILTIN (Group ##func , _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, reduce); \
2274- if (sg_lid == 0) \
2275- { \
2276- scratch[0] = sg_x; \
2277- } \
2278- } \
2279- SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
2280- } \
2281- else /* SIMD16 and SIMD8 */ \
2282- { \
2283- /* Log2(8) = 3, log2(16) = 4 */ \
2284- uint sg_size_shifts = sg_size == 16 ? 4 : 3 ; \
2285- /* gloabal ID for work-items across all subgroups */ \
2286- uint global_id = (sg_id << sg_size_shifts ) + sg_lid ; \
2287- uint SWG_shift = 0 ; \
2288- /* With subgroup reduce it will only take maximally 3 steps to get reduction. */ \
2289- for (int i = 0 ; i < 2 ; ++ i ) \
2290- { \
2291- uint cntNeededSWG = (values_num >> sg_size_shifts ) + 1 ; \
2292- bool allowSWG = \
2293- /* Allow only those subgroups which will continue reduction */ \
2294- sg_id < cntNeededSWG && \
2295- /* Allow reduction if we have more than 1 value */ \
2296- values_num > 1 ; \
2297- \
2298- if (allowSWG ) \
2299- { \
2300- uint shift_global_id = global_id << SWG_shift ; \
2301- type value = global_id < values_num ? scratch [shift_global_id ] : identity ; \
2302- sg_x = SPIRV_BUILTIN (Group ##func , _i32_i32_ ##type_abbr , )(Subgroup , GroupOperationReduce , value ); /* 2 & 3 step */ \
2303- SWG_shift += sg_size_shifts ; \
2304- if (sg_lid == 0 ) \
2305- { \
2306- scratch [sg_id << SWG_shift ] = sg_x ; \
2307- } \
2308- } \
2309- values_num = cntNeededSWG ; \
2310- /* barier for work-items, with mem sync. */ \
2311- SPIRV_BUILTIN (ControlBarrier , _i32_i32_i32 , )(Workgroup , 0 , AcquireRelease | WorkgroupMemory ); \
2312- } \
2313- } \
2314- return scratch [0 ]; \
2250+ #define DEFN_WORK_GROUP_REDUCE (func , type_abbr , type , op , identity ) \
2251+ type __builtin_IB_WorkGroupReduce_##func##_##type_abbr(type X) \
2252+ { \
2253+ type sg_x = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, X); \
2254+ GET_MEMPOOL_PTR(scratch, type, true, 0) \
2255+ uint sg_id = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupId, , )(); \
2256+ uint num_sg = SPIRV_BUILTIN_NO_OP(BuiltInNumSubgroups, , )(); \
2257+ uint sg_lid = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupLocalInvocationId, , )(); \
2258+ uint sg_size = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupSize, , )(); \
2259+ uint sg_max_size = SPIRV_BUILTIN_NO_OP(BuiltInSubgroupMaxSize, , )(); \
2260+ \
2261+ if (sg_lid == 0) { \
2262+ scratch[sg_id] = sg_x; \
2263+ } \
2264+ SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
2265+ \
2266+ uint global_id = sg_id * sg_max_size + sg_lid; \
2267+ uint values_num = num_sg; \
2268+ while(values_num > sg_max_size) { \
2269+ uint max_id = ((values_num + sg_max_size - 1) / sg_max_size) * sg_max_size; \
2270+ type value = global_id < values_num ? scratch[global_id] : identity; \
2271+ SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
2272+ if (global_id < max_id) { \
2273+ sg_x = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, value);\
2274+ if (sg_lid == 0) { \
2275+ scratch[sg_id] = sg_x; \
2276+ } \
2277+ } \
2278+ values_num = max_id / sg_max_size; \
2279+ SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
2280+ } \
2281+ \
2282+ type result; \
2283+ if (values_num > sg_size) { \
2284+ type sg_aggregate = scratch[0]; \
2285+ for (int s = 1; s < values_num; ++s) { \
2286+ sg_aggregate = op(sg_aggregate, scratch[s]); \
2287+ } \
2288+ result = sg_aggregate; \
2289+ } else { \
2290+ type value = sg_lid < values_num ? scratch[sg_lid] : identity; \
2291+ result = SPIRV_BUILTIN(Group##func, _i32_i32_##type_abbr, )(Subgroup, GroupOperationReduce, value); \
2292+ } \
2293+ SPIRV_BUILTIN(ControlBarrier, _i32_i32_i32, )(Workgroup, 0, AcquireRelease | WorkgroupMemory); \
2294+ return result; \
23152295}
23162296
23172297
0 commit comments