From 5a8a46b0145d9f66db85cb9c20e3893063621667 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 7 Nov 2025 09:56:38 +0200 Subject: [PATCH] Fixing flaky tests - part 2 --- dev_requirements.txt | 1 + redis/multidb/client.py | 8 +- .../test_asyncio/test_multidb/test_client.py | 260 +++++++------ .../test_multidb/test_pipeline.py | 155 ++++---- tests/test_auth/test_token_manager.py | 25 +- tests/test_background.py | 16 +- tests/test_multidb/test_client.py | 356 +++++++++++------- tests/test_multidb/test_pipeline.py | 222 ++++++----- 8 files changed, 585 insertions(+), 458 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index f201098a14..0f6be6e848 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -12,6 +12,7 @@ pytest==8.3.4 ; platform_python_implementation == "PyPy" pytest-asyncio>=0.23.0 pytest-asyncio==1.1.0 ; platform_python_implementation == "PyPy" pytest-cov +coverage<7.11.1 pytest-cov==6.0.0 ; platform_python_implementation == "PyPy" coverage==7.6.12 ; platform_python_implementation == "PyPy" pytest-profiling==1.8.1 diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 272064453a..af38b1e0c4 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -301,7 +301,13 @@ def _on_circuit_state_change_callback( ) def close(self): - self.command_executor.active_database.client.close() + """ + Closes the client and all its resources. + """ + if self._bg_scheduler: + self._bg_scheduler.stop() + if self.command_executor.active_database: + self.command_executor.active_database.client.close() def _half_open_circuit(circuit: CircuitBreaker): diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py index 537fdb0c82..4580c897a8 100644 --- a/tests/test_asyncio/test_multidb/test_client.py +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -41,14 +41,16 @@ async def test_execute_command_against_correct_db_on_successful_initialization( mock_hc.check_health.return_value = True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) + assert await client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) == 9 - assert mock_db.circuit.state == CBState.CLOSED - assert mock_db1.circuit.state == CBState.CLOSED - assert mock_db2.circuit.state == CBState.CLOSED + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.CLOSED @pytest.mark.asyncio @pytest.mark.parametrize( @@ -66,31 +68,44 @@ async def test_execute_command_against_correct_db_on_successful_initialization( async def test_execute_command_against_correct_db_and_closed_circuit( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): + """ + Validates that commands are executed against the correct + database when one database becomes unhealthy during initialization. + Ensures the client selects the highest-weighted + healthy database (mock_db1) and executes commands against it + with a CLOSED circuit. + """ databases = create_weighted_list(mock_db, mock_db1, mock_db2) mock_multi_db_config.health_checks = [mock_hc] with patch.object(mock_multi_db_config, "databases", return_value=databases): + mock_db.client.execute_command = AsyncMock( + return_value="NOT_OK-->Response from unexpected db - mock_db" + ) mock_db1.client.execute_command = AsyncMock(return_value="OK1") + mock_db2.client.execute_command = AsyncMock( + return_value="NOT_OK-->Response from unexpected db - mock_db2" + ) - mock_hc.check_health.side_effect = [ - False, - True, - True, - True, - True, - True, - True, - ] + async def mock_check_health(database): + if database == mock_db2: + return False + else: + return True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - result = await client.set("key", "value") - assert result == "OK1" - assert mock_hc.check_health.call_count == 7 + mock_hc.check_health.side_effect = mock_check_health - assert mock_db.circuit.state == CBState.CLOSED - assert mock_db1.circuit.state == CBState.CLOSED - assert mock_db2.circuit.state == CBState.OPEN + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) + result = await client.set("key", "value") + assert result == "OK1" + assert len(mock_hc.check_health.call_args_list) >= 7 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN @pytest.mark.asyncio @pytest.mark.parametrize( @@ -189,40 +204,40 @@ async def mock_check_health(database): mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() - client = MultiDBClient(mock_multi_db_config) - assert await client.set("key", "value") == "OK1" + async with MultiDBClient(mock_multi_db_config) as client: + assert await client.set("key", "value") == "OK1" - # Wait for mock_db1 to become unhealthy - assert await db1_became_unhealthy.wait(), ( - "Timeout waiting for mock_db1 to become unhealthy" - ) + # Wait for mock_db1 to become unhealthy + assert await db1_became_unhealthy.wait(), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) - await asyncio.sleep(0.01) + await asyncio.sleep(0.01) - assert await client.set("key", "value") == "OK2" + assert await client.set("key", "value") == "OK2" - # Wait for mock_db2 to become unhealthy - assert await db2_became_unhealthy.wait(), ( - "Timeout waiting for mock_db2 to become unhealthy" - ) + # Wait for mock_db2 to become unhealthy + assert await db2_became_unhealthy.wait(), ( + "Timeout waiting for mock_db2 to become unhealthy" + ) - # Wait for circuit breaker state to actually reflect the unhealthy status - # (instead of just sleeping) - max_retries = 20 - for _ in range(max_retries): - if cb2.state == CBState.OPEN: # Circuit is open (unhealthy) - break - await asyncio.sleep(0.01) + # Wait for circuit breaker state to actually reflect the unhealthy status + # (instead of just sleeping) + max_retries = 20 + for _ in range(max_retries): + if cb2.state == CBState.OPEN: # Circuit is open (unhealthy) + break + await asyncio.sleep(0.01) - assert await client.set("key", "value") == "OK" + assert await client.set("key", "value") == "OK" - # Wait for mock_db to become unhealthy - assert await db_became_unhealthy.wait(), ( - "Timeout waiting for mock_db to become unhealthy" - ) - await asyncio.sleep(0.01) + # Wait for mock_db to become unhealthy + assert await db_became_unhealthy.wait(), ( + "Timeout waiting for mock_db to become unhealthy" + ) + await asyncio.sleep(0.01) - assert await client.set("key", "value") == "OK1" + assert await client.set("key", "value") == "OK1" @pytest.mark.asyncio @pytest.mark.parametrize( @@ -375,7 +390,7 @@ async def test_execute_command_throws_exception_on_failed_initialization( match="Initial connection failed - no active database found", ): await client.set("key", "value") - assert mock_hc.check_health.call_count == 9 + assert len(mock_hc.check_health.call_args_list) == 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -404,7 +419,6 @@ async def test_add_database_throws_exception_on_same_database( with pytest.raises(ValueError, match="Given database already exists"): await client.add_database(mock_db) - assert mock_hc.check_health.call_count == 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -431,16 +445,18 @@ async def test_add_database_makes_new_database_active( mock_hc.check_health.return_value = True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - assert await client.set("key", "value") == "OK2" - assert mock_hc.check_health.call_count == 6 + assert await client.set("key", "value") == "OK2" + assert len(mock_hc.check_health.call_args_list) == 6 - await client.add_database(mock_db1) - assert mock_hc.check_health.call_count == 9 + await client.add_database(mock_db1) + assert len(mock_hc.check_health.call_args_list) == 9 - assert await client.set("key", "value") == "OK1" + assert await client.set("key", "value") == "OK1" @pytest.mark.asyncio @pytest.mark.parametrize( @@ -467,14 +483,16 @@ async def test_remove_highest_weighted_database( mock_hc.check_health.return_value = True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - assert await client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 + assert await client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) == 9 - await client.remove_database(mock_db1) - assert await client.set("key", "value") == "OK2" + await client.remove_database(mock_db1) + assert await client.set("key", "value") == "OK2" @pytest.mark.asyncio @pytest.mark.parametrize( @@ -501,16 +519,18 @@ async def test_update_database_weight_to_be_highest( mock_hc.check_health.return_value = True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - assert await client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 + assert await client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) == 9 - await client.update_database_weight(mock_db2, 0.8) - assert mock_db2.weight == 0.8 + await client.update_database_weight(mock_db2, 0.8) + assert mock_db2.weight == 0.8 - assert await client.set("key", "value") == "OK2" + assert await client.set("key", "value") == "OK2" @pytest.mark.asyncio @pytest.mark.parametrize( @@ -544,30 +564,32 @@ async def test_add_new_failure_detector( mock_hc.check_health.return_value = True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 - - # Simulate failing command events that lead to a failure detection - for _ in range(5): - await mock_multi_db_config.event_dispatcher.dispatch_async( - command_fail_event + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 ) + assert await client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) == 9 - assert mock_fd.register_failure.call_count == 5 + # Simulate failing command events that lead to a failure detection + for _ in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async( + command_fail_event + ) - another_fd = Mock(spec=AsyncFailureDetector) - client.add_failure_detector(another_fd) + assert mock_fd.register_failure.call_count == 5 - # Simulate failing command events that lead to a failure detection - for _ in range(5): - await mock_multi_db_config.event_dispatcher.dispatch_async( - command_fail_event - ) + another_fd = Mock(spec=AsyncFailureDetector) + client.add_failure_detector(another_fd) + + # Simulate failing command events that lead to a failure detection + for _ in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async( + command_fail_event + ) - assert mock_fd.register_failure.call_count == 10 - assert another_fd.register_failure.call_count == 5 + assert mock_fd.register_failure.call_count == 10 + assert another_fd.register_failure.call_count == 5 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -593,19 +615,21 @@ async def test_add_new_health_check( mock_hc.check_health.return_value = True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) + assert await client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) == 9 - another_hc = Mock(spec=HealthCheck) - another_hc.check_health.return_value = True + another_hc = Mock(spec=HealthCheck) + another_hc.check_health.return_value = True - await client.add_health_check(another_hc) - await client._check_db_health(mock_db1) + await client.add_health_check(another_hc) + await client._check_db_health(mock_db1) - assert mock_hc.check_health.call_count == 12 - assert another_hc.check_health.call_count == 3 + assert len(mock_hc.check_health.call_args_list) == 12 + assert len(another_hc.check_health.call_args_list) == 3 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -632,23 +656,25 @@ async def test_set_active_database( mock_hc.check_health.return_value = True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert await client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) + assert await client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) >= 9 - await client.set_active_database(mock_db) - assert await client.set("key", "value") == "OK" + await client.set_active_database(mock_db) + assert await client.set("key", "value") == "OK" - with pytest.raises( - ValueError, match="Given database is not a member of database list" - ): - await client.set_active_database(Mock(spec=AsyncDatabase)) + with pytest.raises( + ValueError, match="Given database is not a member of database list" + ): + await client.set_active_database(Mock(spec=AsyncDatabase)) - mock_hc.check_health.return_value = False + mock_hc.check_health.return_value = False - with pytest.raises( - NoValidDatabaseException, - match="Cannot set active database, database is unhealthy", - ): - await client.set_active_database(mock_db1) + with pytest.raises( + NoValidDatabaseException, + match="Cannot set active database, database is unhealthy", + ): + await client.set_active_database(mock_db1) diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py index 528f8e813b..d1918c68bc 100644 --- a/tests/test_asyncio/test_multidb/test_pipeline.py +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -50,15 +50,17 @@ async def test_executes_pipeline_against_correct_db( mock_hc.check_health.return_value = True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - pipe = client.pipeline() - pipe.set("key1", "value1") - pipe.get("key1") + pipe = client.pipeline() + pipe.set("key1", "value1") + pipe.get("key1") - assert await pipe.execute() == ["OK1", "value1"] - assert mock_hc.check_health.call_count == 9 + assert await pipe.execute() == ["OK1", "value1"] + assert len(mock_hc.check_health.call_args_list) >= 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -88,29 +90,28 @@ async def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.execute.return_value = ["OK1", "value1"] mock_db1.client.pipeline.return_value = pipe - mock_hc.check_health.side_effect = [ - False, - True, - True, - True, - True, - True, - True, - ] + async def mock_check_health(database): + if database == mock_db2: + return False + else: + return True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + mock_hc.check_health.side_effect = mock_check_health + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - async with client.pipeline() as pipe: - pipe.set("key1", "value1") - pipe.get("key1") + async with client.pipeline() as pipe: + pipe.set("key1", "value1") + pipe.get("key1") - assert await pipe.execute() == ["OK1", "value1"] - assert mock_hc.check_health.call_count == 7 + assert await pipe.execute() == ["OK1", "value1"] + assert len(mock_hc.check_health.call_args_list) >= 7 - assert mock_db.circuit.state == CBState.CLOSED - assert mock_db1.circuit.state == CBState.CLOSED - assert mock_db2.circuit.state == CBState.OPEN + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN @pytest.mark.asyncio @pytest.mark.parametrize( @@ -291,15 +292,19 @@ async def test_executes_transaction_against_correct_db( mock_hc.check_health.return_value = True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - async def callback(pipe: Pipeline): - pipe.set("key1", "value1") - pipe.get("key1") + async def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") - assert await client.transaction(callback) == ["OK1", "value1"] - assert mock_hc.check_health.call_count == 9 + assert await client.transaction(callback) == ["OK1", "value1"] + # if we assume at least 3 health checks have run per each database + # we should have at least 9 total calls + assert len(mock_hc.check_health.call_args_list) >= 9 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -327,29 +332,29 @@ async def test_execute_transaction_against_correct_db_and_closed_circuit( ): mock_db1.client.transaction.return_value = ["OK1", "value1"] - mock_hc.check_health.side_effect = [ - False, - True, - True, - True, - True, - True, - True, - ] + async def mock_check_health(database): + if database == mock_db2: + return False + else: + return True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + mock_hc.check_health.side_effect = mock_check_health - async def callback(pipe: Pipeline): - pipe.set("key1", "value1") - pipe.get("key1") + async with MultiDBClient(mock_multi_db_config) as client: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - assert await client.transaction(callback) == ["OK1", "value1"] - assert mock_hc.check_health.call_count == 7 + async def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") - assert mock_db.circuit.state == CBState.CLOSED - assert mock_db1.circuit.state == CBState.CLOSED - assert mock_db2.circuit.state == CBState.OPEN + assert await client.transaction(callback) == ["OK1", "value1"] + assert len(mock_hc.check_health.call_args_list) >= 7 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN @pytest.mark.asyncio @pytest.mark.parametrize( @@ -455,34 +460,34 @@ async def mock_check_health(database): mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() - client = MultiDBClient(mock_multi_db_config) + async with MultiDBClient(mock_multi_db_config) as client: - async def callback(pipe: Pipeline): - pipe.set("key1", "value1") - pipe.get("key1") + async def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") - assert await client.transaction(callback) == ["OK1", "value"] + assert await client.transaction(callback) == ["OK1", "value"] - # Wait for mock_db1 to become unhealthy - assert await db1_became_unhealthy.wait(), ( - "Timeout waiting for mock_db1 to become unhealthy" - ) - await asyncio.sleep(0.01) + # Wait for mock_db1 to become unhealthy + assert await db1_became_unhealthy.wait(), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) + await asyncio.sleep(0.01) - assert await client.transaction(callback) == ["OK2", "value"] + assert await client.transaction(callback) == ["OK2", "value"] - # Wait for mock_db2 to become unhealthy - assert await db2_became_unhealthy.wait(), ( - "Timeout waiting for mock_db1 to become unhealthy" - ) - await asyncio.sleep(0.01) + # Wait for mock_db2 to become unhealthy + assert await db2_became_unhealthy.wait(), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) + await asyncio.sleep(0.01) - assert await client.transaction(callback) == ["OK", "value"] + assert await client.transaction(callback) == ["OK", "value"] - # Wait for mock_db to become unhealthy - assert await db_became_unhealthy.wait(), ( - "Timeout waiting for mock_db1 to become unhealthy" - ) - await asyncio.sleep(0.01) + # Wait for mock_db to become unhealthy + assert await db_became_unhealthy.wait(), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) + await asyncio.sleep(0.01) - assert await client.transaction(callback) == ["OK1", "value"] + assert await client.transaction(callback) == ["OK1", "value"] diff --git a/tests/test_auth/test_token_manager.py b/tests/test_auth/test_token_manager.py index f675c125dd..f9fdf2f14b 100644 --- a/tests/test_auth/test_token_manager.py +++ b/tests/test_auth/test_token_manager.py @@ -174,13 +174,13 @@ def test_token_renewal_with_skip_initial(self): mock_provider.request_token.side_effect = [ SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 50, + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, (datetime.now(timezone.utc).timestamp() * 1000), {"oid": "test"}, ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 150, + (datetime.now(timezone.utc).timestamp() * 1000) + 1500, (datetime.now(timezone.utc).timestamp() * 1000), {"oid": "test"}, ), @@ -194,12 +194,12 @@ def on_next(token): mock_listener.on_next = on_next retry_policy = RetryPolicy(3, 10) - config = TokenManagerConfig(1, 0, 1000, retry_policy) + config = TokenManagerConfig(0.5, 0, 1000, retry_policy) mgr = TokenManager(mock_provider, config) mgr.start(mock_listener, skip_initial=True) - # Should be less than a 0.1, or it will be flacky due to - # additional token renewal. - sleep(0.1) + assert len(tokens) == 0 + + sleep(0.6) assert len(tokens) > 0 @@ -210,19 +210,19 @@ async def test_async_token_renewal_with_skip_initial(self): mock_provider.request_token.side_effect = [ SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000) + 1000, (datetime.now(timezone.utc).timestamp() * 1000), {"oid": "test"}, ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 120, + (datetime.now(timezone.utc).timestamp() * 1000) + 1200, (datetime.now(timezone.utc).timestamp() * 1000), {"oid": "test"}, ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 140, + (datetime.now(timezone.utc).timestamp() * 1000) + 1400, (datetime.now(timezone.utc).timestamp() * 1000), {"oid": "test"}, ), @@ -236,13 +236,12 @@ async def on_next(token): mock_listener.on_next = on_next retry_policy = RetryPolicy(3, 10) - config = TokenManagerConfig(1, 0, 1000, retry_policy) + config = TokenManagerConfig(0.5, 0, 1000, retry_policy) mgr = TokenManager(mock_provider, config) await mgr.start_async(mock_listener, skip_initial=True) - # Should be less than a 0.1, or it will be flacky - # due to additional token renewal. - await asyncio.sleep(0.2) + assert len(tokens) == 0 + await asyncio.sleep(0.6) assert len(tokens) > 0 def test_success_token_renewal_with_retry(self): diff --git a/tests/test_background.py b/tests/test_background.py index bac9c1eef6..dbd7ad58d6 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -39,7 +39,7 @@ def callback(arg1: str, arg2: int): ], ) def test_run_recurring(self, interval, timeout, call_count): - execute_counter = 0 + execute_counter = [] one = "arg1" two = 9999 @@ -48,18 +48,18 @@ def callback(arg1: str, arg2: int): nonlocal one nonlocal two - execute_counter += 1 + execute_counter.append(1) assert arg1 == one assert arg2 == two scheduler = BackgroundScheduler() scheduler.run_recurring(interval, callback, one, two) - assert execute_counter == 0 + assert len(execute_counter) == 0 sleep(timeout) - assert execute_counter == call_count + assert len(execute_counter) == call_count @pytest.mark.asyncio @pytest.mark.parametrize( @@ -71,7 +71,7 @@ def callback(arg1: str, arg2: int): ], ) async def test_run_recurring_async(self, interval, timeout, call_count): - execute_counter = 0 + execute_counter = [] one = "arg1" two = 9999 @@ -80,15 +80,15 @@ async def callback(arg1: str, arg2: int): nonlocal one nonlocal two - execute_counter += 1 + execute_counter.append(1) assert arg1 == one assert arg2 == two scheduler = BackgroundScheduler() await scheduler.run_recurring_async(interval, callback, one, two) - assert execute_counter == 0 + assert len(execute_counter) == 0 await asyncio.sleep(timeout) - assert execute_counter == call_count + assert len(execute_counter) == call_count diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 490126fdf2..ca74e9def6 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -42,13 +42,18 @@ def test_execute_command_against_correct_db_on_successful_initialization( mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 - - assert mock_db.circuit.state == CBState.CLOSED - assert mock_db1.circuit.state == CBState.CLOSED - assert mock_db2.circuit.state == CBState.CLOSED + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) + assert client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) >= 9 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.CLOSED + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -65,31 +70,47 @@ def test_execute_command_against_correct_db_on_successful_initialization( def test_execute_command_against_correct_db_and_closed_circuit( self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): + """ + Validates that commands are executed against the correct + database when one database becomes unhealthy during initialization. + Ensures the client selects the highest-weighted + healthy database (mock_db1) and executes commands against it + with a CLOSED circuit. + """ databases = create_weighted_list(mock_db, mock_db1, mock_db2) mock_multi_db_config.health_checks = [mock_hc] with patch.object(mock_multi_db_config, "databases", return_value=databases): + mock_db.client.execute_command = MagicMock( + return_value="NOT_OK-->Response from unexpected db - mock_db" + ) mock_db1.client.execute_command = MagicMock(return_value="OK1") + mock_db2.client.execute_command = MagicMock( + return_value="NOT_OK-->Response from unexpected db - mock_db2" + ) - mock_hc.check_health.side_effect = [ - False, - True, - True, - True, - True, - True, - True, - ] + def mock_check_health(database): + if database == mock_db2: + return False + else: + return True - client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - result = client.set("key", "value") - assert result == "OK1" - assert mock_hc.check_health.call_count == 7 + mock_hc.check_health.side_effect = mock_check_health - assert mock_db.circuit.state == CBState.CLOSED - assert mock_db1.circuit.state == CBState.CLOSED - assert mock_db2.circuit.state == CBState.OPEN + client = MultiDBClient(mock_multi_db_config) + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) + result = client.set("key", "value") + assert result == "OK1" + assert len(mock_hc.check_health.call_args_list) >= 7 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -188,31 +209,34 @@ def mock_check_health(database): mock_db2.client.execute_command.return_value = "OK2" client = MultiDBClient(mock_multi_db_config) - assert client.set("key", "value") == "OK1" + try: + assert client.set("key", "value") == "OK1" - # Wait for mock_db1 to become unhealthy - assert db1_became_unhealthy.wait(timeout=1.0), ( - "Timeout waiting for mock_db1 to become unhealthy" - ) - sleep(0.01) + # Wait for mock_db1 to become unhealthy + assert db1_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) + sleep(0.01) - assert client.set("key", "value") == "OK2" + assert client.set("key", "value") == "OK2" - # Wait for mock_db2 to become unhealthy - assert db2_became_unhealthy.wait(timeout=1.0), ( - "Timeout waiting for mock_db2 to become unhealthy" - ) - sleep(0.01) + # Wait for mock_db2 to become unhealthy + assert db2_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db2 to become unhealthy" + ) + sleep(0.01) - assert client.set("key", "value") == "OK" + assert client.set("key", "value") == "OK" - # Wait for mock_db to become unhealthy - assert db_became_unhealthy.wait(timeout=1.0), ( - "Timeout waiting for mock_db to become unhealthy" - ) - sleep(0.01) + # Wait for mock_db to become unhealthy + assert db_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db to become unhealthy" + ) + sleep(0.01) - assert client.set("key", "value") == "OK1" + assert client.set("key", "value") == "OK1" + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -259,20 +283,23 @@ def mock_check_health(database): mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) - assert client.set("key", "value") == "OK1" - error_event.wait(timeout=0.5) - - # Wait for circuit breaker to actually open (not just the event) - max_retries = 20 - for _ in range(max_retries): - if mock_db1.circuit.state == CBState.OPEN: # Circuit is open - break - sleep(0.01) - - # Now the failover strategy will select mock_db2 - assert client.set("key", "value") == "OK2" - sleep(0.5) - assert client.set("key", "value") == "OK1" + try: + assert client.set("key", "value") == "OK1" + error_event.wait(timeout=0.5) + + # Wait for circuit breaker to actually open (not just the event) + max_retries = 20 + for _ in range(max_retries): + if mock_db1.circuit.state == CBState.OPEN: # Circuit is open + break + sleep(0.01) + + # Now the failover strategy will select mock_db2 + assert client.set("key", "value") == "OK2" + sleep(0.5) + assert client.set("key", "value") == "OK1" + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -319,20 +346,23 @@ def mock_check_health(database): mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) - assert client.set("key", "value") == "OK1" - error_event.wait(timeout=0.5) - # Wait for circuit breaker state to actually reflect the unhealthy status - # (instead of just sleeping) - max_retries = 20 - for _ in range(max_retries): - if ( - mock_db1.circuit.state == CBState.OPEN - ): # Circuit is open (unhealthy) - break - sleep(0.01) - assert client.set("key", "value") == "OK2" - sleep(0.5) - assert client.set("key", "value") == "OK2" + try: + assert client.set("key", "value") == "OK1" + error_event.wait(timeout=0.5) + # Wait for circuit breaker state to actually reflect the unhealthy status + # (instead of just sleeping) + max_retries = 20 + for _ in range(max_retries): + if ( + mock_db1.circuit.state == CBState.OPEN + ): # Circuit is open (unhealthy) + break + sleep(0.01) + assert client.set("key", "value") == "OK2" + sleep(0.5) + assert client.set("key", "value") == "OK2" + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -356,15 +386,18 @@ def test_execute_command_throws_exception_on_failed_initialization( mock_hc.check_health.return_value = False client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - - with pytest.raises( - NoValidDatabaseException, - match="Initial connection failed - no active database found", - ): - client.set("key", "value") - - assert mock_hc.check_health.call_count == 3 + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) + + with pytest.raises( + NoValidDatabaseException, + match="Initial connection failed - no active database found", + ): + client.set("key", "value") + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -388,11 +421,17 @@ def test_add_database_throws_exception_on_same_database( mock_hc.check_health.return_value = False client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) + + with pytest.raises(ValueError, match="Given database already exists"): + client.add_database(mock_db) + assert len(mock_hc.check_health.call_args_list) == 3 - with pytest.raises(ValueError, match="Given database already exists"): - client.add_database(mock_db) - assert mock_hc.check_health.call_count == 3 + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -419,15 +458,20 @@ def test_add_database_makes_new_database_active( mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - assert client.set("key", "value") == "OK2" - assert mock_hc.check_health.call_count == 6 + assert client.set("key", "value") == "OK2" + assert len(mock_hc.check_health.call_args_list) == 6 - client.add_database(mock_db1) - assert mock_hc.check_health.call_count == 9 + client.add_database(mock_db1) + assert len(mock_hc.check_health.call_args_list) == 9 - assert client.set("key", "value") == "OK1" + assert client.set("key", "value") == "OK1" + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -454,14 +498,19 @@ def test_remove_highest_weighted_database( mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - assert client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 + assert client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) >= 9 - client.remove_database(mock_db1) + client.remove_database(mock_db1) - assert client.set("key", "value") == "OK2" + assert client.set("key", "value") == "OK2" + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -488,15 +537,20 @@ def test_update_database_weight_to_be_highest( mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - assert client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 + assert client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) >= 9 - client.update_database_weight(mock_db2, 0.8) - assert mock_db2.weight == 0.8 + client.update_database_weight(mock_db2, 0.8) + assert mock_db2.weight == 0.8 - assert client.set("key", "value") == "OK2" + assert client.set("key", "value") == "OK2" + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -530,25 +584,30 @@ def test_add_new_failure_detector( mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) + assert client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) >= 9 - # Simulate failing command events that lead to a failure detection - for i in range(5): - mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) + # Simulate failing command events that lead to a failure detection + for i in range(5): + mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) - assert mock_fd.register_failure.call_count == 5 + assert mock_fd.register_failure.call_count == 5 - another_fd = Mock(spec=FailureDetector) - client.add_failure_detector(another_fd) + another_fd = Mock(spec=FailureDetector) + client.add_failure_detector(another_fd) - # Simulate failing command events that lead to a failure detection - for i in range(5): - mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) + # Simulate failing command events that lead to a failure detection + for i in range(5): + mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) - assert mock_fd.register_failure.call_count == 10 - assert another_fd.register_failure.call_count == 5 + assert mock_fd.register_failure.call_count == 10 + assert another_fd.register_failure.call_count == 5 + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -574,18 +633,24 @@ def test_add_new_health_check( mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) + assert client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) == 9 - another_hc = Mock(spec=HealthCheck) - another_hc.check_health.return_value = True + another_hc = Mock(spec=HealthCheck) + another_hc.check_health.return_value = True - client.add_health_check(another_hc) - client._check_db_health(mock_db1) + client.add_health_check(another_hc) + client._check_db_health(mock_db1) - assert mock_hc.check_health.call_count == 12 - assert another_hc.check_health.call_count == 3 + assert len(mock_hc.check_health.call_args_list) == 12 + assert len(another_hc.check_health.call_args_list) == 3 + + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -612,22 +677,27 @@ def test_set_active_database( mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 - assert client.set("key", "value") == "OK1" - assert mock_hc.check_health.call_count == 9 - - client.set_active_database(mock_db) - assert client.set("key", "value") == "OK" - - with pytest.raises( - ValueError, match="Given database is not a member of database list" - ): - client.set_active_database(Mock(spec=SyncDatabase)) - - mock_hc.check_health.return_value = False - - with pytest.raises( - NoValidDatabaseException, - match="Cannot set active database, database is unhealthy", - ): - client.set_active_database(mock_db1) + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) + assert client.set("key", "value") == "OK1" + assert len(mock_hc.check_health.call_args_list) == 9 + + client.set_active_database(mock_db) + assert client.set("key", "value") == "OK" + + with pytest.raises( + ValueError, match="Given database is not a member of database list" + ): + client.set_active_database(Mock(spec=SyncDatabase)) + + mock_hc.check_health.return_value = False + + with pytest.raises( + NoValidDatabaseException, + match="Cannot set active database, database is unhealthy", + ): + client.set_active_database(mock_db1) + finally: + client.close() diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index 0055718d4f..3c90cff9c6 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -53,14 +53,19 @@ def test_executes_pipeline_against_correct_db( mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - pipe = client.pipeline() - pipe.set("key1", "value1") - pipe.get("key1") + pipe = client.pipeline() + pipe.set("key1", "value1") + pipe.get("key1") - assert pipe.execute() == ["OK1", "value1"] - assert mock_hc.check_health.call_count == 9 + assert pipe.execute() == ["OK1", "value1"] + assert len(mock_hc.check_health.call_args_list) == 9 + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -89,29 +94,31 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.execute.return_value = ["OK1", "value1"] mock_db1.client.pipeline.return_value = pipe - mock_hc.check_health.side_effect = [ - False, - True, - True, - True, - True, - True, - True, - ] + def mock_check_health(database): + if database == mock_db2: + return False + else: + return True + + mock_hc.check_health.side_effect = mock_check_health client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - with client.pipeline() as pipe: - pipe.set("key1", "value1") - pipe.get("key1") + with client.pipeline() as pipe: + pipe.set("key1", "value1") + pipe.get("key1") - assert pipe.execute() == ["OK1", "value1"] - assert mock_hc.check_health.call_count == 7 + assert pipe.execute() == ["OK1", "value1"] - assert mock_db.circuit.state == CBState.CLOSED - assert mock_db1.circuit.state == CBState.CLOSED - assert mock_db2.circuit.state == CBState.OPEN + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -225,40 +232,42 @@ def mock_check_health(database): mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) + try: + with client.pipeline() as pipe: + pipe.set("key1", "value") + pipe.get("key1") - with client.pipeline() as pipe: - pipe.set("key1", "value") - pipe.get("key1") + # Run 1: All databases healthy - should use mock_db1 (highest weight 0.7) + assert pipe.execute() == ["OK1", "value"] - # Run 1: All databases healthy - should use mock_db1 (highest weight 0.7) - assert pipe.execute() == ["OK1", "value"] + # Wait for mock_db1 to become unhealthy + assert db1_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) + sleep(0.01) - # Wait for mock_db1 to become unhealthy - assert db1_became_unhealthy.wait(timeout=1.0), ( - "Timeout waiting for mock_db1 to become unhealthy" - ) - sleep(0.01) + # Run 2: mock_db1 unhealthy - should failover to mock_db2 (weight 0.5) + assert pipe.execute() == ["OK2", "value"] - # Run 2: mock_db1 unhealthy - should failover to mock_db2 (weight 0.5) - assert pipe.execute() == ["OK2", "value"] - - # Wait for mock_db2 to become unhealthy - assert db2_became_unhealthy.wait(timeout=1.0), ( - "Timeout waiting for mock_db2 to become unhealthy" - ) - sleep(0.01) + # Wait for mock_db2 to become unhealthy + assert db2_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db2 to become unhealthy" + ) + sleep(0.01) - # Run 3: mock_db1 and mock_db2 unhealthy - should use mock_db (weight 0.2) - assert pipe.execute() == ["OK", "value"] + # Run 3: mock_db1 and mock_db2 unhealthy - should use mock_db (weight 0.2) + assert pipe.execute() == ["OK", "value"] - # Wait for mock_db to become unhealthy - assert db_became_unhealthy.wait(timeout=1.0), ( - "Timeout waiting for mock_db to become unhealthy" - ) - sleep(0.01) + # Wait for mock_db to become unhealthy + assert db_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db to become unhealthy" + ) + sleep(0.01) - # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) - assert pipe.execute() == ["OK1", "value"] + # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) + assert pipe.execute() == ["OK1", "value"] + finally: + client.close() @pytest.mark.onlynoncluster @@ -291,14 +300,19 @@ def test_executes_transaction_against_correct_db( mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - def callback(pipe: Pipeline): - pipe.set("key1", "value1") - pipe.get("key1") + def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") - assert client.transaction(callback) == ["OK1", "value1"] - assert mock_hc.check_health.call_count == 9 + assert client.transaction(callback) == ["OK1", "value1"] + assert len(mock_hc.check_health.call_args_list) >= 9 + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -325,29 +339,32 @@ def test_execute_transaction_against_correct_db_and_closed_circuit( ): mock_db1.client.transaction.return_value = ["OK1", "value1"] - mock_hc.check_health.side_effect = [ - False, - True, - True, - True, - True, - True, - True, - ] + def mock_check_health(database): + if database == mock_db2: + return False + else: + return True + + mock_hc.check_health.side_effect = mock_check_health client = MultiDBClient(mock_multi_db_config) - assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + try: + assert ( + mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + ) - def callback(pipe: Pipeline): - pipe.set("key1", "value1") - pipe.get("key1") + def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") - assert client.transaction(callback) == ["OK1", "value1"] - assert mock_hc.check_health.call_count == 7 + assert client.transaction(callback) == ["OK1", "value1"] + assert len(mock_hc.check_health.call_args_list) >= 7 - assert mock_db.circuit.state == CBState.CLOSED - assert mock_db1.circuit.state == CBState.CLOSED - assert mock_db2.circuit.state == CBState.OPEN + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + finally: + client.close() @pytest.mark.parametrize( "mock_multi_db_config,mock_db, mock_db1, mock_db2", @@ -453,37 +470,40 @@ def mock_check_health(database): mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() client = MultiDBClient(mock_multi_db_config) + try: - def callback(pipe: Pipeline): - pipe.set("key1", "value1") - pipe.get("key1") + def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") - # Run 1: All databases healthy - should use mock_db1 (highest weight 0.7) - assert client.transaction(callback) == ["OK1", "value"] + # Run 1: All databases healthy - should use mock_db1 (highest weight 0.7) + assert client.transaction(callback) == ["OK1", "value"] - # Wait for mock_db1 to become unhealthy - assert db1_became_unhealthy.wait(timeout=1.0), ( - "Timeout waiting for mock_db1 to become unhealthy" - ) - sleep(0.01) + # Wait for mock_db1 to become unhealthy + assert db1_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db1 to become unhealthy" + ) + sleep(0.01) - # Run 2: mock_db1 unhealthy - should failover to mock_db2 (weight 0.5) - assert client.transaction(callback) == ["OK2", "value"] + # Run 2: mock_db1 unhealthy - should failover to mock_db2 (weight 0.5) + assert client.transaction(callback) == ["OK2", "value"] - # Wait for mock_db2 to become unhealthy - assert db2_became_unhealthy.wait(timeout=1.0), ( - "Timeout waiting for mock_db2 to become unhealthy" - ) - sleep(0.01) + # Wait for mock_db2 to become unhealthy + assert db2_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db2 to become unhealthy" + ) + sleep(0.01) - # Run 3: mock_db1 and mock_db2 unhealthy - should use mock_db (weight 0.2) - assert client.transaction(callback) == ["OK", "value"] + # Run 3: mock_db1 and mock_db2 unhealthy - should use mock_db (weight 0.2) + assert client.transaction(callback) == ["OK", "value"] - # Wait for mock_db to become unhealthy - assert db_became_unhealthy.wait(timeout=1.0), ( - "Timeout waiting for mock_db to become unhealthy" - ) - sleep(0.01) + # Wait for mock_db to become unhealthy + assert db_became_unhealthy.wait(timeout=1.0), ( + "Timeout waiting for mock_db to become unhealthy" + ) + sleep(0.01) - # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) - assert client.transaction(callback) == ["OK1", "value"] + # Run 4: mock_db unhealthy, others healthy - should use mock_db1 (highest weight) + assert client.transaction(callback) == ["OK1", "value"] + finally: + client.close()