Skip to content

Commit 677e362

Browse files
committed
Initial exploration around jit
1 parent 5b18a0e commit 677e362

File tree

2 files changed

+203
-0
lines changed

2 files changed

+203
-0
lines changed

cairo/jit/main.cairo

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
from starkware.cairo.common.alloc import alloc
2+
from starkware.cairo.common.registers import get_label_location
3+
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc
4+
5+
// Compiled Instructions
6+
const CALL_ABS = 0x1084800180018000;
7+
8+
// Modified return instruction to update the offset of return_pc.
9+
// Putting an offset >= 0 let's dynamically chose where to return.
10+
// Offset = -1 is the default ret opcode.
11+
const RET_FP_OFFSET_HIGH = 0x208b8000;
12+
const RET_FP_OFFSET_LOW = 0x7fff7ffe;
13+
const RET = (RET_FP_OFFSET_HIGH - 1) * 256 ** 4 + RET_FP_OFFSET_LOW;
14+
const RET_0 = (RET_FP_OFFSET_HIGH) * 256 ** 4 + RET_FP_OFFSET_LOW;
15+
16+
// Bytecode Opcode
17+
const OP_JUMP4 = 0;
18+
const OP_RET = 1;
19+
const OP_PUSH = 2;
20+
const OP_PC = 3;
21+
const OP_CALL = 4;
22+
const OP_ADD = 5;
23+
const OP_MUL = 6;
24+
25+
func compile(input: felt, code_len: felt, code_ptr: felt*) -> (
26+
compiled_code_len: felt, compiled_code_ptr: felt*
27+
) {
28+
alloc_locals;
29+
let (local op: felt*) = get_label_location(opcodes_location);
30+
let (local compiled_code) = alloc();
31+
if (code_len == 0) {
32+
return (0, compiled_code);
33+
}
34+
35+
tempvar i = 0;
36+
tempvar compiled_code = compiled_code;
37+
38+
loop:
39+
let i = [ap - 2];
40+
let compiled_code = cast([ap - 1], felt*);
41+
42+
let code_len = [fp - 4];
43+
let code = cast([fp - 3], felt*);
44+
let op = cast([fp], felt*);
45+
46+
tempvar opcode_number = code[i];
47+
assert [compiled_code] = CALL_ABS;
48+
assert [compiled_code + 1] = cast(op + 2 * opcode_number + op[2 * opcode_number + 1], felt);
49+
50+
tempvar is_push = opcode_number - OP_PUSH;
51+
jmp not_push if is_push != 0;
52+
53+
push:
54+
assert [compiled_code + 3] = code[i + 1];
55+
tempvar stop = code_len - i - 2;
56+
tempvar i = i + 2;
57+
tempvar compiled_code = compiled_code + 4;
58+
jmp loop if stop != 0;
59+
jmp end;
60+
61+
not_push:
62+
tempvar stop = code_len - i - 1;
63+
tempvar i = i + 1;
64+
tempvar compiled_code = compiled_code + 2;
65+
66+
static_assert i == [ap - 2];
67+
static_assert compiled_code == [ap - 1];
68+
jmp loop if stop != 0;
69+
jmp end;
70+
71+
end:
72+
let i = [ap - 2];
73+
let compiled_code = cast([ap - 1], felt*);
74+
assert [compiled_code] = RET;
75+
76+
let compiled_code = cast([fp + 1], felt*);
77+
78+
return (i, compiled_code);
79+
}
80+
81+
func main() {
82+
alloc_locals;
83+
84+
let (bytecode_start) = alloc();
85+
let bytecode = bytecode_start;
86+
assert [bytecode] = OP_JUMP4;
87+
let bytecode = bytecode + 1;
88+
assert [bytecode] = OP_RET;
89+
let bytecode = bytecode + 1;
90+
assert [bytecode] = OP_RET;
91+
let bytecode = bytecode + 1;
92+
assert [bytecode] = OP_RET;
93+
let bytecode = bytecode + 1;
94+
assert [bytecode] = OP_RET;
95+
let bytecode = bytecode + 1;
96+
assert [bytecode] = OP_PUSH;
97+
let bytecode = bytecode + 1;
98+
assert [bytecode] = OP_PC;
99+
let bytecode = bytecode + 1;
100+
assert [bytecode] = OP_PC;
101+
let bytecode = bytecode + 1;
102+
assert [bytecode] = OP_CALL;
103+
let bytecode = bytecode + 1;
104+
assert [bytecode] = OP_ADD;
105+
let bytecode = bytecode + 1;
106+
assert [bytecode] = OP_MUL;
107+
let bytecode = bytecode + 1;
108+
assert [bytecode] = OP_RET;
109+
110+
tempvar input = 0xdead;
111+
let (compiled_code_len, compiled_code_ptr) = compile(
112+
input, bytecode - bytecode_start, bytecode_start
113+
);
114+
115+
call abs compiled_code_ptr;
116+
let result = [ap - 1];
117+
assert result = 2;
118+
119+
return ();
120+
}
121+
122+
func op_jump(input: felt) -> felt {
123+
alloc_locals;
124+
local return_pc;
125+
tempvar jump_size = 4;
126+
assert return_pc = [fp - 1] + 2 * jump_size;
127+
128+
tempvar result = 0;
129+
130+
dw RET_0;
131+
}
132+
133+
func op_ret(input: felt) -> felt {
134+
alloc_locals;
135+
local main_return_pc;
136+
137+
let return_fp = [fp - 2];
138+
main_return_pc = [return_fp - 1];
139+
140+
tempvar result = input;
141+
dw RET_0;
142+
}
143+
144+
func op_push(input: felt) -> felt {
145+
alloc_locals;
146+
// [fp - 1] is next CALL_ABS instruction
147+
// [[fp - 1] + 1] is the word to push
148+
// [fp - 1] + 2 where to move the PC after the push
149+
local return_pc = [fp - 1] + 2;
150+
tempvar word = [[fp - 1] + 1];
151+
152+
dw RET_0;
153+
}
154+
155+
func op_pc(input: felt) -> felt {
156+
let return_pc = [fp - 1];
157+
let calling_pc = return_pc - 2;
158+
let return_fp = [fp - 2];
159+
let main_return_fp = [return_fp - 2];
160+
let compiled_code_ptr = [return_fp - 3];
161+
162+
return (calling_pc - compiled_code_ptr) / 2;
163+
}
164+
165+
func op_call(input: felt) -> felt {
166+
let (bytecode_start) = alloc();
167+
let bytecode = bytecode_start;
168+
assert [bytecode] = OP_PC;
169+
let bytecode = bytecode + 1;
170+
assert [bytecode] = OP_RET;
171+
172+
let (compiled_code_len, compiled_code_ptr) = compile(
173+
input, bytecode - bytecode_start, bytecode_start
174+
);
175+
176+
call abs compiled_code_ptr;
177+
ret;
178+
}
179+
180+
func op_add(input: felt) -> felt {
181+
let result = input + 1;
182+
return result;
183+
}
184+
185+
func op_mul(input: felt) -> felt {
186+
let result = input * 2;
187+
return result;
188+
}
189+
190+
// Create a label and a list of call rel op to be able to get all the opcodes locations
191+
// with a single call to get_label_location.
192+
opcodes_location:
193+
call op_jump;
194+
call op_ret;
195+
call op_push;
196+
call op_pc;
197+
call op_call;
198+
call op_add;
199+
call op_mul;

cairo/tests/jit/test_main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
class TestMain:
2+
3+
def test_main(self, cairo_run):
4+
cairo_run("main")

0 commit comments

Comments
 (0)