Skip to content

Commit c97c822

Browse files
committed
Add runtime dispatch (mld_poly_pointwise_montgomery_native)
Signed-off-by: willieyz <willie.zhao@chelpis.com>
1 parent 73dfe55 commit c97c822

File tree

6 files changed

+29
-15
lines changed

6 files changed

+29
-15
lines changed

dev/aarch64_clean/meta.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,12 @@ static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *buf)
163163
return MLD_NATIVE_FUNC_SUCCESS;
164164
}
165165

166-
static MLD_INLINE void mld_poly_pointwise_montgomery_native(
166+
static MLD_INLINE int mld_poly_pointwise_montgomery_native(
167167
int32_t out[MLDSA_N], const int32_t in0[MLDSA_N],
168168
const int32_t in1[MLDSA_N])
169169
{
170170
mld_poly_pointwise_montgomery_asm(out, in0, in1);
171+
return MLD_NATIVE_FUNC_SUCCESS;
171172
}
172173

173174
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(

dev/x86_64/meta.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,16 @@ static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a)
216216
return MLD_NATIVE_FUNC_SUCCESS;
217217
}
218218

219-
static MLD_INLINE void mld_poly_pointwise_montgomery_native(
219+
static MLD_INLINE int mld_poly_pointwise_montgomery_native(
220220
int32_t c[MLDSA_N], const int32_t a[MLDSA_N], const int32_t b[MLDSA_N])
221221
{
222+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
223+
{
224+
return MLD_NATIVE_FUNC_FALLBACK;
225+
}
222226
mld_pointwise_avx2((__m256i *)c, (const __m256i *)a, (const __m256i *)b,
223227
mld_qdata.vec);
228+
return MLD_NATIVE_FUNC_SUCCESS;
224229
}
225230

226231
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(

mldsa/src/native/aarch64/meta.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,12 @@ static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *buf)
163163
return MLD_NATIVE_FUNC_SUCCESS;
164164
}
165165

166-
static MLD_INLINE void mld_poly_pointwise_montgomery_native(
166+
static MLD_INLINE int mld_poly_pointwise_montgomery_native(
167167
int32_t out[MLDSA_N], const int32_t in0[MLDSA_N],
168168
const int32_t in1[MLDSA_N])
169169
{
170170
mld_poly_pointwise_montgomery_asm(out, in0, in1);
171+
return MLD_NATIVE_FUNC_SUCCESS;
171172
}
172173

173174
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(

mldsa/src/native/api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a);
320320
* - const int32_t a[MLDSA_N]: first input polynomial
321321
* - const int32_t b[MLDSA_N]: second input polynomial
322322
**************************************************/
323-
static MLD_INLINE void mld_poly_pointwise_montgomery_native(
323+
static MLD_INLINE int mld_poly_pointwise_montgomery_native(
324324
int32_t c[MLDSA_N], const int32_t a[MLDSA_N], const int32_t b[MLDSA_N]);
325325
#endif /* MLD_USE_NATIVE_POINTWISE_MONTGOMERY */
326326

mldsa/src/native/x86_64/meta.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,16 @@ static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a)
216216
return MLD_NATIVE_FUNC_SUCCESS;
217217
}
218218

219-
static MLD_INLINE void mld_poly_pointwise_montgomery_native(
219+
static MLD_INLINE int mld_poly_pointwise_montgomery_native(
220220
int32_t c[MLDSA_N], const int32_t a[MLDSA_N], const int32_t b[MLDSA_N])
221221
{
222+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
223+
{
224+
return MLD_NATIVE_FUNC_FALLBACK;
225+
}
222226
mld_pointwise_avx2((__m256i *)c, (const __m256i *)a, (const __m256i *)b,
223227
mld_qdata.vec);
228+
return MLD_NATIVE_FUNC_SUCCESS;
224229
}
225230

226231
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(

mldsa/src/poly.c

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,18 +184,21 @@ MLD_INTERNAL_API
184184
void mld_poly_pointwise_montgomery(mld_poly *c, const mld_poly *a,
185185
const mld_poly *b)
186186
{
187-
#if defined(MLD_USE_NATIVE_POINTWISE_MONTGOMERY)
188-
/* TODO: proof */
189-
mld_assert_abs_bound(a->coeffs, MLDSA_N, MLD_NTT_BOUND);
190-
mld_assert_abs_bound(b->coeffs, MLDSA_N, MLD_NTT_BOUND);
191-
mld_poly_pointwise_montgomery_native(c->coeffs, a->coeffs, b->coeffs);
192-
mld_assert_abs_bound(c->coeffs, MLDSA_N, MLDSA_Q);
193-
#else /* MLD_USE_NATIVE_POINTWISE_MONTGOMERY */
194187
unsigned int i;
195-
196188
mld_assert_abs_bound(a->coeffs, MLDSA_N, MLD_NTT_BOUND);
197189
mld_assert_abs_bound(b->coeffs, MLDSA_N, MLD_NTT_BOUND);
198-
190+
#if defined(MLD_USE_NATIVE_POINTWISE_MONTGOMERY)
191+
{
192+
/* TODO: proof */
193+
int ret;
194+
ret = mld_poly_pointwise_montgomery_native(c->coeffs, a->coeffs, b->coeffs);
195+
if (ret == MLD_NATIVE_FUNC_SUCCESS)
196+
{
197+
mld_assert_abs_bound(c->coeffs, MLDSA_N, MLDSA_Q);
198+
return;
199+
}
200+
}
201+
#endif /* MLD_USE_NATIVE_POINTWISE_MONTGOMERY */
199202
for (i = 0; i < MLDSA_N; ++i)
200203
__loop__(
201204
invariant(i <= MLDSA_N)
@@ -206,7 +209,6 @@ void mld_poly_pointwise_montgomery(mld_poly *c, const mld_poly *a,
206209
}
207210

208211
mld_assert_abs_bound(c->coeffs, MLDSA_N, MLDSA_Q);
209-
#endif /* !MLD_USE_NATIVE_POINTWISE_MONTGOMERY */
210212
}
211213

212214
MLD_INTERNAL_API

0 commit comments

Comments
 (0)