Skip to content

Commit 94f76b9

Browse files
apsonawaneJaswanth51
authored andcommitted
Fix MoE CPP tests (microsoft#25877)
This change adds skip test for QMoE CPU tests when running on TensorRT or CUDA EP. In the QMoE kernel there was a memory overwrite bug in the accumulate part, updated that and this fixed the python tests back
1 parent bfaa2ff commit 94f76b9

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,13 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
331331
const int64_t token_idx = route_idx / k_;
332332
const float weight = route_scale[route_idx];
333333

334-
float* dest = thread_local_outputs + static_cast<size_t>(thread_id) * output_buffer_size + token_idx * hidden_size;
334+
const size_t buffer_offset = static_cast<size_t>(token_idx) * static_cast<size_t>(hidden_size);
335+
if (buffer_offset + static_cast<size_t>(hidden_size) > output_buffer_size) {
336+
// Skip this token to prevent buffer overflow
337+
continue;
338+
}
339+
340+
float* dest = thread_local_outputs + static_cast<size_t>(thread_id) * output_buffer_size + buffer_offset;
335341
const float* src = C2 + i * hidden_size;
336342
for (int64_t j = 0; j < hidden_size; ++j) {
337343
dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f));
@@ -344,8 +350,9 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
344350
auto accumulate = [&](float* buffer) {
345351
memset(buffer, 0, output_buffer_size * sizeof(float));
346352
for (int i = 0; i < num_expert_threads; ++i) {
353+
const size_t thread_offset = static_cast<size_t>(i) * output_buffer_size;
347354
for (size_t j = 0; j < output_buffer_size; ++j) {
348-
buffer[j] += thread_local_outputs[static_cast<size_t>(i) * output_buffer_size + j];
355+
buffer[j] += thread_local_outputs[thread_offset + j];
349356
}
350357
}
351358
};

onnxruntime/test/contrib_ops/moe_test.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ static void RunQMoETest(const std::vector<float>& input, const std::vector<float
144144
// Test CPU execution provider (always available)
145145
// Skip CPU test if FC3 weights are provided since CPU doesn't support FC3
146146
if (fc3_experts_weights.empty()) {
147+
// Ensure CPU EP is available before running CPU tests
148+
auto cpu_ep = DefaultCpuExecutionProvider();
149+
if (!cpu_ep) {
150+
return; // Skip CPU test if CPU EP is not available
151+
}
152+
147153
OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain);
148154
cpu_tester.AddAttribute<int64_t>("k", static_cast<int64_t>(top_k));
149155
cpu_tester.AddAttribute<std::string>("activation_type", activation_type);
@@ -1323,6 +1329,13 @@ TEST(MoETest, QMoETest_Mixtral_Int4) {
13231329

13241330
// CPU-specific QMoE tests
13251331
TEST(MoETest, QMoETest_CPU_Int4_MLAS) {
1332+
#ifdef USE_MLAS
1333+
// Skip this test if we're not testing CPU execution provider
1334+
auto cpu_ep = DefaultCpuExecutionProvider();
1335+
if (!cpu_ep) {
1336+
GTEST_SKIP() << "CPU execution provider not available";
1337+
}
1338+
13261339
int num_rows = 2;
13271340
int num_experts = 2;
13281341
int hidden_size = 32;
@@ -1387,9 +1400,19 @@ TEST(MoETest, QMoETest_CPU_Int4_MLAS) {
13871400
std::vector<std::unique_ptr<IExecutionProvider>> cpu_execution_providers;
13881401
cpu_execution_providers.push_back(DefaultCpuExecutionProvider());
13891402
cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers);
1403+
#else
1404+
GTEST_SKIP() << "Skipping CPU QMoE test";
1405+
#endif
13901406
}
13911407

13921408
TEST(MoETest, QMoETest_CPU_Int8_MLAS) {
1409+
#ifdef USE_MLAS
1410+
// Skip this test if we're not testing CPU execution provider
1411+
auto cpu_ep = DefaultCpuExecutionProvider();
1412+
if (!cpu_ep) {
1413+
GTEST_SKIP() << "CPU execution provider not available";
1414+
}
1415+
13931416
// Test CPU implementation with 8-bit quantization - CPU ONLY
13941417
int num_rows = 1;
13951418
int num_experts = 2;
@@ -1446,9 +1469,19 @@ TEST(MoETest, QMoETest_CPU_Int8_MLAS) {
14461469
std::vector<std::unique_ptr<IExecutionProvider>> cpu_execution_providers;
14471470
cpu_execution_providers.push_back(DefaultCpuExecutionProvider());
14481471
cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers);
1472+
#else
1473+
GTEST_SKIP() << "Skipping CPU QMoE test";
1474+
#endif
14491475
}
14501476

14511477
TEST(MoETest, QMoETest_CPU_FC3_Error) {
1478+
#ifdef USE_MLAS
1479+
// Skip this test if we're not testing CPU execution provider
1480+
auto cpu_ep = DefaultCpuExecutionProvider();
1481+
if (!cpu_ep) {
1482+
GTEST_SKIP() << "CPU execution provider not available";
1483+
}
1484+
14521485
// Test that CPU throws error when FC3 gating is provided - CPU ONLY
14531486
int num_rows = 1;
14541487
int num_experts = 2;
@@ -1506,9 +1539,19 @@ TEST(MoETest, QMoETest_CPU_FC3_Error) {
15061539

15071540
// Expect this to fail with FC3 not implemented error
15081541
cpu_tester.Run(OpTester::ExpectResult::kExpectFailure, "FC3 gating is not yet implemented", {}, nullptr, &cpu_execution_providers);
1542+
#else
1543+
GTEST_SKIP() << "Skipping CPU QMoE test";
1544+
#endif
15091545
}
15101546

15111547
TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) {
1548+
#ifdef USE_MLAS
1549+
// Skip this test if we're not testing CPU execution provider
1550+
auto cpu_ep = DefaultCpuExecutionProvider();
1551+
if (!cpu_ep) {
1552+
GTEST_SKIP() << "CPU execution provider not available";
1553+
}
1554+
15121555
// Test CPU implementation with 4-bit quantization and SwiGLU activation
15131556
int num_rows = 2;
15141557
int num_experts = 2;
@@ -1573,9 +1616,18 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) {
15731616
std::vector<std::unique_ptr<IExecutionProvider>> cpu_execution_providers;
15741617
cpu_execution_providers.push_back(DefaultCpuExecutionProvider());
15751618
cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers);
1619+
#else
1620+
GTEST_SKIP() << "Skipping CPU QMoE test";
1621+
#endif
15761622
}
15771623

15781624
TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) {
1625+
#ifdef USE_MLAS
1626+
// Skip this test if we're not testing CPU execution provider
1627+
auto cpu_ep = DefaultCpuExecutionProvider();
1628+
if (!cpu_ep) {
1629+
GTEST_SKIP() << "CPU execution provider not available";
1630+
}
15791631
// Test CPU implementation with 8-bit quantization and SwiGLU activation
15801632
int num_rows = 1;
15811633
int num_experts = 2;
@@ -1633,6 +1685,9 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) {
16331685
std::vector<std::unique_ptr<IExecutionProvider>> cpu_execution_providers;
16341686
cpu_execution_providers.push_back(DefaultCpuExecutionProvider());
16351687
cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers);
1688+
#else
1689+
GTEST_SKIP() << "Skipping CPU QMoE test";
1690+
#endif
16361691
}
16371692

16381693
#endif

0 commit comments

Comments
 (0)