Skip to content

Commit dc827b7

Browse files
committed
Fixing flaky tests - part 2
1 parent 6bbc5c7 commit dc827b7

File tree

7 files changed

+582
-450
lines changed

7 files changed

+582
-450
lines changed

dev_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pytest==8.3.4 ; platform_python_implementation == "PyPy"
1212
pytest-asyncio>=0.23.0
1313
pytest-asyncio==1.1.0 ; platform_python_implementation == "PyPy"
1414
pytest-cov
15+
coverage<7.11.1
1516
pytest-cov==6.0.0 ; platform_python_implementation == "PyPy"
1617
coverage==7.6.12 ; platform_python_implementation == "PyPy"
1718
pytest-profiling==1.8.1

redis/multidb/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,13 @@ def _on_circuit_state_change_callback(
301301
)
302302

303303
def close(self):
304-
self.command_executor.active_database.client.close()
304+
"""
305+
Closes the client and all its resources.
306+
"""
307+
if self._bg_scheduler:
308+
self._bg_scheduler.stop()
309+
if self.command_executor.active_database:
310+
self.command_executor.active_database.client.close()
305311

306312

307313
def _half_open_circuit(circuit: CircuitBreaker):

tests/test_asyncio/test_multidb/test_client.py

Lines changed: 143 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,16 @@ async def test_execute_command_against_correct_db_on_successful_initialization(
4141

4242
mock_hc.check_health.return_value = True
4343

44-
client = MultiDBClient(mock_multi_db_config)
45-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
46-
assert await client.set("key", "value") == "OK1"
47-
assert mock_hc.check_health.call_count == 9
44+
async with MultiDBClient(mock_multi_db_config) as client:
45+
assert (
46+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
47+
)
48+
assert await client.set("key", "value") == "OK1"
49+
assert len(mock_hc.check_health.call_args_list) == 9
4850

49-
assert mock_db.circuit.state == CBState.CLOSED
50-
assert mock_db1.circuit.state == CBState.CLOSED
51-
assert mock_db2.circuit.state == CBState.CLOSED
51+
assert mock_db.circuit.state == CBState.CLOSED
52+
assert mock_db1.circuit.state == CBState.CLOSED
53+
assert mock_db2.circuit.state == CBState.CLOSED
5254

5355
@pytest.mark.asyncio
5456
@pytest.mark.parametrize(
@@ -66,31 +68,44 @@ async def test_execute_command_against_correct_db_on_successful_initialization(
6668
async def test_execute_command_against_correct_db_and_closed_circuit(
6769
self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc
6870
):
71+
"""
72+
Validates that commands are executed against the correct
73+
database when one database becomes unhealthy during initialization.
74+
Ensures the client selects the highest-weighted
75+
healthy database (mock_db1) and executes commands against it
76+
with a CLOSED circuit.
77+
"""
6978
databases = create_weighted_list(mock_db, mock_db1, mock_db2)
7079
mock_multi_db_config.health_checks = [mock_hc]
7180

7281
with patch.object(mock_multi_db_config, "databases", return_value=databases):
82+
mock_db.client.execute_command = AsyncMock(
83+
return_value="NOT_OK-->Response from unexpected db - mock_db"
84+
)
7385
mock_db1.client.execute_command = AsyncMock(return_value="OK1")
86+
mock_db2.client.execute_command = AsyncMock(
87+
return_value="NOT_OK-->Response from unexpected db - mock_db2"
88+
)
7489

75-
mock_hc.check_health.side_effect = [
76-
False,
77-
True,
78-
True,
79-
True,
80-
True,
81-
True,
82-
True,
83-
]
90+
async def mock_check_health(database):
91+
if database == mock_db2:
92+
return False
93+
else:
94+
return True
8495

85-
client = MultiDBClient(mock_multi_db_config)
86-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
87-
result = await client.set("key", "value")
88-
assert result == "OK1"
89-
assert mock_hc.check_health.call_count == 7
96+
mock_hc.check_health.side_effect = mock_check_health
9097

91-
assert mock_db.circuit.state == CBState.CLOSED
92-
assert mock_db1.circuit.state == CBState.CLOSED
93-
assert mock_db2.circuit.state == CBState.OPEN
98+
async with MultiDBClient(mock_multi_db_config) as client:
99+
assert (
100+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
101+
)
102+
result = await client.set("key", "value")
103+
assert result == "OK1"
104+
assert mock_hc.check_health.call_count >= 7
105+
106+
assert mock_db.circuit.state == CBState.CLOSED
107+
assert mock_db1.circuit.state == CBState.CLOSED
108+
assert mock_db2.circuit.state == CBState.OPEN
94109

95110
@pytest.mark.asyncio
96111
@pytest.mark.parametrize(
@@ -189,40 +204,40 @@ async def mock_check_health(database):
189204
mock_multi_db_config.health_check_interval = 0.1
190205
mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy()
191206

192-
client = MultiDBClient(mock_multi_db_config)
193-
assert await client.set("key", "value") == "OK1"
207+
async with MultiDBClient(mock_multi_db_config) as client:
208+
assert await client.set("key", "value") == "OK1"
194209

195-
# Wait for mock_db1 to become unhealthy
196-
assert await db1_became_unhealthy.wait(), (
197-
"Timeout waiting for mock_db1 to become unhealthy"
198-
)
210+
# Wait for mock_db1 to become unhealthy
211+
assert await db1_became_unhealthy.wait(), (
212+
"Timeout waiting for mock_db1 to become unhealthy"
213+
)
199214

200-
await asyncio.sleep(0.01)
215+
await asyncio.sleep(0.01)
201216

202-
assert await client.set("key", "value") == "OK2"
217+
assert await client.set("key", "value") == "OK2"
203218

204-
# Wait for mock_db2 to become unhealthy
205-
assert await db2_became_unhealthy.wait(), (
206-
"Timeout waiting for mock_db2 to become unhealthy"
207-
)
219+
# Wait for mock_db2 to become unhealthy
220+
assert await db2_became_unhealthy.wait(), (
221+
"Timeout waiting for mock_db2 to become unhealthy"
222+
)
208223

209-
# Wait for circuit breaker state to actually reflect the unhealthy status
210-
# (instead of just sleeping)
211-
max_retries = 20
212-
for _ in range(max_retries):
213-
if cb2.state == CBState.OPEN: # Circuit is open (unhealthy)
214-
break
215-
await asyncio.sleep(0.01)
224+
# Wait for circuit breaker state to actually reflect the unhealthy status
225+
# (instead of just sleeping)
226+
max_retries = 20
227+
for _ in range(max_retries):
228+
if cb2.state == CBState.OPEN: # Circuit is open (unhealthy)
229+
break
230+
await asyncio.sleep(0.01)
216231

217-
assert await client.set("key", "value") == "OK"
232+
assert await client.set("key", "value") == "OK"
218233

219-
# Wait for mock_db to become unhealthy
220-
assert await db_became_unhealthy.wait(), (
221-
"Timeout waiting for mock_db to become unhealthy"
222-
)
223-
await asyncio.sleep(0.01)
234+
# Wait for mock_db to become unhealthy
235+
assert await db_became_unhealthy.wait(), (
236+
"Timeout waiting for mock_db to become unhealthy"
237+
)
238+
await asyncio.sleep(0.01)
224239

225-
assert await client.set("key", "value") == "OK1"
240+
assert await client.set("key", "value") == "OK1"
226241

227242
@pytest.mark.asyncio
228243
@pytest.mark.parametrize(
@@ -375,7 +390,7 @@ async def test_execute_command_throws_exception_on_failed_initialization(
375390
match="Initial connection failed - no active database found",
376391
):
377392
await client.set("key", "value")
378-
assert mock_hc.check_health.call_count == 9
393+
assert len(mock_hc.check_health.call_args_list) == 9
379394

380395
@pytest.mark.asyncio
381396
@pytest.mark.parametrize(
@@ -404,7 +419,6 @@ async def test_add_database_throws_exception_on_same_database(
404419

405420
with pytest.raises(ValueError, match="Given database already exists"):
406421
await client.add_database(mock_db)
407-
assert mock_hc.check_health.call_count == 9
408422

409423
@pytest.mark.asyncio
410424
@pytest.mark.parametrize(
@@ -431,16 +445,18 @@ async def test_add_database_makes_new_database_active(
431445

432446
mock_hc.check_health.return_value = True
433447

434-
client = MultiDBClient(mock_multi_db_config)
435-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
448+
async with MultiDBClient(mock_multi_db_config) as client:
449+
assert (
450+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
451+
)
436452

437-
assert await client.set("key", "value") == "OK2"
438-
assert mock_hc.check_health.call_count == 6
453+
assert await client.set("key", "value") == "OK2"
454+
assert len(mock_hc.check_health.call_args_list) == 6
439455

440-
await client.add_database(mock_db1)
441-
assert mock_hc.check_health.call_count == 9
456+
await client.add_database(mock_db1)
457+
assert len(mock_hc.check_health.call_args_list) == 9
442458

443-
assert await client.set("key", "value") == "OK1"
459+
assert await client.set("key", "value") == "OK1"
444460

445461
@pytest.mark.asyncio
446462
@pytest.mark.parametrize(
@@ -467,14 +483,16 @@ async def test_remove_highest_weighted_database(
467483

468484
mock_hc.check_health.return_value = True
469485

470-
client = MultiDBClient(mock_multi_db_config)
471-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
486+
async with MultiDBClient(mock_multi_db_config) as client:
487+
assert (
488+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
489+
)
472490

473-
assert await client.set("key", "value") == "OK1"
474-
assert mock_hc.check_health.call_count == 9
491+
assert await client.set("key", "value") == "OK1"
492+
assert len(mock_hc.check_health.call_args_list) == 9
475493

476-
await client.remove_database(mock_db1)
477-
assert await client.set("key", "value") == "OK2"
494+
await client.remove_database(mock_db1)
495+
assert await client.set("key", "value") == "OK2"
478496

479497
@pytest.mark.asyncio
480498
@pytest.mark.parametrize(
@@ -501,16 +519,18 @@ async def test_update_database_weight_to_be_highest(
501519

502520
mock_hc.check_health.return_value = True
503521

504-
client = MultiDBClient(mock_multi_db_config)
505-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
522+
async with MultiDBClient(mock_multi_db_config) as client:
523+
assert (
524+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
525+
)
506526

507-
assert await client.set("key", "value") == "OK1"
508-
assert mock_hc.check_health.call_count == 9
527+
assert await client.set("key", "value") == "OK1"
528+
assert len(mock_hc.check_health.call_args_list) == 9
509529

510-
await client.update_database_weight(mock_db2, 0.8)
511-
assert mock_db2.weight == 0.8
530+
await client.update_database_weight(mock_db2, 0.8)
531+
assert mock_db2.weight == 0.8
512532

513-
assert await client.set("key", "value") == "OK2"
533+
assert await client.set("key", "value") == "OK2"
514534

515535
@pytest.mark.asyncio
516536
@pytest.mark.parametrize(
@@ -544,30 +564,32 @@ async def test_add_new_failure_detector(
544564

545565
mock_hc.check_health.return_value = True
546566

547-
client = MultiDBClient(mock_multi_db_config)
548-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
549-
assert await client.set("key", "value") == "OK1"
550-
assert mock_hc.check_health.call_count == 9
551-
552-
# Simulate failing command events that lead to a failure detection
553-
for _ in range(5):
554-
await mock_multi_db_config.event_dispatcher.dispatch_async(
555-
command_fail_event
567+
async with MultiDBClient(mock_multi_db_config) as client:
568+
assert (
569+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
556570
)
571+
assert await client.set("key", "value") == "OK1"
572+
assert len(mock_hc.check_health.call_args_list) == 9
557573

558-
assert mock_fd.register_failure.call_count == 5
574+
# Simulate failing command events that lead to a failure detection
575+
for _ in range(5):
576+
await mock_multi_db_config.event_dispatcher.dispatch_async(
577+
command_fail_event
578+
)
559579

560-
another_fd = Mock(spec=AsyncFailureDetector)
561-
client.add_failure_detector(another_fd)
580+
assert mock_fd.register_failure.call_count == 5
562581

563-
# Simulate failing command events that lead to a failure detection
564-
for _ in range(5):
565-
await mock_multi_db_config.event_dispatcher.dispatch_async(
566-
command_fail_event
567-
)
582+
another_fd = Mock(spec=AsyncFailureDetector)
583+
client.add_failure_detector(another_fd)
584+
585+
# Simulate failing command events that lead to a failure detection
586+
for _ in range(5):
587+
await mock_multi_db_config.event_dispatcher.dispatch_async(
588+
command_fail_event
589+
)
568590

569-
assert mock_fd.register_failure.call_count == 10
570-
assert another_fd.register_failure.call_count == 5
591+
assert mock_fd.register_failure.call_count == 10
592+
assert another_fd.register_failure.call_count == 5
571593

572594
@pytest.mark.asyncio
573595
@pytest.mark.parametrize(
@@ -593,19 +615,21 @@ async def test_add_new_health_check(
593615

594616
mock_hc.check_health.return_value = True
595617

596-
client = MultiDBClient(mock_multi_db_config)
597-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
598-
assert await client.set("key", "value") == "OK1"
599-
assert mock_hc.check_health.call_count == 9
618+
async with MultiDBClient(mock_multi_db_config) as client:
619+
assert (
620+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
621+
)
622+
assert await client.set("key", "value") == "OK1"
623+
assert len(mock_hc.check_health.call_args_list) == 9
600624

601-
another_hc = Mock(spec=HealthCheck)
602-
another_hc.check_health.return_value = True
625+
another_hc = Mock(spec=HealthCheck)
626+
another_hc.check_health.return_value = True
603627

604-
await client.add_health_check(another_hc)
605-
await client._check_db_health(mock_db1)
628+
await client.add_health_check(another_hc)
629+
await client._check_db_health(mock_db1)
606630

607-
assert mock_hc.check_health.call_count == 12
608-
assert another_hc.check_health.call_count == 3
631+
assert len(mock_hc.check_health.call_args_list) == 12
632+
assert len(another_hc.check_health.call_args_list) == 3
609633

610634
@pytest.mark.asyncio
611635
@pytest.mark.parametrize(
@@ -632,23 +656,25 @@ async def test_set_active_database(
632656

633657
mock_hc.check_health.return_value = True
634658

635-
client = MultiDBClient(mock_multi_db_config)
636-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
637-
assert await client.set("key", "value") == "OK1"
638-
assert mock_hc.check_health.call_count == 9
659+
async with MultiDBClient(mock_multi_db_config) as client:
660+
assert (
661+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
662+
)
663+
assert await client.set("key", "value") == "OK1"
664+
assert len(mock_hc.check_health.call_args_list) >= 9
639665

640-
await client.set_active_database(mock_db)
641-
assert await client.set("key", "value") == "OK"
666+
await client.set_active_database(mock_db)
667+
assert await client.set("key", "value") == "OK"
642668

643-
with pytest.raises(
644-
ValueError, match="Given database is not a member of database list"
645-
):
646-
await client.set_active_database(Mock(spec=AsyncDatabase))
669+
with pytest.raises(
670+
ValueError, match="Given database is not a member of database list"
671+
):
672+
await client.set_active_database(Mock(spec=AsyncDatabase))
647673

648-
mock_hc.check_health.return_value = False
674+
mock_hc.check_health.return_value = False
649675

650-
with pytest.raises(
651-
NoValidDatabaseException,
652-
match="Cannot set active database, database is unhealthy",
653-
):
654-
await client.set_active_database(mock_db1)
676+
with pytest.raises(
677+
NoValidDatabaseException,
678+
match="Cannot set active database, database is unhealthy",
679+
):
680+
await client.set_active_database(mock_db1)

0 commit comments

Comments
 (0)