diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2750cf916..a0559b2dd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,6 +10,15 @@ jobs: tests: name: Python ${{ matrix.python-version }} runs-on: ubuntu-latest + services: + postgres: + env: + POSTGRES_USER: channels + POSTGRES_PASSWORD: channels + POSTGRES_DB: channels + image: postgres:14-alpine + ports: ["5432:5432"] + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 strategy: fail-fast: false matrix: diff --git a/channels/consumer.py b/channels/consumer.py index fc065432b..fbc9c43d7 100644 --- a/channels/consumer.py +++ b/channels/consumer.py @@ -3,9 +3,10 @@ from asgiref.sync import async_to_sync from . import DEFAULT_CHANNEL_LAYER -from .db import aclose_old_connections, database_sync_to_async +from .db import database_sync_to_async from .exceptions import StopConsumer from .layers import get_channel_layer +from .signals import consumer_started, consumer_terminated from .utils import await_many_dispatch @@ -62,7 +63,7 @@ async def __call__(self, scope, receive, send): await await_many_dispatch([receive], self.dispatch) except StopConsumer: # Exit cleanly - pass + await consumer_terminated.asend(sender=self.__class__) async def dispatch(self, message): """ @@ -70,7 +71,7 @@ async def dispatch(self, message): """ handler = getattr(self, get_handler_name(message), None) if handler: - await aclose_old_connections() + await consumer_started.asend(sender=self.__class__) await handler(message) else: raise ValueError("No handler for message type %s" % message["type"]) diff --git a/channels/db.py b/channels/db.py index 2961b5cdb..c07233b10 100644 --- a/channels/db.py +++ b/channels/db.py @@ -1,5 +1,6 @@ from asgiref.sync import SyncToAsync, sync_to_async from django.db import close_old_connections +from .signals import consumer_started, consumer_terminated, db_sync_to_async class DatabaseSyncToAsync(SyncToAsync): @@ -8,16 +9,21 @@ class DatabaseSyncToAsync(SyncToAsync): """ def thread_handler(self, loop, *args, **kwargs): - close_old_connections() + db_sync_to_async.send(sender=self.__class__, start=True) try: return super().thread_handler(loop, *args, **kwargs) finally: - close_old_connections() + db_sync_to_async.send(sender=self.__class__, start=False) # The class is TitleCased, but we want to encourage use as a callable/decorator database_sync_to_async = DatabaseSyncToAsync -async def aclose_old_connections(): +async def aclose_old_connections(**kwargs): return await sync_to_async(close_old_connections)() + + +consumer_started.connect(aclose_old_connections) +consumer_terminated.connect(aclose_old_connections) +db_sync_to_async.connect(close_old_connections) diff --git a/channels/generic/http.py b/channels/generic/http.py index 0d043cc3a..909e85704 100644 --- a/channels/generic/http.py +++ b/channels/generic/http.py @@ -1,6 +1,5 @@ from channels.consumer import AsyncConsumer -from ..db import aclose_old_connections from ..exceptions import StopConsumer @@ -89,5 +88,4 @@ async def http_disconnect(self, message): Let the user do their cleanup and close the consumer. """ await self.disconnect() - await aclose_old_connections() raise StopConsumer() diff --git a/channels/generic/websocket.py b/channels/generic/websocket.py index b4d99119c..899ac8915 100644 --- a/channels/generic/websocket.py +++ b/channels/generic/websocket.py @@ -3,7 +3,6 @@ from asgiref.sync import async_to_sync from ..consumer import AsyncConsumer, SyncConsumer -from ..db import aclose_old_connections from ..exceptions import ( AcceptConnection, DenyConnection, @@ -248,7 +247,6 @@ async def websocket_disconnect(self, message): "BACKEND is unconfigured or doesn't support groups" ) await self.disconnect(message["code"]) - await aclose_old_connections() raise StopConsumer() async def disconnect(self, code): diff --git a/channels/signals.py b/channels/signals.py new file mode 100644 index 000000000..96c613778 --- /dev/null +++ b/channels/signals.py @@ -0,0 +1,5 @@ +from django.dispatch import Signal + +consumer_started = Signal() +consumer_terminated = Signal() +db_sync_to_async = Signal() diff --git a/channels/testing/__init__.py b/channels/testing/__init__.py index d7dee3ef7..991cf12c6 100644 --- a/channels/testing/__init__.py +++ b/channels/testing/__init__.py @@ -1,4 +1,7 @@ -from .application import ApplicationCommunicator # noqa +from channels.db import aclose_old_connections +from channels.signals import consumer_started, consumer_terminated, db_sync_to_async +from django.db import close_old_connections +from asgiref.testing import ApplicationCommunicator # noqa from .http import HttpCommunicator # noqa from .live import ChannelsLiveServerTestCase # noqa from .websocket import WebsocketCommunicator # noqa @@ -8,4 +11,27 @@ "HttpCommunicator", "ChannelsLiveServerTestCase", "WebsocketCommunicator", + "ConsumerTestMixin", ] + + +class ConsumerTestMixin: + """ + Mixin to be applied to Django `TestCase` or `TransactionTestCase` to ensure + that database connections are not closed by consumers during test execution. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + consumer_started.disconnect(aclose_old_connections) + consumer_terminated.disconnect(aclose_old_connections) + db_sync_to_async.disconnect(close_old_connections) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + consumer_started.connect(aclose_old_connections) + consumer_terminated.connect(aclose_old_connections) + db_sync_to_async.connect(close_old_connections) + diff --git a/channels/testing/application.py b/channels/testing/application.py deleted file mode 100644 index 2003178c1..000000000 --- a/channels/testing/application.py +++ /dev/null @@ -1,17 +0,0 @@ -from unittest import mock - -from asgiref.testing import ApplicationCommunicator as BaseApplicationCommunicator - - -def no_op(): - pass - - -class ApplicationCommunicator(BaseApplicationCommunicator): - async def send_input(self, message): - with mock.patch("channels.db.close_old_connections", no_op): - return await super().send_input(message) - - async def receive_output(self, timeout=1): - with mock.patch("channels.db.close_old_connections", no_op): - return await super().receive_output(timeout) diff --git a/channels/testing/http.py b/channels/testing/http.py index 8130265a0..6b1514ca7 100644 --- a/channels/testing/http.py +++ b/channels/testing/http.py @@ -1,6 +1,6 @@ from urllib.parse import unquote, urlparse -from channels.testing.application import ApplicationCommunicator +from asgiref.testing import ApplicationCommunicator class HttpCommunicator(ApplicationCommunicator): diff --git a/channels/testing/websocket.py b/channels/testing/websocket.py index 24e58d369..57ea4a653 100644 --- a/channels/testing/websocket.py +++ b/channels/testing/websocket.py @@ -1,7 +1,7 @@ import json from urllib.parse import unquote, urlparse -from channels.testing.application import ApplicationCommunicator +from asgiref.testing import ApplicationCommunicator class WebsocketCommunicator(ApplicationCommunicator): diff --git a/setup.cfg b/setup.cfg index 45fa26294..1a15c66e4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ tests = pytest-django pytest-asyncio selenium + psycopg daphne = daphne>=4.0.0 types = diff --git a/tests/sample_project/config/settings.py b/tests/sample_project/config/settings.py index 610572173..a472fbd23 100644 --- a/tests/sample_project/config/settings.py +++ b/tests/sample_project/config/settings.py @@ -64,6 +64,14 @@ # Override Django’s default behaviour of using an in-memory database # in tests for SQLite, since that avoids connection.close() working. "TEST": {"NAME": "test_db.sqlite3"}, + }, + "other": { + "ENGINE": "django.db.backends.postgresql", + "NAME": "channels", + "USER": "channels", + "PASSWORD": "channels", + "HOST": "localhost", + "PORT": 5432, } } diff --git a/tests/test_database.py b/tests/test_database.py index 3faf05b5b..4edd60311 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -4,14 +4,18 @@ from channels.db import database_sync_to_async from channels.generic.http import AsyncHttpConsumer from channels.generic.websocket import AsyncWebsocketConsumer -from channels.testing import HttpCommunicator, WebsocketCommunicator +from channels.testing import ConsumerTestMixin, HttpCommunicator, WebsocketCommunicator @database_sync_to_async def basic_query(): with db.connections["default"].cursor() as cursor: - cursor.execute("SELECT 1234") - return cursor.fetchone()[0] + cursor.execute("SELECT 1234;") + cursor.fetchone()[0] + + with db.connections["other"].cursor() as cursor: + cursor.execute("SELECT 1234;") + cursor.fetchone()[0] class WebsocketConsumer(AsyncWebsocketConsumer): @@ -30,7 +34,9 @@ async def handle(self, body): ) -class ConnectionClosingTests(TestCase): +class ConnectionClosingTests(ConsumerTestMixin, TestCase): + databases = {'default', 'other'} + async def test_websocket(self): self.assertNotRegex( db.connections["default"].settings_dict.get("NAME"),