5959 * Implements @[FIPS203, Algorithm 13 (K-PKE.KeyGen), L19]
6060 *
6161 **************************************************/
62- static void mlk_pack_pk (uint8_t r [MLKEM_INDCPA_PUBLICKEYBYTES ],
63- const mlk_polyvec * pk ,
62+ static void mlk_pack_pk (uint8_t r [MLKEM_INDCPA_PUBLICKEYBYTES ], mlk_polyvec pk ,
6463 const uint8_t seed [MLKEM_SYMBYTES ])
6564{
6665 mlk_assert_bound_2d (pk , MLKEM_K , MLKEM_N , 0 , MLKEM_Q );
@@ -84,7 +83,7 @@ static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES],
8483 * Implements @[FIPS203, Algorithm 14 (K-PKE.Encrypt), L2-3]
8584 *
8685 **************************************************/
87- static void mlk_unpack_pk (mlk_polyvec * pk , uint8_t seed [MLKEM_SYMBYTES ],
86+ static void mlk_unpack_pk (mlk_polyvec pk , uint8_t seed [MLKEM_SYMBYTES ],
8887 const uint8_t packedpk [MLKEM_INDCPA_PUBLICKEYBYTES ])
8988{
9089 mlk_polyvec_frombytes (pk , packedpk );
@@ -109,8 +108,7 @@ static void mlk_unpack_pk(mlk_polyvec *pk, uint8_t seed[MLKEM_SYMBYTES],
109108 * Implements @[FIPS203, Algorithm 13 (K-PKE.KeyGen), L20]
110109 *
111110 **************************************************/
112- static void mlk_pack_sk (uint8_t r [MLKEM_INDCPA_SECRETKEYBYTES ],
113- const mlk_polyvec * sk )
111+ static void mlk_pack_sk (uint8_t r [MLKEM_INDCPA_SECRETKEYBYTES ], mlk_polyvec sk )
114112{
115113 mlk_assert_bound_2d (sk , MLKEM_K , MLKEM_N , 0 , MLKEM_Q );
116114 mlk_polyvec_tobytes (r , sk );
@@ -130,7 +128,7 @@ static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES],
130128 * Implements @[FIPS203, Algorithm 15 (K-PKE.Decrypt), L5]
131129 *
132130 **************************************************/
133- static void mlk_unpack_sk (mlk_polyvec * sk ,
131+ static void mlk_unpack_sk (mlk_polyvec sk ,
134132 const uint8_t packedsk [MLKEM_INDCPA_SECRETKEYBYTES ])
135133{
136134 mlk_polyvec_frombytes (sk , packedsk );
@@ -151,8 +149,8 @@ static void mlk_unpack_sk(mlk_polyvec *sk,
151149 * Implements @[FIPS203, Algorithm 14 (K-PKE.Encrypt), L22-23]
152150 *
153151 **************************************************/
154- static void mlk_pack_ciphertext (uint8_t r [MLKEM_INDCPA_BYTES ],
155- const mlk_polyvec * b , mlk_poly * v )
152+ static void mlk_pack_ciphertext (uint8_t r [MLKEM_INDCPA_BYTES ], mlk_polyvec b ,
153+ mlk_poly * v )
156154{
157155 mlk_polyvec_compress_du (r , b );
158156 mlk_poly_compress_dv (r + MLKEM_POLYVECCOMPRESSEDBYTES_DU , v );
@@ -172,7 +170,7 @@ static void mlk_pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES],
172170 * Implements @[FIPS203, Algorithm 15 (K-PKE.Decrypt), L1-4]
173171 *
174172 **************************************************/
175- static void mlk_unpack_ciphertext (mlk_polyvec * b , mlk_poly * v ,
173+ static void mlk_unpack_ciphertext (mlk_polyvec b , mlk_poly * v ,
176174 const uint8_t c [MLKEM_INDCPA_BYTES ])
177175{
178176 mlk_polyvec_decompress_du (b , c );
@@ -203,7 +201,7 @@ __contract__(
203201 *
204202 * Not static for benchmarking */
205203MLK_INTERNAL_API
206- void mlk_gen_matrix (mlk_polymat * a , const uint8_t seed [MLKEM_SYMBYTES ],
204+ void mlk_gen_matrix (mlk_polymat a , const uint8_t seed [MLKEM_SYMBYTES ],
207205 int transposed )
208206{
209207 unsigned i , j ;
@@ -240,11 +238,7 @@ void mlk_gen_matrix(mlk_polymat *a, const uint8_t seed[MLKEM_SYMBYTES],
240238 }
241239 }
242240
243- mlk_poly_rej_uniform_x4 (& a -> vec [i / MLKEM_K ].vec [i % MLKEM_K ],
244- & a -> vec [(i + 1 ) / MLKEM_K ].vec [(i + 1 ) % MLKEM_K ],
245- & a -> vec [(i + 2 ) / MLKEM_K ].vec [(i + 2 ) % MLKEM_K ],
246- & a -> vec [(i + 3 ) / MLKEM_K ].vec [(i + 3 ) % MLKEM_K ],
247- seed_ext );
241+ mlk_poly_rej_uniform_x4 (& a [i ], & a [i + 1 ], & a [i + 2 ], & a [i + 3 ], seed_ext );
248242 }
249243
250244 /* For MLKEM_K == 3, sample the last entry individually. */
@@ -265,7 +259,7 @@ void mlk_gen_matrix(mlk_polymat *a, const uint8_t seed[MLKEM_SYMBYTES],
265259 seed_ext [0 ][MLKEM_SYMBYTES + 1 ] = x ;
266260 }
267261
268- mlk_poly_rej_uniform (& a -> vec [ i / MLKEM_K ]. vec [ i % MLKEM_K ], seed_ext [0 ]);
262+ mlk_poly_rej_uniform (& a [ i ], seed_ext [0 ]);
269263 i ++ ;
270264 }
271265
@@ -277,8 +271,7 @@ void mlk_gen_matrix(mlk_polymat *a, const uint8_t seed[MLKEM_SYMBYTES],
277271 */
278272 for (i = 0 ; i < MLKEM_K * MLKEM_K ; i ++ )
279273 {
280- mlk_poly_permute_bitrev_to_custom (
281- a -> vec [i / MLKEM_K ].vec [i % MLKEM_K ].coeffs );
274+ mlk_poly_permute_bitrev_to_custom (a [i ].coeffs );
282275 }
283276
284277 /* Specification: Partially implements
@@ -305,34 +298,48 @@ void mlk_gen_matrix(mlk_polymat *a, const uint8_t seed[MLKEM_SYMBYTES],
305298 * Specification: Implements @[FIPS203, Section 2.4.7, Eq (2.12), (2.13)]
306299 *
307300 **************************************************/
308- static void mlk_matvec_mul (mlk_polyvec * out , const mlk_polymat * a ,
309- const mlk_polyvec * v , const mlk_polyvec_mulcache * vc )
301+ static void mlk_matvec_mul (mlk_polyvec out , const mlk_polymat a ,
302+ const mlk_polyvec v , const mlk_polyvec_mulcache vc )
310303__contract__ (
311304 requires (memory_no_alias (out , sizeof (mlk_polyvec )))
312305 requires (memory_no_alias (a , sizeof (mlk_polymat )))
313306 requires (memory_no_alias (v , sizeof (mlk_polyvec )))
314307 requires (memory_no_alias (vc , sizeof (mlk_polyvec_mulcache )))
315- requires (forall (k0 , 0 , MLKEM_K ,
316- forall (k1 , 0 , MLKEM_K ,
317- array_bound (a - > vec [k0 ].vec [k1 ].coeffs , 0 , MLKEM_N , 0 , MLKEM_UINT12_LIMIT ))))
308+ requires (forall (k0 , 0 , MLKEM_K * MLKEM_K ,
309+ array_bound (a [k0 ].coeffs , 0 , MLKEM_N , 0 , MLKEM_UINT12_LIMIT )))
318310 requires (forall (k1 , 0 , MLKEM_K ,
319- array_abs_bound (v - > vec [k1 ].coeffs , 0 , MLKEM_N , MLK_NTT_BOUND )))
311+ array_abs_bound (v [k1 ].coeffs , 0 , MLKEM_N , MLK_NTT_BOUND )))
320312 requires (forall (k2 , 0 , MLKEM_K ,
321- array_abs_bound (vc - > vec [k2 ].coeffs , 0 , MLKEM_N /2 , MLKEM_Q )))
322- assigns (memory_slice (out , sizeof ( mlk_polyvec ) ))
313+ array_abs_bound (vc [k2 ].coeffs , 0 , MLKEM_N /2 , MLKEM_Q )))
314+ assigns (object_whole (out ))
323315 ensures (forall (k3 , 0 , MLKEM_K ,
324- array_abs_bound (out - > vec [k3 ].coeffs , 0 , MLKEM_N , INT16_MAX /2 ))))
316+ array_abs_bound (out [k3 ].coeffs , 0 , MLKEM_N , INT16_MAX /2 ))))
325317{
326- unsigned i ;
327- for (i = 0 ; i < MLKEM_K ; i ++ )
328- __loop__ (
329- assigns (i , memory_slice (out , sizeof (mlk_polyvec )))
330- invariant (i <= MLKEM_K )
331- invariant (forall (k , 0 , i ,
332- array_abs_bound (out -> vec [k ].coeffs , 0 , MLKEM_N , INT16_MAX /2 ))))
333- {
334- mlk_polyvec_basemul_acc_montgomery_cached (& out -> vec [i ], & a -> vec [i ], v , vc );
335- }
318+ /* Temporary on the "refine-bounds" branch - unroll to a simple
319+ * sequence of calls for each possible value of MLKEM_K to
320+ * simplify proof.
321+ */
322+ mlk_polyvec_basemul_acc_montgomery_cached (& out [0 ], & a [0 ], v , vc );
323+ mlk_polyvec_basemul_acc_montgomery_cached (& out [1 ], & a [MLKEM_K ], v , vc );
324+
325+ #if MLKEM_K == 3
326+ mlk_polyvec_basemul_acc_montgomery_cached (& out [2 ], & a [MLKEM_K * 2 ], v , vc );
327+ #elif MLKEM_K == 4
328+ mlk_polyvec_basemul_acc_montgomery_cached (& out [2 ], & a [MLKEM_K * 2 ], v , vc );
329+ mlk_polyvec_basemul_acc_montgomery_cached (& out [3 ], & a [MLKEM_K * 3 ], v , vc );
330+ #endif
331+
332+ // unsigned i;
333+ // for (i = 0; i < MLKEM_K; i++)
334+ // __loop__(
335+ // assigns(i, object_whole(out))
336+ // invariant(i <= MLKEM_K)
337+ // invariant(forall(k, 0, i,
338+ // array_abs_bound(out[k].coeffs, 0, MLKEM_N, INT16_MAX/2))))
339+ // {
340+ // mlk_polyvec_basemul_acc_montgomery_cached(&out[i], &a[MLKEM_K * i], v,
341+ // vc);
342+ // }
336343}
337344
338345/* Reference: `indcpa_keypair_derand()` in the reference implementation @[REF].
@@ -370,49 +377,47 @@ void mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
370377 */
371378 MLK_CT_TESTING_DECLASSIFY (publicseed , MLKEM_SYMBYTES );
372379
373- mlk_gen_matrix (& a , publicseed , 0 /* no transpose */ );
380+ mlk_gen_matrix (a , publicseed , 0 /* no transpose */ );
374381
375382#if MLKEM_K == 2
376- mlk_poly_getnoise_eta1_4x (& skpv . vec [0 ], & skpv . vec [1 ], & e . vec [0 ], & e . vec [1 ],
377- noiseseed , 0 , 1 , 2 , 3 );
383+ mlk_poly_getnoise_eta1_4x (& skpv [0 ], & skpv [1 ], & e [0 ], & e [1 ], noiseseed , 0 , 1 ,
384+ 2 , 3 );
378385#elif MLKEM_K == 3
379386 /*
380387 * Only the first three output buffers are needed.
381388 * The laster parameter is a dummy that's overwritten later.
382389 */
383- mlk_poly_getnoise_eta1_4x (& skpv . vec [0 ], & skpv . vec [1 ], & skpv . vec [2 ],
384- & pkpv . vec [0 ] /* irrelevant */ , noiseseed , 0 , 1 , 2 ,
390+ mlk_poly_getnoise_eta1_4x (& skpv [0 ], & skpv [1 ], & skpv [2 ],
391+ & pkpv [0 ] /* irrelevant */ , noiseseed , 0 , 1 , 2 ,
385392 0xFF /* irrelevant */ );
386393 /* Same here */
387- mlk_poly_getnoise_eta1_4x (& e .vec [0 ], & e .vec [1 ], & e .vec [2 ],
388- & pkpv .vec [0 ] /* irrelevant */ , noiseseed , 3 , 4 , 5 ,
389- 0xFF /* irrelevant */ );
394+ mlk_poly_getnoise_eta1_4x (& e [0 ], & e [1 ], & e [2 ], & pkpv [0 ] /* irrelevant */ ,
395+ noiseseed , 3 , 4 , 5 , 0xFF /* irrelevant */ );
390396#elif MLKEM_K == 4
391- mlk_poly_getnoise_eta1_4x (& skpv .vec [0 ], & skpv .vec [1 ], & skpv .vec [2 ],
392- & skpv .vec [3 ], noiseseed , 0 , 1 , 2 , 3 );
393- mlk_poly_getnoise_eta1_4x (& e .vec [0 ], & e .vec [1 ], & e .vec [2 ], & e .vec [3 ],
394- noiseseed , 4 , 5 , 6 , 7 );
395- #endif /* MLKEM_K == 4 */
397+ mlk_poly_getnoise_eta1_4x (& skpv [0 ], & skpv [1 ], & skpv [2 ], & skpv [3 ], noiseseed ,
398+ 0 , 1 , 2 , 3 );
399+ mlk_poly_getnoise_eta1_4x (& e [0 ], & e [1 ], & e [2 ], & e [3 ], noiseseed , 4 , 5 , 6 , 7 );
400+ #endif
396401
397- mlk_polyvec_ntt (& skpv );
398- mlk_polyvec_ntt (& e );
402+ mlk_polyvec_ntt (skpv );
403+ mlk_polyvec_ntt (e );
399404
400- mlk_polyvec_mulcache_compute (& skpv_cache , & skpv );
401- mlk_matvec_mul (& pkpv , & a , & skpv , & skpv_cache );
402- mlk_polyvec_tomont (& pkpv );
405+ mlk_polyvec_mulcache_compute (skpv_cache , skpv );
406+ mlk_matvec_mul (pkpv , a , skpv , skpv_cache );
407+ mlk_polyvec_tomont (pkpv );
403408
404- mlk_polyvec_add (& pkpv , & e );
405- mlk_polyvec_reduce (& pkpv );
406- mlk_polyvec_reduce (& skpv );
409+ mlk_polyvec_add (pkpv , e );
410+ mlk_polyvec_reduce (pkpv );
411+ mlk_polyvec_reduce (skpv );
407412
408- mlk_pack_sk (sk , & skpv );
409- mlk_pack_pk (pk , & pkpv , publicseed );
413+ mlk_pack_sk (sk , skpv );
414+ mlk_pack_pk (pk , pkpv , publicseed );
410415
411416 /* Specification: Partially implements
412417 * @[FIPS203, Section 3.3, Destruction of intermediate values] */
413418 mlk_zeroize (buf , sizeof (buf ));
414419 mlk_zeroize (coins_with_domain_separator , sizeof (coins_with_domain_separator ));
415- mlk_zeroize (& a , sizeof (a ));
420+ mlk_zeroize (a , sizeof (a ));
416421 mlk_zeroize (& e , sizeof (e ));
417422 mlk_zeroize (& skpv , sizeof (skpv ));
418423 mlk_zeroize (& skpv_cache , sizeof (skpv_cache ));
@@ -438,7 +443,7 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
438443 mlk_poly v , k , epp ;
439444 mlk_polyvec_mulcache sp_cache ;
440445
441- mlk_unpack_pk (& pkpv , seed , pk );
446+ mlk_unpack_pk (pkpv , seed , pk );
442447 mlk_poly_frommsg (& k , m );
443448
444449 /*
@@ -449,47 +454,44 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
449454 */
450455 MLK_CT_TESTING_DECLASSIFY (seed , MLKEM_SYMBYTES );
451456
452- mlk_gen_matrix (& at , seed , 1 /* transpose */ );
457+ mlk_gen_matrix (at , seed , 1 /* transpose */ );
453458
454459#if MLKEM_K == 2
455- mlk_poly_getnoise_eta1122_4x (& sp . vec [0 ], & sp . vec [1 ], & ep . vec [0 ], & ep . vec [1 ],
456- coins , 0 , 1 , 2 , 3 );
460+ mlk_poly_getnoise_eta1122_4x (& sp [0 ], & sp [1 ], & ep [0 ], & ep [1 ], coins , 0 , 1 , 2 ,
461+ 3 );
457462 mlk_poly_getnoise_eta2 (& epp , coins , 4 );
458463#elif MLKEM_K == 3
459464 /*
460465 * In this call, only the first three output buffers are needed.
461466 * The last parameter is a dummy that's overwritten later.
462467 */
463- mlk_poly_getnoise_eta1_4x (& sp . vec [0 ], & sp . vec [1 ], & sp . vec [2 ], & b . vec [0 ],
464- coins , 0 , 1 , 2 , 0xFF );
468+ mlk_poly_getnoise_eta1_4x (& sp [0 ], & sp [1 ], & sp [2 ], & b [0 ], coins , 0 , 1 , 2 ,
469+ 0xFF );
465470 /* The fourth output buffer in this call _is_ used. */
466- mlk_poly_getnoise_eta2_4x (& ep .vec [0 ], & ep .vec [1 ], & ep .vec [2 ], & epp , coins , 3 ,
467- 4 , 5 , 6 );
471+ mlk_poly_getnoise_eta2_4x (& ep [0 ], & ep [1 ], & ep [2 ], & epp , coins , 3 , 4 , 5 , 6 );
468472#elif MLKEM_K == 4
469- mlk_poly_getnoise_eta1_4x (& sp .vec [0 ], & sp .vec [1 ], & sp .vec [2 ], & sp .vec [3 ],
470- coins , 0 , 1 , 2 , 3 );
471- mlk_poly_getnoise_eta2_4x (& ep .vec [0 ], & ep .vec [1 ], & ep .vec [2 ], & ep .vec [3 ],
472- coins , 4 , 5 , 6 , 7 );
473+ mlk_poly_getnoise_eta1_4x (& sp [0 ], & sp [1 ], & sp [2 ], & sp [3 ], coins , 0 , 1 , 2 , 3 );
474+ mlk_poly_getnoise_eta2_4x (& ep [0 ], & ep [1 ], & ep [2 ], & ep [3 ], coins , 4 , 5 , 6 , 7 );
473475 mlk_poly_getnoise_eta2 (& epp , coins , 8 );
474- #endif /* MLKEM_K == 4 */
476+ #endif
475477
476- mlk_polyvec_ntt (& sp );
478+ mlk_polyvec_ntt (sp );
477479
478- mlk_polyvec_mulcache_compute (& sp_cache , & sp );
479- mlk_matvec_mul (& b , & at , & sp , & sp_cache );
480- mlk_polyvec_basemul_acc_montgomery_cached (& v , & pkpv , & sp , & sp_cache );
480+ mlk_polyvec_mulcache_compute (sp_cache , sp );
481+ mlk_matvec_mul (b , at , sp , sp_cache );
482+ mlk_polyvec_basemul_acc_montgomery_cached (& v , pkpv , sp , sp_cache );
481483
482- mlk_polyvec_invntt_tomont (& b );
484+ mlk_polyvec_invntt_tomont (b );
483485 mlk_poly_invntt_tomont (& v );
484486
485- mlk_polyvec_add (& b , & ep );
487+ mlk_polyvec_add (b , ep );
486488 mlk_poly_add (& v , & epp );
487489 mlk_poly_add (& v , & k );
488490
489- mlk_polyvec_reduce (& b );
491+ mlk_polyvec_reduce (b );
490492 mlk_poly_reduce (& v );
491493
492- mlk_pack_ciphertext (c , & b , & v );
494+ mlk_pack_ciphertext (c , b , & v );
493495
494496 /* Specification: Partially implements
495497 * @[FIPS203, Section 3.3, Destruction of intermediate values] */
@@ -498,7 +500,7 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
498500 mlk_zeroize (& sp_cache , sizeof (sp_cache ));
499501 mlk_zeroize (& b , sizeof (b ));
500502 mlk_zeroize (& v , sizeof (v ));
501- mlk_zeroize (& at , sizeof (at ));
503+ mlk_zeroize (at , sizeof (at ));
502504 mlk_zeroize (& k , sizeof (k ));
503505 mlk_zeroize (& ep , sizeof (ep ));
504506 mlk_zeroize (& epp , sizeof (epp ));
@@ -516,12 +518,12 @@ void mlk_indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],
516518 mlk_poly v , sb ;
517519 mlk_polyvec_mulcache b_cache ;
518520
519- mlk_unpack_ciphertext (& b , & v , c );
520- mlk_unpack_sk (& skpv , sk );
521+ mlk_unpack_ciphertext (b , & v , c );
522+ mlk_unpack_sk (skpv , sk );
521523
522- mlk_polyvec_ntt (& b );
523- mlk_polyvec_mulcache_compute (& b_cache , & b );
524- mlk_polyvec_basemul_acc_montgomery_cached (& sb , & skpv , & b , & b_cache );
524+ mlk_polyvec_ntt (b );
525+ mlk_polyvec_mulcache_compute (b_cache , b );
526+ mlk_polyvec_basemul_acc_montgomery_cached (& sb , skpv , b , b_cache );
525527 mlk_poly_invntt_tomont (& sb );
526528
527529 mlk_poly_sub (& v , & sb );
0 commit comments