diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d87e6ba1c3..f02ebcc43b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -81,6 +81,12 @@ using `invoke standalone-tests`; similarly, RedisCluster tests can be run by usi Each run of tests starts and stops the various dockers required. Sometimes things get stuck, an `invoke clean` can help. +## Linting and Formatting + +Call `invoke linters` to run linters without also running tests. This command will +only report issues, not fix them automatically. Run `invoke formatters` to +automatically format your code. + ## Documentation If relevant, update the code documentation, via docstrings, or in `/docs`. diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 4254441073..bdfb739788 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -362,6 +362,12 @@ def __init__( # on a set of redis commands self._single_conn_lock = asyncio.Lock() + # When used as an async context manager, we need to increment and decrement + # a usage counter so that we can close the connection pool when no one is + # using the client. + self._usage_counter = 0 + self._usage_lock = asyncio.Lock() + def __repr__(self): return ( f"<{self.__class__.__module__}.{self.__class__.__name__}" @@ -562,10 +568,40 @@ def client(self) -> "Redis": ) async def __aenter__(self: _RedisT) -> _RedisT: - return await self.initialize() + """ + Async context manager entry. Increments a usage counter so that the + connection pool is only closed (via aclose()) when no context is using + the client. + """ + async with self._usage_lock: + self._usage_counter += 1 + try: + # Initialize the client (i.e. establish connection, etc.) + return await self.initialize() + except Exception: + # If initialization fails, decrement the counter to keep it in sync + async with self._usage_lock: + self._usage_counter -= 1 + raise + + async def _decrement_usage(self) -> int: + """ + Helper coroutine to decrement the usage counter while holding the lock. + Returns the new value of the usage counter. + """ + async with self._usage_lock: + self._usage_counter -= 1 + return self._usage_counter async def __aexit__(self, exc_type, exc_value, traceback): - await self.aclose() + """ + Async context manager exit. Decrements a usage counter. If this is the + last exit (counter becomes zero), the client closes its connection pool. + """ + current_usage = await asyncio.shield(self._decrement_usage()) + if current_usage == 0: + # This was the last active context, so disconnect the pool. + await asyncio.shield(self.aclose()) _DEL_MESSAGE = "Unclosed Redis client" diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 87a2c16afa..b13a9ca6f5 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -379,6 +379,12 @@ def __init__( self._initialize = True self._lock: Optional[asyncio.Lock] = None + # When used as an async context manager, we need to increment and decrement + # a usage counter so that we can close the connection pool when no one is + # using the client. + self._usage_counter = 0 + self._usage_lock = asyncio.Lock() + async def initialize(self) -> "RedisCluster": """Get all nodes from startup nodes & creates connections if not initialized.""" if self._initialize: @@ -415,10 +421,40 @@ async def close(self) -> None: await self.aclose() async def __aenter__(self) -> "RedisCluster": - return await self.initialize() + """ + Async context manager entry. Increments a usage counter so that the + connection pool is only closed (via aclose()) when no context is using + the client. + """ + async with self._usage_lock: + self._usage_counter += 1 + try: + # Initialize the client (i.e. establish connection, etc.) + return await self.initialize() + except Exception: + # If initialization fails, decrement the counter to keep it in sync + async with self._usage_lock: + self._usage_counter -= 1 + raise - async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: - await self.aclose() + async def _decrement_usage(self) -> int: + """ + Helper coroutine to decrement the usage counter while holding the lock. + Returns the new value of the usage counter. + """ + async with self._usage_lock: + self._usage_counter -= 1 + return self._usage_counter + + async def __aexit__(self, exc_type, exc_value, traceback): + """ + Async context manager exit. Decrements a usage counter. If this is the + last exit (counter becomes zero), the client closes its connection pool. + """ + current_usage = await asyncio.shield(self._decrement_usage()) + if current_usage == 0: + # This was the last active context, so disconnect the pool. + await asyncio.shield(self.aclose()) def __await__(self) -> Generator[Any, None, "RedisCluster"]: return self.initialize().__await__() diff --git a/tasks.py b/tasks.py index f7b728aed4..f318da5608 100644 --- a/tasks.py +++ b/tasks.py @@ -33,6 +33,12 @@ def linters(c): run("vulture redis whitelist.py --min-confidence 80") run("flynt --fail-on-change --dry-run tests redis") +@task +def formatters(c): + """Format code""" + run("black --target-version py37 tests redis") + run("isort tests redis") + @task def all_tests(c): diff --git a/tests/test_asyncio/test_usage_counter.py b/tests/test_asyncio/test_usage_counter.py new file mode 100644 index 0000000000..566ec6b4d3 --- /dev/null +++ b/tests/test_asyncio/test_usage_counter.py @@ -0,0 +1,16 @@ +import asyncio + +import pytest + + +@pytest.mark.asyncio +async def test_usage_counter(r): + async def dummy_task(): + async with r: + await asyncio.sleep(0.01) + + tasks = [dummy_task() for _ in range(20)] + await asyncio.gather(*tasks) + + # After all tasks have completed, the usage counter should be back to zero. + assert r._usage_counter == 0