Skip to content

Commit dc80b36

Browse files
committed
Updates for readability following review.
Introduces local gs_butterfly_reduce() and gs_butterfly_defer() functions, which are inlined for both compilation and proof, but significantly simplify and improve readability of the calling functions. Signed-off-by: Rod Chapman <[email protected]>
1 parent b85afe8 commit dc80b36

File tree

2 files changed

+73
-127
lines changed

2 files changed

+73
-127
lines changed

examples/monolithic_build/mlkem_native_monobuild.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,6 +1562,16 @@
15621562
#undef ct_butterfly
15631563
#endif
15641564

1565+
/* mlkem/ntt.c */
1566+
#if defined(gs_butterfly_defer)
1567+
#undef gs_butterfly_defer
1568+
#endif
1569+
1570+
/* mlkem/ntt.c */
1571+
#if defined(gs_butterfly_reduce)
1572+
#undef gs_butterfly_reduce
1573+
#endif
1574+
15651575
/* mlkem/ntt.c */
15661576
#if defined(invntt_layer321)
15671577
#undef invntt_layer321

mlkem/ntt.c

Lines changed: 63 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#define ntt_layer7_butterfly MLKEM_NAMESPACE(ntt_layer7_butterfly)
2323
#define ntt_layer7 MLKEM_NAMESPACE(ntt_layer7)
2424

25+
#define gs_butterfly_reduce MLKEM_NAMESPACE(gs_butterfly_reduce)
26+
#define gs_butterfly_defer MLKEM_NAMESPACE(gs_butterfly_defer)
2527
#define invntt_layer7_invert_butterfly \
2628
MLKEM_NAMESPACE(invntt_layer7_invert_butterfly)
2729
#define invntt_layer7_invert MLKEM_NAMESPACE(invntt_layer7_invert)
@@ -108,7 +110,7 @@
108110

109111
/* ct_butterfly() performs a single CT Butterfly step */
110112
/* in polynomial denoted by r, using the coefficients */
111-
/* index by coeff1_index and coeff2_index, and the */
113+
/* indexed by coeff1_index and coeff2_index, and the */
112114
/* given value of zeta. */
113115
/* */
114116
/* NOTE that this function is marked INLINE for */
@@ -445,6 +447,37 @@ STATIC_ASSERT(INVNTT_BOUND_REF <= INVNTT_BOUND, invntt_bound)
445447
/* Used to invert and reduce coefficients in the Inverse NTT. */
446448
#define MONT_F 1441
447449

450+
/* gs_butterfly_reduce() performs a single GS Butterfly */
451+
/* step in polynomial denoted by r, using the */
452+
/* coefficients indexes coeff1_index and coeff2_index */
453+
/* and the given value of zeta. */
454+
/* */
455+
/* Like ct_butterfly(), this functions is inlined */
456+
/* for both compilation and proof. */
457+
static INLINE void gs_butterfly_reduce(int16_t r[MLKEM_N],
458+
const int coeff1_index,
459+
const int coeff2_index,
460+
const int16_t zeta)
461+
{
462+
const int16_t t1 = r[coeff1_index];
463+
const int16_t t2 = r[coeff2_index];
464+
r[coeff1_index] = barrett_reduce(t1 + t2);
465+
r[coeff2_index] = fqmul((t2 - t1), zeta);
466+
}
467+
468+
/* As gs_butterfly_reduce(), but does not reduce the */
469+
/* coefficient denoted by coeff1_index */
470+
static INLINE void gs_butterfly_defer(int16_t r[MLKEM_N],
471+
const int coeff1_index,
472+
const int coeff2_index,
473+
const int16_t zeta)
474+
{
475+
const int16_t t1 = r[coeff1_index];
476+
const int16_t t2 = r[coeff2_index];
477+
r[coeff1_index] = t1 + t2;
478+
r[coeff2_index] = fqmul((t2 - t1), zeta);
479+
}
480+
448481
static INLINE void invntt_layer7_invert_butterfly(int16_t r[MLKEM_N],
449482
int zeta_index, int start)
450483
__contract__(
@@ -467,17 +500,14 @@ __contract__(
467500
/* Invert and reduce all coefficients here the first time they */
468501
/* are read. This is efficient, and also means we can accept */
469502
/* any int16_t value for all coefficients as input. */
470-
const int16_t c0 = fqmul(r[ci0], MONT_F);
471-
const int16_t c1 = fqmul(r[ci1], MONT_F);
472-
const int16_t c2 = fqmul(r[ci2], MONT_F);
473-
const int16_t c3 = fqmul(r[ci3], MONT_F);
503+
r[ci0] = fqmul(r[ci0], MONT_F);
504+
r[ci1] = fqmul(r[ci1], MONT_F);
505+
r[ci2] = fqmul(r[ci2], MONT_F);
506+
r[ci3] = fqmul(r[ci3], MONT_F);
474507

475508
/* Reduce all coefficients here to meet the precondition of Layer 6 */
476-
r[ci0] = barrett_reduce(c0 + c2);
477-
r[ci2] = fqmul((c2 - c0), zeta);
478-
479-
r[ci1] = barrett_reduce(c1 + c3);
480-
r[ci3] = fqmul((c3 - c1), zeta);
509+
gs_butterfly_reduce(r, ci0, ci2, zeta);
510+
gs_butterfly_reduce(r, ci1, ci3, zeta);
481511
}
482512

483513
static void invntt_layer7_invert(int16_t r[MLKEM_N])
@@ -521,28 +551,13 @@ __contract__(
521551
const int ci5 = ci0 + 5;
522552
const int ci6 = ci0 + 6;
523553
const int ci7 = ci0 + 7;
524-
const int16_t c0 = r[ci0];
525-
const int16_t c1 = r[ci1];
526-
const int16_t c2 = r[ci2];
527-
const int16_t c3 = r[ci3];
528-
const int16_t c4 = r[ci4];
529-
const int16_t c5 = r[ci5];
530-
const int16_t c6 = r[ci6];
531-
const int16_t c7 = r[ci7];
532554

533555
/* Defer reduction of coefficients 0, 1, 2, and 3 here so they */
534556
/* are bounded to NTT_BOUND2 after Layer6 */
535-
r[ci0] = c0 + c4;
536-
r[ci4] = fqmul((c4 - c0), zeta);
537-
538-
r[ci1] = c1 + c5;
539-
r[ci5] = fqmul((c5 - c1), zeta);
540-
541-
r[ci2] = c2 + c6;
542-
r[ci6] = fqmul((c6 - c2), zeta);
543-
544-
r[ci3] = c3 + c7;
545-
r[ci7] = fqmul((c7 - c3), zeta);
557+
gs_butterfly_defer(r, ci0, ci4, zeta);
558+
gs_butterfly_defer(r, ci1, ci5, zeta);
559+
gs_butterfly_defer(r, ci2, ci6, zeta);
560+
gs_butterfly_defer(r, ci3, ci7, zeta);
546561
}
547562

548563
static void invntt_layer6(int16_t r[MLKEM_N])
@@ -603,36 +618,13 @@ __contract__(
603618
const int ci16 = ci0 + 16;
604619
const int ci24 = ci0 + 24;
605620

606-
/* Layer 5 */
607-
{
608-
const int16_t c0 = r[ci0];
609-
const int16_t c8 = r[ci8];
610-
const int16_t c16 = r[ci16];
611-
const int16_t c24 = r[ci24];
612-
613-
/* Defer reduction of coeffs 0 and 16 here */
614-
r[ci0] = c0 + c8;
615-
r[ci8] = fqmul(c8 - c0, l5zeta2);
616-
617-
r[ci16] = c16 + c24;
618-
r[ci24] = fqmul(c24 - c16, l5zeta1);
619-
}
620-
621-
/* Layer 4 */
622-
{
623-
const int16_t c0 = r[ci0];
624-
const int16_t c8 = r[ci8];
625-
const int16_t c16 = r[ci16];
626-
const int16_t c24 = r[ci24];
627-
628-
/* In layer 4, reduce all coefficients to be in NTT_BOUND1 */
629-
/* to meet the pre-condition of Layer321 */
630-
r[ci0] = barrett_reduce(c0 + c16);
631-
r[ci16] = fqmul(c16 - c0, l4zeta);
632-
633-
r[ci8] = barrett_reduce(c8 + c24);
634-
r[ci24] = fqmul(c24 - c8, l4zeta);
635-
}
621+
/* Layer 5 - Defer reduction of coeffs 0 and 16 here */
622+
gs_butterfly_defer(r, ci0, ci8, l5zeta2);
623+
gs_butterfly_defer(r, ci16, ci24, l5zeta1);
624+
/* Layer 4 - reduce all coefficients to be in NTT_BOUND1 */
625+
/* to meet the pre-condition of Layer321 */
626+
gs_butterfly_reduce(r, ci0, ci16, l4zeta);
627+
gs_butterfly_reduce(r, ci8, ci24, l4zeta);
636628
}
637629
}
638630

@@ -688,76 +680,20 @@ __contract__(
688680
const int ci224 = j + 224;
689681

690682
/* Layer 3 */
691-
{
692-
const int16_t c0 = r[ci0];
693-
const int16_t c32 = r[ci32];
694-
const int16_t c64 = r[ci64];
695-
const int16_t c96 = r[ci96];
696-
const int16_t c128 = r[ci128];
697-
const int16_t c160 = r[ci160];
698-
const int16_t c192 = r[ci192];
699-
const int16_t c224 = r[ci224];
700-
701-
r[ci0] = c0 + c32;
702-
r[ci32] = fqmul(c32 - c0, l3zeta7);
703-
704-
r[ci64] = c64 + c96;
705-
r[ci96] = fqmul(c96 - c64, l3zeta6);
706-
707-
r[ci128] = c128 + c160;
708-
r[ci160] = fqmul(c160 - c128, l3zeta5);
709-
710-
r[ci192] = c192 + c224;
711-
r[ci224] = fqmul(c224 - c192, l3zeta4);
712-
}
713-
683+
gs_butterfly_defer(r, ci0, ci32, l3zeta7);
684+
gs_butterfly_defer(r, ci64, ci96, l3zeta6);
685+
gs_butterfly_defer(r, ci128, ci160, l3zeta5);
686+
gs_butterfly_defer(r, ci192, ci224, l3zeta4);
714687
/* Layer 2 */
715-
{
716-
const int16_t c0 = r[ci0];
717-
const int16_t c32 = r[ci32];
718-
const int16_t c64 = r[ci64];
719-
const int16_t c96 = r[ci96];
720-
const int16_t c128 = r[ci128];
721-
const int16_t c160 = r[ci160];
722-
const int16_t c192 = r[ci192];
723-
const int16_t c224 = r[ci224];
724-
725-
r[ci0] = c0 + c64;
726-
r[ci64] = fqmul(c64 - c0, l2zeta3);
727-
728-
r[ci32] = c32 + c96;
729-
r[ci96] = fqmul(c96 - c32, l2zeta3);
730-
731-
r[ci128] = c128 + c192;
732-
r[ci192] = fqmul(c192 - c128, l2zeta2);
733-
734-
r[ci160] = c160 + c224;
735-
r[ci224] = fqmul(c224 - c160, l2zeta2);
736-
}
737-
688+
gs_butterfly_defer(r, ci0, ci64, l2zeta3);
689+
gs_butterfly_defer(r, ci32, ci96, l2zeta3);
690+
gs_butterfly_defer(r, ci128, ci192, l2zeta2);
691+
gs_butterfly_defer(r, ci160, ci224, l2zeta2);
738692
/* Layer 1 */
739-
{
740-
const int16_t c0 = r[ci0];
741-
const int16_t c32 = r[ci32];
742-
const int16_t c64 = r[ci64];
743-
const int16_t c96 = r[ci96];
744-
const int16_t c128 = r[ci128];
745-
const int16_t c160 = r[ci160];
746-
const int16_t c192 = r[ci192];
747-
const int16_t c224 = r[ci224];
748-
749-
r[ci0] = c0 + c128;
750-
r[ci128] = fqmul(c128 - c0, l1zeta1);
751-
752-
r[ci32] = c32 + c160;
753-
r[ci160] = fqmul(c160 - c32, l1zeta1);
754-
755-
r[ci64] = c64 + c192;
756-
r[ci192] = fqmul(c192 - c64, l1zeta1);
757-
758-
r[ci96] = c96 + c224;
759-
r[ci224] = fqmul(c224 - c96, l1zeta1);
760-
}
693+
gs_butterfly_defer(r, ci0, ci128, l1zeta1);
694+
gs_butterfly_defer(r, ci32, ci160, l1zeta1);
695+
gs_butterfly_defer(r, ci64, ci192, l1zeta1);
696+
gs_butterfly_defer(r, ci96, ci224, l1zeta1);
761697
}
762698
}
763699

0 commit comments

Comments
 (0)