Skip to content

Commit 54308a0

Browse files
committed
Add first attempt at generalized 2- and 3-layer merged forward
NTT functions. CBMC proofs of these new functions are all TBD. For now, we call these in a "3,2,1,1" pattern to make sure there are no unreferenced functions. Signed-off-by: Rod Chapman <[email protected]>
1 parent 366156b commit 54308a0

File tree

1 file changed

+120
-91
lines changed

1 file changed

+120
-91
lines changed

mlkem/poly.c

Lines changed: 120 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
#define fqmul MLK_NAMESPACE(fqmul)
2323
#define barrett_reduce MLK_NAMESPACE(barrett_reduce)
2424
#define scalar_signed_to_unsigned_q MLK_NAMESPACE(scalar_signed_to_unsigned_q)
25-
#define ntt_butterfly_block MLK_NAMESPACE(ntt_butterfly_block)
26-
#define ntt_layer MLK_NAMESPACE(ntt_layer)
25+
#define ntt_1_layer MLK_NAMESPACE(ntt_1_layer)
26+
#define ntt_2_layers MLK_NAMESPACE(ntt_2_layers)
27+
#define ntt_3_layers MLK_NAMESPACE(ntt_3_layers)
2728
#define invntt_layer MLK_NAMESPACE(invntt_layer)
2829
/* End of static namespacing */
2930

@@ -261,102 +262,133 @@ void poly_mulcache_compute(poly_mulcache *x, const poly *a)
261262
#endif /* MLK_USE_NATIVE_POLY_MULCACHE_COMPUTE */
262263

263264
#if !defined(MLK_USE_NATIVE_NTT)
264-
/*
265-
* Computes a block CT butterflies with a fixed twiddle factor,
266-
* using Montgomery multiplication.
267-
* Parameters:
268-
* - r: Pointer to base of polynomial (_not_ the base of butterfly block)
269-
* - root: Twiddle factor to use for the butterfly. This must be in
270-
* Montgomery form and signed canonical.
271-
* - start: Offset to the beginning of the butterfly block
272-
* - len: Index difference between coefficients subject to a butterfly
273-
* - bound: Ghost variable describing coefficient bound: Prior to `start`,
274-
* coefficients must be bound by `bound + MLKEM_Q`. Post `start`,
275-
* they must be bound by `bound`.
276-
* When this function returns, output coefficients in the index range
277-
* [start, start+2*len) have bound bumped to `bound + MLKEM_Q`.
278-
* Example:
279-
* - start=8, len=4
280-
* This would compute the following four butterflies
281-
* 8 -- 12
282-
* 9 -- 13
283-
* 10 -- 14
284-
* 11 -- 15
285-
* - start=4, len=2
286-
* This would compute the following two butterflies
287-
* 4 -- 6
288-
* 5 -- 7
289-
*/
290-
static void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta,
291-
unsigned start, unsigned len, int bound)
292-
__contract__(
293-
requires(start < MLKEM_N)
294-
requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N)
295-
requires(0 <= bound && bound < INT16_MAX - MLKEM_Q)
296-
requires(-MLKEM_Q_HALF < zeta && zeta < MLKEM_Q_HALF)
297-
requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N))
298-
requires(array_abs_bound(r, 0, start, bound + MLKEM_Q))
299-
requires(array_abs_bound(r, start, MLKEM_N, bound))
300-
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
301-
ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q))
302-
ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound)))
265+
266+
static MLK_INLINE void ct_butterfly(int16_t r[MLKEM_N],
267+
const unsigned coeff1_index,
268+
const unsigned coeff2_index,
269+
const int16_t zeta)
303270
{
304-
/* `bound` is a ghost variable only needed in the CBMC specification */
305-
unsigned j;
306-
((void)bound);
307-
for (j = start; j < start + len; j++)
308-
__loop__(
309-
invariant(start <= j && j <= start + len)
310-
/*
311-
* Coefficients are updated in strided pairs, so the bounds for the
312-
* intermediate states alternate twice between the old and new bound
313-
*/
314-
invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q))
315-
invariant(array_abs_bound(r, j, start + len, bound))
316-
invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q))
317-
invariant(array_abs_bound(r, j + len, MLKEM_N, bound)))
318-
{
319-
int16_t t;
320-
t = fqmul(r[j + len], zeta);
321-
r[j + len] = r[j] - t;
322-
r[j] = r[j] + t;
323-
}
271+
int16_t t1 = r[coeff1_index];
272+
int16_t t2 = fqmul(r[coeff2_index], zeta);
273+
r[coeff1_index] = t1 + t2;
274+
r[coeff2_index] = t1 - t2;
324275
}
325276

326-
/*
327-
*Compute one layer of forward NTT
328-
* Parameters:
329-
* - r: Pointer to base of polynomial
330-
* - len: Stride of butterflies in this layer.
331-
* - layer: Ghost variable indicating which layer is being applied.
332-
* Must match `len` via `len == MLKEM_N >> layer`.
333-
* Note: `len` could be dropped and computed in the function, but
334-
* we are following the structure of the reference NTT from the
335-
* official Kyber implementation here, merely adding `layer` as
336-
* a ghost variable for the specifications.
337-
*/
338-
static void ntt_layer(int16_t r[MLKEM_N], unsigned len, unsigned layer)
277+
static void ntt_1_layer(int16_t r[MLKEM_N], unsigned layer)
339278
__contract__(
340279
requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N))
341-
requires(1 <= layer && layer <= 7 && len == (MLKEM_N >> layer))
280+
requires(1 <= layer && layer <= 7)
342281
requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))
343282
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
344283
ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q)))
345284
{
285+
const unsigned len = MLKEM_N >> layer;
346286
unsigned start, k;
347-
/* `layer` is a ghost variable only needed in the CBMC specification */
348-
((void)layer);
349-
/* Twiddle factors for layer n start at index 2^(layer-1) */
350-
k = MLKEM_N / (2 * len);
287+
/* Twiddle factors for layer n start at index 2 ** (layer-1) */
288+
k = 1 << (layer - 1);
351289
for (start = 0; start < MLKEM_N; start += 2 * len)
352290
__loop__(
353291
invariant(start < MLKEM_N + 2 * len)
354292
invariant(k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N)
355293
invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q))
356294
invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q)))
357295
{
358-
int16_t zeta = zetas[k++];
359-
ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q);
296+
const int16_t zeta = zetas[k++];
297+
unsigned j;
298+
for (j = 0; j < len; j++)
299+
{
300+
ct_butterfly(r, j + start, j + start + len, zeta);
301+
}
302+
}
303+
}
304+
305+
static void ntt_2_layers(int16_t r[MLKEM_N], unsigned layer)
306+
__contract__(
307+
requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N))
308+
requires(1 <= layer && layer <= 6)
309+
requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))
310+
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
311+
ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 2) * MLKEM_Q)))
312+
{
313+
const unsigned len = MLKEM_N >> layer;
314+
unsigned start, k;
315+
/* Twiddle factors for layer n start at index 2 ** (layer-1) */
316+
k = 1 << (layer - 1);
317+
for (start = 0; start < MLKEM_N; start += 2 * len)
318+
{
319+
unsigned j;
320+
const int16_t this_layer_zeta = zetas[k];
321+
const int16_t next_layer_zeta1 = zetas[k * 2];
322+
const int16_t next_layer_zeta2 = zetas[k * 2 + 1];
323+
k++;
324+
325+
for (j = 0; j < len / 2; j++)
326+
{
327+
const unsigned ci0 = j + start;
328+
const unsigned ci1 = ci0 + len / 2;
329+
const unsigned ci2 = ci1 + len / 2;
330+
const unsigned ci3 = ci2 + len / 2;
331+
332+
ct_butterfly(r, ci0, ci2, this_layer_zeta);
333+
ct_butterfly(r, ci1, ci3, this_layer_zeta);
334+
ct_butterfly(r, ci0, ci1, next_layer_zeta1);
335+
ct_butterfly(r, ci2, ci3, next_layer_zeta2);
336+
}
337+
}
338+
}
339+
340+
static void ntt_3_layers(int16_t r[MLKEM_N], unsigned layer)
341+
__contract__(
342+
requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N))
343+
requires(1 <= layer && layer <= 5)
344+
requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))
345+
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
346+
ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 3) * MLKEM_Q)))
347+
{
348+
const unsigned len = MLKEM_N >> layer;
349+
unsigned start, k;
350+
/* Twiddle factors for layer n start at index 2 ** (layer-1) */
351+
k = 1 << (layer - 1);
352+
for (start = 0; start < MLKEM_N; start += 2 * len)
353+
{
354+
unsigned j;
355+
const int16_t first_layer_zeta = zetas[k];
356+
const unsigned second_layer_zi1 = k * 2;
357+
const unsigned second_layer_zi2 = k * 2 + 1;
358+
const int16_t second_layer_zeta1 = zetas[second_layer_zi1];
359+
const int16_t second_layer_zeta2 = zetas[second_layer_zi2];
360+
const int16_t third_layer_zeta1 = zetas[second_layer_zi1 * 2];
361+
const int16_t third_layer_zeta2 = zetas[second_layer_zi1 * 2 + 1];
362+
const int16_t third_layer_zeta3 = zetas[second_layer_zi2 * 2];
363+
const int16_t third_layer_zeta4 = zetas[second_layer_zi2 * 2 + 1];
364+
k++;
365+
366+
for (j = 0; j < len / 4; j++)
367+
{
368+
const unsigned ci0 = j + start;
369+
const unsigned ci1 = ci0 + len / 4;
370+
const unsigned ci2 = ci1 + len / 4;
371+
const unsigned ci3 = ci2 + len / 4;
372+
const unsigned ci4 = ci3 + len / 4;
373+
const unsigned ci5 = ci4 + len / 4;
374+
const unsigned ci6 = ci5 + len / 4;
375+
const unsigned ci7 = ci6 + len / 4;
376+
377+
ct_butterfly(r, ci0, ci4, first_layer_zeta);
378+
ct_butterfly(r, ci1, ci5, first_layer_zeta);
379+
ct_butterfly(r, ci2, ci6, first_layer_zeta);
380+
ct_butterfly(r, ci3, ci7, first_layer_zeta);
381+
382+
ct_butterfly(r, ci0, ci2, second_layer_zeta1);
383+
ct_butterfly(r, ci1, ci3, second_layer_zeta1);
384+
ct_butterfly(r, ci4, ci6, second_layer_zeta2);
385+
ct_butterfly(r, ci5, ci7, second_layer_zeta2);
386+
387+
ct_butterfly(r, ci0, ci1, third_layer_zeta1);
388+
ct_butterfly(r, ci2, ci3, third_layer_zeta2);
389+
ct_butterfly(r, ci4, ci5, third_layer_zeta3);
390+
ct_butterfly(r, ci6, ci7, third_layer_zeta4);
391+
}
360392
}
361393
}
362394

@@ -372,18 +404,14 @@ __contract__(
372404
MLK_INTERNAL_API
373405
void poly_ntt(poly *p)
374406
{
375-
unsigned len, layer;
376407
int16_t *r;
377408
debug_assert_abs_bound(p, MLKEM_N, MLKEM_Q);
378409
r = p->coeffs;
379410

380-
for (len = 128, layer = 1; len >= 2; len >>= 1, layer++)
381-
__loop__(
382-
invariant(1 <= layer && layer <= 8 && len == (MLKEM_N >> layer))
383-
invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)))
384-
{
385-
ntt_layer(r, len, layer);
386-
}
411+
ntt_3_layers(r, 1);
412+
ntt_2_layers(r, 4);
413+
ntt_1_layer(r, 6);
414+
ntt_1_layer(r, 7);
387415

388416
/* Check the stronger bound */
389417
debug_assert_abs_bound(p, MLKEM_N, MLK_NTT_BOUND);
@@ -490,6 +518,7 @@ MLK_EMPTY_CU(poly)
490518
#undef fqmul
491519
#undef barrett_reduce
492520
#undef scalar_signed_to_unsigned_q
493-
#undef ntt_butterfly_block
494-
#undef ntt_layer
521+
#undef ntt_1_layer
522+
#undef ntt_2_layers
523+
#undef ntt_3_layers
495524
#undef invntt_layer

0 commit comments

Comments
 (0)