Skip to content

Commit c1fa0b0

Browse files
committed
Add first attempt at generalized 2- and 3-layer merged forward
NTT functions. CBMC proofs of these new functions are all TBD. For now, we call these in a "3,2,1,1" pattern to make sure there are no unreferenced functions. Signed-off-by: Rod Chapman <[email protected]> Correct one call from fqmul() to mlk_fqmul() Signed-off-by: Rod Chapman <[email protected]>
1 parent e319849 commit c1fa0b0

File tree

1 file changed

+112
-88
lines changed

1 file changed

+112
-88
lines changed

mlkem/poly.c

Lines changed: 112 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -279,106 +279,134 @@ void mlk_poly_mulcache_compute(mlk_poly_mulcache *x, const mlk_poly *a)
279279
#endif /* MLK_USE_NATIVE_POLY_MULCACHE_COMPUTE */
280280

281281
#if !defined(MLK_USE_NATIVE_NTT)
282-
/*
283-
* Computes a block CT butterflies with a fixed twiddle factor,
284-
* using Montgomery multiplication.
285-
* Parameters:
286-
* - r: Pointer to base of polynomial (_not_ the base of butterfly block)
287-
* - root: Twiddle factor to use for the butterfly. This must be in
288-
* Montgomery form and signed canonical.
289-
* - start: Offset to the beginning of the butterfly block
290-
* - len: Index difference between coefficients subject to a butterfly
291-
* - bound: Ghost variable describing coefficient bound: Prior to `start`,
292-
* coefficients must be bound by `bound + MLKEM_Q`. Post `start`,
293-
* they must be bound by `bound`.
294-
* When this function returns, output coefficients in the index range
295-
* [start, start+2*len) have bound bumped to `bound + MLKEM_Q`.
296-
* Example:
297-
* - start=8, len=4
298-
* This would compute the following four butterflies
299-
* 8 -- 12
300-
* 9 -- 13
301-
* 10 -- 14
302-
* 11 -- 15
303-
* - start=4, len=2
304-
* This would compute the following two butterflies
305-
* 4 -- 6
306-
* 5 -- 7
307-
*/
308282

309283
/* Reference: Embedded in `ntt()` in the reference implementation. */
310-
static void mlk_ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta,
311-
unsigned start, unsigned len, int bound)
284+
static MLK_INLINE void mlk_ct_butterfly(int16_t r[MLKEM_N],
285+
const unsigned coeff1_index,
286+
const unsigned coeff2_index,
287+
const int16_t zeta)
288+
{
289+
int16_t t1 = r[coeff1_index];
290+
int16_t t2 = mlk_fqmul(r[coeff2_index], zeta);
291+
r[coeff1_index] = t1 + t2;
292+
r[coeff2_index] = t1 - t2;
293+
}
294+
295+
static void mlk_ntt_1_layer(int16_t r[MLKEM_N], unsigned layer)
312296
__contract__(
313-
requires(start < MLKEM_N)
314-
requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N)
315-
requires(0 <= bound && bound < INT16_MAX - MLKEM_Q)
316-
requires(-MLKEM_Q_HALF < zeta && zeta < MLKEM_Q_HALF)
317297
requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N))
318-
requires(array_abs_bound(r, 0, start, bound + MLKEM_Q))
319-
requires(array_abs_bound(r, start, MLKEM_N, bound))
298+
requires(1 <= layer && layer <= 7)
299+
requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))
320300
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
321-
ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q))
322-
ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound)))
301+
ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q)))
323302
{
324-
/* `bound` is a ghost variable only needed in the CBMC specification */
325-
unsigned j;
326-
((void)bound);
327-
for (j = start; j < start + len; j++)
303+
const unsigned len = MLKEM_N >> layer;
304+
unsigned start, k;
305+
/* Twiddle factors for layer n start at index 2 ** (layer-1) */
306+
k = 1 << (layer - 1);
307+
for (start = 0; start < MLKEM_N; start += 2 * len)
328308
__loop__(
329-
invariant(start <= j && j <= start + len)
330-
/*
331-
* Coefficients are updated in strided pairs, so the bounds for the
332-
* intermediate states alternate twice between the old and new bound
333-
*/
334-
invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q))
335-
invariant(array_abs_bound(r, j, start + len, bound))
336-
invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q))
337-
invariant(array_abs_bound(r, j + len, MLKEM_N, bound)))
309+
invariant(start < MLKEM_N + 2 * len)
310+
invariant(k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N)
311+
invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q))
312+
invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q)))
338313
{
339-
int16_t t;
340-
t = mlk_fqmul(r[j + len], zeta);
341-
r[j + len] = r[j] - t;
342-
r[j] = r[j] + t;
314+
const int16_t zeta = zetas[k++];
315+
unsigned j;
316+
for (j = 0; j < len; j++)
317+
{
318+
mlk_ct_butterfly(r, j + start, j + start + len, zeta);
319+
}
343320
}
344321
}
345322

346-
/*
347-
* Compute one layer of forward NTT
348-
* Parameters:
349-
* - r: Pointer to base of polynomial
350-
* - len: Stride of butterflies in this layer.
351-
* - layer: Ghost variable indicating which layer is being applied.
352-
* Must match `len` via `len == MLKEM_N >> layer`.
353-
* Note: `len` could be dropped and computed in the function, but
354-
* we are following the structure of the reference NTT from the
355-
* official Kyber implementation here, merely adding `layer` as
356-
* a ghost variable for the specifications.
357-
*/
323+
static void mlk_ntt_2_layers(int16_t r[MLKEM_N], unsigned layer)
324+
__contract__(
325+
requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N))
326+
requires(1 <= layer && layer <= 6)
327+
requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))
328+
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
329+
ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 2) * MLKEM_Q)))
330+
{
331+
const unsigned len = MLKEM_N >> layer;
332+
unsigned start, k;
333+
/* Twiddle factors for layer n start at index 2 ** (layer-1) */
334+
k = 1 << (layer - 1);
335+
for (start = 0; start < MLKEM_N; start += 2 * len)
336+
{
337+
unsigned j;
338+
const int16_t this_layer_zeta = zetas[k];
339+
const int16_t next_layer_zeta1 = zetas[k * 2];
340+
const int16_t next_layer_zeta2 = zetas[k * 2 + 1];
341+
k++;
358342

359-
/* Reference: Embedded in `ntt()` in the reference implementation. */
360-
static void mlk_ntt_layer(int16_t r[MLKEM_N], unsigned len, unsigned layer)
343+
for (j = 0; j < len / 2; j++)
344+
{
345+
const unsigned ci0 = j + start;
346+
const unsigned ci1 = ci0 + len / 2;
347+
const unsigned ci2 = ci1 + len / 2;
348+
const unsigned ci3 = ci2 + len / 2;
349+
350+
mlk_ct_butterfly(r, ci0, ci2, this_layer_zeta);
351+
mlk_ct_butterfly(r, ci1, ci3, this_layer_zeta);
352+
mlk_ct_butterfly(r, ci0, ci1, next_layer_zeta1);
353+
mlk_ct_butterfly(r, ci2, ci3, next_layer_zeta2);
354+
}
355+
}
356+
}
357+
358+
static void mlk_ntt_3_layers(int16_t r[MLKEM_N], unsigned layer)
361359
__contract__(
362360
requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N))
363-
requires(1 <= layer && layer <= 7 && len == (MLKEM_N >> layer))
361+
requires(1 <= layer && layer <= 5)
364362
requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))
365363
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
366-
ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q)))
364+
ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 3) * MLKEM_Q)))
367365
{
366+
const unsigned len = MLKEM_N >> layer;
368367
unsigned start, k;
369-
/* `layer` is a ghost variable only needed in the CBMC specification */
370-
((void)layer);
371-
/* Twiddle factors for layer n start at index 2^(layer-1) */
372-
k = MLKEM_N / (2 * len);
368+
/* Twiddle factors for layer n start at index 2 ** (layer-1) */
369+
k = 1 << (layer - 1);
373370
for (start = 0; start < MLKEM_N; start += 2 * len)
374-
__loop__(
375-
invariant(start < MLKEM_N + 2 * len)
376-
invariant(k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N)
377-
invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q))
378-
invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q)))
379371
{
380-
int16_t zeta = zetas[k++];
381-
mlk_ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q);
372+
unsigned j;
373+
const int16_t first_layer_zeta = zetas[k];
374+
const unsigned second_layer_zi1 = k * 2;
375+
const unsigned second_layer_zi2 = k * 2 + 1;
376+
const int16_t second_layer_zeta1 = zetas[second_layer_zi1];
377+
const int16_t second_layer_zeta2 = zetas[second_layer_zi2];
378+
const int16_t third_layer_zeta1 = zetas[second_layer_zi1 * 2];
379+
const int16_t third_layer_zeta2 = zetas[second_layer_zi1 * 2 + 1];
380+
const int16_t third_layer_zeta3 = zetas[second_layer_zi2 * 2];
381+
const int16_t third_layer_zeta4 = zetas[second_layer_zi2 * 2 + 1];
382+
k++;
383+
384+
for (j = 0; j < len / 4; j++)
385+
{
386+
const unsigned ci0 = j + start;
387+
const unsigned ci1 = ci0 + len / 4;
388+
const unsigned ci2 = ci1 + len / 4;
389+
const unsigned ci3 = ci2 + len / 4;
390+
const unsigned ci4 = ci3 + len / 4;
391+
const unsigned ci5 = ci4 + len / 4;
392+
const unsigned ci6 = ci5 + len / 4;
393+
const unsigned ci7 = ci6 + len / 4;
394+
395+
mlk_ct_butterfly(r, ci0, ci4, first_layer_zeta);
396+
mlk_ct_butterfly(r, ci1, ci5, first_layer_zeta);
397+
mlk_ct_butterfly(r, ci2, ci6, first_layer_zeta);
398+
mlk_ct_butterfly(r, ci3, ci7, first_layer_zeta);
399+
400+
mlk_ct_butterfly(r, ci0, ci2, second_layer_zeta1);
401+
mlk_ct_butterfly(r, ci1, ci3, second_layer_zeta1);
402+
mlk_ct_butterfly(r, ci4, ci6, second_layer_zeta2);
403+
mlk_ct_butterfly(r, ci5, ci7, second_layer_zeta2);
404+
405+
mlk_ct_butterfly(r, ci0, ci1, third_layer_zeta1);
406+
mlk_ct_butterfly(r, ci2, ci3, third_layer_zeta2);
407+
mlk_ct_butterfly(r, ci4, ci5, third_layer_zeta3);
408+
mlk_ct_butterfly(r, ci6, ci7, third_layer_zeta4);
409+
}
382410
}
383411
}
384412

@@ -395,18 +423,14 @@ __contract__(
395423
MLK_INTERNAL_API
396424
void mlk_poly_ntt(mlk_poly *p)
397425
{
398-
unsigned len, layer;
399426
int16_t *r;
400427
mlk_assert_abs_bound(p, MLKEM_N, MLKEM_Q);
401428
r = p->coeffs;
402429

403-
for (len = 128, layer = 1; len >= 2; len >>= 1, layer++)
404-
__loop__(
405-
invariant(1 <= layer && layer <= 8 && len == (MLKEM_N >> layer))
406-
invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)))
407-
{
408-
mlk_ntt_layer(r, len, layer);
409-
}
430+
mlk_ntt_3_layers(r, 1);
431+
mlk_ntt_2_layers(r, 4);
432+
mlk_ntt_1_layer(r, 6);
433+
mlk_ntt_1_layer(r, 7);
410434

411435
/* Check the stronger bound */
412436
mlk_assert_abs_bound(p, MLKEM_N, MLK_NTT_BOUND);

0 commit comments

Comments
 (0)