diff --git a/solara/server/patch.py b/solara/server/patch.py index 5107ded78..9a480db84 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -269,6 +269,9 @@ def hook(msg): if msg["msg_type"] == "display_data": self.outputs += ({"output_type": "display_data", "data": msg["content"]["data"], "metadata": msg["content"]["metadata"]},) return None + if msg["msg_type"] == "clear_output": + self.outputs = () + return None return msg get_ipython().display_pub.register_hook(hook) diff --git a/solara/server/shell.py b/solara/server/shell.py index 01c635879..4a1d0df57 100644 --- a/solara/server/shell.py +++ b/solara/server/shell.py @@ -119,11 +119,15 @@ def clear_output(self, wait=False): """ content = dict(wait=wait) self._flush_streams() + msg = self.session.msg("clear_output", json_clean(content), parent=self.parent_header) + for hook in self._hooks: + msg = hook(msg) + if msg is None: + return + self.session.send( self.pub_socket, - "clear_output", - content, - parent=self.parent_header, + msg, ident=self.topic, ) diff --git a/tests/unit/output_widget_test.py b/tests/unit/output_widget_test.py new file mode 100644 index 000000000..a07273384 --- /dev/null +++ b/tests/unit/output_widget_test.py @@ -0,0 +1,43 @@ +from unittest.mock import Mock + +import IPython.display +import ipywidgets as widgets + +from solara.server import app, kernel + + +def test_interactive_shell(no_app_context): + ws1 = Mock() + ws2 = Mock() + kernel1 = kernel.Kernel() + kernel2 = kernel.Kernel() + kernel1.session.websockets.add(ws1) + kernel2.session.websockets.add(ws2) + context1 = app.AppContext(id="1", kernel=kernel1) + context2 = app.AppContext(id="2", kernel=kernel2) + + with context1: + output1 = widgets.Output() + with output1: + IPython.display.display("test1") + assert output1.outputs[0]["data"]["text/plain"] == "'test1'" + assert ws1.send.call_count == 3 # create 2 widgets (layout and output) and update data + assert ws2.send.call_count == 0 + with context2: + output2 = widgets.Output() + with output2: + IPython.display.display("test2") + assert output2.outputs[0]["data"]["text/plain"] == "'test2'" + assert ws1.send.call_count == 3 + assert ws2.send.call_count == 3 + + context1.close() + context2.close() + + +def test_clear_output(): + output1 = widgets.Output() + with output1: + IPython.display.display("test1") + IPython.display.clear_output() + assert len(output1.outputs) == 0