1818
1919#include < algorithm>
2020#include < cfloat>
21- #include < cuda.h>
21+ #include < cuda.h> // for CUDA_VERSION
2222#include < cuda_fp16.h>
2323#include < math.h>
2424#include < sstream>
3838
3939#include " moe_kernel.h"
4040
41- #if CUDA_VERSION >= 11000
4241#include < cub/cub.cuh>
4342#include < cub/device/device_radix_sort.cuh>
4443#include < cub/util_type.cuh>
45- #else
46- #include " cub/cub.cuh"
47- #include " cub/device/device_radix_sort.cuh"
48- #include " cub/util_type.cuh"
49- #endif
5044
5145namespace ort_fastertransformer {
5246static constexpr int WARP_SIZE = 32 ;
53-
5447// ====================== Softmax things ===============================
5548// We have our own implementation of softmax here so we can support transposing the output
5649// in the softmax kernel when we extend this module to support expert-choice routing.
@@ -65,13 +58,6 @@ __launch_bounds__(TPB) __global__
6558
6659 const int thread_row_offset = blockIdx .x * num_cols;
6760
68- #if CUDA_VERSION >= 12090
69- ::cuda::std::plus sum;
70- #else
71- // Deprecated on CUDA 12.9
72- cub::Sum sum;
73- #endif
74-
7561 float threadData (-FLT_MAX);
7662
7763 // Don't touch finished rows.
@@ -84,7 +70,12 @@ __launch_bounds__(TPB) __global__
8470 threadData = max (static_cast <float >(input[idx]), threadData);
8571 }
8672
73+ #if defined(CUDA_VERSION) && CUDA_VERSION >= 12090
74+ const float maxElem = BlockReduce (tmpStorage).Reduce (threadData, ::cuda::maximum ());
75+ #else
8776 const float maxElem = BlockReduce (tmpStorage).Reduce (threadData, cub::Max ());
77+ #endif
78+
8879 if (threadIdx .x == 0 ) {
8980 float_max = maxElem;
9081 }
@@ -97,7 +88,12 @@ __launch_bounds__(TPB) __global__
9788 threadData += exp ((static_cast <float >(input[idx]) - float_max));
9889 }
9990
100- const auto Z = BlockReduce (tmpStorage).Reduce (threadData, sum);
91+ #if defined(CUDA_VERSION) && CUDA_VERSION >= 12090
92+ const auto Z = BlockReduce (tmpStorage).Reduce (threadData, ::cuda::std::plus ());
93+ #else
94+ // Deprecated on CUDA 12.9
95+ const auto Z = BlockReduce (tmpStorage).Reduce (threadData, cub::Sum ());
96+ #endif
10197
10298 if (threadIdx .x == 0 ) {
10399 normalizing_factor = 1 .f / Z;
@@ -993,6 +989,7 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::get_total_rows_info(int64_t expe
993989 if (experts_start_index > 0 ) {
994990 total_past_rows = total_rows_before_expert_host_[experts_start_index - 1 ];
995991 }
992+
996993 total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows;
997994}
998995
0 commit comments