@@ -50,15 +50,17 @@ async def test_executes_pipeline_against_correct_db(
5050
5151 mock_hc .check_health .return_value = True
5252
53- client = MultiDBClient (mock_multi_db_config )
54- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
53+ async with MultiDBClient (mock_multi_db_config ) as client :
54+ assert (
55+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
56+ )
5557
56- pipe = client .pipeline ()
57- pipe .set ("key1" , "value1" )
58- pipe .get ("key1" )
58+ pipe = client .pipeline ()
59+ pipe .set ("key1" , "value1" )
60+ pipe .get ("key1" )
5961
60- assert await pipe .execute () == ["OK1" , "value1" ]
61- assert mock_hc .check_health .call_count = = 9
62+ assert await pipe .execute () == ["OK1" , "value1" ]
63+ assert len ( mock_hc .check_health .call_args_list ) > = 9
6264
6365 @pytest .mark .asyncio
6466 @pytest .mark .parametrize (
@@ -88,29 +90,28 @@ async def test_execute_pipeline_against_correct_db_and_closed_circuit(
8890 pipe .execute .return_value = ["OK1" , "value1" ]
8991 mock_db1 .client .pipeline .return_value = pipe
9092
91- mock_hc .check_health .side_effect = [
92- False ,
93- True ,
94- True ,
95- True ,
96- True ,
97- True ,
98- True ,
99- ]
93+ async def mock_check_health (database ):
94+ if database == mock_db2 :
95+ return False
96+ else :
97+ return True
10098
101- client = MultiDBClient (mock_multi_db_config )
102- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
99+ mock_hc .check_health .side_effect = mock_check_health
100+ async with MultiDBClient (mock_multi_db_config ) as client :
101+ assert (
102+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
103+ )
103104
104- async with client .pipeline () as pipe :
105- pipe .set ("key1" , "value1" )
106- pipe .get ("key1" )
105+ async with client .pipeline () as pipe :
106+ pipe .set ("key1" , "value1" )
107+ pipe .get ("key1" )
107108
108- assert await pipe .execute () == ["OK1" , "value1" ]
109- assert mock_hc .check_health .call_count = = 7
109+ assert await pipe .execute () == ["OK1" , "value1" ]
110+ assert len ( mock_hc .check_health .call_args_list ) > = 7
110111
111- assert mock_db .circuit .state == CBState .CLOSED
112- assert mock_db1 .circuit .state == CBState .CLOSED
113- assert mock_db2 .circuit .state == CBState .OPEN
112+ assert mock_db .circuit .state == CBState .CLOSED
113+ assert mock_db1 .circuit .state == CBState .CLOSED
114+ assert mock_db2 .circuit .state == CBState .OPEN
114115
115116 @pytest .mark .asyncio
116117 @pytest .mark .parametrize (
@@ -291,15 +292,19 @@ async def test_executes_transaction_against_correct_db(
291292
292293 mock_hc .check_health .return_value = True
293294
294- client = MultiDBClient (mock_multi_db_config )
295- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
295+ async with MultiDBClient (mock_multi_db_config ) as client :
296+ assert (
297+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
298+ )
296299
297- async def callback (pipe : Pipeline ):
298- pipe .set ("key1" , "value1" )
299- pipe .get ("key1" )
300+ async def callback (pipe : Pipeline ):
301+ pipe .set ("key1" , "value1" )
302+ pipe .get ("key1" )
300303
301- assert await client .transaction (callback ) == ["OK1" , "value1" ]
302- assert mock_hc .check_health .call_count == 9
304+ assert await client .transaction (callback ) == ["OK1" , "value1" ]
305+ # if we assume at least 3 health checks have run per each database
306+ # we should have at least 9 total calls
307+ assert len (mock_hc .check_health .call_args_list ) >= 9
303308
304309 @pytest .mark .asyncio
305310 @pytest .mark .parametrize (
@@ -327,29 +332,29 @@ async def test_execute_transaction_against_correct_db_and_closed_circuit(
327332 ):
328333 mock_db1 .client .transaction .return_value = ["OK1" , "value1" ]
329334
330- mock_hc .check_health .side_effect = [
331- False ,
332- True ,
333- True ,
334- True ,
335- True ,
336- True ,
337- True ,
338- ]
335+ async def mock_check_health (database ):
336+ if database == mock_db2 :
337+ return False
338+ else :
339+ return True
339340
340- client = MultiDBClient (mock_multi_db_config )
341- assert mock_multi_db_config .failover_strategy .set_databases .call_count == 1
341+ mock_hc .check_health .side_effect = mock_check_health
342342
343- async def callback (pipe : Pipeline ):
344- pipe .set ("key1" , "value1" )
345- pipe .get ("key1" )
343+ async with MultiDBClient (mock_multi_db_config ) as client :
344+ assert (
345+ mock_multi_db_config .failover_strategy .set_databases .call_count == 1
346+ )
346347
347- assert await client .transaction (callback ) == ["OK1" , "value1" ]
348- assert mock_hc .check_health .call_count == 7
348+ async def callback (pipe : Pipeline ):
349+ pipe .set ("key1" , "value1" )
350+ pipe .get ("key1" )
349351
350- assert mock_db .circuit .state == CBState .CLOSED
351- assert mock_db1 .circuit .state == CBState .CLOSED
352- assert mock_db2 .circuit .state == CBState .OPEN
352+ assert await client .transaction (callback ) == ["OK1" , "value1" ]
353+ assert len (mock_hc .check_health .call_args_list ) >= 7
354+
355+ assert mock_db .circuit .state == CBState .CLOSED
356+ assert mock_db1 .circuit .state == CBState .CLOSED
357+ assert mock_db2 .circuit .state == CBState .OPEN
353358
354359 @pytest .mark .asyncio
355360 @pytest .mark .parametrize (
@@ -455,34 +460,34 @@ async def mock_check_health(database):
455460 mock_multi_db_config .health_check_interval = 0.1
456461 mock_multi_db_config .failover_strategy = WeightBasedFailoverStrategy ()
457462
458- client = MultiDBClient (mock_multi_db_config )
463+ async with MultiDBClient (mock_multi_db_config ) as client :
459464
460- async def callback (pipe : Pipeline ):
461- pipe .set ("key1" , "value1" )
462- pipe .get ("key1" )
465+ async def callback (pipe : Pipeline ):
466+ pipe .set ("key1" , "value1" )
467+ pipe .get ("key1" )
463468
464- assert await client .transaction (callback ) == ["OK1" , "value" ]
469+ assert await client .transaction (callback ) == ["OK1" , "value" ]
465470
466- # Wait for mock_db1 to become unhealthy
467- assert await db1_became_unhealthy .wait (), (
468- "Timeout waiting for mock_db1 to become unhealthy"
469- )
470- await asyncio .sleep (0.01 )
471+ # Wait for mock_db1 to become unhealthy
472+ assert await db1_became_unhealthy .wait (), (
473+ "Timeout waiting for mock_db1 to become unhealthy"
474+ )
475+ await asyncio .sleep (0.01 )
471476
472- assert await client .transaction (callback ) == ["OK2" , "value" ]
477+ assert await client .transaction (callback ) == ["OK2" , "value" ]
473478
474- # Wait for mock_db2 to become unhealthy
475- assert await db2_became_unhealthy .wait (), (
476- "Timeout waiting for mock_db1 to become unhealthy"
477- )
478- await asyncio .sleep (0.01 )
479+ # Wait for mock_db2 to become unhealthy
480+ assert await db2_became_unhealthy .wait (), (
481+ "Timeout waiting for mock_db1 to become unhealthy"
482+ )
483+ await asyncio .sleep (0.01 )
479484
480- assert await client .transaction (callback ) == ["OK" , "value" ]
485+ assert await client .transaction (callback ) == ["OK" , "value" ]
481486
482- # Wait for mock_db to become unhealthy
483- assert await db_became_unhealthy .wait (), (
484- "Timeout waiting for mock_db1 to become unhealthy"
485- )
486- await asyncio .sleep (0.01 )
487+ # Wait for mock_db to become unhealthy
488+ assert await db_became_unhealthy .wait (), (
489+ "Timeout waiting for mock_db1 to become unhealthy"
490+ )
491+ await asyncio .sleep (0.01 )
487492
488- assert await client .transaction (callback ) == ["OK1" , "value" ]
493+ assert await client .transaction (callback ) == ["OK1" , "value" ]
0 commit comments