Skip to content

Commit 08a69ea

Browse files
committed
Run Neon NTT through SLOTHY and add Makefile
This adds a Makefile that runs the Neon NTT through SLOTHY. To accomodate this the clean assembly is moved to dev/aarch64_clean/, while the mldsa/native/aarch64 contains the optimized assembly. The main difference to mlkem-native is that we need set an explicit timeout as optimizing the second loop doesn't result reasonable performance, but a good solution is found within one minute on my Apple M4. I set the timeout to 2 minutes with the hope that it works on most platforms. We have have to increase that later. For now the clean backend is not tested in CI - that's left for a follow-up PR. SLOTHY is also not run in CI, yet. We probably want to put the assembly simplification scripts in place so we can follow the same structure as in mlkem-native. Signed-off-by: Matthias J. Kannwischer <[email protected]>
1 parent 5e28164 commit 08a69ea

File tree

3 files changed

+1498
-93
lines changed

3 files changed

+1498
-93
lines changed

dev/aarch64_clean/src/ntt.S

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
/* Copyright (c) 2022 Arm Limited
2+
* Copyright (c) 2022 Hanno Becker
3+
* Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer
4+
* Copyright (c) The mldsa-native project authors
5+
* SPDX-License-Identifier: MIT
6+
*
7+
* Permission is hereby granted, free of charge, to any person obtaining a copy
8+
* of this software and associated documentation files (the "Software"), to deal
9+
* in the Software without restriction, including without limitation the rights
10+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
* copies of the Software, and to permit persons to whom the Software is
12+
* furnished to do so, subject to the following conditions:
13+
*
14+
* The above copyright notice and this permission notice shall be included in all
15+
* copies or substantial portions of the Software.
16+
*
17+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23+
* SOFTWARE.
24+
*/
25+
26+
#include "../../../common.h"
27+
#if defined(MLD_ARITH_BACKEND_AARCH64)
28+
29+
.macro mulmodq dst, src, const, idx0, idx1
30+
sqrdmulh t2.4s, \src\().4s, \const\().s[\idx1\()]
31+
mul \dst\().4s, \src\().4s, \const\().s[\idx0\()]
32+
mls \dst\().4s, t2.4s, consts.s[0]
33+
.endm
34+
35+
.macro mulmod dst, src, const, const_twisted
36+
sqrdmulh t2.4s, \src\().4s, \const_twisted\().4s
37+
mul \dst\().4s, \src\().4s, \const\().4s
38+
mls \dst\().4s, t2.4s, consts.s[0]
39+
.endm
40+
41+
.macro ct_butterfly a, b, root, idx0, idx1
42+
mulmodq tmp, \b, \root, \idx0, \idx1
43+
sub \b\().4s, \a\().4s, tmp.4s
44+
add \a\().4s, \a\().4s, tmp.4s
45+
.endm
46+
47+
.macro ct_butterfly_v a, b, root, root_twisted
48+
mulmod tmp, \b, \root, \root_twisted
49+
sub \b\().4s, \a\().4s, tmp.4s
50+
add \a\().4s, \a\().4s, tmp.4s
51+
.endm
52+
53+
.macro load_roots_123
54+
ldr q_root0, [r012345_ptr], #64
55+
ldr q_root1, [r012345_ptr, #(-64 + 16)]
56+
ldr q_root2, [r012345_ptr, #(-64 + 32)]
57+
ldr q_root3, [r012345_ptr, #(-64 + 48)]
58+
.endm
59+
60+
.macro load_roots_456
61+
ldr q_root0, [r012345_ptr], #64
62+
ldr q_root1, [r012345_ptr, #(-64 + 16)]
63+
ldr q_root2, [r012345_ptr, #(-64 + 32)]
64+
ldr q_root3, [r012345_ptr, #(-64 + 48)]
65+
.endm
66+
67+
.macro load_roots_78_part1
68+
ldr q_root0, [r67_ptr], #(12*16)
69+
ldr q_root0_tw, [r67_ptr, #(-12*16 + 1*16)]
70+
ldr q_root1, [r67_ptr, #(-12*16 + 2*16)]
71+
ldr q_root1_tw, [r67_ptr, #(-12*16 + 3*16)]
72+
ldr q_root2, [r67_ptr, #(-12*16 + 4*16)]
73+
ldr q_root2_tw, [r67_ptr, #(-12*16 + 5*16)]
74+
.endm
75+
76+
.macro load_roots_78_part2
77+
ldr q_root0, [r67_ptr, #(-12*16 + 6*16)]
78+
ldr q_root0_tw, [r67_ptr, #(-12*16 + 7*16)]
79+
ldr q_root1, [r67_ptr, #(-12*16 + 8*16)]
80+
ldr q_root1_tw, [r67_ptr, #(-12*16 + 9*16)]
81+
ldr q_root2, [r67_ptr, #(-12*16 + 10*16)]
82+
ldr q_root2_tw, [r67_ptr, #(-12*16 + 11*16)]
83+
.endm
84+
85+
.macro transpose4 data0, data1, data2, data3
86+
trn1 t0.4s, \data0\().4s, \data1\().4s
87+
trn2 t1.4s, \data0\().4s, \data1\().4s
88+
trn1 t2.4s, \data2\().4s, \data3\().4s
89+
trn2 t3.4s, \data2\().4s, \data3\().4s
90+
91+
trn2 \data2\().2d, t0.2d, t2.2d
92+
trn2 \data3\().2d, t1.2d, t3.2d
93+
trn1 \data0\().2d, t0.2d, t2.2d
94+
trn1 \data1\().2d, t1.2d, t3.2d
95+
.endm
96+
97+
.macro save_vregs
98+
sub sp, sp, #(16*4)
99+
stp d8, d9, [sp, #16*0]
100+
stp d10, d11, [sp, #16*1]
101+
stp d12, d13, [sp, #16*2]
102+
stp d14, d15, [sp, #16*3]
103+
.endm
104+
105+
.macro restore_vregs
106+
ldp d8, d9, [sp, #16*0]
107+
ldp d10, d11, [sp, #16*1]
108+
ldp d12, d13, [sp, #16*2]
109+
ldp d14, d15, [sp, #16*3]
110+
add sp, sp, #(16*4)
111+
.endm
112+
113+
.macro push_stack
114+
save_vregs
115+
.endm
116+
117+
.macro pop_stack
118+
restore_vregs
119+
.endm
120+
121+
// Inputs
122+
in .req x0 // Input/output buffer
123+
r012345_ptr .req x1 // twiddles for layer 0,1,2,3,4,5
124+
r67_ptr .req x2 // twiddles for layer 6,7
125+
126+
count .req x3
127+
inp .req x4
128+
inpp .req x5
129+
xtmp .req x6
130+
wtmp .req w6
131+
132+
data0 .req v9
133+
data1 .req v10
134+
data2 .req v11
135+
data3 .req v12
136+
data4 .req v13
137+
data5 .req v14
138+
data6 .req v15
139+
data7 .req v16
140+
141+
q_data0 .req q9
142+
q_data1 .req q10
143+
q_data2 .req q11
144+
q_data3 .req q12
145+
q_data4 .req q13
146+
q_data5 .req q14
147+
q_data6 .req q15
148+
q_data7 .req q16
149+
150+
root0 .req v0
151+
root1 .req v1
152+
root2 .req v2
153+
root3 .req v3
154+
155+
q_root0 .req q0
156+
q_root1 .req q1
157+
q_root2 .req q2
158+
q_root3 .req q3
159+
160+
root0_tw .req v4
161+
root1_tw .req v5
162+
root2_tw .req v6
163+
root3_tw .req v7
164+
165+
q_root0_tw .req q4
166+
q_root1_tw .req q5
167+
q_root2_tw .req q6
168+
q_root3_tw .req q7
169+
170+
tmp .req v24
171+
t0 .req v25
172+
t1 .req v26
173+
t2 .req v27
174+
t3 .req v28
175+
consts .req v8
176+
q_consts .req q8
177+
178+
.text
179+
.global MLD_ASM_NAMESPACE(ntt_asm)
180+
.balign 4
181+
MLD_ASM_FN_SYMBOL(ntt_asm)
182+
push_stack
183+
184+
// load q = 8380417
185+
movz wtmp, #57345
186+
movk wtmp, #127, lsl #16
187+
dup consts.4s, wtmp
188+
189+
mov inp, in
190+
mov count, #8
191+
192+
load_roots_123
193+
194+
.p2align 2
195+
layer123_start:
196+
ldr q_data0, [in, #(0*(1024/8))]
197+
ldr q_data1, [in, #(1*(1024/8))]
198+
ldr q_data2, [in, #(2*(1024/8))]
199+
ldr q_data3, [in, #(3*(1024/8))]
200+
ldr q_data4, [in, #(4*(1024/8))]
201+
ldr q_data5, [in, #(5*(1024/8))]
202+
ldr q_data6, [in, #(6*(1024/8))]
203+
ldr q_data7, [in, #(7*(1024/8))]
204+
205+
ct_butterfly data0, data4, root0, 0, 1
206+
ct_butterfly data1, data5, root0, 0, 1
207+
ct_butterfly data2, data6, root0, 0, 1
208+
ct_butterfly data3, data7, root0, 0, 1
209+
210+
ct_butterfly data0, data2, root0, 2, 3
211+
ct_butterfly data1, data3, root0, 2, 3
212+
ct_butterfly data4, data6, root1, 0, 1
213+
ct_butterfly data5, data7, root1, 0, 1
214+
215+
ct_butterfly data0, data1, root1, 2, 3
216+
ct_butterfly data2, data3, root2, 0, 1
217+
ct_butterfly data4, data5, root2, 2, 3
218+
ct_butterfly data6, data7, root3, 0, 1
219+
220+
str q_data0, [in], #16
221+
str q_data1, [in, #(-16 + 1*(1024/8))]
222+
str q_data2, [in, #(-16 + 2*(1024/8))]
223+
str q_data3, [in, #(-16 + 3*(1024/8))]
224+
str q_data4, [in, #(-16 + 4*(1024/8))]
225+
str q_data5, [in, #(-16 + 5*(1024/8))]
226+
str q_data6, [in, #(-16 + 6*(1024/8))]
227+
str q_data7, [in, #(-16 + 7*(1024/8))]
228+
229+
subs count, count, #1
230+
cbnz count, layer123_start
231+
232+
mov in, inp
233+
add inpp, in, #64
234+
mov count, #8
235+
236+
// Use two data pointers and carefully arrange
237+
// increments to facilitate reordering of loads
238+
// and stores by SLOTHY.
239+
//
240+
// TODO: Think of alternatives here -- the start with `in`
241+
// pointing to 64 byte below the actual data, which in theory
242+
// could underflow. It's unclear how the CPU would behave in this case.
243+
sub in, in, #64
244+
sub inpp, inpp, #64
245+
246+
.p2align 2
247+
layer45678_start:
248+
ldr q_data0, [in, #(64 + 16*0)]
249+
ldr q_data1, [in, #(64 + 16*1)]
250+
ldr q_data2, [in, #(64 + 16*2)]
251+
ldr q_data3, [in, #(64 + 16*3)]
252+
ldr q_data4, [inpp, #(64 + 16*0)]
253+
ldr q_data5, [inpp, #(64 + 16*1)]
254+
ldr q_data6, [inpp, #(64 + 16*2)]
255+
ldr q_data7, [inpp, #(64 + 16*3)]
256+
257+
add in, in, #64
258+
add inpp, inpp, #64
259+
260+
load_roots_456
261+
262+
ct_butterfly data0, data4, root0, 0, 1
263+
ct_butterfly data1, data5, root0, 0, 1
264+
ct_butterfly data2, data6, root0, 0, 1
265+
ct_butterfly data3, data7, root0, 0, 1
266+
267+
ct_butterfly data0, data2, root0, 2, 3
268+
ct_butterfly data1, data3, root0, 2, 3
269+
ct_butterfly data4, data6, root1, 0, 1
270+
ct_butterfly data5, data7, root1, 0, 1
271+
272+
ct_butterfly data0, data1, root1, 2, 3
273+
ct_butterfly data2, data3, root2, 0, 1
274+
ct_butterfly data4, data5, root2, 2, 3
275+
ct_butterfly data6, data7, root3, 0, 1
276+
277+
// Transpose using trn
278+
transpose4 data0, data1, data2, data3
279+
transpose4 data4, data5, data6, data7
280+
281+
load_roots_78_part1
282+
283+
ct_butterfly_v data0, data2, root0, root0_tw
284+
ct_butterfly_v data1, data3, root0, root0_tw
285+
ct_butterfly_v data0, data1, root1, root1_tw
286+
ct_butterfly_v data2, data3, root2, root2_tw
287+
288+
load_roots_78_part2
289+
290+
ct_butterfly_v data4, data6, root0, root0_tw
291+
ct_butterfly_v data5, data7, root0, root0_tw
292+
ct_butterfly_v data4, data5, root1, root1_tw
293+
ct_butterfly_v data6, data7, root2, root2_tw
294+
295+
// Transpose as part of st4
296+
st4 {data0.4S, data1.4S, data2.4S, data3.4S}, [in], #64
297+
st4 {data4.4S, data5.4S, data6.4S, data7.4S}, [inpp], #64
298+
299+
subs count, count, #1
300+
cbnz count, layer45678_start
301+
302+
pop_stack
303+
ret
304+
305+
#endif /* MLD_ARITH_BACKEND_AARCH64 */

mldsa/native/aarch64/src/Makefile

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) The mldsa-native project authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
######
5+
# To run, see the README.md file
6+
######
7+
.PHONY: all clean
8+
9+
# ISA to optimize for
10+
TARGET_ISA=Arm_AArch64
11+
12+
# MicroArch target to optimize for
13+
TARGET_MICROARCH=Arm_Cortex_A55
14+
15+
SLOTHY_EXTRA_FLAGS ?=
16+
17+
SLOTHY_FLAGS=-c sw_pipelining.enabled=true \
18+
-c inputs_are_outputs \
19+
-c sw_pipelining.minimize_overlapping=False \
20+
-c sw_pipelining.allow_post \
21+
-c variable_size \
22+
-c constraints.stalls_first_attempt=64 \
23+
-c timeout=120 \
24+
$(SLOTHY_EXTRA_FLAGS)
25+
26+
# For kernels which stash callee-saved v8-v15 but don't stash callee-saved GPRs x19-x30.
27+
# Allow SLOTHY to use all V-registers, but only caller-saved GPRs.
28+
RESERVE_X_ONLY_FLAG=-c reserved_regs="[x18--x30,sp]"
29+
30+
# Used for kernels which don't stash callee-saved registers.
31+
# Restrict SLOTHY to caller-saved registers.
32+
RESERVE_ALL_FLAG=-c reserved_regs="[x18--x30,sp,v8--v15]"
33+
34+
all: ntt.S
35+
36+
# These units explicitly save and restore registers v8-v15, so SLOTHY can freely use
37+
# those registers.
38+
ntt.S: ../../../../dev/aarch64_clean/src/ntt.S
39+
slothy-cli $(TARGET_ISA) $(TARGET_MICROARCH) $< -o $@ -l layer123_start -l layer45678_start $(SLOTHY_FLAGS) $(RESERVE_X_ONLY_FLAG)
40+
41+
clean:
42+
-$(RM) -rf *.S

0 commit comments

Comments
 (0)