Skip to content

Commit b248d02

Browse files
jakemasmkannwischer
authored andcommitted
C: Public Key from Secret Key function
Refactor keygen to use a new function that derives t0 t1 tr pk from rho s1 s2 so that this function can also be called by a utility function pk_to_sk that generates the pk given the sk. We also include ct_memcmp for constant time comparison. Signed-off-by: Jake Massimo <jakemas@amazon.com>
1 parent 62c7a8d commit b248d02

File tree

18 files changed

+667
-42
lines changed

18 files changed

+667
-42
lines changed

mldsa/mldsa_native.S

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@
321321
#undef crypto_sign_keypair
322322
#undef crypto_sign_keypair_internal
323323
#undef crypto_sign_open
324+
#undef crypto_sign_pk_from_sk
324325
#undef crypto_sign_signature
325326
#undef crypto_sign_signature_extmu
326327
#undef crypto_sign_signature_internal

mldsa/mldsa_native.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@
318318
#undef crypto_sign_keypair
319319
#undef crypto_sign_keypair_internal
320320
#undef crypto_sign_open
321+
#undef crypto_sign_pk_from_sk
321322
#undef crypto_sign_signature
322323
#undef crypto_sign_signature_extmu
323324
#undef crypto_sign_signature_internal

mldsa/mldsa_native.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,25 @@ size_t MLD_API_NAMESPACE(prepare_domain_separation_prefix)(
615615
uint8_t prefix[MLD_DOMAIN_SEPARATION_MAX_BYTES], const uint8_t *ph,
616616
size_t phlen, const uint8_t *ctx, size_t ctxlen, int hashalg);
617617

618+
/*************************************************
619+
* Name: crypto_sign_pk_from_sk
620+
*
621+
* Description: Derives public key from secret key with validation.
622+
* Checks that t0 and tr stored in sk match recomputed values.
623+
*
624+
* Arguments: - uint8_t pk[CRYPTO_PUBLICKEYBYTES]: output public key
625+
* - const uint8_t sk[CRYPTO_SECRETKEYBYTES]: input secret key
626+
*
627+
* Returns 0 on success, -1 if validation fails (invalid secret key)
628+
*
629+
* Note: This function leaks whether the secret key is valid or invalid
630+
* through its return value and timing.
631+
**************************************************/
632+
MLD_API_MUST_CHECK_RETURN_VALUE
633+
int MLD_API_NAMESPACE(pk_from_sk)(
634+
uint8_t pk[MLDSA_PUBLICKEYBYTES(MLD_CONFIG_API_PARAMETER_SET)],
635+
const uint8_t sk[MLDSA_SECRETKEYBYTES(MLD_CONFIG_API_PARAMETER_SET)]);
636+
618637
/****************************** SUPERCOP API *********************************/
619638

620639
#if !defined(MLD_CONFIG_API_NO_SUPERCOP)

mldsa/src/ct.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,19 @@ __contract__(ensures(return_value == 0)) { return (int64_t)mld_ct_get_optblocker
9292
static MLD_INLINE uint32_t mld_ct_get_optblocker_u32(void)
9393
__contract__(ensures(return_value == 0)) { return (uint32_t)mld_ct_get_optblocker_u64(); }
9494

95+
static MLD_INLINE uint8_t mld_ct_get_optblocker_u8(void)
96+
__contract__(ensures(return_value == 0)) { return (uint8_t)mld_ct_get_optblocker_u64(); }
97+
9598
/* Opt-blocker based implementation of value barriers */
9699
static MLD_INLINE int64_t mld_value_barrier_i64(int64_t b)
97100
__contract__(ensures(return_value == b)) { return (b ^ mld_ct_get_optblocker_i64()); }
98101

99102
static MLD_INLINE uint32_t mld_value_barrier_u32(uint32_t b)
100103
__contract__(ensures(return_value == b)) { return (b ^ mld_ct_get_optblocker_u32()); }
101104

105+
static MLD_INLINE uint8_t mld_value_barrier_u8(uint8_t b)
106+
__contract__(ensures(return_value == b)) { return (b ^ mld_ct_get_optblocker_u8()); }
107+
102108

103109
#else /* !MLD_USE_ASM_VALUE_BARRIER */
104110
static MLD_INLINE int64_t mld_value_barrier_i64(int64_t b)
@@ -114,6 +120,13 @@ __contract__(ensures(return_value == b))
114120
__asm__("" : "+r"(b));
115121
return b;
116122
}
123+
124+
static MLD_INLINE uint8_t mld_value_barrier_u8(uint8_t b)
125+
__contract__(ensures(return_value == b))
126+
{
127+
__asm__("" : "+r"(b));
128+
return b;
129+
}
117130
#endif /* MLD_USE_ASM_VALUE_BARRIER */
118131

119132
#ifdef CBMC
@@ -240,6 +253,48 @@ __contract__(
240253
#if !defined(__ASSEMBLER__)
241254
#include <string.h>
242255

256+
/*************************************************
257+
* Name: mld_ct_memcmp
258+
*
259+
* Description: Compare two arrays for equality in constant time.
260+
*
261+
* Arguments: const void *a: pointer to first byte array
262+
* const void *b: pointer to second byte array
263+
* size_t len: length of the byte arrays
264+
*
265+
* Returns 0 if the byte arrays are equal, a non-zero value otherwise
266+
**************************************************/
267+
static MLD_INLINE uint8_t mld_ct_memcmp(const void *a, const void *b,
268+
const size_t len)
269+
__contract__(
270+
requires(len <= UINT16_MAX)
271+
requires(memory_no_alias(a, len))
272+
requires(memory_no_alias(b, len))
273+
ensures((return_value == 0) == forall(i, 0, len, (((const uint8_t *)a)[i] == ((const uint8_t *)b)[i])))
274+
)
275+
{
276+
const uint8_t *a_bytes = (const uint8_t *)a;
277+
const uint8_t *b_bytes = (const uint8_t *)b;
278+
uint8_t r = 0, s = 0;
279+
unsigned i;
280+
281+
for (i = 0; i < len; i++)
282+
__loop__(
283+
invariant(i <= len)
284+
invariant((r == 0) == (forall(k, 0, i, (a_bytes[k] == b_bytes[k])))))
285+
{
286+
r |= a_bytes[i] ^ b_bytes[i];
287+
/* s is useless, but prevents the loop from being aborted once r=0xff. */
288+
s ^= a_bytes[i] ^ b_bytes[i];
289+
}
290+
291+
/*
292+
* XOR twice with s, separated by a value barrier, to prevent the compile
293+
* from dropping the s computation in the loop.
294+
*/
295+
return (uint8_t)((mld_value_barrier_u32((uint32_t)r) ^ s) ^ s);
296+
}
297+
243298
/*************************************************
244299
* Name: mld_zeroize
245300
*

mldsa/src/sign.c

Lines changed: 137 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <string.h>
2929

3030
#include "cbmc.h"
31+
#include "ct.h"
3132
#include "debug.h"
3233
#include "packing.h"
3334
#include "poly.h"
@@ -48,6 +49,8 @@
4849
#define mld_H MLD_ADD_PARAM_SET(mld_H)
4950
#define mld_attempt_signature_generation \
5051
MLD_ADD_PARAM_SET(mld_attempt_signature_generation)
52+
#define mld_compute_t0_t1_tr_from_sk_components \
53+
MLD_ADD_PARAM_SET(mld_compute_t0_t1_tr_from_sk_components)
5154
/* End of parameter set namespacing */
5255

5356

@@ -174,6 +177,85 @@ __contract__(
174177
#endif /* !MLD_CONFIG_SERIAL_FIPS202_ONLY */
175178
}
176179

180+
/*************************************************
181+
* Name: mld_compute_t0_t1_tr_from_sk_components
182+
*
183+
* Description: Computes t0, t1, tr, and pk from secret key components
184+
* rho, s1, s2. This is the shared computation used by
185+
* both keygen and generating the public key from the
186+
* secret key.
187+
*
188+
* Arguments: - mld_polyveck *t0: output t0
189+
* - mld_polyveck *t1: output t1
190+
* - uint8_t tr[MLDSA_TRBYTES]: output tr
191+
* - uint8_t pk[CRYPTO_PUBLICKEYBYTES]: output public key
192+
* - const uint8_t rho[MLDSA_SEEDBYTES]: input rho
193+
* - const mld_polyvecl *s1: input s1
194+
* - const mld_polyveck *s2: input s2
195+
**************************************************/
196+
static void mld_compute_t0_t1_tr_from_sk_components(
197+
mld_polyveck *t0, mld_polyveck *t1, uint8_t tr[MLDSA_TRBYTES],
198+
uint8_t pk[CRYPTO_PUBLICKEYBYTES], const uint8_t rho[MLDSA_SEEDBYTES],
199+
const mld_polyvecl *s1, const mld_polyveck *s2)
200+
__contract__(
201+
requires(memory_no_alias(t0, sizeof(mld_polyveck)))
202+
requires(memory_no_alias(t1, sizeof(mld_polyveck)))
203+
requires(memory_no_alias(tr, MLDSA_TRBYTES))
204+
requires(memory_no_alias(pk, CRYPTO_PUBLICKEYBYTES))
205+
requires(memory_no_alias(rho, MLDSA_SEEDBYTES))
206+
requires(memory_no_alias(s1, sizeof(mld_polyvecl)))
207+
requires(memory_no_alias(s2, sizeof(mld_polyveck)))
208+
requires(forall(l0, 0, MLDSA_L, array_bound(s1->vec[l0].coeffs, 0, MLDSA_N, MLD_POLYETA_UNPACK_LOWER_BOUND, MLDSA_ETA + 1)))
209+
requires(forall(k0, 0, MLDSA_K, array_bound(s2->vec[k0].coeffs, 0, MLDSA_N, MLD_POLYETA_UNPACK_LOWER_BOUND, MLDSA_ETA + 1)))
210+
assigns(memory_slice(t0, sizeof(mld_polyveck)))
211+
assigns(memory_slice(t1, sizeof(mld_polyveck)))
212+
assigns(memory_slice(tr, MLDSA_TRBYTES))
213+
assigns(memory_slice(pk, CRYPTO_PUBLICKEYBYTES))
214+
ensures(forall(k1, 0, MLDSA_K, array_bound(t0->vec[k1].coeffs, 0, MLDSA_N, -(1<<(MLDSA_D-1)) + 1, (1<<(MLDSA_D-1)) + 1)))
215+
ensures(forall(k2, 0, MLDSA_K, array_bound(t1->vec[k2].coeffs, 0, MLDSA_N, 0, 1 << 10)))
216+
)
217+
{
218+
mld_polyvecl mat[MLDSA_K], s1hat;
219+
mld_polyveck t;
220+
221+
/* Expand matrix */
222+
mld_polyvec_matrix_expand(mat, rho);
223+
224+
/* Matrix-vector multiplication */
225+
s1hat = *s1;
226+
mld_polyvecl_ntt(&s1hat);
227+
mld_polyvec_matrix_pointwise_montgomery(&t, mat, &s1hat);
228+
mld_polyveck_reduce(&t);
229+
mld_polyveck_invntt_tomont(&t);
230+
231+
/* Add error vector s2 */
232+
mld_polyveck_add(&t, s2);
233+
234+
/* Reference: The following reduction is not present in the reference
235+
* implementation. Omitting this reduction requires the output of
236+
* the invntt to be small enough such that the addition of s2 does
237+
* not result in absolute values >= MLDSA_Q. While our C, x86_64,
238+
* and AArch64 invntt implementations produce small enough
239+
* values for this to work out, it complicates the bounds
240+
* reasoning. We instead add an additional reduction, and can
241+
* consequently, relax the bounds requirements for the invntt.
242+
*/
243+
mld_polyveck_reduce(&t);
244+
245+
/* Decompose to get t1, t0 */
246+
mld_polyveck_caddq(&t);
247+
mld_polyveck_power2round(t1, t0, &t);
248+
249+
/* Pack public key and compute tr */
250+
mld_pack_pk(pk, rho, t1);
251+
mld_shake256(tr, MLDSA_TRBYTES, pk, CRYPTO_PUBLICKEYBYTES);
252+
253+
/* @[FIPS204, Section 3.6.3] Destruction of intermediate values. */
254+
mld_zeroize(mat, sizeof(mat));
255+
mld_zeroize(&s1hat, sizeof(s1hat));
256+
mld_zeroize(&t, sizeof(t));
257+
}
258+
177259
MLD_MUST_CHECK_RETURN_VALUE
178260
MLD_EXTERNAL_API
179261
int crypto_sign_keypair_internal(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
@@ -184,9 +266,8 @@ int crypto_sign_keypair_internal(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
184266
MLD_ALIGN uint8_t inbuf[MLDSA_SEEDBYTES + 2];
185267
MLD_ALIGN uint8_t tr[MLDSA_TRBYTES];
186268
const uint8_t *rho, *rhoprime, *key;
187-
mld_polyvecl mat[MLDSA_K];
188-
mld_polyvecl s1, s1hat;
189-
mld_polyveck s2, t2, t1, t0;
269+
mld_polyvecl s1;
270+
mld_polyveck s2, t1, t0;
190271

191272
/* Get randomness for rho, rhoprime and key */
192273
mld_memcpy(inbuf, seed, MLDSA_SEEDBYTES);
@@ -200,50 +281,23 @@ int crypto_sign_keypair_internal(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
200281

201282
/* Constant time: rho is part of the public key and, hence, public. */
202283
MLD_CT_TESTING_DECLASSIFY(rho, MLDSA_SEEDBYTES);
203-
/* Expand matrix */
204-
mld_polyvec_matrix_expand(mat, rho);
205-
mld_sample_s1_s2(&s1, &s2, rhoprime);
206284

207-
/* Matrix-vector multiplication */
208-
s1hat = s1;
209-
mld_polyvecl_ntt(&s1hat);
210-
mld_polyvec_matrix_pointwise_montgomery(&t1, mat, &s1hat);
211-
mld_polyveck_reduce(&t1);
212-
mld_polyveck_invntt_tomont(&t1);
213-
214-
/* Add error vector s2 */
215-
mld_polyveck_add(&t1, &s2);
216-
217-
/* Reference: The following reduction is not present in the reference
218-
* implementation. Omitting this reduction requires the output of
219-
* the invntt to be small enough such that the addition of s2 does
220-
* not result in absolute values >= MLDSA_Q. While our C, x86_64,
221-
* and AArch64 invntt implementations produce small enough
222-
* values for this to work out, it complicates the bounds
223-
* reasoning. We instead add an additional reduction, and can
224-
* consequently, relax the bounds requirements for the invntt.
225-
*/
226-
mld_polyveck_reduce(&t1);
285+
/* Sample s1 and s2 */
286+
mld_sample_s1_s2(&s1, &s2, rhoprime);
227287

228-
/* Extract t1 and write public key */
229-
mld_polyveck_caddq(&t1);
230-
mld_polyveck_power2round(&t2, &t0, &t1);
231-
mld_pack_pk(pk, rho, &t2);
288+
/* Compute t0, t1, tr, and pk from rho, s1, s2 */
289+
mld_compute_t0_t1_tr_from_sk_components(&t0, &t1, tr, pk, rho, &s1, &s2);
232290

233-
/* Compute H(rho, t1) and write secret key */
234-
mld_shake256(tr, MLDSA_TRBYTES, pk, CRYPTO_PUBLICKEYBYTES);
291+
/* Pack secret key */
235292
mld_pack_sk(sk, rho, tr, key, &t0, &s1, &s2);
236293

237294
/* @[FIPS204, Section 3.6.3] Destruction of intermediate values. */
238295
mld_zeroize(seedbuf, sizeof(seedbuf));
239296
mld_zeroize(inbuf, sizeof(inbuf));
240297
mld_zeroize(tr, sizeof(tr));
241-
mld_zeroize(mat, sizeof(mat));
242298
mld_zeroize(&s1, sizeof(s1));
243-
mld_zeroize(&s1hat, sizeof(s1hat));
244299
mld_zeroize(&s2, sizeof(s2));
245300
mld_zeroize(&t1, sizeof(t1));
246-
mld_zeroize(&t2, sizeof(t2));
247301
mld_zeroize(&t0, sizeof(t0));
248302

249303
/* Constant time: pk is the public key, inherently public data */
@@ -1131,6 +1185,53 @@ size_t mld_prepare_domain_separation_prefix(
11311185
return 2 + ctxlen + MLD_PRE_HASH_OID_LEN + phlen;
11321186
}
11331187

1188+
MLD_EXTERNAL_API
1189+
int crypto_sign_pk_from_sk(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
1190+
const uint8_t sk[CRYPTO_SECRETKEYBYTES])
1191+
{
1192+
MLD_ALIGN uint8_t rho[MLDSA_SEEDBYTES];
1193+
MLD_ALIGN uint8_t tr[MLDSA_TRBYTES];
1194+
MLD_ALIGN uint8_t tr_computed[MLDSA_TRBYTES];
1195+
MLD_ALIGN uint8_t key[MLDSA_SEEDBYTES];
1196+
mld_polyvecl s1;
1197+
mld_polyveck s2, t0, t0_computed, t1;
1198+
uint8_t res, res0, res1;
1199+
1200+
/* Unpack secret key */
1201+
mld_unpack_sk(rho, tr, key, &t0, &s1, &s2, sk);
1202+
1203+
/* Recompute t0, t1, tr, and pk from rho, s1, s2 */
1204+
mld_compute_t0_t1_tr_from_sk_components(&t0_computed, &t1, tr_computed, pk,
1205+
rho, &s1, &s2);
1206+
1207+
/* Validate t0 and tr using constant-time comparisons */
1208+
res0 = mld_ct_memcmp(&t0, &t0_computed, sizeof(mld_polyveck));
1209+
res1 = mld_ct_memcmp(tr, tr_computed, MLDSA_TRBYTES);
1210+
res = mld_value_barrier_u8(res0 | res1);
1211+
1212+
/* Declassify the final result of the validity check. */
1213+
MLD_CT_TESTING_DECLASSIFY(&res, sizeof(res));
1214+
if (res != 0)
1215+
{
1216+
mld_zeroize(pk, CRYPTO_PUBLICKEYBYTES);
1217+
}
1218+
1219+
/* @[FIPS204, Section 3.6.3] Destruction of intermediate values. */
1220+
mld_zeroize(rho, sizeof(rho));
1221+
mld_zeroize(tr, sizeof(tr));
1222+
mld_zeroize(tr_computed, sizeof(tr_computed));
1223+
mld_zeroize(key, sizeof(key));
1224+
mld_zeroize(&s1, sizeof(s1));
1225+
mld_zeroize(&s2, sizeof(s2));
1226+
mld_zeroize(&t0, sizeof(t0));
1227+
mld_zeroize(&t0_computed, sizeof(t0_computed));
1228+
mld_zeroize(&t1, sizeof(t1));
1229+
1230+
/* Constant time: pk is either the valid public key or zeroed on error */
1231+
MLD_CT_TESTING_DECLASSIFY(pk, CRYPTO_PUBLICKEYBYTES);
1232+
return (res != 0) ? -1 : 0;
1233+
}
1234+
11341235
/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.
11351236
* Don't modify by hand -- this is auto-generated by scripts/autogen. */
11361237
#undef mld_check_pct
@@ -1139,5 +1240,6 @@ size_t mld_prepare_domain_separation_prefix(
11391240
#undef mld_get_hash_oid
11401241
#undef mld_H
11411242
#undef mld_attempt_signature_generation
1243+
#undef mld_compute_t0_t1_tr_from_sk_components
11421244
#undef NONCE_UB
11431245
#undef MLD_PRE_HASH_OID_LEN

mldsa/src/sign.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
MLD_NAMESPACE_KL(verify_pre_hash_shake256)
6565
#define mld_prepare_domain_separation_prefix \
6666
MLD_NAMESPACE_KL(prepare_domain_separation_prefix)
67+
#define crypto_sign_pk_from_sk MLD_NAMESPACE_KL(pk_from_sk)
6768

6869
/*************************************************
6970
* Hash algorithm constants for domain separation
@@ -686,4 +687,28 @@ __contract__(
686687
ensures(return_value <= MLD_DOMAIN_SEPARATION_MAX_BYTES)
687688
);
688689

690+
/*************************************************
691+
* Name: crypto_sign_pk_from_sk
692+
*
693+
* Description: Derives public key from secret key with validation.
694+
* Checks that t0 and tr stored in sk match recomputed values.
695+
*
696+
* Arguments: - uint8_t pk[CRYPTO_PUBLICKEYBYTES]: output public key
697+
* - const uint8_t sk[CRYPTO_SECRETKEYBYTES]: input secret key
698+
*
699+
* Returns 0 on success, -1 if validation fails (invalid secret key)
700+
*
701+
* Note: This function leaks whether the secret key is valid or invalid
702+
* through its return value and timing.
703+
**************************************************/
704+
MLD_MUST_CHECK_RETURN_VALUE
705+
MLD_EXTERNAL_API
706+
int crypto_sign_pk_from_sk(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
707+
const uint8_t sk[CRYPTO_SECRETKEYBYTES])
708+
__contract__(
709+
requires(memory_no_alias(pk, CRYPTO_PUBLICKEYBYTES))
710+
requires(memory_no_alias(sk, CRYPTO_SECRETKEYBYTES))
711+
assigns(memory_slice(pk, CRYPTO_PUBLICKEYBYTES))
712+
ensures(return_value == 0 || return_value == -1)
713+
);
689714
#endif /* !MLD_SIGN_H */

0 commit comments

Comments
 (0)