From 366bfa043a71dd646b6fc0092c6dc3636946542a Mon Sep 17 00:00:00 2001 From: Milan Oberkirch | geOps Date: Tue, 21 Feb 2023 20:08:21 +0100 Subject: [PATCH 1/6] Make connection handling more explicit With this patch we can be more certain that the connection is not kept for longer than the `read_timeout`. --- redis_websocket_api/handler.py | 11 ++++++----- redis_websocket_api/protocol.py | 2 +- redis_websocket_api/server.py | 16 ++++++++++++++-- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/redis_websocket_api/handler.py b/redis_websocket_api/handler.py index 980ad64..27e6795 100644 --- a/redis_websocket_api/handler.py +++ b/redis_websocket_api/handler.py @@ -132,7 +132,7 @@ def channel_is_allowed(self, channel_name): async def create(cls, redis, websocket, read_timeout=None): """Create a handler instance setting up tasks and queues.""" self = cls(redis, websocket, read_timeout=read_timeout) - self.consumer_task = asyncio.ensure_future(self._websocket_reader()) + self.consumer_task = asyncio.create_task(self._websocket_reader()) return self async def listen(self): @@ -141,7 +141,7 @@ async def listen(self): This coroutine blocks for up to self.read_timeout seconds. """ await self._send("websocket", {"status": "open"}) - while not self.consumer_task.done(): + while not self.consumer_task.done() and not self.websocket.closed: try: await self._queue_reader() except asyncio.TimeoutError: @@ -150,9 +150,10 @@ async def listen(self): async def close(self): """Close all connections and cancel all tasks.""" - asyncio.wait_for( # attempt to say goodbye to the client - self.websocket.close(), self.read_timeout - ) + if self.websocket.open: + await asyncio.wait_for( # attempt to say goodbye to the client + self.websocket.close(), self.read_timeout + ) self.consumer_task.cancel() diff --git a/redis_websocket_api/protocol.py b/redis_websocket_api/protocol.py index 2c21050..6c0a8b1 100644 --- a/redis_websocket_api/protocol.py +++ b/redis_websocket_api/protocol.py @@ -16,7 +16,7 @@ class CommandsMixin: The commend name is translated to a method name like this: - "_handle_{name}_command".format(name=a_tralis_protocol_command.lower()) + "_handle_{name}_command".format(name=a_custom_protocol_command.lower()) """ async def _handle_del_command(self, channel_name): diff --git a/redis_websocket_api/server.py b/redis_websocket_api/server.py index 6fe106a..b2e16b5 100644 --- a/redis_websocket_api/server.py +++ b/redis_websocket_api/server.py @@ -49,13 +49,25 @@ async def websocket_handler(self, websocket, path): logger.info("Client %s connected", websocket.remote_address) handler = await self.handler_class.create( - self.redis, websocket, read_timeout=self.read_timeout, + self.redis, + websocket, + read_timeout=self.read_timeout, ) self.handlers[websocket.remote_address] = handler + handler_listen_task = asyncio.create_task( + asyncio.wait_for(handler.listen(), self.keep_alive_timeout) + ) try: - await asyncio.wait_for(handler.listen(), self.keep_alive_timeout) + await asyncio.wait( + { + handler_listen_task, + websocket.handler_task, + }, + return_when="FIRST_COMPLETED", + ) finally: del self.handlers[websocket.remote_address] + handler_listen_task.cancel() await handler.close() logger.info("Client %s removed", websocket.remote_address) From 7277ba4e19767b9441dbc2ccd8d38a75e1cc7b6e Mon Sep 17 00:00:00 2001 From: Milan Oberkirch | geOps Date: Tue, 21 Feb 2023 20:10:58 +0100 Subject: [PATCH 2/6] Remove outdated examples Refer to https://developer.geops.io for up to date documentation. --- examples/demo.html | 41 ------------------ examples/demo.py | 102 --------------------------------------------- 2 files changed, 143 deletions(-) delete mode 100644 examples/demo.html delete mode 100644 examples/demo.py diff --git a/examples/demo.html b/examples/demo.html deleted file mode 100644 index aec3f0d..0000000 --- a/examples/demo.html +++ /dev/null @@ -1,41 +0,0 @@ - - - - - - WebSocket demo - - -

Frontend for demo.py

-

This demonstates some features by requesting

- - -

Received messages

-
    -

    ... waiting for messages

    - - - diff --git a/examples/demo.py b/examples/demo.py deleted file mode 100644 index d6169e5..0000000 --- a/examples/demo.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Example for extending the WebsocketServer with the GeoCommandsMixin""" - -import logging -import asyncio - -from aioredis import create_redis_pool -from websockets import exceptions - -from redis_websocket_api import WebsocketHandler, WebsocketServer -from redis_websocket_api.geo_protocol import GeoCommandsMixin -from redis_websocket_api.exceptions import ( - RemoteMessageHandlerError, InternalMessageHandlerError) - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.DEBUG) - -REDIS_ADDRESS = ('localhost', 6379) -EXAMPLE_GEOJSON = """\ -{{ - "type": "Feature", - "properties": {{"id": {id}}}, - "geometry": {{ - "type": "Point", - "coordinates": [{lon}, {lat}] - }} -}} -""" - - -class ExampleWebsocketHandler(WebsocketHandler, GeoCommandsMixin): - """Implement and handle websocket based Redis proxy protocol.""" - - allowed_commands = 'SUB', 'DEL', 'BBOX', 'PROJECTION', 'PING', 'GET' - - -class ExampleWebsocketServer(WebsocketServer): - - handler_class = ExampleWebsocketHandler - - async def websocket_handler(self, websocket, path): - """Add some error handling to the default implementation""" - try: - await super().websocket_handler(websocket, path) - except (exceptions.ConnectionClosed, asyncio.TimeoutError): - logger.info("Client %s disconnected", websocket.remote_address) - except RemoteMessageHandlerError as e: - logger.info("Hanging up on %s after invalid remote command: %s", - websocket.remote_address, e, exc_info=True) - except InternalMessageHandlerError: - logger.exception("Hanging up on %s because of buggy message!", - websocket.remote_address) - - -async def example_producer(): - """Dummy producer putting data into redis for demonstrating the API - - This will put one message into example_channel_2 every 0.1 seconds with - coordinates shifted by one each time. - """ - - redis = await create_redis_pool(REDIS_ADDRESS) - - # If there is a HSET with the same name as a channel it's content serves as - # initial data when performing a `GET channel_name` from the websocket - # client. - await redis.hset( - 'example_channel_1', 'initial_data_1', '{"some_json": "no GeoJSON"}') - await redis.hset( - 'example_channel_2', 'initial_data_1', EXAMPLE_GEOJSON.format( - id=1, lon=7.8486934304237375, lat=47.9914151679489)) - - counter = 1 - while True: - counter += 1 - message = EXAMPLE_GEOJSON.format( - id=counter, lon=counter % 180, lat=counter % 90) - # Keeping the initial data up to date - redis.hset('example_channel_2', 'data_{}'.format(counter), message) - # Pushing to subscribers - redis.publish('example_channel_2', message) - await asyncio.sleep(.1) - - -def main(): - loop = asyncio.get_event_loop() - loop.create_task(example_producer()) - - ExampleWebsocketServer( - redis=loop.run_until_complete(create_redis_pool(REDIS_ADDRESS)), - subscriber=loop.run_until_complete(create_redis_pool(REDIS_ADDRESS)), - read_timeout=30, - keep_alive_timeout=120, - ).listen( - host='localhost', - port=8000, - channel_names=('example_channel_1', 'example_channel_2'), - loop=loop - ) - - -if __name__ == '__main__': - main() From da8432c3bbf48199bfe3ca944433b9fbc55d3c90 Mon Sep 17 00:00:00 2001 From: Milan Oberkirch | geOps Date: Tue, 21 Feb 2023 20:25:30 +0100 Subject: [PATCH 3/6] Quick and dirty unittest fix --- tests/conftest.py | 2 ++ tests/test_server.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3d645ab..60ef29c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import MagicMock import pytest @@ -37,6 +38,7 @@ def __getattr__(self, name): def websocket(): websocket = AsyncMagicMock() websocket.remote_address = ("EGG", 2000) + websocket.handler_task = asyncio.sleep(0) return websocket diff --git a/tests/test_server.py b/tests/test_server.py index 5485c10..8dd8f4b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -9,8 +9,7 @@ def test_websocket_handler_creation(loop, server, websocket): server.handlers = MagicMock() websocket.await_recv.side_effect = exceptions.ConnectionClosed(1001, "foo") - with pytest.warns(RuntimeWarning): - asyncio.run(server.websocket_handler(websocket, "/foo")) + asyncio.run(server.websocket_handler(websocket, "/foo")) assert websocket.await_recv.call_count == 1 assert websocket.await_send.call_count == 1 From cabc3d19d69f633759ea1ad43903a62dbd33a754 Mon Sep 17 00:00:00 2001 From: Milan Oberkirch | geOps Date: Thu, 23 Feb 2023 14:17:09 +0100 Subject: [PATCH 4/6] Remove polling for state, cancel reader instead This deprecates the `read_timeout` argument but does not brake the existing public API. Also switched some early Python 3.5 asyncio coding style for Python 3.7+ patterns. Most notable the `loop` is only used directly if given as parameter. --- README.md | 19 ++++++------ redis_websocket_api/__main__.py | 8 +++-- redis_websocket_api/handler.py | 53 +++++++++++++++++++++------------ redis_websocket_api/server.py | 53 ++++++++++++++++++++++----------- 4 files changed, 86 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index fef9673..91a5cc0 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ to start a simple redis websocket api on `ws://localhost:8765`. This does [roughly](./redis_websocket_api/__main__.py) the equivalant of: ```python +import asyncio from aioredis import from_url from redis_websocket_api import WebsocketServer, WebsocketHandler @@ -37,20 +38,20 @@ class PublishEverythingHandler(WebsocketHandler): return True -WebsocketServer( +ws_server = WebsocketServer( redis=from_url("redis:///", encoding="utf-8", decode_responses=True), - read_timeout=30, keep_alive_timeout=120, handler_class=PublishEverythingHandler, -).listen( - host='localhost', - port=8000, - channel_patterns=["[a-z]*"], ) -``` -Have a look at `examples/demo.py` for an example with the `GeoCommandsMixin` -added. +asyncio.run( + ws_server.serve( + host='localhost', + port=8000, + channel_patterns=["[a-z]*"], + ) +) +``` Client-Side Usage diff --git a/redis_websocket_api/__main__.py b/redis_websocket_api/__main__.py index e101ccc..63d41ed 100755 --- a/redis_websocket_api/__main__.py +++ b/redis_websocket_api/__main__.py @@ -5,6 +5,7 @@ Intended for debugging and development only. """ +import asyncio from os import getenv from logging import basicConfig, INFO import aioredis @@ -22,10 +23,13 @@ def main(): redis = aioredis.from_url( getenv("REDIS_DSN", "redis:///"), encoding="utf-8", decode_responses=True ) - server = WebsocketServer(redis=redis, handler_class=PublishEverythingHandler,) + server = WebsocketServer( + redis=redis, + handler_class=PublishEverythingHandler, + ) host = getenv("HOST", "localhost") port = int(getenv("PORT", 8765)) - server.listen(host, port, channel_patterns=["[a-z]*"]) + asyncio.run(server.serve(host, port, channel_patterns=["[a-z]*"])) if __name__ == "__main__": diff --git a/redis_websocket_api/handler.py b/redis_websocket_api/handler.py index 27e6795..e5286f7 100644 --- a/redis_websocket_api/handler.py +++ b/redis_websocket_api/handler.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from collections import OrderedDict from logging import getLogger @@ -15,18 +17,26 @@ class WebsocketHandlerBase: """Define protocol for communication between web client and server.""" - read_timeout = NotImplemented allowed_commands = NotImplemented + consumer_task: asyncio.Task # for backwrds-compatibility, also in: + tasks: dict[str, asyncio.Task] # Self-made TaskGroup for old Python + def __init__(self, redis, websocket, read_timeout=None): + if read_timeout: + logger.warning( + "read_timeout is not used anymore because cleanup is trigered " + "immidiatly on connection loss" + ) + self.websocket = websocket self.redis = redis - self.read_timeout = read_timeout or self.read_timeout self.queue = asyncio.Queue() self.filters = OrderedDict() self.subscriptions = set() + self.tasks = {} async def _websocket_reader(self): try: @@ -62,7 +72,7 @@ def _apply_filters(self, message, exclude=()): return passes, data async def _queue_reader(self): - source, message = await asyncio.wait_for(self.queue.get(), self.read_timeout) + source, message = await self.queue.get() if source == "websocket": await self._handle_remote_message(message) else: @@ -131,36 +141,41 @@ def channel_is_allowed(self, channel_name): @classmethod async def create(cls, redis, websocket, read_timeout=None): """Create a handler instance setting up tasks and queues.""" - self = cls(redis, websocket, read_timeout=read_timeout) + + if read_timeout: + logger.warning( + "read_timeout is not used anymore because cleanup is trigered " + "immidiatly on connection loss" + ) + + self = cls(redis, websocket) self.consumer_task = asyncio.create_task(self._websocket_reader()) + self.tasks["consumer_task"] = self.consumer_task return self async def listen(self): - """Read and handle messages from internal message queue. - - This coroutine blocks for up to self.read_timeout seconds. - """ + """Read and handle messages from internal message queue""" await self._send("websocket", {"status": "open"}) - while not self.consumer_task.done() and not self.websocket.closed: - try: - await self._queue_reader() - except asyncio.TimeoutError: - if not self.websocket.open: - raise + while self.websocket.open: + queue_reader_task = asyncio.create_task(self._queue_reader()) + self.tasks["queue_reader"] = queue_reader_task + await queue_reader_task async def close(self): """Close all connections and cancel all tasks.""" if self.websocket.open: - await asyncio.wait_for( # attempt to say goodbye to the client - self.websocket.close(), self.read_timeout - ) - self.consumer_task.cancel() + await self.websocket.close() + for task in self.tasks.values(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass class WebsocketHandler(WebsocketHandlerBase, CommandsMixin): """Provides a Redis proxy to predefined channels""" - read_timeout = 30 allowed_commands = "SUB", "DEL", "PING", "GET" channel_names = set() diff --git a/redis_websocket_api/server.py b/redis_websocket_api/server.py index b2e16b5..3d6b100 100644 --- a/redis_websocket_api/server.py +++ b/redis_websocket_api/server.py @@ -26,13 +26,16 @@ def __init__( """Set default values for new WebsocketHandlers. :param redis: aioredis.client.Redis instance - :param read_timeout: Timeout, after which the websocket connection is - checked and kept if still open (does not cancel an open connection) :param keep_alive_timeout: Time after which the server cancels the handler task (independently of it's internal state) """ - self.read_timeout = read_timeout + if read_timeout: + logger.warning( + "read_timeout is not used anymore because cleanup is trigered " + "immidiatly on connection loss" + ) + self.keep_alive_timeout = keep_alive_timeout self.handlers = {} self.redis = redis @@ -48,11 +51,7 @@ async def websocket_handler(self, websocket, path): """Return handler for a single websocket connection.""" logger.info("Client %s connected", websocket.remote_address) - handler = await self.handler_class.create( - self.redis, - websocket, - read_timeout=self.read_timeout, - ) + handler = await self.handler_class.create(self.redis, websocket) self.handlers[websocket.remote_address] = handler handler_listen_task = asyncio.create_task( asyncio.wait_for(handler.listen(), self.keep_alive_timeout) @@ -61,7 +60,7 @@ async def websocket_handler(self, websocket, path): await asyncio.wait( { handler_listen_task, - websocket.handler_task, + asyncio.create_task(websocket.wait_closed()), }, return_when="FIRST_COMPLETED", ) @@ -69,6 +68,10 @@ async def websocket_handler(self, websocket, path): del self.handlers[websocket.remote_address] handler_listen_task.cancel() await handler.close() + try: + await handler_listen_task + except asyncio.CancelledError: + pass logger.info("Client %s removed", websocket.remote_address) async def redis_subscribe(self, p, channel_names=(), channel_patterns=()): @@ -106,13 +109,29 @@ async def redis_reader(self, channel_names=(), channel_patterns=()): Message(source=channel_name, content=message["data"]) ) - psub.close() - - def listen(self, host, port, channel_names=(), channel_patterns=(), loop=None): + async def serve( + self, + host, + port, + channel_names=(), + channel_patterns=(), + **websockets_serve_kwargs + ): """Listen for websocket connections and manage redis subscriptions.""" - loop = loop or asyncio.get_event_loop() - start_server = serve(self.websocket_handler, host, port) - loop.run_until_complete(start_server) - logger.info("Listening on %s:%s...", host, port) - loop.run_until_complete(self.redis_reader(channel_names, channel_patterns)) + async with serve( + ws_handler=self.websocket_handler, + host=host, + port=port, + **websockets_serve_kwargs, + ): + logger.info("Listening on %s:%s...", host, port) + await self.redis_reader(channel_names, channel_patterns) + + def listen(self, host, port, channel_names=(), channel_patterns=(), loop=None): + logger.warning("`listen` is deprecated, use `serve` instead") + serve_coro = self.serve(host, port, channel_names, channel_patterns) + if loop: + loop.run_until_complete(serve_coro) + else: + asyncio.run(serve_coro) From c1a2dbb8bf138050c1a35982e85b6af415a1dc5f Mon Sep 17 00:00:00 2001 From: Milan Oberkirch Date: Mon, 6 Mar 2023 11:53:43 +0100 Subject: [PATCH 5/6] Fix typos in warnings Co-authored-by: Alexander Held --- redis_websocket_api/handler.py | 8 ++++---- redis_websocket_api/server.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/redis_websocket_api/handler.py b/redis_websocket_api/handler.py index e5286f7..aed984a 100644 --- a/redis_websocket_api/handler.py +++ b/redis_websocket_api/handler.py @@ -25,8 +25,8 @@ class WebsocketHandlerBase: def __init__(self, redis, websocket, read_timeout=None): if read_timeout: logger.warning( - "read_timeout is not used anymore because cleanup is trigered " - "immidiatly on connection loss" + "read_timeout is not used anymore because cleanup is triggered " + "immediately on connection loss" ) self.websocket = websocket @@ -144,8 +144,8 @@ async def create(cls, redis, websocket, read_timeout=None): if read_timeout: logger.warning( - "read_timeout is not used anymore because cleanup is trigered " - "immidiatly on connection loss" + "read_timeout is not used anymore because cleanup is triggered " + "immediately on connection loss" ) self = cls(redis, websocket) diff --git a/redis_websocket_api/server.py b/redis_websocket_api/server.py index 3d6b100..7574fd1 100644 --- a/redis_websocket_api/server.py +++ b/redis_websocket_api/server.py @@ -32,8 +32,8 @@ def __init__( if read_timeout: logger.warning( - "read_timeout is not used anymore because cleanup is trigered " - "immidiatly on connection loss" + "read_timeout is not used anymore because cleanup is triggered " + "immediately on connection loss" ) self.keep_alive_timeout = keep_alive_timeout From 3f2ca6bfd45816535a5dc87e309ba0be850dd47d Mon Sep 17 00:00:00 2001 From: Milan Oberkirch | geOps Date: Mon, 6 Mar 2023 17:22:13 +0100 Subject: [PATCH 6/6] Await everything regadless of relevance --- redis_websocket_api/server.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/redis_websocket_api/server.py b/redis_websocket_api/server.py index 7574fd1..958efbb 100644 --- a/redis_websocket_api/server.py +++ b/redis_websocket_api/server.py @@ -1,7 +1,8 @@ import asyncio from logging import getLogger +import websockets -from websockets import serve +from websockets.legacy.server import WebSocketServerProtocol, serve from redis_websocket_api.handler import WebsocketHandler, WebsocketHandlerBase from redis_websocket_api.protocol import Message @@ -47,7 +48,7 @@ def __init__( "handler_class has to be a subclass of WebsocketHandlerBase" ) - async def websocket_handler(self, websocket, path): + async def websocket_handler(self, websocket: WebSocketServerProtocol, path) -> None: """Return handler for a single websocket connection.""" logger.info("Client %s connected", websocket.remote_address) @@ -56,22 +57,33 @@ async def websocket_handler(self, websocket, path): handler_listen_task = asyncio.create_task( asyncio.wait_for(handler.listen(), self.keep_alive_timeout) ) + wait_closed_task = asyncio.create_task(websocket.wait_closed()) + await asyncio.wait( + { + handler_listen_task, + wait_closed_task, + }, + return_when="FIRST_COMPLETED", + ) try: - await asyncio.wait( - { - handler_listen_task, - asyncio.create_task(websocket.wait_closed()), - }, - return_when="FIRST_COMPLETED", - ) - finally: del self.handlers[websocket.remote_address] handler_listen_task.cancel() await handler.close() try: await handler_listen_task + except (asyncio.CancelledError, asyncio.TimeoutError) as e: + logger.debug( + "Closing connection for %s: %s", websocket.remote_address, e + ) + wait_closed_task.cancel() + try: + await wait_closed_task except asyncio.CancelledError: - pass + logger.warning( + "Websocket connection for %s seems to be open after cleanup", + websocket.remote_address, + ) + finally: logger.info("Client %s removed", websocket.remote_address) async def redis_subscribe(self, p, channel_names=(), channel_patterns=()):