Skip to content

Commit 5a8a46b

Browse files
committed
Fixing flaky tests - part 2
1 parent 6bbc5c7 commit 5a8a46b

File tree

8 files changed

+585
-458
lines changed

8 files changed

+585
-458
lines changed

dev_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pytest==8.3.4 ; platform_python_implementation == "PyPy"
1212
pytest-asyncio>=0.23.0
1313
pytest-asyncio==1.1.0 ; platform_python_implementation == "PyPy"
1414
pytest-cov
15+
coverage<7.11.1
1516
pytest-cov==6.0.0 ; platform_python_implementation == "PyPy"
1617
coverage==7.6.12 ; platform_python_implementation == "PyPy"
1718
pytest-profiling==1.8.1

redis/multidb/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,13 @@ def _on_circuit_state_change_callback(
301301
)
302302

303303
def close(self):
304-
self.command_executor.active_database.client.close()
304+
"""
305+
Closes the client and all its resources.
306+
"""
307+
if self._bg_scheduler:
308+
self._bg_scheduler.stop()
309+
if self.command_executor.active_database:
310+
self.command_executor.active_database.client.close()
305311

306312

307313
def _half_open_circuit(circuit: CircuitBreaker):

tests/test_asyncio/test_multidb/test_client.py

Lines changed: 143 additions & 117 deletions
Large diffs are not rendered by default.

tests/test_asyncio/test_multidb/test_pipeline.py

Lines changed: 80 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,17 @@ async def test_executes_pipeline_against_correct_db(
5050

5151
mock_hc.check_health.return_value = True
5252

53-
client = MultiDBClient(mock_multi_db_config)
54-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
53+
async with MultiDBClient(mock_multi_db_config) as client:
54+
assert (
55+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
56+
)
5557

56-
pipe = client.pipeline()
57-
pipe.set("key1", "value1")
58-
pipe.get("key1")
58+
pipe = client.pipeline()
59+
pipe.set("key1", "value1")
60+
pipe.get("key1")
5961

60-
assert await pipe.execute() == ["OK1", "value1"]
61-
assert mock_hc.check_health.call_count == 9
62+
assert await pipe.execute() == ["OK1", "value1"]
63+
assert len(mock_hc.check_health.call_args_list) >= 9
6264

6365
@pytest.mark.asyncio
6466
@pytest.mark.parametrize(
@@ -88,29 +90,28 @@ async def test_execute_pipeline_against_correct_db_and_closed_circuit(
8890
pipe.execute.return_value = ["OK1", "value1"]
8991
mock_db1.client.pipeline.return_value = pipe
9092

91-
mock_hc.check_health.side_effect = [
92-
False,
93-
True,
94-
True,
95-
True,
96-
True,
97-
True,
98-
True,
99-
]
93+
async def mock_check_health(database):
94+
if database == mock_db2:
95+
return False
96+
else:
97+
return True
10098

101-
client = MultiDBClient(mock_multi_db_config)
102-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
99+
mock_hc.check_health.side_effect = mock_check_health
100+
async with MultiDBClient(mock_multi_db_config) as client:
101+
assert (
102+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
103+
)
103104

104-
async with client.pipeline() as pipe:
105-
pipe.set("key1", "value1")
106-
pipe.get("key1")
105+
async with client.pipeline() as pipe:
106+
pipe.set("key1", "value1")
107+
pipe.get("key1")
107108

108-
assert await pipe.execute() == ["OK1", "value1"]
109-
assert mock_hc.check_health.call_count == 7
109+
assert await pipe.execute() == ["OK1", "value1"]
110+
assert len(mock_hc.check_health.call_args_list) >= 7
110111

111-
assert mock_db.circuit.state == CBState.CLOSED
112-
assert mock_db1.circuit.state == CBState.CLOSED
113-
assert mock_db2.circuit.state == CBState.OPEN
112+
assert mock_db.circuit.state == CBState.CLOSED
113+
assert mock_db1.circuit.state == CBState.CLOSED
114+
assert mock_db2.circuit.state == CBState.OPEN
114115

115116
@pytest.mark.asyncio
116117
@pytest.mark.parametrize(
@@ -291,15 +292,19 @@ async def test_executes_transaction_against_correct_db(
291292

292293
mock_hc.check_health.return_value = True
293294

294-
client = MultiDBClient(mock_multi_db_config)
295-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
295+
async with MultiDBClient(mock_multi_db_config) as client:
296+
assert (
297+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
298+
)
296299

297-
async def callback(pipe: Pipeline):
298-
pipe.set("key1", "value1")
299-
pipe.get("key1")
300+
async def callback(pipe: Pipeline):
301+
pipe.set("key1", "value1")
302+
pipe.get("key1")
300303

301-
assert await client.transaction(callback) == ["OK1", "value1"]
302-
assert mock_hc.check_health.call_count == 9
304+
assert await client.transaction(callback) == ["OK1", "value1"]
305+
# if we assume at least 3 health checks have run per each database
306+
# we should have at least 9 total calls
307+
assert len(mock_hc.check_health.call_args_list) >= 9
303308

304309
@pytest.mark.asyncio
305310
@pytest.mark.parametrize(
@@ -327,29 +332,29 @@ async def test_execute_transaction_against_correct_db_and_closed_circuit(
327332
):
328333
mock_db1.client.transaction.return_value = ["OK1", "value1"]
329334

330-
mock_hc.check_health.side_effect = [
331-
False,
332-
True,
333-
True,
334-
True,
335-
True,
336-
True,
337-
True,
338-
]
335+
async def mock_check_health(database):
336+
if database == mock_db2:
337+
return False
338+
else:
339+
return True
339340

340-
client = MultiDBClient(mock_multi_db_config)
341-
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
341+
mock_hc.check_health.side_effect = mock_check_health
342342

343-
async def callback(pipe: Pipeline):
344-
pipe.set("key1", "value1")
345-
pipe.get("key1")
343+
async with MultiDBClient(mock_multi_db_config) as client:
344+
assert (
345+
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
346+
)
346347

347-
assert await client.transaction(callback) == ["OK1", "value1"]
348-
assert mock_hc.check_health.call_count == 7
348+
async def callback(pipe: Pipeline):
349+
pipe.set("key1", "value1")
350+
pipe.get("key1")
349351

350-
assert mock_db.circuit.state == CBState.CLOSED
351-
assert mock_db1.circuit.state == CBState.CLOSED
352-
assert mock_db2.circuit.state == CBState.OPEN
352+
assert await client.transaction(callback) == ["OK1", "value1"]
353+
assert len(mock_hc.check_health.call_args_list) >= 7
354+
355+
assert mock_db.circuit.state == CBState.CLOSED
356+
assert mock_db1.circuit.state == CBState.CLOSED
357+
assert mock_db2.circuit.state == CBState.OPEN
353358

354359
@pytest.mark.asyncio
355360
@pytest.mark.parametrize(
@@ -455,34 +460,34 @@ async def mock_check_health(database):
455460
mock_multi_db_config.health_check_interval = 0.1
456461
mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy()
457462

458-
client = MultiDBClient(mock_multi_db_config)
463+
async with MultiDBClient(mock_multi_db_config) as client:
459464

460-
async def callback(pipe: Pipeline):
461-
pipe.set("key1", "value1")
462-
pipe.get("key1")
465+
async def callback(pipe: Pipeline):
466+
pipe.set("key1", "value1")
467+
pipe.get("key1")
463468

464-
assert await client.transaction(callback) == ["OK1", "value"]
469+
assert await client.transaction(callback) == ["OK1", "value"]
465470

466-
# Wait for mock_db1 to become unhealthy
467-
assert await db1_became_unhealthy.wait(), (
468-
"Timeout waiting for mock_db1 to become unhealthy"
469-
)
470-
await asyncio.sleep(0.01)
471+
# Wait for mock_db1 to become unhealthy
472+
assert await db1_became_unhealthy.wait(), (
473+
"Timeout waiting for mock_db1 to become unhealthy"
474+
)
475+
await asyncio.sleep(0.01)
471476

472-
assert await client.transaction(callback) == ["OK2", "value"]
477+
assert await client.transaction(callback) == ["OK2", "value"]
473478

474-
# Wait for mock_db2 to become unhealthy
475-
assert await db2_became_unhealthy.wait(), (
476-
"Timeout waiting for mock_db1 to become unhealthy"
477-
)
478-
await asyncio.sleep(0.01)
479+
# Wait for mock_db2 to become unhealthy
480+
assert await db2_became_unhealthy.wait(), (
481+
"Timeout waiting for mock_db1 to become unhealthy"
482+
)
483+
await asyncio.sleep(0.01)
479484

480-
assert await client.transaction(callback) == ["OK", "value"]
485+
assert await client.transaction(callback) == ["OK", "value"]
481486

482-
# Wait for mock_db to become unhealthy
483-
assert await db_became_unhealthy.wait(), (
484-
"Timeout waiting for mock_db1 to become unhealthy"
485-
)
486-
await asyncio.sleep(0.01)
487+
# Wait for mock_db to become unhealthy
488+
assert await db_became_unhealthy.wait(), (
489+
"Timeout waiting for mock_db1 to become unhealthy"
490+
)
491+
await asyncio.sleep(0.01)
487492

488-
assert await client.transaction(callback) == ["OK1", "value"]
493+
assert await client.transaction(callback) == ["OK1", "value"]

tests/test_auth/test_token_manager.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,13 @@ def test_token_renewal_with_skip_initial(self):
174174
mock_provider.request_token.side_effect = [
175175
SimpleToken(
176176
"value",
177-
(datetime.now(timezone.utc).timestamp() * 1000) + 50,
177+
(datetime.now(timezone.utc).timestamp() * 1000) + 1000,
178178
(datetime.now(timezone.utc).timestamp() * 1000),
179179
{"oid": "test"},
180180
),
181181
SimpleToken(
182182
"value",
183-
(datetime.now(timezone.utc).timestamp() * 1000) + 150,
183+
(datetime.now(timezone.utc).timestamp() * 1000) + 1500,
184184
(datetime.now(timezone.utc).timestamp() * 1000),
185185
{"oid": "test"},
186186
),
@@ -194,12 +194,12 @@ def on_next(token):
194194
mock_listener.on_next = on_next
195195

196196
retry_policy = RetryPolicy(3, 10)
197-
config = TokenManagerConfig(1, 0, 1000, retry_policy)
197+
config = TokenManagerConfig(0.5, 0, 1000, retry_policy)
198198
mgr = TokenManager(mock_provider, config)
199199
mgr.start(mock_listener, skip_initial=True)
200-
# Should be less than a 0.1, or it will be flacky due to
201-
# additional token renewal.
202-
sleep(0.1)
200+
assert len(tokens) == 0
201+
202+
sleep(0.6)
203203

204204
assert len(tokens) > 0
205205

@@ -210,19 +210,19 @@ async def test_async_token_renewal_with_skip_initial(self):
210210
mock_provider.request_token.side_effect = [
211211
SimpleToken(
212212
"value",
213-
(datetime.now(timezone.utc).timestamp() * 1000) + 100,
213+
(datetime.now(timezone.utc).timestamp() * 1000) + 1000,
214214
(datetime.now(timezone.utc).timestamp() * 1000),
215215
{"oid": "test"},
216216
),
217217
SimpleToken(
218218
"value",
219-
(datetime.now(timezone.utc).timestamp() * 1000) + 120,
219+
(datetime.now(timezone.utc).timestamp() * 1000) + 1200,
220220
(datetime.now(timezone.utc).timestamp() * 1000),
221221
{"oid": "test"},
222222
),
223223
SimpleToken(
224224
"value",
225-
(datetime.now(timezone.utc).timestamp() * 1000) + 140,
225+
(datetime.now(timezone.utc).timestamp() * 1000) + 1400,
226226
(datetime.now(timezone.utc).timestamp() * 1000),
227227
{"oid": "test"},
228228
),
@@ -236,13 +236,12 @@ async def on_next(token):
236236
mock_listener.on_next = on_next
237237

238238
retry_policy = RetryPolicy(3, 10)
239-
config = TokenManagerConfig(1, 0, 1000, retry_policy)
239+
config = TokenManagerConfig(0.5, 0, 1000, retry_policy)
240240
mgr = TokenManager(mock_provider, config)
241241
await mgr.start_async(mock_listener, skip_initial=True)
242-
# Should be less than a 0.1, or it will be flacky
243-
# due to additional token renewal.
244-
await asyncio.sleep(0.2)
242+
assert len(tokens) == 0
245243

244+
await asyncio.sleep(0.6)
246245
assert len(tokens) > 0
247246

248247
def test_success_token_renewal_with_retry(self):

tests/test_background.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def callback(arg1: str, arg2: int):
3939
],
4040
)
4141
def test_run_recurring(self, interval, timeout, call_count):
42-
execute_counter = 0
42+
execute_counter = []
4343
one = "arg1"
4444
two = 9999
4545

@@ -48,18 +48,18 @@ def callback(arg1: str, arg2: int):
4848
nonlocal one
4949
nonlocal two
5050

51-
execute_counter += 1
51+
execute_counter.append(1)
5252

5353
assert arg1 == one
5454
assert arg2 == two
5555

5656
scheduler = BackgroundScheduler()
5757
scheduler.run_recurring(interval, callback, one, two)
58-
assert execute_counter == 0
58+
assert len(execute_counter) == 0
5959

6060
sleep(timeout)
6161

62-
assert execute_counter == call_count
62+
assert len(execute_counter) == call_count
6363

6464
@pytest.mark.asyncio
6565
@pytest.mark.parametrize(
@@ -71,7 +71,7 @@ def callback(arg1: str, arg2: int):
7171
],
7272
)
7373
async def test_run_recurring_async(self, interval, timeout, call_count):
74-
execute_counter = 0
74+
execute_counter = []
7575
one = "arg1"
7676
two = 9999
7777

@@ -80,15 +80,15 @@ async def callback(arg1: str, arg2: int):
8080
nonlocal one
8181
nonlocal two
8282

83-
execute_counter += 1
83+
execute_counter.append(1)
8484

8585
assert arg1 == one
8686
assert arg2 == two
8787

8888
scheduler = BackgroundScheduler()
8989
await scheduler.run_recurring_async(interval, callback, one, two)
90-
assert execute_counter == 0
90+
assert len(execute_counter) == 0
9191

9292
await asyncio.sleep(timeout)
9393

94-
assert execute_counter == call_count
94+
assert len(execute_counter) == call_count

0 commit comments

Comments
 (0)