Skip to content

Commit 1f45838

Browse files
tianleiwuCopilot
andauthored
[MLAS] Fix rotary avx2 kernel invalid access (microsoft#26389)
This fixes an issue that _mm256_maskload_ps intrinsic used in remainder-handling logic introduced in microsoft#23694. The core of the problem is that _mm256_maskload_ps (and its store equivalent) can read beyond the masked elements. Even if mask correctly specifies that you only want to load, for example, 3 floats, the intrinsic may still read the full 32 bytes (8 floats) from the provided memory address. The invalid access occurs when one of buffers (input, sin_data, or cos_data) ends near the boundary of a memory page, and the part of the 32-byte read that you don't care about (i.e., the masked-off part) falls onto an unmapped page. This will cause a segmentation fault (invalid access). The Solution: Use a Scalar Remainder Loop The simplest, safest, and most robust solution is to replace the masked AVX remainder logic with a simple scalar loop. This is the exact strategy already used by your RopeKernel_Avx2_fp16_Impl functions, which are safe from this bug. The performance impact of this change will be negligible, as this loop only processes the final 1-15 elements. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 6e67664 commit 1f45838

File tree

1 file changed

+27
-56
lines changed

1 file changed

+27
-56
lines changed

onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp

Lines changed: 27 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ namespace rope_avx2 {
2525
namespace {
2626

2727
typedef __m256 float32x8_t;
28-
static constexpr int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0};
2928

3029
template <bool interleaved>
3130
void
@@ -123,20 +122,15 @@ RopeKernel_Avx2_fp16_Impl<true>(
123122
convert_to_fp16_and_store(output + i + 8, y1);
124123
}
125124

126-
for (; i < dim; i++) {
125+
// Scalar remainder loop to safely handle trailing elements in pairs.
126+
for (; i + 1 < dim; i += 2) {
127127
size_t cache_idx = i / 2;
128-
bool sign = i & 1;
129-
size_t j = sign ? i - 1 : i + 1;
130-
131-
float output_data_i = input[i].ToFloat() * cos_data[cache_idx].ToFloat();
132-
float input_data_j = input[j].ToFloat();
133-
float sin_data_cache_idx = sin_data[cache_idx].ToFloat();
134-
if (sign) {
135-
output_data_i += input_data_j * sin_data_cache_idx;
136-
} else {
137-
output_data_i -= input_data_j * sin_data_cache_idx;
138-
}
139-
output[i] = MLAS_FP16(output_data_i);
128+
float input0 = input[i].ToFloat();
129+
float input1 = input[i + 1].ToFloat();
130+
float sin_val = sin_data[cache_idx].ToFloat();
131+
float cos_val = cos_data[cache_idx].ToFloat();
132+
output[i] = MLAS_FP16(input0 * cos_val - input1 * sin_val);
133+
output[i + 1] = MLAS_FP16(input0 * sin_val + input1 * cos_val);
140134
}
141135
}
142136

@@ -173,20 +167,15 @@ RopeKernel_Avx2_fp32_Impl<false>(
173167
_mm256_storeu_ps(output + i, real_out);
174168
_mm256_storeu_ps(output + j, imag_out);
175169
}
176-
if (half_dim - i != 0) {
177-
size_t rem = half_dim - i;
178-
const __m256i mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - rem));
179-
//Use a mask to load the remaining input values
180-
float32x8_t real = _mm256_maskload_ps(input + i, mask);
181-
float32x8_t imag = _mm256_maskload_ps(input + j, mask);
182-
float32x8_t sin_val = _mm256_maskload_ps(sin_data + i, mask);
183-
float32x8_t cos_val = _mm256_maskload_ps(cos_data + i, mask);
184-
//Compute Real and Imaginary output values
185-
float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val));
186-
float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val));
187-
//Store back into non interleaved format
188-
_mm256_maskstore_ps(output + i, mask, real_out);
189-
_mm256_maskstore_ps(output + j, mask, imag_out);
170+
171+
// Scalar remainder loop to safely handle trailing elements
172+
for (; i < half_dim; i++, j++) {
173+
float real = input[i];
174+
float imag = input[j];
175+
float sin_val = sin_data[i];
176+
float cos_val = cos_data[i];
177+
output[i] = real * cos_val - imag * sin_val;
178+
output[j] = real * sin_val + imag * cos_val;
190179
}
191180
}
192181

@@ -223,34 +212,16 @@ RopeKernel_Avx2_fp32_Impl<true>(
223212
_mm256_storeu_ps(output + i, y0);
224213
_mm256_storeu_ps(output + i + 8, y1);
225214
}
226-
if (dim - i != 0) {
227-
size_t rem = dim - i;
228-
const __m256i mask0 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?8:rem)));
229-
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?(rem-8):0)));
230-
float32x8_t x0 = _mm256_maskload_ps(input + i, mask0); //Load the first set of data using mask
231-
float32x8_t x1 = _mm256_maskload_ps(input + i + 8, mask1); //Load the reminder of data using a second mask
232-
//Load imaginary and real values to separate non-interleaved vectors
233-
float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000);
234-
float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101);
235-
__m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
236-
float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec);
237-
float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec);
238-
// Use masked loads for sin/cos data to avoid reading beyond buffer bounds
239-
size_t cos_sin_rem = rem / 2;
240-
const __m256i cos_sin_mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - cos_sin_rem));
241-
float32x8_t sin_val = _mm256_maskload_ps(sin_data + i / 2, cos_sin_mask);
242-
float32x8_t cos_val = _mm256_maskload_ps(cos_data + i / 2, cos_sin_mask);
243-
//Compute Real and Imaginary output values
244-
float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val));
245-
float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val));
246-
//Store back into interleaved format
247-
__m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
248-
float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec);
249-
float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec);
250-
float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s);
251-
float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s);
252-
_mm256_maskstore_ps(output + i, mask0, y0);
253-
_mm256_maskstore_ps(output + i + 8, mask1, y1);
215+
216+
// Scalar remainder loop to safely handle trailing elements in pairs
217+
for (; i + 1 < dim; i += 2) {
218+
size_t cache_idx = i / 2;
219+
float input0 = input[i];
220+
float input1 = input[i + 1];
221+
float sin_val = sin_data[cache_idx];
222+
float cos_val = cos_data[cache_idx];
223+
output[i] = input0 * cos_val - input1 * sin_val;
224+
output[i + 1] = input0 * sin_val + input1 * cos_val;
254225
}
255226
}
256227

0 commit comments

Comments
 (0)