Skip to content
177 changes: 176 additions & 1 deletion enterprise_gateway/services/kernels/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@
"""Tornado handlers for kernel CRUD and communication."""
import json
import os
from datetime import datetime, timezone
from functools import partial

import jupyter_server.services.kernels.handlers as jupyter_server_handlers
import tornado
from jupyter_client.jsonutil import date_default
from jupyter_server.base.handlers import APIHandler
from tornado import web

try:
from jupyter_client.jsonutil import json_default
except ImportError:
from jupyter_client.jsonutil import date_default as json_default

from ...mixins import CORSMixin, JSONErrorsMixin, TokenAuthorizationMixin


Expand Down Expand Up @@ -146,11 +153,179 @@ def get(self, kernel_id):
self.finish(json.dumps(model, default=date_default))


default_handlers = []
class ConfigureMagicHandler(CORSMixin, JSONErrorsMixin, APIHandler):
@web.authenticated
async def post(self, kernel_id):
self.log.info(f"Update request received for kernel: {kernel_id}")
km = self.kernel_manager
km.check_kernel_id(kernel_id)
payload = self.get_json_body()
self.log.debug(f"Request payload: {payload}")
if payload is None:
self.finish(
json.dumps(
{
"message": f"Empty payload received. No operation performed on kernel: {kernel_id}"
},
default=date_default,
)
)
return
if type(payload) != dict:
raise web.HTTPError(400, f"Invalid JSON payload received for kernel: {kernel_id}.")
if payload.get("env", None) is None: # We only allow env field for now.
raise web.HTTPError(
400, "Missing required field `env` in payload for kernel: {kernel_id}."
)
kernel = km.get_kernel(kernel_id)
if kernel.restarting: # handle duplicate request.
self.log.info(
"An existing restart request is still in progress. Skipping this request."
)
raise web.HTTPError(
400, f"Duplicate configure kernel request received for kernel: {kernel_id}."
)
try:
# update Kernel metadata
kernel.set_user_extra_overrides(payload)
await km.restart_kernel(kernel_id)
kernel.fire_kernel_event_callbacks(
event="kernel_refresh", zmq_messages=payload.get("zmq_messages", {})
)
except web.HTTPError as he:
self.log.exception(
f"HTTPError exception occurred while re-configuring kernel: {kernel_id}: {he}"
)
await km.shutdown_kernel(kernel_id)
kernel.fire_kernel_event_callbacks(
event="kernel_refresh_failure", zmq_messages=payload.get("zmq_messages", {})
)
raise he
except Exception as e:
self.log.exception(
f"An exception occurred while re-configuring kernel: {kernel_id}: {e}"
)
await km.shutdown_kernel(kernel_id)
kernel.fire_kernel_event_callbacks(
event="kernel_refresh_failure", zmq_messages=payload.get("zmq_messages", {})
)
raise web.HTTPError(
500,
f"Error occurred while re-configuring kernel: {kernel_id}",
reason=f"{e}",
)
else:
response_body = {"message": f"Successfully re-configured kernel: {kernel_id}."}
self.finish(json.dumps(response_body, default=date_default))
return


class RemoteZMQChannelsHandler(
TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, jupyter_server_handlers.ZMQChannelsHandler
):
def open(self, kernel_id):
self.log.debug(f"Websocket open request received for kernel: {kernel_id}")
super().open(kernel_id)
km = self.kernel_manager
km.add_kernel_event_callbacks(kernel_id, self.on_kernel_refresh, "kernel_refresh")
km.add_kernel_event_callbacks(
kernel_id, self.on_kernel_refresh_failure, "kernel_refresh_failure"
)

def on_kernel_refresh(self, **kwargs):
self.log.info("Refreshing the client websocket to kernel connection.")
self.refresh_zmq_sockets()
zmq_messages = kwargs.get("zmq_messages", {})
if "stream_reply" in zmq_messages:
self.log.debug("Sending stream_reply success message.")
success_message = zmq_messages.get("stream_reply")
success_message["content"] = {
"name": "stdout",
"text": "The kernel is successfully refreshed.",
}
self._send_ws_message(success_message)
if "exec_reply" in zmq_messages:
self.log.debug("Sending exec_reply message.")
self._send_ws_message(zmq_messages.get("exec_reply"))
if "idle_reply" in zmq_messages:
self.log.debug("Sending idle_reply message.")
self._send_ws_message(zmq_messages.get("idle_reply"))
self._send_status_message(
"kernel_refreshed"
) # In the future, UI clients might start to consume this.

def on_kernel_refresh_failure(self, **kwargs):
self.log.error("kernel %s refresh failed!", self.kernel_id)
zmq_messages = kwargs.get("zmq_messages", {})
if "error_reply" in zmq_messages:
self.log.debug("Sending stream_reply error message.")
error_message = zmq_messages.get("error_reply")
error_message["content"] = {
"ename": "KernelRefreshFailed",
"evalue": "The kernel refresh operation failed.",
"traceback": ["The kernel refresh operation failed."],
}
self._send_ws_message(error_message)
if "exec_reply" in zmq_messages:
self.log.debug("Sending exec_reply message.")
exec_reply = zmq_messages.get("exec_reply").copy()
if "metadata" in exec_reply:
exec_reply["metadata"]["status"] = "error"
exec_reply["content"]["status"] = "error"
exec_reply["content"]["ename"] = "KernelRefreshFailed."
exec_reply["content"]["evalue"] = "The kernel refresh operation failed."
exec_reply["content"]["traceback"] = ["The kernel refresh operation failed."]
self._send_ws_message(exec_reply)
if "idle_reply" in zmq_messages:
self.log.info("Sending idle reply message.")
self._send_ws_message(zmq_messages.get("idle_reply"))
self.log.debug("sending kernel dead message.")
self._send_status_message("dead")

def refresh_zmq_sockets(self):
self.close_existing_streams()
kernel = self.kernel_manager.get_kernel(self.kernel_id)
self.session.key = kernel.session.key # refresh the session key
self.log.debug("Creating new ZMQ Socket streams.")
self.create_stream()
for channel, stream in self.channels.items():
self.log.debug(f"Updating channel: {channel}")
stream.on_recv_stream(self._on_zmq_reply)

def close_existing_streams(self):
self.log.debug(f"Closing existing channels for kernel: {self.kernel_id}")
for channel, stream in self.channels.items():
if stream is not None and not stream.closed():
self.log.debug(f"Close channel : {channel}")
stream.on_recv(None)
stream.close()
self.channels = {}

def _send_ws_message(self, kernel_msg):
self.log.debug(f"Sending websocket message: {kernel_msg}")
if "header" in kernel_msg and type(kernel_msg["header"] == dict):
kernel_msg["header"]["date"] = datetime.utcnow().replace(tzinfo=timezone.utc)
self.write_message(json.dumps(kernel_msg, default=json_default))

def on_close(self):
self.log.info(f"Websocket close request received for kernel: {self.kernel_id}")
super().on_close()
self.kernel_manager.remove_kernel_event_callbacks(
self.kernel_id, self.on_kernel_refresh, "kernel_refresh"
)
self.kernel_manager.remove_kernel_event_callbacks(
self.kernel_id, self.on_kernel_refresh_failure, "kernel_refresh_failure"
)


_kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
default_handlers = [(r"/api/kernels/configure/%s" % _kernel_id_regex, ConfigureMagicHandler)]
for path, cls in jupyter_server_handlers.default_handlers:
if cls.__name__ in globals():
# Use the same named class from here if it exists
default_handlers.append((path, globals()[cls.__name__]))
elif cls.__name__ == jupyter_server_handlers.ZMQChannelsHandler.__name__:
default_handlers.append((path, RemoteZMQChannelsHandler))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this meant to replace ZMQChannelsHandler? I guess I don't understand why ZMQChannelsHandler isn't satisfied by the first condition - but I'm not that familiar with globals() (sorry).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to discuss this further.
what I am trying do here is replace the ZMQChannelsHandler with RemoteZMQChannelsHandler for handling the channels requests.

I tried to re-use the same class name on EG but was facing some issue where websocket connection was failing.

else:
# Gen a new type with CORS and token auth
bases = (TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, cls)
Expand Down
Loading