@@ -95,7 +95,12 @@ __device__ inline void Softmax(const int total_sequence_length,
9595 }
9696 }
9797 }
98+
99+ #if CUDART_VERSION >= 12090
100+ const auto max = BlockReduce (tmp_storage).Reduce (thread_data_max, ::cuda::maximum ());
101+ #else
98102 const auto max = BlockReduce (tmp_storage).Reduce (thread_data_max, cub::Max ());
103+ #endif
99104
100105 // Store max value
101106 if (threadIdx .x == 0 ) {
@@ -114,7 +119,12 @@ __device__ inline void Softmax(const int total_sequence_length,
114119 }
115120 }
116121
122+ #if CUDART_VERSION >= 12090
123+ const auto sum = BlockReduce (tmp_storage).Reduce (thread_data_sum, ::cuda::std::plus ());
124+ #else
117125 const auto sum = BlockReduce (tmp_storage).Reduce (thread_data_sum, cub::Sum ());
126+ #endif
127+
118128 if (threadIdx .x == 0 ) {
119129 sum_reverse_block = 1 .f / sum;
120130 }
@@ -171,7 +181,11 @@ __device__ inline void SoftmaxSmall(const int total_sequence_length,
171181 // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
172182 // a math transform as below is leveraged to get a stable softmax:
173183 // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
184+ #if CUDART_VERSION >= 12090
185+ const auto max = BlockReduce (tmp_storage).Reduce (input_data, ::cuda::maximum (), end);
186+ #else
174187 const auto max = BlockReduce (tmp_storage).Reduce (input_data, cub::Max (), end);
188+ #endif
175189
176190 // Store max value
177191 if (threadIdx .x == 0 ) {
@@ -184,7 +198,11 @@ __device__ inline void SoftmaxSmall(const int total_sequence_length,
184198 thread_data_exp = expf (input_data - max_block);
185199 }
186200
201+ #if CUDART_VERSION >= 12090
202+ const auto sum = BlockReduce (tmp_storage).Reduce (thread_data_exp, ::cuda::std::plus (), end);
203+ #else
187204 const auto sum = BlockReduce (tmp_storage).Reduce (thread_data_exp, cub::Sum (), end);
205+ #endif
188206
189207 // Store value of 1.0/sum.
190208 if (threadIdx .x == 0 ) {
@@ -240,7 +258,12 @@ __global__ void SoftmaxLargeKernel(const int total_sequence_length,
240258 cached_data[i] = input_data;
241259 thread_data_max = max (thread_data_max, input_data);
242260 }
261+
262+ #if CUDART_VERSION >= 12090
263+ const auto max = BlockReduce (tmp_storage).Reduce (thread_data_max, ::cuda::maximum (), end);
264+ #else
243265 const auto max = BlockReduce (tmp_storage).Reduce (thread_data_max, cub::Max (), end);
266+ #endif
244267
245268 // Store max value
246269 if (threadIdx .x == 0 ) {
@@ -254,7 +277,12 @@ __global__ void SoftmaxLargeKernel(const int total_sequence_length,
254277 cached_data[i] = is_valid ? expf (cached_data[i] - max_block) : 0 .0f ;
255278 thread_data_exp += cached_data[i];
256279 }
280+
281+ #if CUDART_VERSION >= 12090
282+ const auto sum = BlockReduce (tmp_storage).Reduce (thread_data_exp, ::cuda::std::plus (), end);
283+ #else
257284 const auto sum = BlockReduce (tmp_storage).Reduce (thread_data_exp, cub::Sum (), end);
285+ #endif
258286
259287 // Store value of 1.0/sum.
260288 if (threadIdx .x == 0 ) {
@@ -343,7 +371,11 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length,
343371 return ;
344372 }
345373
374+ #if CUDART_VERSION >= 12090
375+ const float max = BlockReduce (tmp_storage).Reduce (max_thread_data, ::cuda::maximum (), total_sequence_length);
376+ #else
346377 const float max = BlockReduce (tmp_storage).Reduce (max_thread_data, cub::Max (), total_sequence_length);
378+ #endif
347379
348380 // Store max value
349381 if (threadIdx .x == 0 ) {
@@ -357,7 +389,12 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length,
357389 cached_data[i] = ev;
358390 sum_thread_data_exp += ev;
359391 }
392+
393+ #if CUDART_VERSION >= 12090
394+ const auto sum = BlockReduce (tmp_storage).Reduce (sum_thread_data_exp, ::cuda::std::plus (), TPB);
395+ #else
360396 const auto sum = BlockReduce (tmp_storage).Reduce (sum_thread_data_exp, cub::Sum (), TPB);
397+ #endif
361398
362399 // Store value of 1.0/sum
363400 if (threadIdx .x == 0 ) {
@@ -441,7 +478,11 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length,
441478 return ;
442479 }
443480
481+ #if CUDART_VERSION >= 12090
482+ const float max = BlockReduce (tmp_storage).Reduce (thread_data, ::cuda::maximum (), total_sequence_length);
483+ #else
444484 const float max = BlockReduce (tmp_storage).Reduce (thread_data, cub::Max (), total_sequence_length);
485+ #endif
445486
446487 // Store max value
447488 if (threadIdx .x == 0 ) {
@@ -450,7 +491,12 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length,
450491 __syncthreads ();
451492
452493 float thread_data_exp = threadIdx .x < total_sequence_length ? expf (thread_data - max_block) : 0 .0f ;
494+
495+ #if CUDART_VERSION >= 12090
496+ const auto sum = BlockReduce (tmp_storage).Reduce (thread_data_exp, ::cuda::std::plus (), total_sequence_length);
497+ #else
453498 const auto sum = BlockReduce (tmp_storage).Reduce (thread_data_exp, cub::Sum (), total_sequence_length);
499+ #endif
454500
455501 // Store value of 1.0/sum
456502 if (threadIdx .x == 0 ) {
@@ -596,7 +642,12 @@ __device__ inline void SoftmaxSmallPacked(const int total_sequence_length,
596642 float input_data = HAS_BIAS ? float (input[index]) + float (attn_bias[bias_offset + threadIdx .x ]) : float (input[index]);
597643
598644 float thread_data_max = is_valid ? input_data : float (-CUDART_INF_F);
645+
646+ #if CUDART_VERSION >= 12090
647+ const auto max = BlockReduce (tmp_storage).Reduce (thread_data_max, ::cuda::maximum (), end);
648+ #else
599649 const auto max = BlockReduce (tmp_storage).Reduce (thread_data_max, cub::Max (), end);
650+ #endif
600651
601652 // Store max value
602653 if (threadIdx .x == 0 ) {
@@ -609,7 +660,11 @@ __device__ inline void SoftmaxSmallPacked(const int total_sequence_length,
609660 thread_data_exp = expf (input_data - max_block);
610661 }
611662
663+ #if CUDART_VERSION >= 12090
664+ const auto sum = BlockReduce (tmp_storage).Reduce (thread_data_exp, ::cuda::std::plus (), end);
665+ #else
612666 const auto sum = BlockReduce (tmp_storage).Reduce (thread_data_exp, cub::Sum (), end);
667+ #endif
613668
614669 // Store value of 1.0/sum.
615670 if (threadIdx .x == 0 ) {
0 commit comments