diff --git a/README.md b/README.md index fef9673..91a5cc0 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ to start a simple redis websocket api on `ws://localhost:8765`. This does [roughly](./redis_websocket_api/__main__.py) the equivalant of: ```python +import asyncio from aioredis import from_url from redis_websocket_api import WebsocketServer, WebsocketHandler @@ -37,20 +38,20 @@ class PublishEverythingHandler(WebsocketHandler): return True -WebsocketServer( +ws_server = WebsocketServer( redis=from_url("redis:///", encoding="utf-8", decode_responses=True), - read_timeout=30, keep_alive_timeout=120, handler_class=PublishEverythingHandler, -).listen( - host='localhost', - port=8000, - channel_patterns=["[a-z]*"], ) -``` -Have a look at `examples/demo.py` for an example with the `GeoCommandsMixin` -added. +asyncio.run( + ws_server.serve( + host='localhost', + port=8000, + channel_patterns=["[a-z]*"], + ) +) +``` Client-Side Usage diff --git a/examples/demo.html b/examples/demo.html deleted file mode 100644 index aec3f0d..0000000 --- a/examples/demo.html +++ /dev/null @@ -1,41 +0,0 @@ - - - -
- -This demonstates some features by requesting
-... waiting for messages
- - - diff --git a/examples/demo.py b/examples/demo.py deleted file mode 100644 index d6169e5..0000000 --- a/examples/demo.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Example for extending the WebsocketServer with the GeoCommandsMixin""" - -import logging -import asyncio - -from aioredis import create_redis_pool -from websockets import exceptions - -from redis_websocket_api import WebsocketHandler, WebsocketServer -from redis_websocket_api.geo_protocol import GeoCommandsMixin -from redis_websocket_api.exceptions import ( - RemoteMessageHandlerError, InternalMessageHandlerError) - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.DEBUG) - -REDIS_ADDRESS = ('localhost', 6379) -EXAMPLE_GEOJSON = """\ -{{ - "type": "Feature", - "properties": {{"id": {id}}}, - "geometry": {{ - "type": "Point", - "coordinates": [{lon}, {lat}] - }} -}} -""" - - -class ExampleWebsocketHandler(WebsocketHandler, GeoCommandsMixin): - """Implement and handle websocket based Redis proxy protocol.""" - - allowed_commands = 'SUB', 'DEL', 'BBOX', 'PROJECTION', 'PING', 'GET' - - -class ExampleWebsocketServer(WebsocketServer): - - handler_class = ExampleWebsocketHandler - - async def websocket_handler(self, websocket, path): - """Add some error handling to the default implementation""" - try: - await super().websocket_handler(websocket, path) - except (exceptions.ConnectionClosed, asyncio.TimeoutError): - logger.info("Client %s disconnected", websocket.remote_address) - except RemoteMessageHandlerError as e: - logger.info("Hanging up on %s after invalid remote command: %s", - websocket.remote_address, e, exc_info=True) - except InternalMessageHandlerError: - logger.exception("Hanging up on %s because of buggy message!", - websocket.remote_address) - - -async def example_producer(): - """Dummy producer putting data into redis for demonstrating the API - - This will put one message into example_channel_2 every 0.1 seconds with - coordinates shifted by one each time. - """ - - redis = await create_redis_pool(REDIS_ADDRESS) - - # If there is a HSET with the same name as a channel it's content serves as - # initial data when performing a `GET channel_name` from the websocket - # client. - await redis.hset( - 'example_channel_1', 'initial_data_1', '{"some_json": "no GeoJSON"}') - await redis.hset( - 'example_channel_2', 'initial_data_1', EXAMPLE_GEOJSON.format( - id=1, lon=7.8486934304237375, lat=47.9914151679489)) - - counter = 1 - while True: - counter += 1 - message = EXAMPLE_GEOJSON.format( - id=counter, lon=counter % 180, lat=counter % 90) - # Keeping the initial data up to date - redis.hset('example_channel_2', 'data_{}'.format(counter), message) - # Pushing to subscribers - redis.publish('example_channel_2', message) - await asyncio.sleep(.1) - - -def main(): - loop = asyncio.get_event_loop() - loop.create_task(example_producer()) - - ExampleWebsocketServer( - redis=loop.run_until_complete(create_redis_pool(REDIS_ADDRESS)), - subscriber=loop.run_until_complete(create_redis_pool(REDIS_ADDRESS)), - read_timeout=30, - keep_alive_timeout=120, - ).listen( - host='localhost', - port=8000, - channel_names=('example_channel_1', 'example_channel_2'), - loop=loop - ) - - -if __name__ == '__main__': - main() diff --git a/redis_websocket_api/__main__.py b/redis_websocket_api/__main__.py index e101ccc..63d41ed 100755 --- a/redis_websocket_api/__main__.py +++ b/redis_websocket_api/__main__.py @@ -5,6 +5,7 @@ Intended for debugging and development only. """ +import asyncio from os import getenv from logging import basicConfig, INFO import aioredis @@ -22,10 +23,13 @@ def main(): redis = aioredis.from_url( getenv("REDIS_DSN", "redis:///"), encoding="utf-8", decode_responses=True ) - server = WebsocketServer(redis=redis, handler_class=PublishEverythingHandler,) + server = WebsocketServer( + redis=redis, + handler_class=PublishEverythingHandler, + ) host = getenv("HOST", "localhost") port = int(getenv("PORT", 8765)) - server.listen(host, port, channel_patterns=["[a-z]*"]) + asyncio.run(server.serve(host, port, channel_patterns=["[a-z]*"])) if __name__ == "__main__": diff --git a/redis_websocket_api/handler.py b/redis_websocket_api/handler.py index 980ad64..aed984a 100644 --- a/redis_websocket_api/handler.py +++ b/redis_websocket_api/handler.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from collections import OrderedDict from logging import getLogger @@ -15,18 +17,26 @@ class WebsocketHandlerBase: """Define protocol for communication between web client and server.""" - read_timeout = NotImplemented allowed_commands = NotImplemented + consumer_task: asyncio.Task # for backwrds-compatibility, also in: + tasks: dict[str, asyncio.Task] # Self-made TaskGroup for old Python + def __init__(self, redis, websocket, read_timeout=None): + if read_timeout: + logger.warning( + "read_timeout is not used anymore because cleanup is triggered " + "immediately on connection loss" + ) + self.websocket = websocket self.redis = redis - self.read_timeout = read_timeout or self.read_timeout self.queue = asyncio.Queue() self.filters = OrderedDict() self.subscriptions = set() + self.tasks = {} async def _websocket_reader(self): try: @@ -62,7 +72,7 @@ def _apply_filters(self, message, exclude=()): return passes, data async def _queue_reader(self): - source, message = await asyncio.wait_for(self.queue.get(), self.read_timeout) + source, message = await self.queue.get() if source == "websocket": await self._handle_remote_message(message) else: @@ -131,35 +141,41 @@ def channel_is_allowed(self, channel_name): @classmethod async def create(cls, redis, websocket, read_timeout=None): """Create a handler instance setting up tasks and queues.""" - self = cls(redis, websocket, read_timeout=read_timeout) - self.consumer_task = asyncio.ensure_future(self._websocket_reader()) + + if read_timeout: + logger.warning( + "read_timeout is not used anymore because cleanup is triggered " + "immediately on connection loss" + ) + + self = cls(redis, websocket) + self.consumer_task = asyncio.create_task(self._websocket_reader()) + self.tasks["consumer_task"] = self.consumer_task return self async def listen(self): - """Read and handle messages from internal message queue. - - This coroutine blocks for up to self.read_timeout seconds. - """ + """Read and handle messages from internal message queue""" await self._send("websocket", {"status": "open"}) - while not self.consumer_task.done(): - try: - await self._queue_reader() - except asyncio.TimeoutError: - if not self.websocket.open: - raise + while self.websocket.open: + queue_reader_task = asyncio.create_task(self._queue_reader()) + self.tasks["queue_reader"] = queue_reader_task + await queue_reader_task async def close(self): """Close all connections and cancel all tasks.""" - asyncio.wait_for( # attempt to say goodbye to the client - self.websocket.close(), self.read_timeout - ) - self.consumer_task.cancel() + if self.websocket.open: + await self.websocket.close() + for task in self.tasks.values(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass class WebsocketHandler(WebsocketHandlerBase, CommandsMixin): """Provides a Redis proxy to predefined channels""" - read_timeout = 30 allowed_commands = "SUB", "DEL", "PING", "GET" channel_names = set() diff --git a/redis_websocket_api/protocol.py b/redis_websocket_api/protocol.py index 2c21050..6c0a8b1 100644 --- a/redis_websocket_api/protocol.py +++ b/redis_websocket_api/protocol.py @@ -16,7 +16,7 @@ class CommandsMixin: The commend name is translated to a method name like this: - "_handle_{name}_command".format(name=a_tralis_protocol_command.lower()) + "_handle_{name}_command".format(name=a_custom_protocol_command.lower()) """ async def _handle_del_command(self, channel_name): diff --git a/redis_websocket_api/server.py b/redis_websocket_api/server.py index 6fe106a..958efbb 100644 --- a/redis_websocket_api/server.py +++ b/redis_websocket_api/server.py @@ -1,7 +1,8 @@ import asyncio from logging import getLogger +import websockets -from websockets import serve +from websockets.legacy.server import WebSocketServerProtocol, serve from redis_websocket_api.handler import WebsocketHandler, WebsocketHandlerBase from redis_websocket_api.protocol import Message @@ -26,13 +27,16 @@ def __init__( """Set default values for new WebsocketHandlers. :param redis: aioredis.client.Redis instance - :param read_timeout: Timeout, after which the websocket connection is - checked and kept if still open (does not cancel an open connection) :param keep_alive_timeout: Time after which the server cancels the handler task (independently of it's internal state) """ - self.read_timeout = read_timeout + if read_timeout: + logger.warning( + "read_timeout is not used anymore because cleanup is triggered " + "immediately on connection loss" + ) + self.keep_alive_timeout = keep_alive_timeout self.handlers = {} self.redis = redis @@ -44,19 +48,42 @@ def __init__( "handler_class has to be a subclass of WebsocketHandlerBase" ) - async def websocket_handler(self, websocket, path): + async def websocket_handler(self, websocket: WebSocketServerProtocol, path) -> None: """Return handler for a single websocket connection.""" logger.info("Client %s connected", websocket.remote_address) - handler = await self.handler_class.create( - self.redis, websocket, read_timeout=self.read_timeout, - ) + handler = await self.handler_class.create(self.redis, websocket) self.handlers[websocket.remote_address] = handler + handler_listen_task = asyncio.create_task( + asyncio.wait_for(handler.listen(), self.keep_alive_timeout) + ) + wait_closed_task = asyncio.create_task(websocket.wait_closed()) + await asyncio.wait( + { + handler_listen_task, + wait_closed_task, + }, + return_when="FIRST_COMPLETED", + ) try: - await asyncio.wait_for(handler.listen(), self.keep_alive_timeout) - finally: del self.handlers[websocket.remote_address] + handler_listen_task.cancel() await handler.close() + try: + await handler_listen_task + except (asyncio.CancelledError, asyncio.TimeoutError) as e: + logger.debug( + "Closing connection for %s: %s", websocket.remote_address, e + ) + wait_closed_task.cancel() + try: + await wait_closed_task + except asyncio.CancelledError: + logger.warning( + "Websocket connection for %s seems to be open after cleanup", + websocket.remote_address, + ) + finally: logger.info("Client %s removed", websocket.remote_address) async def redis_subscribe(self, p, channel_names=(), channel_patterns=()): @@ -94,13 +121,29 @@ async def redis_reader(self, channel_names=(), channel_patterns=()): Message(source=channel_name, content=message["data"]) ) - psub.close() - - def listen(self, host, port, channel_names=(), channel_patterns=(), loop=None): + async def serve( + self, + host, + port, + channel_names=(), + channel_patterns=(), + **websockets_serve_kwargs + ): """Listen for websocket connections and manage redis subscriptions.""" - loop = loop or asyncio.get_event_loop() - start_server = serve(self.websocket_handler, host, port) - loop.run_until_complete(start_server) - logger.info("Listening on %s:%s...", host, port) - loop.run_until_complete(self.redis_reader(channel_names, channel_patterns)) + async with serve( + ws_handler=self.websocket_handler, + host=host, + port=port, + **websockets_serve_kwargs, + ): + logger.info("Listening on %s:%s...", host, port) + await self.redis_reader(channel_names, channel_patterns) + + def listen(self, host, port, channel_names=(), channel_patterns=(), loop=None): + logger.warning("`listen` is deprecated, use `serve` instead") + serve_coro = self.serve(host, port, channel_names, channel_patterns) + if loop: + loop.run_until_complete(serve_coro) + else: + asyncio.run(serve_coro) diff --git a/tests/conftest.py b/tests/conftest.py index 3d645ab..60ef29c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import MagicMock import pytest @@ -37,6 +38,7 @@ def __getattr__(self, name): def websocket(): websocket = AsyncMagicMock() websocket.remote_address = ("EGG", 2000) + websocket.handler_task = asyncio.sleep(0) return websocket diff --git a/tests/test_server.py b/tests/test_server.py index 5485c10..8dd8f4b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -9,8 +9,7 @@ def test_websocket_handler_creation(loop, server, websocket): server.handlers = MagicMock() websocket.await_recv.side_effect = exceptions.ConnectionClosed(1001, "foo") - with pytest.warns(RuntimeWarning): - asyncio.run(server.websocket_handler(websocket, "/foo")) + asyncio.run(server.websocket_handler(websocket, "/foo")) assert websocket.await_recv.call_count == 1 assert websocket.await_send.call_count == 1