Skip to content

Commit ce39606

Browse files
committed
🔧 fix session_id handling
1 parent 722f3b6 commit ce39606

File tree

2 files changed

+64
-39
lines changed

2 files changed

+64
-39
lines changed

codeboxapi/box/codebox.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,28 +63,29 @@ def __new__(cls, *args, **kwargs):
6363

6464
return super().__new__(cls)
6565

66-
def __init__(self, *args, **kwargs) -> None:
67-
super().__init__(*args, **kwargs)
66+
def __init__(self, /, **kwargs) -> None:
67+
super().__init__()
6868
self.session_id: Optional[UUID] = kwargs.pop("session_id", None)
6969
self.aiohttp_session: Optional[ClientSession] = None
7070

7171
@classmethod
72-
def from_id(cls, session_id: Union[int, UUID]) -> "CodeBox":
73-
if isinstance(session_id, int):
74-
session_id = UUID(int=session_id)
75-
return cls(session_id=session_id)
72+
def from_id(cls, session_id: Union[int, UUID], **kwargs) -> "CodeBox":
73+
kwargs["session_id"] = (
74+
UUID(int=session_id) if isinstance(session_id, int) else session_id
75+
)
76+
return cls(**kwargs)
7677

7778
def _update(self) -> None:
7879
"""Update last interaction time"""
79-
if self.session_id is None:
80-
raise RuntimeError("Make sure to start your CodeBox before using it.")
8180
self.last_interaction = datetime.now()
8281

8382
def codebox_request(self, method, endpoint, *args, **kwargs) -> Dict[str, Any]:
8483
"""Basic request to the CodeBox API"""
8584
self._update()
85+
if self.session_id is None:
86+
raise RuntimeError("Make sure to start your CodeBox before using it.")
8687
return base_request(
87-
method, f"/codebox/{self.session_id}" + endpoint, *args, **kwargs
88+
method, f"/codebox/{self.session_id.int}" + endpoint, *args, **kwargs
8889
)
8990

9091
async def acodebox_request(
@@ -94,30 +95,46 @@ async def acodebox_request(
9495
self._update()
9596
if self.aiohttp_session is None:
9697
self.aiohttp_session = ClientSession()
98+
if self.session_id is None:
99+
raise RuntimeError("Make sure to start your CodeBox before using it.")
97100
return await abase_request(
98101
self.aiohttp_session,
99102
method,
100-
f"/codebox/{self.session_id}" + endpoint,
103+
f"/codebox/{self.session_id.int}" + endpoint,
101104
*args,
102105
**kwargs,
103106
)
104107

105108
def start(self) -> CodeBoxStatus:
106-
self.session_id = base_request(
107-
method="GET",
108-
endpoint="/codebox/start",
109-
)["id"]
109+
if self.session_id is not None:
110+
if settings.VERBOSE:
111+
print(f"{self} is already started!")
112+
return CodeBoxStatus(status="started")
113+
self.session_id = UUID(
114+
int=base_request(
115+
method="GET",
116+
endpoint="/codebox/start",
117+
)["id"]
118+
)
119+
if settings.VERBOSE:
120+
print(f"{self} started!")
110121
return CodeBoxStatus(status="started")
111122

112123
async def astart(self) -> CodeBoxStatus:
113124
self.aiohttp_session = ClientSession()
125+
if self.session_id is not None:
126+
if settings.VERBOSE:
127+
print(f"{self} is already started!")
128+
return CodeBoxStatus(status="started")
114129
self.session_id = (
115130
await abase_request(
116131
self.aiohttp_session,
117132
method="GET",
118133
endpoint="/codebox/start",
119134
)
120135
)["id"]
136+
if settings.VERBOSE:
137+
print(f"{self} started!")
121138
return CodeBoxStatus(status="started")
122139

123140
def status(self):

codeboxapi/box/localbox.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
import json
1010
import os
11+
import signal
1112
import subprocess
1213
import sys
1314
import time
@@ -38,10 +39,11 @@ class LocalBox(BaseBox):
3839
"""
3940

4041
_instance: Optional["LocalBox"] = None
42+
_jupyter_pids: List[int] = []
4143

4244
def __new__(cls, *args, **kwargs):
4345
if not cls._instance:
44-
cls._instance = super().__new__(cls, *args, **kwargs)
46+
cls._instance = super().__new__(cls)
4547
else:
4648
if settings.SHOW_INFO:
4749
print(
@@ -52,15 +54,16 @@ def __new__(cls, *args, **kwargs):
5254
)
5355
return cls._instance
5456

55-
def __init__(self) -> None:
56-
super().__init__()
57+
def __init__(self, /, **kwargs) -> None:
58+
super().__init__(session_id=kwargs.pop("session_id", None))
5759
self.port: int = 8888
5860
self.kernel_id: Optional[dict] = None
5961
self.ws: Union[WebSocketClientProtocol, ClientConnection, None] = None
6062
self.jupyter: Union[Process, subprocess.Popen, None] = None
6163
self.aiohttp_session: Optional[aiohttp.ClientSession] = None
6264

6365
def start(self) -> CodeBoxStatus:
66+
self.session_id = uuid4()
6467
os.makedirs(".codebox", exist_ok=True)
6568
self._check_port()
6669
if settings.VERBOSE:
@@ -84,6 +87,7 @@ def start(self) -> CodeBoxStatus:
8487
stderr=out,
8588
cwd=".codebox",
8689
)
90+
self._jupyter_pids.append(self.jupyter.pid)
8791
except FileNotFoundError:
8892
raise ModuleNotFoundError(
8993
"Jupyter Kernel Gateway not found, please install it with:\n"
@@ -100,7 +104,10 @@ def start(self) -> CodeBoxStatus:
100104
if settings.VERBOSE:
101105
print("Waiting for kernel to start...")
102106
time.sleep(1)
107+
self._connect()
108+
return CodeBoxStatus(status="started")
103109

110+
def _connect(self) -> None:
104111
response = requests.post(
105112
f"{self.kernel_url}/kernels",
106113
headers={"Content-Type": "application/json"},
@@ -112,20 +119,15 @@ def start(self) -> CodeBoxStatus:
112119

113120
self.ws = ws_connect_sync(f"{self.ws_url}/kernels/{self.kernel_id}/channels")
114121

115-
return CodeBoxStatus(status="started")
116-
117122
def _check_port(self) -> None:
118123
try:
119124
response = requests.get(f"http://localhost:{self.port}", timeout=90)
120125
except requests.exceptions.ConnectionError:
121126
pass
122127
else:
123-
try:
124-
requests.post(f"http://localhost:{self.port}/api/shutdown")
125-
except requests.exceptions.ConnectionError:
126-
if response.status_code == 200:
127-
self.port += 1
128-
self._check_port()
128+
if response.status_code == 200:
129+
self.port += 1
130+
self._check_port()
129131

130132
def _check_installed(self) -> None:
131133
try:
@@ -139,6 +141,7 @@ def _check_installed(self) -> None:
139141
raise
140142

141143
async def astart(self) -> CodeBoxStatus:
144+
self.session_id = uuid4()
142145
os.makedirs(".codebox", exist_ok=True)
143146
self.aiohttp_session = aiohttp.ClientSession()
144147
await self._acheck_port()
@@ -161,6 +164,7 @@ async def astart(self) -> CodeBoxStatus:
161164
stderr=out,
162165
cwd=".codebox",
163166
)
167+
self._jupyter_pids.append(self.jupyter.pid)
164168
except Exception as e:
165169
print(e)
166170
raise ModuleNotFoundError(
@@ -180,7 +184,12 @@ async def astart(self) -> CodeBoxStatus:
180184
if settings.VERBOSE:
181185
print("Waiting for kernel to start...")
182186
await asyncio.sleep(1)
187+
await self._aconnect()
188+
return CodeBoxStatus(status="started")
183189

190+
async def _aconnect(self) -> None:
191+
if self.aiohttp_session is None:
192+
self.aiohttp_session = aiohttp.ClientSession()
184193
response = await self.aiohttp_session.post(
185194
f"{self.kernel_url}/kernels", headers={"Content-Type": "application/json"}
186195
)
@@ -189,8 +198,6 @@ async def astart(self) -> CodeBoxStatus:
189198
raise Exception("Could not start kernel")
190199
self.ws = await ws_connect(f"{self.ws_url}/kernels/{self.kernel_id}/channels")
191200

192-
return CodeBoxStatus(status="started")
193-
194201
async def _acheck_port(self) -> None:
195202
try:
196203
if self.aiohttp_session is None:
@@ -206,6 +213,9 @@ async def _acheck_port(self) -> None:
206213
await self._acheck_port()
207214

208215
def status(self) -> CodeBoxStatus:
216+
if not self.kernel_id:
217+
self._connect()
218+
209219
return CodeBoxStatus(
210220
status="running"
211221
if self.kernel_id
@@ -214,6 +224,8 @@ def status(self) -> CodeBoxStatus:
214224
)
215225

216226
async def astatus(self) -> CodeBoxStatus:
227+
if not self.kernel_id:
228+
await self._aconnect()
217229
return CodeBoxStatus(
218230
status="running"
219231
if self.kernel_id
@@ -242,9 +254,9 @@ def run(
242254
if retry <= 0:
243255
raise RuntimeError("Could not connect to kernel")
244256
if not self.ws:
245-
self.start()
257+
self._connect()
246258
if not self.ws:
247-
raise RuntimeError("Could not connect to kernel")
259+
raise RuntimeError("Jupyter not running. Make sure to start it first.")
248260

249261
if settings.VERBOSE:
250262
print("Running code:\n", code)
@@ -354,9 +366,9 @@ async def arun(
354366
if retry <= 0:
355367
raise RuntimeError("Could not connect to kernel")
356368
if not self.ws:
357-
await self.astart()
369+
await self._aconnect()
358370
if not self.ws:
359-
raise RuntimeError("Could not connect to kernel")
371+
raise RuntimeError("Jupyter not running. Make sure to start it first.")
360372

361373
if settings.VERBOSE:
362374
print("Running code:\n", code)
@@ -488,20 +500,16 @@ async def alist_files(self) -> List[CodeBoxFile]:
488500
return await asyncio.to_thread(self.list_files)
489501

490502
def restart(self) -> CodeBoxStatus:
491-
if self.jupyter is not None:
492-
self.stop()
493-
else:
494-
self.start()
495503
return CodeBoxStatus(status="restarted")
496504

497505
async def arestart(self) -> CodeBoxStatus:
498-
if self.jupyter is not None:
499-
await self.astop()
500-
else:
501-
await self.astart()
502506
return CodeBoxStatus(status="restarted")
503507

504508
def stop(self) -> CodeBoxStatus:
509+
for pid in self._jupyter_pids:
510+
print(f"Killing {pid}")
511+
os.kill(pid, signal.SIGTERM)
512+
505513
if self.jupyter is not None:
506514
self.jupyter.terminate()
507515
self.jupyter.wait()

0 commit comments

Comments
 (0)