Skip to content

Commit 06f2445

Browse files
committed
Fixing flaky tests - part 2
1 parent 6bbc5c7 commit 06f2445

File tree

3 files changed

+110
-33
lines changed

3 files changed

+110
-33
lines changed

tests/test_asyncio/test_multidb/test_client.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import threading
23
from unittest.mock import patch, AsyncMock, Mock
34

45
import pybreaker
@@ -66,27 +67,63 @@ async def test_execute_command_against_correct_db_on_successful_initialization(
6667
async def test_execute_command_against_correct_db_and_closed_circuit(
6768
self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc
6869
):
70+
"""
71+
Validates that commands are executed against the correct
72+
database when one database becomes unhealthy during initialization.
73+
Ensures the client selects the highest-weighted
74+
healthy database (mock_db1) and executes commands against it
75+
with a CLOSED circuit.
76+
"""
6977
databases = create_weighted_list(mock_db, mock_db1, mock_db2)
7078
mock_multi_db_config.health_checks = [mock_hc]
7179

7280
with patch.object(mock_multi_db_config, "databases", return_value=databases):
81+
mock_db.client.execute_command = AsyncMock(
82+
return_value="NOT_OK-->Response from unexpected db - mock_db"
83+
)
7384
mock_db1.client.execute_command = AsyncMock(return_value="OK1")
85+
mock_db2.client.execute_command = AsyncMock(
86+
return_value="NOT_OK-->Response from unexpected db - mock_db2"
87+
)
88+
89+
# Track health check runs across all databases
90+
health_check_run = 0
91+
run_lock = threading.Lock()
92+
93+
async def mock_check_health(database):
94+
nonlocal health_check_run
95+
96+
# Increment run counter for each health check call
97+
with run_lock:
98+
health_check_run += 1
99+
current_run = health_check_run
100+
101+
# Run 1 (health_check_run 1): mock_db unhealthy
102+
if current_run == 1:
103+
if database == mock_db:
104+
return False
105+
return True
106+
107+
# Run 2 (health_check_run 2-3): mock_db1 and mock_db2 healthy
108+
elif current_run <= 3:
109+
return True
110+
111+
# Run 3 (health_check_run 4-6): All databases healthy
112+
elif current_run <= 6:
113+
return True
74114

75-
mock_hc.check_health.side_effect = [
76-
False,
77-
True,
78-
True,
79-
True,
80-
True,
81-
True,
82-
True,
83-
]
115+
# Run 4 (health_check_run 7): mock_db2 unhealthy
116+
else:
117+
if database == mock_db2:
118+
return False
119+
return True
84120

121+
mock_hc.check_health.side_effect = mock_check_health
85122
client = MultiDBClient(mock_multi_db_config)
86123
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
87124
result = await client.set("key", "value")
88125
assert result == "OK1"
89-
assert mock_hc.check_health.call_count == 7
126+
assert mock_hc.check_health.call_count >= 7
90127

91128
assert mock_db.circuit.state == CBState.CLOSED
92129
assert mock_db1.circuit.state == CBState.CLOSED

tests/test_auth/test_token_manager.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,13 @@ def test_token_renewal_with_skip_initial(self):
174174
mock_provider.request_token.side_effect = [
175175
SimpleToken(
176176
"value",
177-
(datetime.now(timezone.utc).timestamp() * 1000) + 50,
177+
(datetime.now(timezone.utc).timestamp() * 1000) + 1000,
178178
(datetime.now(timezone.utc).timestamp() * 1000),
179179
{"oid": "test"},
180180
),
181181
SimpleToken(
182182
"value",
183-
(datetime.now(timezone.utc).timestamp() * 1000) + 150,
183+
(datetime.now(timezone.utc).timestamp() * 1000) + 1500,
184184
(datetime.now(timezone.utc).timestamp() * 1000),
185185
{"oid": "test"},
186186
),
@@ -194,12 +194,12 @@ def on_next(token):
194194
mock_listener.on_next = on_next
195195

196196
retry_policy = RetryPolicy(3, 10)
197-
config = TokenManagerConfig(1, 0, 1000, retry_policy)
197+
config = TokenManagerConfig(0.5, 0, 1000, retry_policy)
198198
mgr = TokenManager(mock_provider, config)
199199
mgr.start(mock_listener, skip_initial=True)
200-
# Should be less than a 0.1, or it will be flacky due to
201-
# additional token renewal.
202-
sleep(0.1)
200+
assert len(tokens) == 0
201+
202+
sleep(0.6)
203203

204204
assert len(tokens) > 0
205205

@@ -210,19 +210,19 @@ async def test_async_token_renewal_with_skip_initial(self):
210210
mock_provider.request_token.side_effect = [
211211
SimpleToken(
212212
"value",
213-
(datetime.now(timezone.utc).timestamp() * 1000) + 100,
213+
(datetime.now(timezone.utc).timestamp() * 1000) + 1000,
214214
(datetime.now(timezone.utc).timestamp() * 1000),
215215
{"oid": "test"},
216216
),
217217
SimpleToken(
218218
"value",
219-
(datetime.now(timezone.utc).timestamp() * 1000) + 120,
219+
(datetime.now(timezone.utc).timestamp() * 1000) + 1200,
220220
(datetime.now(timezone.utc).timestamp() * 1000),
221221
{"oid": "test"},
222222
),
223223
SimpleToken(
224224
"value",
225-
(datetime.now(timezone.utc).timestamp() * 1000) + 140,
225+
(datetime.now(timezone.utc).timestamp() * 1000) + 1400,
226226
(datetime.now(timezone.utc).timestamp() * 1000),
227227
{"oid": "test"},
228228
),
@@ -236,13 +236,12 @@ async def on_next(token):
236236
mock_listener.on_next = on_next
237237

238238
retry_policy = RetryPolicy(3, 10)
239-
config = TokenManagerConfig(1, 0, 1000, retry_policy)
239+
config = TokenManagerConfig(0.5, 0, 1000, retry_policy)
240240
mgr = TokenManager(mock_provider, config)
241241
await mgr.start_async(mock_listener, skip_initial=True)
242-
# Should be less than a 0.1, or it will be flacky
243-
# due to additional token renewal.
244-
await asyncio.sleep(0.2)
242+
assert len(tokens) == 0
245243

244+
await asyncio.sleep(0.6)
246245
assert len(tokens) > 0
247246

248247
def test_success_token_renewal_with_retry(self):

tests/test_multidb/test_client.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,27 +65,68 @@ def test_execute_command_against_correct_db_on_successful_initialization(
6565
def test_execute_command_against_correct_db_and_closed_circuit(
6666
self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc
6767
):
68+
"""
69+
Validates that commands are executed against the correct
70+
database when one database becomes unhealthy during initialization.
71+
Ensures the client selects the highest-weighted
72+
healthy database (mock_db1) and executes commands against it
73+
with a CLOSED circuit.
74+
"""
6875
databases = create_weighted_list(mock_db, mock_db1, mock_db2)
6976
mock_multi_db_config.health_checks = [mock_hc]
7077

7178
with patch.object(mock_multi_db_config, "databases", return_value=databases):
79+
mock_db.client.execute_command = MagicMock(
80+
return_value="NOT_OK-->Response from unexpected db - mock_db"
81+
)
7282
mock_db1.client.execute_command = MagicMock(return_value="OK1")
83+
mock_db2.client.execute_command = MagicMock(
84+
return_value="NOT_OK-->Response from unexpected db - mock_db2"
85+
)
86+
87+
# Track health check runs across all databases
88+
health_check_run = 0
89+
run_lock = threading.Lock()
90+
91+
def mock_check_health(database):
92+
nonlocal health_check_run
93+
94+
# Increment run counter for each health check call
95+
with run_lock:
96+
health_check_run += 1
97+
current_run = health_check_run
98+
99+
# Run 1 (health_check_run 1): mock_db unhealthy
100+
if current_run == 1:
101+
if database == mock_db:
102+
return False
103+
return True
104+
105+
# Run 2 (health_check_run 2-3): mock_db1 and mock_db2 healthy
106+
elif current_run <= 3:
107+
return True
108+
109+
# Run 3 (health_check_run 4-6): All databases healthy
110+
elif current_run <= 6:
111+
return True
112+
113+
# Run 4 (health_check_run 7): mock_db2 unhealthy
114+
elif current_run == 7:
115+
if database == mock_db2:
116+
return False
117+
return True
118+
119+
# Background health checks (calls 8+): All healthy
120+
else:
121+
return True
73122

74-
mock_hc.check_health.side_effect = [
75-
False,
76-
True,
77-
True,
78-
True,
79-
True,
80-
True,
81-
True,
82-
]
123+
mock_hc.check_health.side_effect = mock_check_health
83124

84125
client = MultiDBClient(mock_multi_db_config)
85126
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
86127
result = client.set("key", "value")
87128
assert result == "OK1"
88-
assert mock_hc.check_health.call_count == 7
129+
assert mock_hc.check_health.call_count >= 7
89130

90131
assert mock_db.circuit.state == CBState.CLOSED
91132
assert mock_db1.circuit.state == CBState.CLOSED

0 commit comments

Comments
 (0)