Skip to content

Commit bfaa2ff

Browse files
JonathanC-ARMJaswanth51
authored andcommitted
Fixes for DynamicQuantizeMatMul and Attention3D tests (microsoft#25814)
### Description This change fixes correctness issues in two areas that were causing failures in onnxruntime_test_all: - DynamicQuantizeMatMul.WithConstantBInputs - AttentionTest.Attention3DDefault - AttentionTest.Attention3DWithPastAndPresentQkMatmul What was wrong and how it’s fixed 1) DynamicQuantizeMatMul.WithConstantBInputs - Root cause: The Kleidi dynamic quantization GEMM path could be selected even when the B scales contained values such as (zero, negative, or non-finite). That violates kernel assumptions and can lead to incorrect results. - Fix: In `onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc`, we now explicitly validate that all B scales are finite and strictly positive before enabling the Kleidi/MLAS dynamic path. If any scale is invalid, we disable that path. 2) Attention tests (Attention3DDefault, Attention3DWithPastAndPresentQkMatmul) - Root causes in `onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp`: - Incorrect handling of GEMM corner cases for alpha/beta and K==0 (e.g., not respecting C = beta*C when alpha==0 or K==0). - Unnecessary or premature fallbacks for small shapes. - Fixes: - Add early-outs for degenerate sizes: if M==0 or N==0, return handled. - Correctly implement alpha/beta semantics: --------- Signed-off-by: Jonathan Clohessy <jonathan.clohessy@arm.com>
1 parent fa81996 commit bfaa2ff

File tree

2 files changed

+56
-34
lines changed

2 files changed

+56
-34
lines changed

onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,19 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
200200

201201
can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available);
202202

203+
// Kleidi dynamic path requires strictly positive, finite scales.
204+
// Disable if any invalid scale is detected.
205+
if (can_use_dynamic_quant_mlas_) {
206+
const auto bs = b_scale_tensor->DataAsSpan<float>();
207+
const bool has_invalid =
208+
std::any_of(bs.begin(), bs.end(),
209+
[](float s) { return !std::isfinite(s) || s <= 0.0f; });
210+
211+
if (has_invalid) {
212+
can_use_dynamic_quant_mlas_ = false;
213+
}
214+
}
215+
203216
// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
204217
// We check that here too before attempting to use them.
205218
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) {
@@ -379,7 +392,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
379392
if (y->Shape().Size() == 0)
380393
return Status::OK();
381394

382-
auto a_data = static_cast<const uint8_t*>(ctx->Input<Tensor>(IN_A)->DataRaw());
395+
const float* a_data = ctx->Input<Tensor>(IN_A)->Data<float>();
383396
auto* y_data = y->MutableData<float>();
384397

385398
// batch gemm
@@ -393,7 +406,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
393406

394407
for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) {
395408
auto& params = gemm_data_vec[gemm_idx];
396-
params.A = reinterpret_cast<const float*>(a_data + helper.LeftOffsets()[gemm_idx]);
409+
params.A = a_data + helper.LeftOffsets()[gemm_idx];
397410
params.lda = gemm_shape.K;
398411
params.PackedB = packed_b_.get();
399412
params.C = y_data + helper.OutputOffsets()[gemm_idx];

onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -153,28 +153,23 @@ ArmKleidiAI::MlasGemmBatch(
153153
MLAS_THREADPOOL* ThreadPool
154154
)
155155
{
156-
if(TransA == CblasTrans)
157-
{
158-
return false;
156+
if (M == 0 || N == 0) {
157+
return true;
159158
}
160-
if (TransA == CblasNoTrans && K == 0) {
161-
if (Data->beta != 1.0f) {
159+
160+
if (Data->alpha == 0.0f || K == 0) {
161+
if (Data->beta == 0.0f) {
162+
for (size_t i = 0; i < M; ++i) {
163+
std::fill_n(Data->C + i * Data->ldc, N, 0.0f);
164+
}
165+
} else if (Data->beta != 1.0f) {
162166
for (size_t i = 0; i < M; ++i) {
163167
for (size_t j = 0; j < N; ++j) {
164168
Data->C[i * Data->ldc + j] *= Data->beta;
165169
}
166170
}
167171
}
168-
}
169-
if (Data->beta == 0.0f){
170-
std::fill_n(Data->C, M * Data->ldc, 0.0f);
171-
}
172-
//Fallback in the case of unsupported cases
173-
if (M == 0 || N == 0 || K == 0 ||
174-
TransA != CblasNoTrans ||
175-
(TransB != CblasNoTrans && !Data[0].BIsPacked))
176-
{
177-
return false;
172+
return true;
178173
}
179174

180175
if (TransA == CblasNoTrans) {
@@ -185,11 +180,9 @@ ArmKleidiAI::MlasGemmBatch(
185180
auto m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
186181
auto n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
187182

188-
if (M < m_step || N < n_step) {
189-
if (GetMlasPlatform().MlasGemmBatchOverride != ArmKleidiAI::MlasGemmBatch){
190-
//Fallback to MLAS
191-
return false;
192-
}
183+
if (M < m_step && N < n_step && !Data->BIsPacked) {
184+
// Fallback to MLAS
185+
return false;
193186
}
194187

195188
std::vector<MLAS_SGEMM_DATA_PARAMS> KaiPackedData;
@@ -316,7 +309,7 @@ ArmKleidiAI::MlasGemmBatch(
316309
float* dst_tile = reinterpret_cast<float*>(CTile);
317310

318311
// quick copy of data in cases where we are not scaling or accumulating anything
319-
// with bounds checking on tile sizing to ensure the data fits in the memory block
312+
// with bounds checking on tile sizing to ensure the data fits in the memory block
320313
bool can_memcpy = (
321314
Data[BIdx].alpha == 1.0f &&
322315
Data[BIdx].beta == 0.0f &&
@@ -328,21 +321,37 @@ ArmKleidiAI::MlasGemmBatch(
328321

329322
if (can_memcpy) {
330323
std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float));
331-
}else {
332-
// apply alpha scaling and beta to output files
333-
for (size_t i = 0; i < TileSizeM; ++i) {
334-
for (size_t j = 0; j < TileSizeN; ++j) {
335-
const size_t idx = i * TileSizeN + j;
336-
const size_t dst_idx = i * Data[BIdx].ldc + j;
337-
338-
float ab = temp_tile[idx];
339-
float c_orig = dst_tile[dst_idx];
324+
return;
325+
}
340326

341-
dst_tile[dst_idx] = Data[BIdx].alpha * ab + Data[BIdx].beta * c_orig;
327+
float alpha = Data[BIdx].alpha;
328+
float beta = Data[BIdx].beta;
329+
size_t ldc = Data[BIdx].ldc;
330+
331+
for (size_t i = 0; i < TileSizeM; ++i) {
332+
for (size_t j = 0; j < TileSizeN; ++j) {
333+
const size_t temp_idx = i * TileSizeN + j;
334+
const size_t dst_idx = i * ldc + j;
335+
336+
float ab = temp_tile[temp_idx];
337+
float c_orig = dst_tile[dst_idx];
338+
339+
if (alpha == 1.0f && beta == 0.0f) {
340+
dst_tile[dst_idx] = ab;
341+
} else if (alpha == 1.0f) {
342+
dst_tile[dst_idx] = ab + beta * c_orig;
343+
} else if (beta == 0.0f) {
344+
dst_tile[dst_idx] = alpha * ab;
345+
} else {
346+
dst_tile[dst_idx] = alpha * ab + beta * c_orig;
342347
}
343348
}
344349
}
350+
return;
345351
});
352+
return true;
353+
}
354+
else {
355+
return false;
346356
}
347-
return true;
348357
}

0 commit comments

Comments
 (0)