Skip to content

Commit 8b16e80

Browse files
committed
Add runtime dispatch (mld_polyvecl_pointwise_acc_montgomery_l4/l5/l7_native)
Signed-off-by: willieyz <willie.zhao@chelpis.com>
1 parent 5d286e4 commit 8b16e80

File tree

6 files changed

+91
-46
lines changed

6 files changed

+91
-46
lines changed

dev/aarch64_clean/meta.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,28 +170,31 @@ static MLD_INLINE int mld_poly_pointwise_montgomery_native(
170170
return MLD_NATIVE_FUNC_SUCCESS;
171171
}
172172

173-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
173+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l4_native(
174174
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
175175
const int32_t v[4][MLDSA_N])
176176
{
177177
mld_polyvecl_pointwise_acc_montgomery_l4_asm(w, (const int32_t *)u,
178178
(const int32_t *)v);
179+
return MLD_NATIVE_FUNC_SUCCESS;
179180
}
180181

181-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
182+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l5_native(
182183
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
183184
const int32_t v[5][MLDSA_N])
184185
{
185186
mld_polyvecl_pointwise_acc_montgomery_l5_asm(w, (const int32_t *)u,
186187
(const int32_t *)v);
188+
return MLD_NATIVE_FUNC_SUCCESS;
187189
}
188190

189-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
191+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l7_native(
190192
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
191193
const int32_t v[7][MLDSA_N])
192194
{
193195
mld_polyvecl_pointwise_acc_montgomery_l7_asm(w, (const int32_t *)u,
194196
(const int32_t *)v);
197+
return MLD_NATIVE_FUNC_SUCCESS;
195198
}
196199

197200
#endif /* !__ASSEMBLER__ */

dev/x86_64/meta.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,28 +227,43 @@ static MLD_INLINE int mld_poly_pointwise_montgomery_native(
227227
return MLD_NATIVE_FUNC_SUCCESS;
228228
}
229229

230-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
230+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l4_native(
231231
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
232232
const int32_t v[4][MLDSA_N])
233233
{
234+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
235+
{
236+
return MLD_NATIVE_FUNC_FALLBACK;
237+
}
234238
mld_pointwise_acc_l4_avx2((__m256i *)w, (const __m256i *)u,
235239
(const __m256i *)v, mld_qdata.vec);
240+
return MLD_NATIVE_FUNC_SUCCESS;
236241
}
237242

238-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
243+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l5_native(
239244
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
240245
const int32_t v[5][MLDSA_N])
241246
{
247+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
248+
{
249+
return MLD_NATIVE_FUNC_FALLBACK;
250+
}
242251
mld_pointwise_acc_l5_avx2((__m256i *)w, (const __m256i *)u,
243252
(const __m256i *)v, mld_qdata.vec);
253+
return MLD_NATIVE_FUNC_SUCCESS;
244254
}
245255

246-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
256+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l7_native(
247257
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
248258
const int32_t v[7][MLDSA_N])
249259
{
260+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
261+
{
262+
return MLD_NATIVE_FUNC_FALLBACK;
263+
}
250264
mld_pointwise_acc_l7_avx2((__m256i *)w, (const __m256i *)u,
251265
(const __m256i *)v, mld_qdata.vec);
266+
return MLD_NATIVE_FUNC_SUCCESS;
252267
}
253268

254269
#endif /* !__ASSEMBLER__ */

mldsa/src/native/aarch64/meta.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,28 +170,31 @@ static MLD_INLINE int mld_poly_pointwise_montgomery_native(
170170
return MLD_NATIVE_FUNC_SUCCESS;
171171
}
172172

173-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
173+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l4_native(
174174
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
175175
const int32_t v[4][MLDSA_N])
176176
{
177177
mld_polyvecl_pointwise_acc_montgomery_l4_asm(w, (const int32_t *)u,
178178
(const int32_t *)v);
179+
return MLD_NATIVE_FUNC_SUCCESS;
179180
}
180181

181-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
182+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l5_native(
182183
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
183184
const int32_t v[5][MLDSA_N])
184185
{
185186
mld_polyvecl_pointwise_acc_montgomery_l5_asm(w, (const int32_t *)u,
186187
(const int32_t *)v);
188+
return MLD_NATIVE_FUNC_SUCCESS;
187189
}
188190

189-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
191+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l7_native(
190192
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
191193
const int32_t v[7][MLDSA_N])
192194
{
193195
mld_polyvecl_pointwise_acc_montgomery_l7_asm(w, (const int32_t *)u,
194196
(const int32_t *)v);
197+
return MLD_NATIVE_FUNC_SUCCESS;
195198
}
196199

197200
#endif /* !__ASSEMBLER__ */

mldsa/src/native/api.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ static MLD_INLINE int mld_poly_pointwise_montgomery_native(
346346
* - const int32_t u[MLDSA_L][MLDSA_N]: first input vector
347347
* - const int32_t v[MLDSA_L][MLDSA_N]: second input vector
348348
**************************************************/
349-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
349+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l4_native(
350350
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
351351
const int32_t v[4][MLDSA_N]);
352352
#endif /* MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 */
@@ -366,7 +366,7 @@ static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
366366
* - const int32_t u[MLDSA_L][MLDSA_N]: first input vector
367367
* - const int32_t v[MLDSA_L][MLDSA_N]: second input vector
368368
**************************************************/
369-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
369+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l5_native(
370370
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
371371
const int32_t v[5][MLDSA_N]);
372372
#endif /* MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 */
@@ -386,7 +386,7 @@ static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
386386
* - const int32_t u[MLDSA_L][MLDSA_N]: first input vector
387387
* - const int32_t v[MLDSA_L][MLDSA_N]: second input vector
388388
**************************************************/
389-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
389+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l7_native(
390390
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
391391
const int32_t v[7][MLDSA_N]);
392392
#endif /* MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 */

mldsa/src/native/x86_64/meta.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,28 +227,43 @@ static MLD_INLINE int mld_poly_pointwise_montgomery_native(
227227
return MLD_NATIVE_FUNC_SUCCESS;
228228
}
229229

230-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
230+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l4_native(
231231
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
232232
const int32_t v[4][MLDSA_N])
233233
{
234+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
235+
{
236+
return MLD_NATIVE_FUNC_FALLBACK;
237+
}
234238
mld_pointwise_acc_l4_avx2((__m256i *)w, (const __m256i *)u,
235239
(const __m256i *)v, mld_qdata.vec);
240+
return MLD_NATIVE_FUNC_SUCCESS;
236241
}
237242

238-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
243+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l5_native(
239244
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
240245
const int32_t v[5][MLDSA_N])
241246
{
247+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
248+
{
249+
return MLD_NATIVE_FUNC_FALLBACK;
250+
}
242251
mld_pointwise_acc_l5_avx2((__m256i *)w, (const __m256i *)u,
243252
(const __m256i *)v, mld_qdata.vec);
253+
return MLD_NATIVE_FUNC_SUCCESS;
244254
}
245255

246-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
256+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l7_native(
247257
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
248258
const int32_t v[7][MLDSA_N])
249259
{
260+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
261+
{
262+
return MLD_NATIVE_FUNC_FALLBACK;
263+
}
250264
mld_pointwise_acc_l7_avx2((__m256i *)w, (const __m256i *)u,
251265
(const __m256i *)v, mld_qdata.vec);
266+
return MLD_NATIVE_FUNC_SUCCESS;
252267
}
253268

254269
#endif /* !__ASSEMBLER__ */

mldsa/src/polyvec.c

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -328,42 +328,57 @@ MLD_INTERNAL_API
328328
void mld_polyvecl_pointwise_acc_montgomery(mld_poly *w, const mld_polyvecl *u,
329329
const mld_polyvecl *v)
330330
{
331-
#if defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4) && \
332-
MLD_CONFIG_PARAMETER_SET == 44
333-
/* TODO: proof */
331+
unsigned int i, j;
334332
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
335333
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
336-
mld_polyvecl_pointwise_acc_montgomery_l4_native(
337-
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
338-
(const int32_t(*)[MLDSA_N])v->vec);
339-
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
334+
#if defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4) && \
335+
MLD_CONFIG_PARAMETER_SET == 44
336+
{
337+
/* TODO: proof */
338+
int ret;
339+
ret = mld_polyvecl_pointwise_acc_montgomery_l4_native(
340+
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
341+
(const int32_t(*)[MLDSA_N])v->vec);
342+
if (ret == MLD_NATIVE_FUNC_SUCCESS)
343+
{
344+
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
345+
return;
346+
}
347+
}
340348
#elif defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5) && \
341349
MLD_CONFIG_PARAMETER_SET == 65
342-
/* TODO: proof */
343-
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
344-
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
345-
mld_polyvecl_pointwise_acc_montgomery_l5_native(
346-
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
347-
(const int32_t(*)[MLDSA_N])v->vec);
348-
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
350+
{
351+
/* TODO: proof */
352+
int ret;
353+
ret = mld_polyvecl_pointwise_acc_montgomery_l5_native(
354+
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
355+
(const int32_t(*)[MLDSA_N])v->vec);
356+
if (ret == MLD_NATIVE_FUNC_SUCCESS)
357+
{
358+
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
359+
return;
360+
}
361+
}
349362
#elif defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7) && \
350363
MLD_CONFIG_PARAMETER_SET == 87
351-
/* TODO: proof */
352-
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
353-
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
354-
mld_polyvecl_pointwise_acc_montgomery_l7_native(
355-
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
356-
(const int32_t(*)[MLDSA_N])v->vec);
357-
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
358-
#else /* !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 && \
364+
{
365+
/* TODO: proof */
366+
int ret;
367+
ret = mld_polyvecl_pointwise_acc_montgomery_l7_native(
368+
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
369+
(const int32_t(*)[MLDSA_N])v->vec);
370+
if (ret == MLD_NATIVE_FUNC_SUCCESS)
371+
{
372+
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
373+
return;
374+
}
375+
}
376+
#endif /* !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 && \
359377
MLD_CONFIG_PARAMETER_SET == 44) && \
360378
!(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 && \
361379
MLD_CONFIG_PARAMETER_SET == 65) && \
362380
MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 && \
363381
MLD_CONFIG_PARAMETER_SET == 87 */
364-
unsigned int i, j;
365-
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
366-
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
367382
/* The first input is bounded by [0, Q-1] inclusive
368383
* The second input is bounded by [-9Q+1, 9Q-1] inclusive . Hence, we can
369384
* safely accumulate in 64-bits without intermediate reductions as
@@ -398,12 +413,6 @@ void mld_polyvecl_pointwise_acc_montgomery(mld_poly *w, const mld_polyvecl *u,
398413
}
399414

400415
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
401-
#endif /* !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 && \
402-
MLD_CONFIG_PARAMETER_SET == 44) && \
403-
!(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 && \
404-
MLD_CONFIG_PARAMETER_SET == 65) && \
405-
!(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 && \
406-
MLD_CONFIG_PARAMETER_SET == 87) */
407416
}
408417

409418
MLD_INTERNAL_API

0 commit comments

Comments
 (0)