Skip to content

Commit 80a5e5a

Browse files
committed
Fixing flaky tests - part 2
1 parent 6bbc5c7 commit 80a5e5a

File tree

3 files changed

+98
-33
lines changed

3 files changed

+98
-33
lines changed

tests/test_asyncio/test_multidb/test_client.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import threading
23
from unittest.mock import patch, AsyncMock, Mock
34

45
import pybreaker
@@ -66,27 +67,57 @@ async def test_execute_command_against_correct_db_on_successful_initialization(
6667
async def test_execute_command_against_correct_db_and_closed_circuit(
6768
self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc
6869
):
70+
"""
71+
Validates that commands are executed against the correct
72+
database when one database becomes unhealthy during initialization.
73+
Ensures the client selects the highest-weighted
74+
healthy database (mock_db1) and executes commands against it
75+
with a CLOSED circuit.
76+
"""
6977
databases = create_weighted_list(mock_db, mock_db1, mock_db2)
7078
mock_multi_db_config.health_checks = [mock_hc]
7179

7280
with patch.object(mock_multi_db_config, "databases", return_value=databases):
7381
mock_db1.client.execute_command = AsyncMock(return_value="OK1")
7482

75-
mock_hc.check_health.side_effect = [
76-
False,
77-
True,
78-
True,
79-
True,
80-
True,
81-
True,
82-
True,
83-
]
83+
# Track health check runs across all databases
84+
health_check_run = 0
85+
run_lock = threading.Lock()
8486

87+
async def mock_check_health(database):
88+
nonlocal health_check_run
89+
90+
# Increment run counter for each health check call
91+
with run_lock:
92+
health_check_run += 1
93+
current_run = health_check_run
94+
95+
# Run 1 (health_check_run 1): mock_db unhealthy
96+
if current_run == 1:
97+
if database == mock_db:
98+
return False
99+
return True
100+
101+
# Run 2 (health_check_run 2-3): mock_db1 and mock_db2 healthy
102+
elif current_run <= 3:
103+
return True
104+
105+
# Run 3 (health_check_run 4-6): All databases healthy
106+
elif current_run <= 6:
107+
return True
108+
109+
# Run 4 (health_check_run 7): mock_db2 unhealthy
110+
else:
111+
if database == mock_db2:
112+
return False
113+
return True
114+
115+
mock_hc.check_health.side_effect = mock_check_health
85116
client = MultiDBClient(mock_multi_db_config)
86117
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
87118
result = await client.set("key", "value")
88119
assert result == "OK1"
89-
assert mock_hc.check_health.call_count == 7
120+
assert mock_hc.check_health.call_count >= 7
90121

91122
assert mock_db.circuit.state == CBState.CLOSED
92123
assert mock_db1.circuit.state == CBState.CLOSED

tests/test_auth/test_token_manager.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,13 @@ def test_token_renewal_with_skip_initial(self):
174174
mock_provider.request_token.side_effect = [
175175
SimpleToken(
176176
"value",
177-
(datetime.now(timezone.utc).timestamp() * 1000) + 50,
177+
(datetime.now(timezone.utc).timestamp() * 1000) + 1000,
178178
(datetime.now(timezone.utc).timestamp() * 1000),
179179
{"oid": "test"},
180180
),
181181
SimpleToken(
182182
"value",
183-
(datetime.now(timezone.utc).timestamp() * 1000) + 150,
183+
(datetime.now(timezone.utc).timestamp() * 1000) + 1500,
184184
(datetime.now(timezone.utc).timestamp() * 1000),
185185
{"oid": "test"},
186186
),
@@ -194,12 +194,12 @@ def on_next(token):
194194
mock_listener.on_next = on_next
195195

196196
retry_policy = RetryPolicy(3, 10)
197-
config = TokenManagerConfig(1, 0, 1000, retry_policy)
197+
config = TokenManagerConfig(0.5, 0, 1000, retry_policy)
198198
mgr = TokenManager(mock_provider, config)
199199
mgr.start(mock_listener, skip_initial=True)
200-
# Should be less than a 0.1, or it will be flacky due to
201-
# additional token renewal.
202-
sleep(0.1)
200+
assert len(tokens) == 0
201+
202+
sleep(0.5)
203203

204204
assert len(tokens) > 0
205205

@@ -210,19 +210,19 @@ async def test_async_token_renewal_with_skip_initial(self):
210210
mock_provider.request_token.side_effect = [
211211
SimpleToken(
212212
"value",
213-
(datetime.now(timezone.utc).timestamp() * 1000) + 100,
213+
(datetime.now(timezone.utc).timestamp() * 1000) + 1000,
214214
(datetime.now(timezone.utc).timestamp() * 1000),
215215
{"oid": "test"},
216216
),
217217
SimpleToken(
218218
"value",
219-
(datetime.now(timezone.utc).timestamp() * 1000) + 120,
219+
(datetime.now(timezone.utc).timestamp() * 1000) + 1200,
220220
(datetime.now(timezone.utc).timestamp() * 1000),
221221
{"oid": "test"},
222222
),
223223
SimpleToken(
224224
"value",
225-
(datetime.now(timezone.utc).timestamp() * 1000) + 140,
225+
(datetime.now(timezone.utc).timestamp() * 1000) + 1400,
226226
(datetime.now(timezone.utc).timestamp() * 1000),
227227
{"oid": "test"},
228228
),
@@ -236,13 +236,12 @@ async def on_next(token):
236236
mock_listener.on_next = on_next
237237

238238
retry_policy = RetryPolicy(3, 10)
239-
config = TokenManagerConfig(1, 0, 1000, retry_policy)
239+
config = TokenManagerConfig(0.5, 0, 1000, retry_policy)
240240
mgr = TokenManager(mock_provider, config)
241241
await mgr.start_async(mock_listener, skip_initial=True)
242-
# Should be less than a 0.1, or it will be flacky
243-
# due to additional token renewal.
244-
await asyncio.sleep(0.2)
242+
assert len(tokens) == 0
245243

244+
await asyncio.sleep(0.5)
246245
assert len(tokens) > 0
247246

248247
def test_success_token_renewal_with_retry(self):

tests/test_multidb/test_client.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,27 +65,62 @@ def test_execute_command_against_correct_db_on_successful_initialization(
6565
def test_execute_command_against_correct_db_and_closed_circuit(
6666
self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc
6767
):
68+
"""
69+
Validates that commands are executed against the correct
70+
database when one database becomes unhealthy during initialization.
71+
Ensures the client selects the highest-weighted
72+
healthy database (mock_db1) and executes commands against it
73+
with a CLOSED circuit.
74+
"""
6875
databases = create_weighted_list(mock_db, mock_db1, mock_db2)
6976
mock_multi_db_config.health_checks = [mock_hc]
7077

7178
with patch.object(mock_multi_db_config, "databases", return_value=databases):
7279
mock_db1.client.execute_command = MagicMock(return_value="OK1")
7380

74-
mock_hc.check_health.side_effect = [
75-
False,
76-
True,
77-
True,
78-
True,
79-
True,
80-
True,
81-
True,
82-
]
81+
# Track health check runs across all databases
82+
health_check_run = 0
83+
run_lock = threading.Lock()
84+
85+
def mock_check_health(database):
86+
nonlocal health_check_run
87+
88+
# Increment run counter for each health check call
89+
with run_lock:
90+
health_check_run += 1
91+
current_run = health_check_run
92+
93+
# Run 1 (health_check_run 1): mock_db unhealthy
94+
if current_run == 1:
95+
if database == mock_db:
96+
return False
97+
return True
98+
99+
# Run 2 (health_check_run 2-3): mock_db1 and mock_db2 healthy
100+
elif current_run <= 3:
101+
return True
102+
103+
# Run 3 (health_check_run 4-6): All databases healthy
104+
elif current_run <= 6:
105+
return True
106+
107+
# Run 4 (health_check_run 7): mock_db2 unhealthy
108+
elif current_run == 7:
109+
if database == mock_db2:
110+
return False
111+
return True
112+
113+
# Background health checks (calls 8+): All healthy
114+
else:
115+
return True
116+
117+
mock_hc.check_health.side_effect = mock_check_health
83118

84119
client = MultiDBClient(mock_multi_db_config)
85120
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
86121
result = client.set("key", "value")
87122
assert result == "OK1"
88-
assert mock_hc.check_health.call_count == 7
123+
assert mock_hc.check_health.call_count >= 7
89124

90125
assert mock_db.circuit.state == CBState.CLOSED
91126
assert mock_db1.circuit.state == CBState.CLOSED

0 commit comments

Comments
 (0)