Skip to content

Commit 69b4e2c

Browse files
committed
Extended API
Signed-off-by: Matthias J. Kannwischer <[email protected]>
1 parent e319849 commit 69b4e2c

File tree

5 files changed

+480
-141
lines changed

5 files changed

+480
-141
lines changed

mlkem/indcpa.c

Lines changed: 79 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,15 @@
2020
* This is to facilitate building multiple instances
2121
* of mlkem-native (e.g. with varying security levels)
2222
* within a single compilation unit. */
23-
#define mlk_pack_pk MLK_ADD_LEVEL(mlk_pack_pk)
2423
#define mlk_unpack_pk MLK_ADD_LEVEL(mlk_unpack_pk)
25-
#define mlk_pack_sk MLK_ADD_LEVEL(mlk_pack_sk)
2624
#define mlk_unpack_sk MLK_ADD_LEVEL(mlk_unpack_sk)
2725
#define mlk_pack_ciphertext MLK_ADD_LEVEL(mlk_pack_ciphertext)
2826
#define mlk_unpack_ciphertext MLK_ADD_LEVEL(mlk_unpack_ciphertext)
2927
#define mlk_matvec_mul MLK_ADD_LEVEL(mlk_matvec_mul)
3028
/* End of level namespacing */
3129

3230
/*************************************************
33-
* Name: mlk_pack_pk
31+
* Name: mlk_indcpa_marshal_pk
3432
*
3533
* Description: Serialize the public key as concatenation of the
3634
* serialized vector of polynomials pk
@@ -45,16 +43,18 @@
4543
* Implements [FIPS 203, Algorithm 13 (K-PKE.KeyGen), L19]
4644
*
4745
**************************************************/
48-
static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], mlk_polyvec *pk,
49-
const uint8_t seed[MLKEM_SYMBYTES])
46+
MLK_INTERNAL_API
47+
void mlk_indcpa_marshal_pk(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
48+
const mlk_indcpa_public_key *pks)
5049
{
51-
mlk_assert_bound_2d(pk, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
52-
mlk_polyvec_tobytes(r, pk);
53-
memcpy(r + MLKEM_POLYVECBYTES, seed, MLKEM_SYMBYTES);
50+
mlk_assert_bound_2d(pks->pkpv, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
51+
mlk_polyvec_tobytes(pk, &pks->pkpv);
52+
memcpy(pk + MLKEM_POLYVECBYTES, pks->seed, MLKEM_SYMBYTES);
5453
}
5554

55+
5656
/*************************************************
57-
* Name: mlk_unpack_pk
57+
* Name: mlk_indcpa_parse_pk
5858
*
5959
* Description: De-serialize public key from a byte array;
6060
* approximate inverse of mlk_pack_pk
@@ -69,11 +69,13 @@ static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], mlk_polyvec *pk,
6969
* Implements [FIPS 203, Algorithm 14 (K-PKE.Encrypt), L2-3]
7070
*
7171
**************************************************/
72-
static void mlk_unpack_pk(mlk_polyvec *pk, uint8_t seed[MLKEM_SYMBYTES],
73-
const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES])
72+
MLK_INTERNAL_API
73+
void mlk_indcpa_parse_pk(mlk_indcpa_public_key *pks,
74+
const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES])
7475
{
75-
mlk_polyvec_frombytes(pk, packedpk);
76-
memcpy(seed, packedpk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES);
76+
mlk_polyvec_frombytes(&pks->pkpv, pk);
77+
memcpy(pks->seed, pk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES);
78+
mlk_gen_matrix(pks->at, pks->seed, 1);
7779

7880
/* NOTE: If a modulus check was conducted on the PK, we know at this
7981
* point that the coefficients of `pk` are unsigned canonical. The
@@ -82,7 +84,7 @@ static void mlk_unpack_pk(mlk_polyvec *pk, uint8_t seed[MLKEM_SYMBYTES],
8284
}
8385

8486
/*************************************************
85-
* Name: mlk_pack_sk
87+
* Name: mlk_indcpa_marshal_sk
8688
*
8789
* Description: Serialize the secret key
8890
*
@@ -94,14 +96,16 @@ static void mlk_unpack_pk(mlk_polyvec *pk, uint8_t seed[MLKEM_SYMBYTES],
9496
* Implements [FIPS 203, Algorithm 13 (K-PKE.KeyGen), L20]
9597
*
9698
**************************************************/
97-
static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], mlk_polyvec *sk)
99+
MLK_INTERNAL_API
100+
void mlk_indcpa_marshal_sk(uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES],
101+
const mlk_indcpa_secret_key *sks)
98102
{
99-
mlk_assert_bound_2d(sk, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
100-
mlk_polyvec_tobytes(r, sk);
103+
mlk_assert_bound_2d(&sks->skpv, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
104+
mlk_polyvec_tobytes(sk, &sks->skpv);
101105
}
102106

103107
/*************************************************
104-
* Name: mlk_unpack_sk
108+
* Name: mlk_indcpa_parse_sk
105109
*
106110
* Description: De-serialize the secret key; inverse of mlk_pack_sk
107111
*
@@ -114,10 +118,11 @@ static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], mlk_polyvec *sk)
114118
* Implements [FIPS 203, Algorithm 15 (K-PKE.Decrypt), L5]
115119
*
116120
**************************************************/
117-
static void mlk_unpack_sk(mlk_polyvec *sk,
118-
const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES])
121+
MLK_INTERNAL_API
122+
void mlk_indcpa_parse_sk(mlk_indcpa_secret_key *sks,
123+
const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES])
119124
{
120-
mlk_polyvec_frombytes(sk, packedsk);
125+
mlk_polyvec_frombytes(&sks->skpv, sk);
121126
}
122127

123128
/*************************************************
@@ -179,6 +184,24 @@ __contract__(
179184
ensures(array_bound(data, 0, MLKEM_N, 0, MLKEM_Q))) { ((void)data); }
180185
#endif /* MLK_USE_NATIVE_NTT_CUSTOM_ORDER */
181186

187+
static void transpose_matrix(mlk_polyvec a[MLKEM_K])
188+
{
189+
unsigned int i, j, k;
190+
int16_t t;
191+
for (i = 0; i < MLKEM_K; i++)
192+
{
193+
for (j = i + 1; j < MLKEM_K; j++)
194+
{
195+
for (k = 0; k < MLKEM_N; k++)
196+
{
197+
t = a[i].vec[j].coeffs[k];
198+
a[i].vec[j].coeffs[k] = a[j].vec[i].coeffs[k];
199+
a[j].vec[i].coeffs[k] = t;
200+
}
201+
}
202+
}
203+
}
204+
182205
/* Reference: `gen_matrix()` in the reference implementation.
183206
* - We use a special subroutine to generate 4 polynomials
184207
* at a time, to be able to leverage batched Keccak-f1600
@@ -332,14 +355,14 @@ __contract__(
332355
* - We include buffer zeroization.
333356
*/
334357
MLK_INTERNAL_API
335-
void mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
336-
uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES],
358+
void mlk_indcpa_keypair_derand(mlk_indcpa_public_key *pk,
359+
mlk_indcpa_secret_key *sk,
337360
const uint8_t coins[MLKEM_SYMBYTES])
338361
{
339362
MLK_ALIGN uint8_t buf[2 * MLKEM_SYMBYTES];
340363
const uint8_t *publicseed = buf;
341364
const uint8_t *noiseseed = buf + MLKEM_SYMBYTES;
342-
mlk_polyvec a[MLKEM_K], e, pkpv, skpv;
365+
mlk_polyvec e;
343366
mlk_polyvec_mulcache skpv_cache;
344367

345368
MLK_ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1];
@@ -357,51 +380,54 @@ void mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
357380
*/
358381
MLK_CT_TESTING_DECLASSIFY(publicseed, MLKEM_SYMBYTES);
359382

360-
mlk_gen_matrix(a, publicseed, 0 /* no transpose */);
383+
mlk_gen_matrix(pk->at, publicseed, 0 /* no transpose */);
361384

362385
#if MLKEM_K == 2
363-
mlk_poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, e.vec + 0, e.vec + 1,
364-
noiseseed, 0, 1, 2, 3);
386+
mlk_poly_getnoise_eta1_4x(sk->skpv.vec + 0, sk->skpv.vec + 1, e.vec + 0,
387+
e.vec + 1, noiseseed, 0, 1, 2, 3);
365388
#elif MLKEM_K == 3
366389
/*
367390
* Only the first three output buffers are needed.
368391
* The laster parameter is a dummy that's overwritten later.
369392
*/
370-
mlk_poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2,
371-
pkpv.vec + 0 /* irrelevant */, noiseseed, 0, 1, 2,
372-
0xFF /* irrelevant */);
393+
mlk_poly_getnoise_eta1_4x(sk->skpv.vec + 0, sk->skpv.vec + 1,
394+
sk->skpv.vec + 2, pk->pkpv.vec + 0 /* irrelevant */,
395+
noiseseed, 0, 1, 2, 0xFF /* irrelevant */);
373396
/* Same here */
374397
mlk_poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2,
375-
pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4, 5,
376-
0xFF /* irrelevant */);
398+
pk->pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4,
399+
5, 0xFF /* irrelevant */);
377400
#elif MLKEM_K == 4
378-
mlk_poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2,
379-
skpv.vec + 3, noiseseed, 0, 1, 2, 3);
401+
mlk_poly_getnoise_eta1_4x(sk->skpv.vec + 0, sk->skpv.vec + 1,
402+
sk->skpv.vec + 2, sk->skpv.vec + 3, noiseseed, 0, 1,
403+
2, 3);
380404
mlk_poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, e.vec + 3,
381405
noiseseed, 4, 5, 6, 7);
382406
#endif
383407

384-
mlk_polyvec_ntt(&skpv);
408+
mlk_polyvec_ntt(&sk->skpv);
385409
mlk_polyvec_ntt(&e);
386410

387-
mlk_polyvec_mulcache_compute(&skpv_cache, &skpv);
388-
mlk_matvec_mul(&pkpv, a, &skpv, &skpv_cache);
389-
mlk_polyvec_tomont(&pkpv);
390-
391-
mlk_polyvec_add(&pkpv, &e);
392-
mlk_polyvec_reduce(&pkpv);
393-
mlk_polyvec_reduce(&skpv);
411+
mlk_polyvec_mulcache_compute(&skpv_cache, &sk->skpv);
412+
mlk_matvec_mul(&pk->pkpv, pk->at, &sk->skpv, &skpv_cache);
413+
mlk_polyvec_tomont(&pk->pkpv);
414+
415+
mlk_polyvec_add(&pk->pkpv, &e);
416+
mlk_polyvec_reduce(&pk->pkpv);
417+
mlk_polyvec_reduce(&sk->skpv);
418+
memcpy(pk->seed, publicseed, MLKEM_SYMBYTES);
419+
/**
420+
* TODO: We can eliminate this by moving the transpose into the matvec_mul
421+
* Then we can also eliminate the flag from the from the gen_matrix function
422+
*/
423+
transpose_matrix(pk->at);
394424

395-
mlk_pack_sk(sk, &skpv);
396-
mlk_pack_pk(pk, &pkpv, publicseed);
397425

398426
/* Specification: Partially implements
399427
* [FIPS 203, Section 3.3, Destruction of intermediate values] */
400428
mlk_zeroize(buf, sizeof(buf));
401429
mlk_zeroize(coins_with_domain_separator, sizeof(coins_with_domain_separator));
402-
mlk_zeroize(a, sizeof(a));
403430
mlk_zeroize(&e, sizeof(e));
404-
mlk_zeroize(&skpv, sizeof(skpv));
405431
mlk_zeroize(&skpv_cache, sizeof(skpv_cache));
406432
}
407433

@@ -416,27 +442,14 @@ void mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
416442
MLK_INTERNAL_API
417443
void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
418444
const uint8_t m[MLKEM_INDCPA_MSGBYTES],
419-
const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
445+
const mlk_indcpa_public_key *pk,
420446
const uint8_t coins[MLKEM_SYMBYTES])
421447
{
422-
MLK_ALIGN uint8_t seed[MLKEM_SYMBYTES];
423-
mlk_polyvec sp, pkpv, ep, at[MLKEM_K], b;
448+
mlk_polyvec sp, ep, b;
424449
mlk_poly v, k, epp;
425450
mlk_polyvec_mulcache sp_cache;
426-
427-
mlk_unpack_pk(&pkpv, seed, pk);
428451
mlk_poly_frommsg(&k, m);
429452

430-
/*
431-
* Declassify the public seed.
432-
* Required to use it in conditional-branches in rejection sampling.
433-
* This is needed because in re-encryption the publicseed originated from sk
434-
* which is marked undefined.
435-
*/
436-
MLK_CT_TESTING_DECLASSIFY(seed, MLKEM_SYMBYTES);
437-
438-
mlk_gen_matrix(at, seed, 1 /* transpose */);
439-
440453
#if MLKEM_K == 2
441454
mlk_poly_getnoise_eta1122_4x(sp.vec + 0, sp.vec + 1, ep.vec + 0, ep.vec + 1,
442455
coins, 0, 1, 2, 3);
@@ -462,8 +475,8 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
462475
mlk_polyvec_ntt(&sp);
463476

464477
mlk_polyvec_mulcache_compute(&sp_cache, &sp);
465-
mlk_matvec_mul(&b, at, &sp, &sp_cache);
466-
mlk_polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache);
478+
mlk_matvec_mul(&b, pk->at, &sp, &sp_cache);
479+
mlk_polyvec_basemul_acc_montgomery_cached(&v, &pk->pkpv, &sp, &sp_cache);
467480

468481
mlk_polyvec_invntt_tomont(&b);
469482
mlk_poly_invntt_tomont(&v);
@@ -479,12 +492,10 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
479492

480493
/* Specification: Partially implements
481494
* [FIPS 203, Section 3.3, Destruction of intermediate values] */
482-
mlk_zeroize(seed, sizeof(seed));
483495
mlk_zeroize(&sp, sizeof(sp));
484496
mlk_zeroize(&sp_cache, sizeof(sp_cache));
485497
mlk_zeroize(&b, sizeof(b));
486498
mlk_zeroize(&v, sizeof(v));
487-
mlk_zeroize(at, sizeof(at));
488499
mlk_zeroize(&k, sizeof(k));
489500
mlk_zeroize(&ep, sizeof(ep));
490501
mlk_zeroize(&epp, sizeof(epp));
@@ -496,18 +507,17 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
496507
MLK_INTERNAL_API
497508
void mlk_indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],
498509
const uint8_t c[MLKEM_INDCPA_BYTES],
499-
const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES])
510+
const mlk_indcpa_secret_key *sk)
500511
{
501-
mlk_polyvec b, skpv;
512+
mlk_polyvec b;
502513
mlk_poly v, sb;
503514
mlk_polyvec_mulcache b_cache;
504515

505516
mlk_unpack_ciphertext(&b, &v, c);
506-
mlk_unpack_sk(&skpv, sk);
507517

508518
mlk_polyvec_ntt(&b);
509519
mlk_polyvec_mulcache_compute(&b_cache, &b);
510-
mlk_polyvec_basemul_acc_montgomery_cached(&sb, &skpv, &b, &b_cache);
520+
mlk_polyvec_basemul_acc_montgomery_cached(&sb, &sk->skpv, &b, &b_cache);
511521
mlk_poly_invntt_tomont(&sb);
512522

513523
mlk_poly_sub(&v, &sb);
@@ -517,7 +527,6 @@ void mlk_indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],
517527

518528
/* Specification: Partially implements
519529
* [FIPS 203, Section 3.3, Destruction of intermediate values] */
520-
mlk_zeroize(&skpv, sizeof(skpv));
521530
mlk_zeroize(&b, sizeof(b));
522531
mlk_zeroize(&b_cache, sizeof(b_cache));
523532
mlk_zeroize(&v, sizeof(v));
@@ -526,9 +535,7 @@ void mlk_indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],
526535

527536
/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.
528537
* Don't modify by hand -- this is auto-generated by scripts/autogen. */
529-
#undef mlk_pack_pk
530538
#undef mlk_unpack_pk
531-
#undef mlk_pack_sk
532539
#undef mlk_unpack_sk
533540
#undef mlk_pack_ciphertext
534541
#undef mlk_unpack_ciphertext

0 commit comments

Comments
 (0)