diff --git a/docs/multi_database.rst b/docs/multi_database.rst index 97becef27c..88ec21a2fb 100644 --- a/docs/multi_database.rst +++ b/docs/multi_database.rst @@ -20,13 +20,16 @@ Key concepts - Health checks: A set of checks determines whether a database is healthy in proactive manner. By default, an "PING" check runs against the database (all cluster nodes must - pass for a cluster). You can add custom checks. A Redis Enterprise specific + pass for a cluster). You can provide your own set of health checks or add an + additional health check on top of the default one. A Redis Enterprise specific "lag-aware" health check is also available. - Failure detector: A detector observes command failures over a moving window (reactive monitoring). You can specify an exact number of failures and failures rate to have more fine-grain tuned configuration of triggering fail over based on organic traffic. + You can provide your own set of custom failure detectors or add an additional + detector on top of the default one. - Failover strategy: The default strategy is based on statically configured weights. It prefers the highest weighted healthy database. diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index d3e3f241d2..f6385e46ea 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -28,19 +28,21 @@ class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): def __init__(self, config: MultiDbConfig): self._databases = config.databases() - self._health_checks = config.default_health_checks() - - if config.health_checks is not None: - self._health_checks.extend(config.health_checks) + self._health_checks = ( + config.default_health_checks() + if not config.health_checks + else config.health_checks + ) self._health_check_interval = config.health_check_interval self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value( config.health_check_probes, config.health_check_delay ) - self._failure_detectors = config.default_failure_detectors() - - if config.failure_detectors is not None: - self._failure_detectors.extend(config.failure_detectors) + self._failure_detectors = ( + config.default_failure_detectors() + if not config.failure_detectors + else config.failure_detectors + ) self._failover_strategy = ( config.default_failover_strategy() diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index c09b8b9969..51c50c223a 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -303,7 +303,7 @@ async def _check_active_database(self): self._active_database is None or self._active_database.circuit.state != CBState.CLOSED or ( - self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + self._auto_fallback_interval > 0 and self._next_fallback_attempt <= datetime.now() ) ): diff --git a/redis/multidb/client.py b/redis/multidb/client.py index c46a53af32..272064453a 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -29,19 +29,20 @@ class MultiDBClient(RedisModuleCommands, CoreCommands): def __init__(self, config: MultiDbConfig): self._databases = config.databases() - self._health_checks = config.default_health_checks() - - if config.health_checks is not None: - self._health_checks.extend(config.health_checks) - + self._health_checks = ( + config.default_health_checks() + if not config.health_checks + else config.health_checks + ) self._health_check_interval = config.health_check_interval self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value( config.health_check_probes, config.health_check_probes_delay ) - self._failure_detectors = config.default_failure_detectors() - - if config.failure_detectors is not None: - self._failure_detectors.extend(config.failure_detectors) + self._failure_detectors = ( + config.default_failure_detectors() + if not config.failure_detectors + else config.failure_detectors + ) self._failover_strategy = ( config.default_failover_strategy() diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index f8e6171bc8..202bf723fe 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -55,7 +55,7 @@ def auto_fallback_interval(self, auto_fallback_interval: int) -> None: self._auto_fallback_interval = auto_fallback_interval def _schedule_next_fallback(self) -> None: - if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: + if self._auto_fallback_interval < 0: return self._next_fallback_attempt = datetime.now() + timedelta( @@ -321,7 +321,7 @@ def _check_active_database(self): self._active_database is None or self._active_database.circuit.state != CBState.CLOSED or ( - self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + self._auto_fallback_interval > 0 and self._next_fallback_attempt <= datetime.now() ) ): diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index e912a00466..e028655c38 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -34,12 +34,10 @@ async def test_execute_command_against_correct_db_on_successful_initialization( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command = AsyncMock(return_value="OK1") @@ -71,12 +69,10 @@ async def test_execute_command_against_correct_db_and_closed_circuit( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command = AsyncMock(return_value="OK1") @@ -187,14 +183,10 @@ async def mock_check_health(database): return True mock_hc.check_health.side_effect = mock_check_health + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, - "default_health_checks", - return_value=[mock_hc], - ), ): mock_db.client.execute_command.return_value = "OK" mock_db1.client.execute_command.return_value = "OK1" @@ -264,14 +256,10 @@ async def mock_check_health(database): return True mock_hc.check_health.side_effect = mock_check_health + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, - "default_health_checks", - return_value=[mock_hc], - ), ): mock_db.client.execute_command.return_value = "OK" mock_db1.client.execute_command.return_value = "OK1" @@ -287,6 +275,60 @@ async def mock_check_health(database): await asyncio.sleep(0.5) assert await client.set("key", "value") == "OK1" + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_do_not_auto_fallback_to_highest_weight_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + db1_counter = 0 + error_event = asyncio.Event() + check = False + + async def mock_check_health(database): + nonlocal db1_counter, check + + if database == mock_db1 and not check: + db1_counter += 1 + + if db1_counter > 1: + error_event.set() + check = True + return False + + return True + + mock_hc.check_health.side_effect = mock_check_health + mock_multi_db_config.health_checks = [mock_hc] + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + ): + mock_db.client.execute_command.return_value = "OK" + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = -1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() + + async with MultiDBClient(mock_multi_db_config) as client: + assert await client.set("key", "value") == "OK1" + await error_event.wait() + assert await client.set("key", "value") == "OK2" + await asyncio.sleep(0.5) + assert await client.set("key", "value") == "OK2" + @pytest.mark.asyncio @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -304,12 +346,10 @@ async def test_execute_command_throws_exception_on_failed_initialization( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_hc.check_health.return_value = False @@ -340,12 +380,10 @@ async def test_add_database_throws_exception_on_same_database( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_hc.check_health.return_value = False @@ -373,12 +411,10 @@ async def test_add_database_makes_new_database_active( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" mock_db2.client.execute_command.return_value = "OK2" @@ -413,12 +449,10 @@ async def test_remove_highest_weighted_database( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" mock_db2.client.execute_command.return_value = "OK2" @@ -451,12 +485,10 @@ async def test_update_database_weight_to_be_highest( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" mock_db2.client.execute_command.return_value = "OK2" @@ -491,12 +523,10 @@ async def test_add_new_failure_detector( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" mock_multi_db_config.event_dispatcher = EventDispatcher() @@ -552,12 +582,10 @@ async def test_add_new_health_check( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" @@ -594,12 +622,10 @@ async def test_set_active_database( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" mock_db.client.execute_command.return_value = "OK" diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index cbc81b15ed..b342b4b91b 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -34,12 +34,10 @@ def test_execute_command_against_correct_db_on_successful_initialization( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" @@ -70,12 +68,10 @@ def test_execute_command_against_correct_db_and_closed_circuit( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" @@ -185,14 +181,10 @@ def mock_check_health(database): return True mock_hc.check_health.side_effect = mock_check_health + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, - "default_health_checks", - return_value=[mock_hc], - ), ): mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() @@ -261,14 +253,10 @@ def mock_check_health(database): return True mock_hc.check_health.side_effect = mock_check_health + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, - "default_health_checks", - return_value=[mock_hc], - ), ): mock_db.client.execute_command.return_value = "OK" mock_db1.client.execute_command.return_value = "OK1" @@ -284,6 +272,59 @@ def mock_check_health(database): sleep(0.5) assert client.set("key", "value") == "OK1" + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_do_not_auto_fallback_to_highest_weight_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + db1_counter = 0 + error_event = threading.Event() + check = False + + def mock_check_health(database): + nonlocal db1_counter, check + + if database == mock_db1 and not check: + db1_counter += 1 + + if db1_counter > 1: + error_event.set() + check = True + return False + + return True + + mock_hc.check_health.side_effect = mock_check_health + mock_multi_db_config.health_checks = [mock_hc] + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + ): + mock_db.client.execute_command.return_value = "OK" + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = -1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() + + client = MultiDBClient(mock_multi_db_config) + assert client.set("key", "value") == "OK1" + error_event.wait(timeout=0.5) + assert client.set("key", "value") == "OK2" + sleep(0.5) + assert client.set("key", "value") == "OK2" + @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", [ @@ -300,12 +341,10 @@ def test_execute_command_throws_exception_on_failed_initialization( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_hc.check_health.return_value = False @@ -336,12 +375,10 @@ def test_add_database_throws_exception_on_same_database( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_hc.check_health.return_value = False @@ -368,12 +405,10 @@ def test_add_database_makes_new_database_active( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" mock_db2.client.execute_command.return_value = "OK2" @@ -407,12 +442,10 @@ def test_remove_highest_weighted_database( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" mock_db2.client.execute_command.return_value = "OK2" @@ -445,12 +478,10 @@ def test_update_database_weight_to_be_highest( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" mock_db2.client.execute_command.return_value = "OK2" @@ -484,12 +515,10 @@ def test_add_new_failure_detector( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" mock_multi_db_config.event_dispatcher = EventDispatcher() @@ -540,12 +569,10 @@ def test_add_new_health_check( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" @@ -581,12 +608,10 @@ def test_set_active_database( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_multi_db_config.health_checks = [mock_hc] with ( patch.object(mock_multi_db_config, "databases", return_value=databases), - patch.object( - mock_multi_db_config, "default_health_checks", return_value=[mock_hc] - ), ): mock_db1.client.execute_command.return_value = "OK1" mock_db.client.execute_command.return_value = "OK"