Skip to content

Commit 42fcd71

Browse files
authored
[CPU] Fix compilation errors because of unused variables (microsoft#26147)
This PR fixes few unused variables
1 parent b160e8c commit 42fcd71

File tree

2 files changed

+77
-77
lines changed

2 files changed

+77
-77
lines changed

onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "core/framework/float16.h"
1313
#include "core/framework/allocator.h"
1414
#include "core/platform/threadpool.h"
15+
#include "core/common/narrow.h"
1516

1617
#include <algorithm>
1718
#include <vector>
@@ -120,7 +121,7 @@ Status MoE<T>::ComputeMoE(const OpKernelContext* context,
120121
int num_routing_threads = 1;
121122
if (tp != nullptr && num_tokens >= 1024) {
122123
int max_threads = concurrency::ThreadPool::DegreeOfParallelism(tp);
123-
num_routing_threads = std::min(static_cast<int>(num_tokens / 512), max_threads);
124+
num_routing_threads = std::min(narrow<int>(num_tokens / 512), max_threads);
124125
num_routing_threads = std::max(1, num_routing_threads);
125126
}
126127

@@ -133,7 +134,7 @@ Status MoE<T>::ComputeMoE(const OpKernelContext* context,
133134
}
134135

135136
concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) {
136-
auto work = concurrency::ThreadPool::PartitionWork(static_cast<int>(thread_id), num_routing_threads, static_cast<std::ptrdiff_t>(num_tokens));
137+
auto work = concurrency::ThreadPool::PartitionWork(narrow<int>(thread_id), num_routing_threads, static_cast<std::ptrdiff_t>(num_tokens));
137138
auto& local_expert_token_map = thread_local_expert_token_maps[thread_id];
138139

139140
std::vector<std::pair<float, int64_t>> sorted_logits(static_cast<size_t>(num_experts));
@@ -173,7 +174,7 @@ Status MoE<T>::ComputeMoE(const OpKernelContext* context,
173174
int64_t route_idx = i * k_ + j;
174175
float normalized_weight = sorted_logits[static_cast<size_t>(j)].first * inv_top_k_sum;
175176

176-
route_expert[route_idx] = static_cast<int>(expert_idx);
177+
route_expert[route_idx] = narrow<int>(expert_idx);
177178
route_scale[route_idx] = normalized_weight;
178179
if (normalized_weight > 0.0f) {
179180
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
@@ -185,7 +186,7 @@ Status MoE<T>::ComputeMoE(const OpKernelContext* context,
185186
int64_t route_idx = i * k_ + j;
186187
float weight = sorted_logits[static_cast<size_t>(j)].first;
187188

188-
route_expert[route_idx] = static_cast<int>(expert_idx);
189+
route_expert[route_idx] = narrow<int>(expert_idx);
189190
route_scale[route_idx] = weight;
190191
if (weight > 0.0f) {
191192
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
@@ -319,7 +320,7 @@ Status MoE<T>::ComputeMoE(const OpKernelContext* context,
319320

320321
// Optimized expert processing with thread-local buffer reuse
321322
concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) {
322-
int thread_id = static_cast<int>(thread_id_pd);
323+
int thread_id = narrow<int>(thread_id_pd);
323324
auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast<std::ptrdiff_t>(num_experts));
324325

325326
float* local_output = thread_local_outputs + static_cast<size_t>(thread_id) * output_buffer_size;
@@ -440,6 +441,11 @@ Status MoE<T>::ProcessExpertBatch(const T* input_tokens,
440441
int64_t inter_size,
441442
T* fc1_output_buffer,
442443
T* activation_output_buffer) const {
444+
ORT_UNUSED_PARAMETER(token_expert_ids);
445+
ORT_UNUSED_PARAMETER(token_weights);
446+
ORT_UNUSED_PARAMETER(expert_id);
447+
ORT_UNUSED_PARAMETER(fc1_output_buffer);
448+
ORT_UNUSED_PARAMETER(activation_output_buffer);
443449
const bool is_swiglu = activation_type_ == ActivationType::SwiGLU;
444450
const int64_t fc1_output_size = is_swiglu ? (inter_size * 2) : inter_size;
445451

0 commit comments

Comments
 (0)