Skip to content

Commit 310c223

Browse files
hanno-beckermkannwischer
authored andcommitted
RV64: Use lazy reduction in invNTT
Previously, the invNTT would keep the coefficients in unsigned canonical range [0,MLKEM_Q). This commit changes this to a lazy reduction strategy following the AArch64 and x86_64 backends. Bounds are tracked coarsely and reductions introduced where necessary. The reduction is certainly not optimal, but the returns of further improvements are diminishing while complicating review. Bounds assertions (debug only) are introduced to check the bounds at runtime. Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
1 parent 79a6b61 commit 310c223

File tree

2 files changed

+107
-24
lines changed

2 files changed

+107
-24
lines changed

.clang-format

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ IncludeBlocks: Preserve
1717
# as "attributes" so they don't get increasingly indented line after line
1818
BreakBeforeBraces: Allman
1919
InsertBraces: true
20-
WhitespaceSensitiveMacros: ['__contract__', '__loop__' ]
20+
WhitespaceSensitiveMacros: ['__contract__', '__loop__', 'MLK_RV64V_ABS_BOUNDS16' ]
2121
Macros:
2222
# Make this artifically long to avoid function bodies after short contracts
2323
- __contract__(x)={ void a; void b; void c; void d; void e; void f; } void abcdefghijklmnopqrstuvw()

mlkem/src/native/riscv64/src/rv64v_poly.c

Lines changed: 106 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ static inline vint16m1_t fq_mulq_vx(vint16m1_t rx, int16_t ry, size_t vl)
121121
{
122122
vint16m1_t result;
123123

124-
result = fq_cadd(fq_mul_vx(rx, ry, vl), vl);
124+
result = fq_mul_vx(rx, ry, vl);
125125

126-
mlk_assert_bound_int16m1(result, vl, 0, MLKEM_Q);
126+
mlk_assert_abs_bound_int16m1(result, vl, MLKEM_Q);
127127
return result;
128128
}
129129

@@ -343,28 +343,18 @@ void mlk_rv64v_poly_ntt(int16_t *r)
343343

344344
/* Reverse / Gentleman-Sande butterfly operation */
345345

346-
#define MLK_RVV_GS_BFLY_RX(u0, u1, ut, uc, vl) \
347-
{ \
348-
ut = __riscv_vsub_vv_i16m1(u0, u1, vl); \
349-
u0 = __riscv_vadd_vv_i16m1(u0, u1, vl); \
350-
u0 = fq_csub(u0, vl); \
351-
u1 = fq_mul_vx(ut, uc, vl); \
352-
u1 = fq_cadd(u1, vl); \
353-
\
354-
mlk_assert_bound_int16m1(u0, vl, 0, MLKEM_Q); \
355-
mlk_assert_bound_int16m1(u1, vl, 0, MLKEM_Q); \
346+
#define MLK_RVV_GS_BFLY_RX(u0, u1, ut, uc, vl) \
347+
{ \
348+
ut = __riscv_vsub_vv_i16m1(u0, u1, vl); \
349+
u0 = __riscv_vadd_vv_i16m1(u0, u1, vl); \
350+
u1 = fq_mul_vx(ut, uc, vl); \
356351
}
357352

358-
#define MLK_RVV_GS_BFLY_RV(u0, u1, ut, uc, vl) \
359-
{ \
360-
ut = __riscv_vsub_vv_i16m1(u0, u1, vl); \
361-
u0 = __riscv_vadd_vv_i16m1(u0, u1, vl); \
362-
u0 = fq_csub(u0, vl); \
363-
u1 = fq_mul_vv(ut, uc, vl); \
364-
u1 = fq_cadd(u1, vl); \
365-
\
366-
mlk_assert_bound_int16m1(u0, vl, 0, MLKEM_Q); \
367-
mlk_assert_bound_int16m1(u1, vl, 0, MLKEM_Q); \
353+
#define MLK_RVV_GS_BFLY_RV(u0, u1, ut, uc, vl) \
354+
{ \
355+
ut = __riscv_vsub_vv_i16m1(u0, u1, vl); \
356+
u0 = __riscv_vadd_vv_i16m1(u0, u1, vl); \
357+
u1 = fq_mul_vv(ut, uc, vl); \
368358
}
369359

370360
static vint16m2_t mlk_rv64v_intt2(vint16m2_t vp, vint16m1_t cz)
@@ -395,13 +385,21 @@ static vint16m2_t mlk_rv64v_intt2(vint16m2_t vp, vint16m1_t cz)
395385
t0 = __riscv_vget_v_i16m2_i16m1(vp, 0);
396386
t1 = __riscv_vget_v_i16m2_i16m1(vp, 1);
397387

398-
/* pre-scale and move to positive range [0, q-1] for inverse transform */
388+
/* pre-scale */
399389
t0 = fq_mulq_vx(t0, MLK_RVV_MONT_NR, vl);
400390
t1 = fq_mulq_vx(t1, MLK_RVV_MONT_NR, vl);
401391

392+
/* absolute bounds: < t0 < q, t1 < q */
393+
mlk_assert_abs_bound_int16m1(t0, vl, MLKEM_Q);
394+
mlk_assert_abs_bound_int16m1(t1, vl, MLKEM_Q);
395+
402396
c0 = __riscv_vrgather_vv_i16m1(cz, cs2, vl);
403397
MLK_RVV_GS_BFLY_RV(t0, t1, vt, c0, vl);
404398

399+
/* absolute bounds: < t0 < 2*q, t1 < q */
400+
mlk_assert_abs_bound_int16m1(t0, vl, 2 * MLKEM_Q);
401+
mlk_assert_abs_bound_int16m1(t1, vl, MLKEM_Q);
402+
405403
/* swap 2 */
406404
vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1);
407405
vp = __riscv_vrgatherei16_vv_i16m2(vp, v2p2, vl2);
@@ -410,6 +408,10 @@ static vint16m2_t mlk_rv64v_intt2(vint16m2_t vp, vint16m1_t cz)
410408
c0 = __riscv_vrgather_vv_i16m1(cz, cs4, vl);
411409
MLK_RVV_GS_BFLY_RV(t0, t1, vt, c0, vl);
412410

411+
/* absolute bounds: t0 < 4*q, t1 < q */
412+
mlk_assert_abs_bound_int16m1(t0, vl, 4 * MLKEM_Q);
413+
mlk_assert_abs_bound_int16m1(t1, vl, MLKEM_Q);
414+
413415
/* swap 4 */
414416
vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1);
415417
vp = __riscv_vrgatherei16_vv_i16m2(vp, v2p4, vl2);
@@ -418,13 +420,47 @@ static vint16m2_t mlk_rv64v_intt2(vint16m2_t vp, vint16m1_t cz)
418420
c0 = __riscv_vrgather_vv_i16m1(cz, cs8, vl);
419421
MLK_RVV_GS_BFLY_RV(t0, t1, vt, c0, vl);
420422

423+
/* absolute bounds: < 8*q */
424+
mlk_assert_abs_bound_int16m1(t0, vl, 8 * MLKEM_Q);
425+
mlk_assert_abs_bound_int16m1(t1, vl, MLKEM_Q);
426+
427+
t0 = fq_mulq_vx(t0, MLK_RVV_MONT_R1, vl);
428+
429+
/* absolute bounds: < q */
430+
mlk_assert_abs_bound_int16m1(t0, vl, MLKEM_Q);
431+
mlk_assert_abs_bound_int16m1(t1, vl, MLKEM_Q);
432+
421433
/* swap 8 */
422434
vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1);
423435
vp = __riscv_vrgatherei16_vv_i16m2(vp, v2p8, vl2);
424436

425437
return vp;
426438
}
427439

440+
#define MLK_RV64V_ABS_BOUNDS16(vl, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, \
441+
vb, vc, vd, ve, vf, b0, b1, b2, b3, b4, b5, b6, \
442+
b7, b8, b9, ba, bb, bc, bd, be, bf) \
443+
do \
444+
{ \
445+
mlk_assert_abs_bound_int16m1(v0, vl, (b0) * MLKEM_Q); \
446+
mlk_assert_abs_bound_int16m1(v1, vl, (b1) * MLKEM_Q); \
447+
mlk_assert_abs_bound_int16m1(v2, vl, (b2) * MLKEM_Q); \
448+
mlk_assert_abs_bound_int16m1(v3, vl, (b3) * MLKEM_Q); \
449+
mlk_assert_abs_bound_int16m1(v4, vl, (b4) * MLKEM_Q); \
450+
mlk_assert_abs_bound_int16m1(v5, vl, (b5) * MLKEM_Q); \
451+
mlk_assert_abs_bound_int16m1(v6, vl, (b6) * MLKEM_Q); \
452+
mlk_assert_abs_bound_int16m1(v7, vl, (b7) * MLKEM_Q); \
453+
mlk_assert_abs_bound_int16m1(v8, vl, (b8) * MLKEM_Q); \
454+
mlk_assert_abs_bound_int16m1(v9, vl, (b9) * MLKEM_Q); \
455+
mlk_assert_abs_bound_int16m1(va, vl, (ba) * MLKEM_Q); \
456+
mlk_assert_abs_bound_int16m1(vb, vl, (bb) * MLKEM_Q); \
457+
mlk_assert_abs_bound_int16m1(vc, vl, (bc) * MLKEM_Q); \
458+
mlk_assert_abs_bound_int16m1(vd, vl, (bd) * MLKEM_Q); \
459+
mlk_assert_abs_bound_int16m1(ve, vl, (be) * MLKEM_Q); \
460+
mlk_assert_abs_bound_int16m1(vf, vl, (bf) * MLKEM_Q); \
461+
} while (0)
462+
463+
428464
/* Only for VLEN=256 for now */
429465
void mlk_rv64v_poly_invntt_tomont(int16_t *r)
430466
{
@@ -479,6 +515,11 @@ void mlk_rv64v_poly_invntt_tomont(int16_t *r)
479515
ve = __riscv_vget_v_i16m2_i16m1(vp, 0);
480516
vf = __riscv_vget_v_i16m2_i16m1(vp, 1);
481517

518+
/* absolute bounds < q (see mlk_rv64v_intt2) */
519+
MLK_RV64V_ABS_BOUNDS16(vl,
520+
v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf,
521+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
522+
482523
MLK_RVV_GS_BFLY_RX(v0, v1, vt, izeta[0x40], vl);
483524
MLK_RVV_GS_BFLY_RX(v2, v3, vt, izeta[0x41], vl);
484525
MLK_RVV_GS_BFLY_RX(v4, v5, vt, izeta[0x50], vl);
@@ -488,6 +529,14 @@ void mlk_rv64v_poly_invntt_tomont(int16_t *r)
488529
MLK_RVV_GS_BFLY_RX(vc, vd, vt, izeta[0x70], vl);
489530
MLK_RVV_GS_BFLY_RX(ve, vf, vt, izeta[0x71], vl);
490531

532+
/* absolute bounds:
533+
* - v{0,2,4,6,8,a,c,e}: < 2*q
534+
* - v{1,3,5,7,9,b,d,f}: < 1*q
535+
*/
536+
MLK_RV64V_ABS_BOUNDS16(vl,
537+
v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf,
538+
2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1);
539+
491540
MLK_RVV_GS_BFLY_RX(v0, v2, vt, izeta[0x20], vl);
492541
MLK_RVV_GS_BFLY_RX(v1, v3, vt, izeta[0x20], vl);
493542
MLK_RVV_GS_BFLY_RX(v4, v6, vt, izeta[0x21], vl);
@@ -497,6 +546,15 @@ void mlk_rv64v_poly_invntt_tomont(int16_t *r)
497546
MLK_RVV_GS_BFLY_RX(vc, ve, vt, izeta[0x31], vl);
498547
MLK_RVV_GS_BFLY_RX(vd, vf, vt, izeta[0x31], vl);
499548

549+
/* absolute bounds:
550+
* - v{0,4,8,c}: < 4*q
551+
* - v{1,5,9,d}: < 2*q
552+
* - v{2,3,6,7,a,b,e,f}: < 1*q
553+
*/
554+
MLK_RV64V_ABS_BOUNDS16(vl,
555+
v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf,
556+
4, 2, 1, 1, 4, 2, 1, 1, 4, 2, 1, 1, 4, 2, 1, 1);
557+
500558
MLK_RVV_GS_BFLY_RX(v0, v4, vt, izeta[0x10], vl);
501559
MLK_RVV_GS_BFLY_RX(v1, v5, vt, izeta[0x10], vl);
502560
MLK_RVV_GS_BFLY_RX(v2, v6, vt, izeta[0x10], vl);
@@ -506,6 +564,25 @@ void mlk_rv64v_poly_invntt_tomont(int16_t *r)
506564
MLK_RVV_GS_BFLY_RX(va, ve, vt, izeta[0x11], vl);
507565
MLK_RVV_GS_BFLY_RX(vb, vf, vt, izeta[0x11], vl);
508566

567+
/* absolute bounds:
568+
* - v{0,8}: < 8*q
569+
* - v{1,9}: < 4*q
570+
* - v{2,3,a,b}: < 2*q
571+
* - v{4,5,6,7,c,d,e,f}: < 1*q
572+
*/
573+
MLK_RV64V_ABS_BOUNDS16(vl,
574+
v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf,
575+
8, 4, 2, 2, 1, 1, 1, 1, 8, 4, 2, 2, 1, 1, 1, 1);
576+
577+
/* Reduce v0, v8 to avoid overflow */
578+
v0 = fq_mulq_vx(v0, MLK_RVV_MONT_R1, vl);
579+
v8 = fq_mulq_vx(v8, MLK_RVV_MONT_R1, vl);
580+
581+
/* absolute bounds: < 4*q */
582+
MLK_RV64V_ABS_BOUNDS16(vl,
583+
v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf,
584+
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4);
585+
509586
MLK_RVV_GS_BFLY_RX(v0, v8, vt, izeta[0x01], vl);
510587
MLK_RVV_GS_BFLY_RX(v1, v9, vt, izeta[0x01], vl);
511588
MLK_RVV_GS_BFLY_RX(v2, va, vt, izeta[0x01], vl);
@@ -515,6 +592,11 @@ void mlk_rv64v_poly_invntt_tomont(int16_t *r)
515592
MLK_RVV_GS_BFLY_RX(v6, ve, vt, izeta[0x01], vl);
516593
MLK_RVV_GS_BFLY_RX(v7, vf, vt, izeta[0x01], vl);
517594

595+
/* absolute bounds: < 8*q */
596+
MLK_RV64V_ABS_BOUNDS16(vl,
597+
v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf,
598+
8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8);
599+
518600
__riscv_vse16_v_i16m1(&r[0x00], v0, vl);
519601
__riscv_vse16_v_i16m1(&r[0x10], v1, vl);
520602
__riscv_vse16_v_i16m1(&r[0x20], v2, vl);
@@ -708,3 +790,4 @@ MLK_EMPTY_CU(rv64v_poly)
708790
#undef MLK_RVV_CT_BFLY_FV
709791
#undef MLK_RVV_GS_BFLY_RX
710792
#undef MLK_RVV_GS_BFLY_RV
793+
#undef MLK_RV64V_ABS_BOUNDS16

0 commit comments

Comments
 (0)