Skip to content

Commit bd44c8f

Browse files
committed
Use asyncio.Future in execute_requests
1 parent 742c7a2 commit bd44c8f

File tree

1 file changed

+10
-12
lines changed
  • plugins/kernels/fps_kernels/kernel_driver

1 file changed

+10
-12
lines changed

plugins/kernels/fps_kernels/kernel_driver/driver.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959
self.key = cast(str, self.connection_cfg["key"])
6060
self.session_id = uuid.uuid4().hex
6161
self.msg_cnt = 0
62-
self.execute_requests: Dict[str, Any] = {}
62+
self.execute_requests: Dict[str, Dict[str, asyncio.Future]] = {}
6363
self.channel_tasks: List[asyncio.Task] = []
6464

6565
async def restart(self, startup_timeout: float = float("inf")) -> None:
@@ -109,16 +109,14 @@ async def listen_iopub(self):
109109
msg = await receive_message(self.iopub_channel) # type: ignore
110110
msg_id = msg["parent_header"].get("msg_id")
111111
if msg_id in self.execute_requests.keys():
112-
self.execute_requests[msg_id]["iopub_msg"] = msg
113-
self.execute_requests[msg_id]["iopub_event"].set()
112+
self.execute_requests[msg_id]["iopub_msg"].set_result(msg)
114113

115114
async def listen_shell(self):
116115
while True:
117116
msg = await receive_message(self.shell_channel) # type: ignore
118117
msg_id = msg["parent_header"].get("msg_id")
119118
if msg_id in self.execute_requests.keys():
120-
self.execute_requests[msg_id]["shell_msg"] = msg
121-
self.execute_requests[msg_id]["shell_event"].set()
119+
self.execute_requests[msg_id]["shell_msg"].set_result(msg)
122120

123121
async def execute(
124122
self,
@@ -142,35 +140,35 @@ async def execute(
142140
if wait_for_executed:
143141
deadline = time.time() + timeout
144142
self.execute_requests[msg_id] = {
145-
"iopub_event": asyncio.Event(),
146-
"shell_event": asyncio.Event(),
143+
"iopub_msg": asyncio.Future(),
144+
"shell_msg": asyncio.Future(),
147145
}
148146
while True:
149147
try:
150148
await asyncio.wait_for(
151-
self.execute_requests[msg_id]["iopub_event"].wait(),
149+
self.execute_requests[msg_id]["iopub_msg"],
152150
deadline_to_timeout(deadline),
153151
)
154152
except asyncio.TimeoutError:
155153
error_message = f"Kernel didn't respond in {timeout} seconds"
156154
raise RuntimeError(error_message)
157-
msg = self.execute_requests[msg_id]["iopub_msg"]
155+
msg = self.execute_requests[msg_id]["iopub_msg"].result()
158156
self._handle_outputs(cell["outputs"], msg)
159157
if (
160158
msg["header"]["msg_type"] == "status"
161159
and msg["content"]["execution_state"] == "idle"
162160
):
163161
break
164-
self.execute_requests[msg_id]["iopub_event"].clear()
162+
self.execute_requests[msg_id]["iopub_msg"] = asyncio.Future()
165163
try:
166164
await asyncio.wait_for(
167-
self.execute_requests[msg_id]["shell_event"].wait(),
165+
self.execute_requests[msg_id]["shell_msg"],
168166
deadline_to_timeout(deadline),
169167
)
170168
except asyncio.TimeoutError:
171169
error_message = f"Kernel didn't respond in {timeout} seconds"
172170
raise RuntimeError(error_message)
173-
msg = self.execute_requests[msg_id]["shell_msg"]
171+
msg = self.execute_requests[msg_id]["shell_msg"].result()
174172
cell["execution_count"] = msg["content"]["execution_count"]
175173
del self.execute_requests[msg_id]
176174

0 commit comments

Comments
 (0)