Skip to content

Commit e026d45

Browse files
Merge pull request #191 from davidbrochart/kernels_rest_api
Add POST execute cell
2 parents 1aaebb0 + 96d70df commit e026d45

File tree

15 files changed

+777
-3
lines changed

15 files changed

+777
-3
lines changed

.github/workflows/main.yml

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ jobs:
4646
pip install ./plugins/lab
4747
pip install ./plugins/jupyterlab
4848
pip install "jupyter_ydoc >=0.1.16,<0.2.0" # FIXME: remove with next JupyterLab release
49+
pip install "y-py >=0.5.4"
4950
5051
pip install mypy pytest pytest-asyncio requests ipykernel
5152

plugins/kernels/fps_kernels/kernel_driver/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import asyncio
2+
import json
3+
import os
4+
import socket
5+
import tempfile
6+
import uuid
7+
from typing import Dict, Tuple, Union
8+
9+
import zmq
10+
import zmq.asyncio
11+
from zmq.asyncio import Socket
12+
13+
channel_socket_types = {
14+
"hb": zmq.REQ,
15+
"shell": zmq.DEALER,
16+
"iopub": zmq.SUB,
17+
"stdin": zmq.DEALER,
18+
"control": zmq.DEALER,
19+
}
20+
21+
context = zmq.asyncio.Context()
22+
23+
cfg_t = Dict[str, Union[str, int]]
24+
25+
26+
def get_port(ip: str) -> int:
27+
sock = socket.socket()
28+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8)
29+
sock.bind((ip, 0))
30+
port = sock.getsockname()[1]
31+
sock.close()
32+
return port
33+
34+
35+
def write_connection_file(
36+
fname: str = "",
37+
ip: str = "",
38+
transport: str = "tcp",
39+
signature_scheme: str = "hmac-sha256",
40+
kernel_name: str = "",
41+
) -> Tuple[str, cfg_t]:
42+
ip = ip or "127.0.0.1"
43+
44+
if not fname:
45+
fd, fname = tempfile.mkstemp(suffix=".json")
46+
os.close(fd)
47+
f = open(fname, "wt")
48+
49+
channels = ["shell", "iopub", "stdin", "control", "hb"]
50+
51+
cfg: cfg_t = {f"{c}_port": get_port(ip) for c in channels}
52+
53+
cfg["ip"] = ip
54+
cfg["key"] = uuid.uuid4().hex
55+
cfg["transport"] = transport
56+
cfg["signature_scheme"] = signature_scheme
57+
cfg["kernel_name"] = kernel_name
58+
59+
f.write(json.dumps(cfg, indent=2))
60+
f.close()
61+
62+
return fname, cfg
63+
64+
65+
def read_connection_file(fname: str = "") -> cfg_t:
66+
with open(fname, "rt") as f:
67+
cfg: cfg_t = json.load(f)
68+
69+
return cfg
70+
71+
72+
async def launch_kernel(
73+
kernelspec_path: str, connection_file_path: str, capture_output: bool
74+
) -> asyncio.subprocess.Process:
75+
with open(kernelspec_path) as f:
76+
kernelspec = json.load(f)
77+
cmd = [s.format(connection_file=connection_file_path) for s in kernelspec["argv"]]
78+
if capture_output:
79+
p = await asyncio.create_subprocess_exec(
80+
*cmd, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.STDOUT
81+
)
82+
else:
83+
p = await asyncio.create_subprocess_exec(*cmd)
84+
return p
85+
86+
87+
def create_socket(channel: str, cfg: cfg_t) -> Socket:
88+
ip = cfg["ip"]
89+
port = cfg[f"{channel}_port"]
90+
url = f"tcp://{ip}:{port}"
91+
socket_type = channel_socket_types[channel]
92+
sock = context.socket(socket_type)
93+
sock.linger = 1000 # set linger to 1s to prevent hangs at exit
94+
sock.connect(url)
95+
return sock
96+
97+
98+
def connect_channel(channel_name: str, cfg: cfg_t) -> Socket:
99+
sock = create_socket(channel_name, cfg)
100+
if channel_name == "iopub":
101+
sock.setsockopt(zmq.SUBSCRIBE, b"")
102+
return sock
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import asyncio
2+
import os
3+
import time
4+
import uuid
5+
from typing import Any, Dict, List, Optional, Tuple, cast
6+
7+
from zmq.asyncio import Socket
8+
9+
from .connect import cfg_t, connect_channel, launch_kernel, read_connection_file
10+
from .connect import write_connection_file as _write_connection_file
11+
from .kernelspec import find_kernelspec
12+
from .message import create_message, deserialize, serialize
13+
14+
DELIM = b"<IDS|MSG>"
15+
16+
17+
def deadline_to_timeout(deadline: float) -> float:
18+
return max(0, deadline - time.time())
19+
20+
21+
def feed_identities(msg_list: List[bytes]) -> Tuple[List[bytes], List[bytes]]:
22+
idx = msg_list.index(DELIM)
23+
return msg_list[:idx], msg_list[idx + 1 :] # noqa
24+
25+
26+
def send_message(msg: Dict[str, Any], sock: Socket, key: str) -> None:
27+
to_send = serialize(msg, key)
28+
sock.send_multipart(to_send, copy=True)
29+
30+
31+
async def receive_message(sock: Socket, timeout: float = float("inf")) -> Optional[Dict[str, Any]]:
32+
timeout *= 1000 # in ms
33+
ready = await sock.poll(timeout)
34+
if ready:
35+
msg_list = await sock.recv_multipart()
36+
idents, msg_list = feed_identities(msg_list)
37+
return deserialize(msg_list)
38+
return None
39+
40+
41+
class KernelDriver:
42+
def __init__(
43+
self,
44+
kernel_name: str = "",
45+
kernelspec_path: str = "",
46+
connection_file: str = "",
47+
write_connection_file: bool = True,
48+
capture_kernel_output: bool = True,
49+
) -> None:
50+
self.capture_kernel_output = capture_kernel_output
51+
self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name)
52+
if not self.kernelspec_path:
53+
raise RuntimeError("Could not find a kernel, maybe you forgot to install one?")
54+
if write_connection_file:
55+
self.connection_file_path, self.connection_cfg = _write_connection_file(connection_file)
56+
else:
57+
self.connection_file_path = connection_file
58+
self.connection_cfg = read_connection_file(connection_file)
59+
self.key = cast(str, self.connection_cfg["key"])
60+
self.session_id = uuid.uuid4().hex
61+
self.msg_cnt = 0
62+
self.execute_requests: Dict[str, Dict[str, asyncio.Future]] = {}
63+
self.channel_tasks: List[asyncio.Task] = []
64+
65+
async def restart(self, startup_timeout: float = float("inf")) -> None:
66+
for task in self.channel_tasks:
67+
task.cancel()
68+
msg = create_message("shutdown_request", content={"restart": True})
69+
send_message(msg, self.control_channel, self.key)
70+
while True:
71+
msg = cast(Dict[str, Any], await receive_message(self.control_channel))
72+
if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]:
73+
break
74+
await self._wait_for_ready(startup_timeout)
75+
self.channel_tasks = []
76+
self.listen_channels()
77+
78+
async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None:
79+
self.kernel_process = await launch_kernel(
80+
self.kernelspec_path, self.connection_file_path, self.capture_kernel_output
81+
)
82+
if connect:
83+
await self.connect(startup_timeout)
84+
85+
async def connect(self, startup_timeout: float = float("inf")) -> None:
86+
self.connect_channels()
87+
await self._wait_for_ready(startup_timeout)
88+
self.listen_channels()
89+
90+
def connect_channels(self, connection_cfg: cfg_t = None):
91+
connection_cfg = connection_cfg or self.connection_cfg
92+
self.shell_channel = connect_channel("shell", connection_cfg)
93+
self.control_channel = connect_channel("control", connection_cfg)
94+
self.iopub_channel = connect_channel("iopub", connection_cfg)
95+
96+
def listen_channels(self):
97+
self.channel_tasks.append(asyncio.create_task(self.listen_iopub()))
98+
self.channel_tasks.append(asyncio.create_task(self.listen_shell()))
99+
100+
async def stop(self) -> None:
101+
self.kernel_process.kill()
102+
await self.kernel_process.wait()
103+
os.remove(self.connection_file_path)
104+
for task in self.channel_tasks:
105+
task.cancel()
106+
107+
async def listen_iopub(self):
108+
while True:
109+
msg = await receive_message(self.iopub_channel) # type: ignore
110+
msg_id = msg["parent_header"].get("msg_id")
111+
if msg_id in self.execute_requests.keys():
112+
self.execute_requests[msg_id]["iopub_msg"].set_result(msg)
113+
114+
async def listen_shell(self):
115+
while True:
116+
msg = await receive_message(self.shell_channel) # type: ignore
117+
msg_id = msg["parent_header"].get("msg_id")
118+
if msg_id in self.execute_requests.keys():
119+
self.execute_requests[msg_id]["shell_msg"].set_result(msg)
120+
121+
async def execute(
122+
self,
123+
cell: Dict[str, Any],
124+
timeout: float = float("inf"),
125+
msg_id: str = "",
126+
wait_for_executed: bool = True,
127+
) -> None:
128+
if cell["cell_type"] != "code":
129+
return
130+
content = {"code": cell["source"], "silent": False}
131+
msg = create_message(
132+
"execute_request", content, session_id=self.session_id, msg_cnt=self.msg_cnt
133+
)
134+
if msg_id:
135+
msg["header"]["msg_id"] = msg_id
136+
else:
137+
msg_id = msg["header"]["msg_id"]
138+
self.msg_cnt += 1
139+
send_message(msg, self.shell_channel, self.key)
140+
if wait_for_executed:
141+
deadline = time.time() + timeout
142+
self.execute_requests[msg_id] = {
143+
"iopub_msg": asyncio.Future(),
144+
"shell_msg": asyncio.Future(),
145+
}
146+
while True:
147+
try:
148+
await asyncio.wait_for(
149+
self.execute_requests[msg_id]["iopub_msg"],
150+
deadline_to_timeout(deadline),
151+
)
152+
except asyncio.TimeoutError:
153+
error_message = f"Kernel didn't respond in {timeout} seconds"
154+
raise RuntimeError(error_message)
155+
msg = self.execute_requests[msg_id]["iopub_msg"].result()
156+
self._handle_outputs(cell["outputs"], msg)
157+
if (
158+
msg["header"]["msg_type"] == "status"
159+
and msg["content"]["execution_state"] == "idle"
160+
):
161+
break
162+
self.execute_requests[msg_id]["iopub_msg"] = asyncio.Future()
163+
try:
164+
await asyncio.wait_for(
165+
self.execute_requests[msg_id]["shell_msg"],
166+
deadline_to_timeout(deadline),
167+
)
168+
except asyncio.TimeoutError:
169+
error_message = f"Kernel didn't respond in {timeout} seconds"
170+
raise RuntimeError(error_message)
171+
msg = self.execute_requests[msg_id]["shell_msg"].result()
172+
cell["execution_count"] = msg["content"]["execution_count"]
173+
del self.execute_requests[msg_id]
174+
175+
async def _wait_for_ready(self, timeout):
176+
deadline = time.time() + timeout
177+
new_timeout = timeout
178+
while True:
179+
msg = create_message(
180+
"kernel_info_request", session_id=self.session_id, msg_cnt=self.msg_cnt
181+
)
182+
self.msg_cnt += 1
183+
send_message(msg, self.shell_channel, self.key)
184+
msg = await receive_message(self.shell_channel, new_timeout)
185+
if msg is None:
186+
error_message = f"Kernel didn't respond in {timeout} seconds"
187+
raise RuntimeError(error_message)
188+
if msg["msg_type"] == "kernel_info_reply":
189+
msg = await receive_message(self.iopub_channel, 0.2)
190+
if msg is not None:
191+
break
192+
new_timeout = deadline_to_timeout(deadline)
193+
194+
def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
195+
msg_type = msg["header"]["msg_type"]
196+
content = msg["content"]
197+
if msg_type == "stream":
198+
if (not outputs) or (outputs[-1]["name"] != content["name"]):
199+
outputs.append({"name": content["name"], "output_type": msg_type, "text": []})
200+
outputs[-1]["text"].append(content["text"])
201+
elif msg_type in ("display_data", "execute_result"):
202+
outputs.append(
203+
{
204+
"data": {"text/plain": [content["data"].get("text/plain", "")]},
205+
"execution_count": content["execution_count"],
206+
"metadata": {},
207+
"output_type": msg_type,
208+
}
209+
)
210+
elif msg_type == "error":
211+
outputs.append(
212+
{
213+
"ename": content["ename"],
214+
"evalue": content["evalue"],
215+
"output_type": "error",
216+
"traceback": content["traceback"],
217+
}
218+
)
219+
else:
220+
return

0 commit comments

Comments
 (0)