@@ -41,14 +41,16 @@ async def test_execute_command_against_correct_db_on_successful_initialization(
4141
4242 mock_hc .check_health .return_value = True
4343
44- client = MultiDBClient (mock_multi_db_config )
45- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
46- assert await client .set ("key" , "value" ) == "OK1"
47- assert mock_hc .check_health .call_count == 9
44+ async with MultiDBClient (mock_multi_db_config ) as client :
45+ assert (
46+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
47+ )
48+ assert await client .set ("key" , "value" ) == "OK1"
49+ assert len (mock_hc .check_health .call_args_list ) == 9
4850
49- assert mock_db .circuit .state == CBState .CLOSED
50- assert mock_db1 .circuit .state == CBState .CLOSED
51- assert mock_db2 .circuit .state == CBState .CLOSED
51+ assert mock_db .circuit .state == CBState .CLOSED
52+ assert mock_db1 .circuit .state == CBState .CLOSED
53+ assert mock_db2 .circuit .state == CBState .CLOSED
5254
5355 @pytest .mark .asyncio
5456 @pytest .mark .parametrize (
@@ -66,31 +68,44 @@ async def test_execute_command_against_correct_db_on_successful_initialization(
6668 async def test_execute_command_against_correct_db_and_closed_circuit (
6769 self , mock_multi_db_config , mock_db , mock_db1 , mock_db2 , mock_hc
6870 ):
71+ """
72+ Validates that commands are executed against the correct
73+ database when one database becomes unhealthy during initialization.
74+ Ensures the client selects the highest-weighted
75+ healthy database (mock_db1) and executes commands against it
76+ with a CLOSED circuit.
77+ """
6978 databases = create_weighted_list (mock_db , mock_db1 , mock_db2 )
7079 mock_multi_db_config .health_checks = [mock_hc ]
7180
7281 with patch .object (mock_multi_db_config , "databases" , return_value = databases ):
82+ mock_db .client .execute_command = AsyncMock (
83+ return_value = "NOT_OK-->Response from unexpected db - mock_db"
84+ )
7385 mock_db1 .client .execute_command = AsyncMock (return_value = "OK1" )
86+ mock_db2 .client .execute_command = AsyncMock (
87+ return_value = "NOT_OK-->Response from unexpected db - mock_db2"
88+ )
7489
75- mock_hc .check_health .side_effect = [
76- False ,
77- True ,
78- True ,
79- True ,
80- True ,
81- True ,
82- True ,
83- ]
90+ async def mock_check_health (database ):
91+ if database == mock_db2 :
92+ return False
93+ else :
94+ return True
8495
85- client = MultiDBClient (mock_multi_db_config )
86- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
87- result = await client .set ("key" , "value" )
88- assert result == "OK1"
89- assert mock_hc .check_health .call_count == 7
96+ mock_hc .check_health .side_effect = mock_check_health
9097
91- assert mock_db .circuit .state == CBState .CLOSED
92- assert mock_db1 .circuit .state == CBState .CLOSED
93- assert mock_db2 .circuit .state == CBState .OPEN
98+ async with MultiDBClient (mock_multi_db_config ) as client :
99+ assert (
100+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
101+ )
102+ result = await client .set ("key" , "value" )
103+ assert result == "OK1"
104+ assert mock_hc .check_health .call_count >= 7
105+
106+ assert mock_db .circuit .state == CBState .CLOSED
107+ assert mock_db1 .circuit .state == CBState .CLOSED
108+ assert mock_db2 .circuit .state == CBState .OPEN
94109
95110 @pytest .mark .asyncio
96111 @pytest .mark .parametrize (
@@ -189,40 +204,40 @@ async def mock_check_health(database):
189204 mock_multi_db_config .health_check_interval = 0.1
190205 mock_multi_db_config .failover_strategy = WeightBasedFailoverStrategy ()
191206
192- client = MultiDBClient (mock_multi_db_config )
193- assert await client .set ("key" , "value" ) == "OK1"
207+ async with MultiDBClient (mock_multi_db_config ) as client :
208+ assert await client .set ("key" , "value" ) == "OK1"
194209
195- # Wait for mock_db1 to become unhealthy
196- assert await db1_became_unhealthy .wait (), (
197- "Timeout waiting for mock_db1 to become unhealthy"
198- )
210+ # Wait for mock_db1 to become unhealthy
211+ assert await db1_became_unhealthy .wait (), (
212+ "Timeout waiting for mock_db1 to become unhealthy"
213+ )
199214
200- await asyncio .sleep (0.01 )
215+ await asyncio .sleep (0.01 )
201216
202- assert await client .set ("key" , "value" ) == "OK2"
217+ assert await client .set ("key" , "value" ) == "OK2"
203218
204- # Wait for mock_db2 to become unhealthy
205- assert await db2_became_unhealthy .wait (), (
206- "Timeout waiting for mock_db2 to become unhealthy"
207- )
219+ # Wait for mock_db2 to become unhealthy
220+ assert await db2_became_unhealthy .wait (), (
221+ "Timeout waiting for mock_db2 to become unhealthy"
222+ )
208223
209- # Wait for circuit breaker state to actually reflect the unhealthy status
210- # (instead of just sleeping)
211- max_retries = 20
212- for _ in range (max_retries ):
213- if cb2 .state == CBState .OPEN : # Circuit is open (unhealthy)
214- break
215- await asyncio .sleep (0.01 )
224+ # Wait for circuit breaker state to actually reflect the unhealthy status
225+ # (instead of just sleeping)
226+ max_retries = 20
227+ for _ in range (max_retries ):
228+ if cb2 .state == CBState .OPEN : # Circuit is open (unhealthy)
229+ break
230+ await asyncio .sleep (0.01 )
216231
217- assert await client .set ("key" , "value" ) == "OK"
232+ assert await client .set ("key" , "value" ) == "OK"
218233
219- # Wait for mock_db to become unhealthy
220- assert await db_became_unhealthy .wait (), (
221- "Timeout waiting for mock_db to become unhealthy"
222- )
223- await asyncio .sleep (0.01 )
234+ # Wait for mock_db to become unhealthy
235+ assert await db_became_unhealthy .wait (), (
236+ "Timeout waiting for mock_db to become unhealthy"
237+ )
238+ await asyncio .sleep (0.01 )
224239
225- assert await client .set ("key" , "value" ) == "OK1"
240+ assert await client .set ("key" , "value" ) == "OK1"
226241
227242 @pytest .mark .asyncio
228243 @pytest .mark .parametrize (
@@ -375,7 +390,7 @@ async def test_execute_command_throws_exception_on_failed_initialization(
375390 match = "Initial connection failed - no active database found" ,
376391 ):
377392 await client .set ("key" , "value" )
378- assert mock_hc .check_health .call_count == 9
393+ assert len ( mock_hc .check_health .call_args_list ) == 9
379394
380395 @pytest .mark .asyncio
381396 @pytest .mark .parametrize (
@@ -404,7 +419,6 @@ async def test_add_database_throws_exception_on_same_database(
404419
405420 with pytest .raises (ValueError , match = "Given database already exists" ):
406421 await client .add_database (mock_db )
407- assert mock_hc .check_health .call_count == 9
408422
409423 @pytest .mark .asyncio
410424 @pytest .mark .parametrize (
@@ -431,16 +445,18 @@ async def test_add_database_makes_new_database_active(
431445
432446 mock_hc .check_health .return_value = True
433447
434- client = MultiDBClient (mock_multi_db_config )
435- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
448+ async with MultiDBClient (mock_multi_db_config ) as client :
449+ assert (
450+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
451+ )
436452
437- assert await client .set ("key" , "value" ) == "OK2"
438- assert mock_hc .check_health .call_count == 6
453+ assert await client .set ("key" , "value" ) == "OK2"
454+ assert len ( mock_hc .check_health .call_args_list ) == 6
439455
440- await client .add_database (mock_db1 )
441- assert mock_hc .check_health .call_count == 9
456+ await client .add_database (mock_db1 )
457+ assert len ( mock_hc .check_health .call_args_list ) == 9
442458
443- assert await client .set ("key" , "value" ) == "OK1"
459+ assert await client .set ("key" , "value" ) == "OK1"
444460
445461 @pytest .mark .asyncio
446462 @pytest .mark .parametrize (
@@ -467,14 +483,16 @@ async def test_remove_highest_weighted_database(
467483
468484 mock_hc .check_health .return_value = True
469485
470- client = MultiDBClient (mock_multi_db_config )
471- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
486+ async with MultiDBClient (mock_multi_db_config ) as client :
487+ assert (
488+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
489+ )
472490
473- assert await client .set ("key" , "value" ) == "OK1"
474- assert mock_hc .check_health .call_count == 9
491+ assert await client .set ("key" , "value" ) == "OK1"
492+ assert len ( mock_hc .check_health .call_args_list ) == 9
475493
476- await client .remove_database (mock_db1 )
477- assert await client .set ("key" , "value" ) == "OK2"
494+ await client .remove_database (mock_db1 )
495+ assert await client .set ("key" , "value" ) == "OK2"
478496
479497 @pytest .mark .asyncio
480498 @pytest .mark .parametrize (
@@ -501,16 +519,18 @@ async def test_update_database_weight_to_be_highest(
501519
502520 mock_hc .check_health .return_value = True
503521
504- client = MultiDBClient (mock_multi_db_config )
505- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
522+ async with MultiDBClient (mock_multi_db_config ) as client :
523+ assert (
524+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
525+ )
506526
507- assert await client .set ("key" , "value" ) == "OK1"
508- assert mock_hc .check_health .call_count == 9
527+ assert await client .set ("key" , "value" ) == "OK1"
528+ assert len ( mock_hc .check_health .call_args_list ) == 9
509529
510- await client .update_database_weight (mock_db2 , 0.8 )
511- assert mock_db2 .weight == 0.8
530+ await client .update_database_weight (mock_db2 , 0.8 )
531+ assert mock_db2 .weight == 0.8
512532
513- assert await client .set ("key" , "value" ) == "OK2"
533+ assert await client .set ("key" , "value" ) == "OK2"
514534
515535 @pytest .mark .asyncio
516536 @pytest .mark .parametrize (
@@ -544,30 +564,32 @@ async def test_add_new_failure_detector(
544564
545565 mock_hc .check_health .return_value = True
546566
547- client = MultiDBClient (mock_multi_db_config )
548- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
549- assert await client .set ("key" , "value" ) == "OK1"
550- assert mock_hc .check_health .call_count == 9
551-
552- # Simulate failing command events that lead to a failure detection
553- for _ in range (5 ):
554- await mock_multi_db_config .event_dispatcher .dispatch_async (
555- command_fail_event
567+ async with MultiDBClient (mock_multi_db_config ) as client :
568+ assert (
569+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
556570 )
571+ assert await client .set ("key" , "value" ) == "OK1"
572+ assert len (mock_hc .check_health .call_args_list ) == 9
557573
558- assert mock_fd .register_failure .call_count == 5
574+ # Simulate failing command events that lead to a failure detection
575+ for _ in range (5 ):
576+ await mock_multi_db_config .event_dispatcher .dispatch_async (
577+ command_fail_event
578+ )
559579
560- another_fd = Mock (spec = AsyncFailureDetector )
561- client .add_failure_detector (another_fd )
580+ assert mock_fd .register_failure .call_count == 5
562581
563- # Simulate failing command events that lead to a failure detection
564- for _ in range (5 ):
565- await mock_multi_db_config .event_dispatcher .dispatch_async (
566- command_fail_event
567- )
582+ another_fd = Mock (spec = AsyncFailureDetector )
583+ client .add_failure_detector (another_fd )
584+
585+ # Simulate failing command events that lead to a failure detection
586+ for _ in range (5 ):
587+ await mock_multi_db_config .event_dispatcher .dispatch_async (
588+ command_fail_event
589+ )
568590
569- assert mock_fd .register_failure .call_count == 10
570- assert another_fd .register_failure .call_count == 5
591+ assert mock_fd .register_failure .call_count == 10
592+ assert another_fd .register_failure .call_count == 5
571593
572594 @pytest .mark .asyncio
573595 @pytest .mark .parametrize (
@@ -593,19 +615,21 @@ async def test_add_new_health_check(
593615
594616 mock_hc .check_health .return_value = True
595617
596- client = MultiDBClient (mock_multi_db_config )
597- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
598- assert await client .set ("key" , "value" ) == "OK1"
599- assert mock_hc .check_health .call_count == 9
618+ async with MultiDBClient (mock_multi_db_config ) as client :
619+ assert (
620+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
621+ )
622+ assert await client .set ("key" , "value" ) == "OK1"
623+ assert len (mock_hc .check_health .call_args_list ) == 9
600624
601- another_hc = Mock (spec = HealthCheck )
602- another_hc .check_health .return_value = True
625+ another_hc = Mock (spec = HealthCheck )
626+ another_hc .check_health .return_value = True
603627
604- await client .add_health_check (another_hc )
605- await client ._check_db_health (mock_db1 )
628+ await client .add_health_check (another_hc )
629+ await client ._check_db_health (mock_db1 )
606630
607- assert mock_hc .check_health .call_count == 12
608- assert another_hc .check_health .call_count == 3
631+ assert len ( mock_hc .check_health .call_args_list ) == 12
632+ assert len ( another_hc .check_health .call_args_list ) == 3
609633
610634 @pytest .mark .asyncio
611635 @pytest .mark .parametrize (
@@ -632,23 +656,26 @@ async def test_set_active_database(
632656
633657 mock_hc .check_health .return_value = True
634658
635- client = MultiDBClient (mock_multi_db_config )
636- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
637- assert await client .set ("key" , "value" ) == "OK1"
638- assert mock_hc .check_health .call_count == 9
659+ async with MultiDBClient (mock_multi_db_config ) as client :
660+ assert (
661+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
662+ )
663+ assert await client .set ("key" , "value" ) == "OK1"
664+ assert len (mock_hc .check_health .call_args_list ) >= 9
665+ assert mock_hc .check_health .call_count >= 9
639666
640- await client .set_active_database (mock_db )
641- assert await client .set ("key" , "value" ) == "OK"
667+ await client .set_active_database (mock_db )
668+ assert await client .set ("key" , "value" ) == "OK"
642669
643- with pytest .raises (
644- ValueError , match = "Given database is not a member of database list"
645- ):
646- await client .set_active_database (Mock (spec = AsyncDatabase ))
670+ with pytest .raises (
671+ ValueError , match = "Given database is not a member of database list"
672+ ):
673+ await client .set_active_database (Mock (spec = AsyncDatabase ))
647674
648- mock_hc .check_health .return_value = False
675+ mock_hc .check_health .return_value = False
649676
650- with pytest .raises (
651- NoValidDatabaseException ,
652- match = "Cannot set active database, database is unhealthy" ,
653- ):
654- await client .set_active_database (mock_db1 )
677+ with pytest .raises (
678+ NoValidDatabaseException ,
679+ match = "Cannot set active database, database is unhealthy" ,
680+ ):
681+ await client .set_active_database (mock_db1 )
0 commit comments