Skip to content

Commit 47cfe36

Browse files
committed
Rework outputs handling
1 parent d63dedc commit 47cfe36

File tree

4 files changed

+54
-53
lines changed

4 files changed

+54
-53
lines changed

plugins/kernels/fps_kernels/kernel_driver/driver.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import os
3-
import sys
43
import time
54
import uuid
65
from typing import Any, Dict, List, Optional, Tuple, cast
@@ -39,19 +38,6 @@ async def receive_message(sock: Socket, timeout: float = float("inf")) -> Option
3938
return None
4039

4140

42-
def _output_hook_default(msg: Dict[str, Any]) -> None:
43-
"""Default hook for redisplaying plain-text output"""
44-
msg_type = msg["header"]["msg_type"]
45-
content = msg["content"]
46-
if msg_type == "stream":
47-
stream = getattr(sys, content["name"])
48-
stream.write(content["text"])
49-
elif msg_type in ("display_data", "execute_result"):
50-
sys.stdout.write(content["data"].get("text/plain", ""))
51-
elif msg_type == "error":
52-
print("\n".join(content["traceback"]), file=sys.stderr)
53-
54-
5541
class KernelDriver:
5642
def __init__(
5743
self,
@@ -136,13 +122,14 @@ async def listen_shell(self):
136122

137123
async def execute(
138124
self,
139-
code: str,
125+
cell: Dict[str, Any],
140126
timeout: float = float("inf"),
141127
msg_id: str = "",
142128
wait_for_executed: bool = True,
143-
output_hook=_output_hook_default,
144129
) -> None:
145-
content = {"code": code, "silent": False}
130+
if cell["cell_type"] != "code":
131+
return
132+
content = {"code": cell["source"], "silent": False}
146133
msg = create_message(
147134
"execute_request", content, session_id=self.session_id, msg_cnt=self.msg_cnt
148135
)
@@ -168,7 +155,7 @@ async def execute(
168155
error_message = f"Kernel didn't respond in {timeout} seconds"
169156
raise RuntimeError(error_message)
170157
msg = self.execute_requests[msg_id]["iopub_msg"]
171-
output_hook(msg)
158+
self._handle_outputs(cell["outputs"], msg)
172159
if (
173160
msg["header"]["msg_type"] == "status"
174161
and msg["content"]["execution_state"] == "idle"
@@ -184,6 +171,7 @@ async def execute(
184171
error_message = f"Kernel didn't respond in {timeout} seconds"
185172
raise RuntimeError(error_message)
186173
msg = self.execute_requests[msg_id]["shell_msg"]
174+
cell["execution_count"] = msg["content"]["execution_count"]
187175
del self.execute_requests[msg_id]
188176

189177
async def _wait_for_ready(self, timeout):
@@ -204,3 +192,31 @@ async def _wait_for_ready(self, timeout):
204192
if msg is not None:
205193
break
206194
new_timeout = deadline_to_timeout(deadline)
195+
196+
def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
197+
msg_type = msg["header"]["msg_type"]
198+
content = msg["content"]
199+
if msg_type == "stream":
200+
if (not outputs) or (outputs[-1]["name"] != content["name"]):
201+
outputs.append({"name": content["name"], "output_type": msg_type, "text": []})
202+
outputs[-1]["text"].append(content["text"])
203+
elif msg_type in ("display_data", "execute_result"):
204+
outputs.append(
205+
{
206+
"data": {"text/plain": [content["data"].get("text/plain", "")]},
207+
"execution_count": content["execution_count"],
208+
"metadata": {},
209+
"output_type": msg_type,
210+
}
211+
)
212+
elif msg_type == "error":
213+
outputs.append(
214+
{
215+
"ename": content["ename"],
216+
"evalue": content["evalue"],
217+
"output_type": "error",
218+
"traceback": content["traceback"],
219+
}
220+
)
221+
else:
222+
return

plugins/kernels/fps_kernels/routes.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
import pathlib
33
import sys
44
import uuid
5-
from functools import partial
65
from http import HTTPStatus
7-
from typing import Any, Dict
86

97
from fastapi import APIRouter, Depends, Response, WebSocket, status
108
from fastapi.responses import FileResponse
@@ -201,41 +199,10 @@ async def execute_cell(
201199
await driver.connect()
202200
driver = kernel["driver"]
203201

204-
await driver.execute(cell["source"], output_hook=partial(output_hook, cell["outputs"]))
202+
await driver.execute(cell)
205203
room.document.source = nb
206204

207205

208-
def output_hook(outputs, msg: Dict[str, Any]):
209-
# msg_id = msg["parent_header"]["msg_id"]
210-
execution_count = 1 # self.msg_id_2_execution_count[msg_id]
211-
msg_type = msg["header"]["msg_type"]
212-
content = msg["content"]
213-
if msg_type == "stream":
214-
if (not outputs) or (outputs[-1]["name"] != content["name"]):
215-
outputs.append({"name": content["name"], "output_type": msg_type, "text": []})
216-
outputs[-1]["text"].append(content["text"])
217-
elif msg_type in ("display_data", "execute_result"):
218-
outputs.append(
219-
{
220-
"data": {"text/plain": [content["data"].get("text/plain", "")]},
221-
"execution_count": execution_count,
222-
"metadata": {},
223-
"output_type": msg_type,
224-
}
225-
)
226-
elif msg_type == "error":
227-
outputs.append(
228-
{
229-
"ename": content["ename"],
230-
"evalue": content["evalue"],
231-
"output_type": "error",
232-
"traceback": content["traceback"],
233-
}
234-
)
235-
else:
236-
return
237-
238-
239206
@router.get("/api/kernels/{kernel_id}")
240207
async def get_kernel(
241208
kernel_id,

tests/data/notebook0.ipynb

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@
1919
"metadata": {
2020
"trusted": 0
2121
}
22+
},
23+
{
24+
"source": "3 + 4",
25+
"outputs": [],
26+
"metadata": {
27+
"trusted": 0
28+
},
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"id": "a7243792-6f06-4462-a6b5-7e9ec604348f"
2232
}
2333
],
2434
"metadata": {

tests/test_server.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ async def test_rest_api(start_jupyverse):
7171
# wait for file to be loaded and Y model to be created in server and client
7272
await asyncio.sleep(0.1)
7373
# execute notebook
74-
for cell_idx in range(2):
74+
for cell_idx in range(3):
7575
response = requests.post(
7676
f"{url}/api/kernels/{kernel_id}/execute",
7777
data=json.dumps(
@@ -96,3 +96,11 @@ async def test_rest_api(start_jupyverse):
9696
assert cells[1]["outputs"] == [
9797
{"name": "stdout", "output_type": "stream", "text": ["Hello World!\n"]}
9898
]
99+
assert cells[2]["outputs"] == [
100+
{
101+
"data": {"text/plain": ["7"]},
102+
"execution_count": 3.0,
103+
"metadata": {},
104+
"output_type": "execute_result",
105+
}
106+
]

0 commit comments

Comments
 (0)