Skip to content

Commit 7c18d89

Browse files
authored
Fix cuda 12.9 windows build (microsoft#25317)
### Description Fix Windows build with MSVC 17.14.7 and cuda 12.9.1. The build error was like: `CUDACOMPILE : nvcc error : 'cudafe++' died with status 0xC0000005 (ACCESS_VIOLATION)` The cause is unknown (maybe cudafe bug). The code change resolved the issue. I've verified it in two machines.
1 parent 5ae3ee7 commit 7c18d89

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
#include <algorithm>
2020
#include <cfloat>
21-
#include <cuda.h>
21+
#include <cuda.h> // for CUDA_VERSION
2222
#include <cuda_fp16.h>
2323
#include <math.h>
2424
#include <sstream>
@@ -38,19 +38,12 @@
3838

3939
#include "moe_kernel.h"
4040

41-
#if CUDA_VERSION >= 11000
4241
#include <cub/cub.cuh>
4342
#include <cub/device/device_radix_sort.cuh>
4443
#include <cub/util_type.cuh>
45-
#else
46-
#include "cub/cub.cuh"
47-
#include "cub/device/device_radix_sort.cuh"
48-
#include "cub/util_type.cuh"
49-
#endif
5044

5145
namespace ort_fastertransformer {
5246
static constexpr int WARP_SIZE = 32;
53-
5447
// ====================== Softmax things ===============================
5548
// We have our own implementation of softmax here so we can support transposing the output
5649
// in the softmax kernel when we extend this module to support expert-choice routing.
@@ -65,13 +58,6 @@ __launch_bounds__(TPB) __global__
6558

6659
const int thread_row_offset = blockIdx.x * num_cols;
6760

68-
#if CUDA_VERSION >= 12090
69-
::cuda::std::plus sum;
70-
#else
71-
// Deprecated on CUDA 12.9
72-
cub::Sum sum;
73-
#endif
74-
7561
float threadData(-FLT_MAX);
7662

7763
// Don't touch finished rows.
@@ -84,7 +70,12 @@ __launch_bounds__(TPB) __global__
8470
threadData = max(static_cast<float>(input[idx]), threadData);
8571
}
8672

73+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12090
74+
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, ::cuda::maximum());
75+
#else
8776
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
77+
#endif
78+
8879
if (threadIdx.x == 0) {
8980
float_max = maxElem;
9081
}
@@ -97,7 +88,12 @@ __launch_bounds__(TPB) __global__
9788
threadData += exp((static_cast<float>(input[idx]) - float_max));
9889
}
9990

100-
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
91+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12090
92+
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, ::cuda::std::plus());
93+
#else
94+
// Deprecated on CUDA 12.9
95+
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, cub::Sum());
96+
#endif
10197

10298
if (threadIdx.x == 0) {
10399
normalizing_factor = 1.f / Z;
@@ -993,6 +989,7 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::get_total_rows_info(int64_t expe
993989
if (experts_start_index > 0) {
994990
total_past_rows = total_rows_before_expert_host_[experts_start_index - 1];
995991
}
992+
996993
total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows;
997994
}
998995

0 commit comments

Comments
 (0)