Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pytest==8.3.4 ; platform_python_implementation == "PyPy"
pytest-asyncio>=0.23.0
pytest-asyncio==1.1.0 ; platform_python_implementation == "PyPy"
pytest-cov
coverage<7.11.1
pytest-cov==6.0.0 ; platform_python_implementation == "PyPy"
coverage==7.6.12 ; platform_python_implementation == "PyPy"
pytest-profiling==1.8.1
Expand Down
8 changes: 7 additions & 1 deletion redis/multidb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,13 @@ def _on_circuit_state_change_callback(
)

def close(self):
self.command_executor.active_database.client.close()
"""
Closes the client and all its resources.
"""
if self._bg_scheduler:
self._bg_scheduler.stop()
if self.command_executor.active_database:
self.command_executor.active_database.client.close()


def _half_open_circuit(circuit: CircuitBreaker):
Expand Down
260 changes: 143 additions & 117 deletions tests/test_asyncio/test_multidb/test_client.py

Large diffs are not rendered by default.

155 changes: 80 additions & 75 deletions tests/test_asyncio/test_multidb/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,17 @@ async def test_executes_pipeline_against_correct_db(

mock_hc.check_health.return_value = True

client = MultiDBClient(mock_multi_db_config)
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
async with MultiDBClient(mock_multi_db_config) as client:
assert (
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
)

pipe = client.pipeline()
pipe.set("key1", "value1")
pipe.get("key1")
pipe = client.pipeline()
pipe.set("key1", "value1")
pipe.get("key1")

assert await pipe.execute() == ["OK1", "value1"]
assert mock_hc.check_health.call_count == 9
assert await pipe.execute() == ["OK1", "value1"]
assert len(mock_hc.check_health.call_args_list) >= 9

@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down Expand Up @@ -88,29 +90,28 @@ async def test_execute_pipeline_against_correct_db_and_closed_circuit(
pipe.execute.return_value = ["OK1", "value1"]
mock_db1.client.pipeline.return_value = pipe

mock_hc.check_health.side_effect = [
False,
True,
True,
True,
True,
True,
True,
]
async def mock_check_health(database):
if database == mock_db2:
return False
else:
return True

client = MultiDBClient(mock_multi_db_config)
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
mock_hc.check_health.side_effect = mock_check_health
async with MultiDBClient(mock_multi_db_config) as client:
assert (
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
)

async with client.pipeline() as pipe:
pipe.set("key1", "value1")
pipe.get("key1")
async with client.pipeline() as pipe:
pipe.set("key1", "value1")
pipe.get("key1")

assert await pipe.execute() == ["OK1", "value1"]
assert mock_hc.check_health.call_count == 7
assert await pipe.execute() == ["OK1", "value1"]
assert len(mock_hc.check_health.call_args_list) >= 7

assert mock_db.circuit.state == CBState.CLOSED
assert mock_db1.circuit.state == CBState.CLOSED
assert mock_db2.circuit.state == CBState.OPEN
assert mock_db.circuit.state == CBState.CLOSED
assert mock_db1.circuit.state == CBState.CLOSED
assert mock_db2.circuit.state == CBState.OPEN

@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down Expand Up @@ -291,15 +292,19 @@ async def test_executes_transaction_against_correct_db(

mock_hc.check_health.return_value = True

client = MultiDBClient(mock_multi_db_config)
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
async with MultiDBClient(mock_multi_db_config) as client:
assert (
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
)

async def callback(pipe: Pipeline):
pipe.set("key1", "value1")
pipe.get("key1")
async def callback(pipe: Pipeline):
pipe.set("key1", "value1")
pipe.get("key1")

assert await client.transaction(callback) == ["OK1", "value1"]
assert mock_hc.check_health.call_count == 9
assert await client.transaction(callback) == ["OK1", "value1"]
# if we assume at least 3 health checks have run per each database
# we should have at least 9 total calls
assert len(mock_hc.check_health.call_args_list) >= 9

@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down Expand Up @@ -327,29 +332,29 @@ async def test_execute_transaction_against_correct_db_and_closed_circuit(
):
mock_db1.client.transaction.return_value = ["OK1", "value1"]

mock_hc.check_health.side_effect = [
False,
True,
True,
True,
True,
True,
True,
]
async def mock_check_health(database):
if database == mock_db2:
return False
else:
return True

client = MultiDBClient(mock_multi_db_config)
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
mock_hc.check_health.side_effect = mock_check_health

async def callback(pipe: Pipeline):
pipe.set("key1", "value1")
pipe.get("key1")
async with MultiDBClient(mock_multi_db_config) as client:
assert (
mock_multi_db_config.failover_strategy.set_databases.call_count == 1
)

assert await client.transaction(callback) == ["OK1", "value1"]
assert mock_hc.check_health.call_count == 7
async def callback(pipe: Pipeline):
pipe.set("key1", "value1")
pipe.get("key1")

assert mock_db.circuit.state == CBState.CLOSED
assert mock_db1.circuit.state == CBState.CLOSED
assert mock_db2.circuit.state == CBState.OPEN
assert await client.transaction(callback) == ["OK1", "value1"]
assert len(mock_hc.check_health.call_args_list) >= 7

assert mock_db.circuit.state == CBState.CLOSED
assert mock_db1.circuit.state == CBState.CLOSED
assert mock_db2.circuit.state == CBState.OPEN

@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down Expand Up @@ -455,34 +460,34 @@ async def mock_check_health(database):
mock_multi_db_config.health_check_interval = 0.1
mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy()

client = MultiDBClient(mock_multi_db_config)
async with MultiDBClient(mock_multi_db_config) as client:

async def callback(pipe: Pipeline):
pipe.set("key1", "value1")
pipe.get("key1")
async def callback(pipe: Pipeline):
pipe.set("key1", "value1")
pipe.get("key1")

assert await client.transaction(callback) == ["OK1", "value"]
assert await client.transaction(callback) == ["OK1", "value"]

# Wait for mock_db1 to become unhealthy
assert await db1_became_unhealthy.wait(), (
"Timeout waiting for mock_db1 to become unhealthy"
)
await asyncio.sleep(0.01)
# Wait for mock_db1 to become unhealthy
assert await db1_became_unhealthy.wait(), (
"Timeout waiting for mock_db1 to become unhealthy"
)
await asyncio.sleep(0.01)

assert await client.transaction(callback) == ["OK2", "value"]
assert await client.transaction(callback) == ["OK2", "value"]

# Wait for mock_db2 to become unhealthy
assert await db2_became_unhealthy.wait(), (
"Timeout waiting for mock_db1 to become unhealthy"
)
await asyncio.sleep(0.01)
# Wait for mock_db2 to become unhealthy
assert await db2_became_unhealthy.wait(), (
"Timeout waiting for mock_db1 to become unhealthy"
)
await asyncio.sleep(0.01)

assert await client.transaction(callback) == ["OK", "value"]
assert await client.transaction(callback) == ["OK", "value"]

# Wait for mock_db to become unhealthy
assert await db_became_unhealthy.wait(), (
"Timeout waiting for mock_db1 to become unhealthy"
)
await asyncio.sleep(0.01)
# Wait for mock_db to become unhealthy
assert await db_became_unhealthy.wait(), (
"Timeout waiting for mock_db1 to become unhealthy"
)
await asyncio.sleep(0.01)

assert await client.transaction(callback) == ["OK1", "value"]
assert await client.transaction(callback) == ["OK1", "value"]
25 changes: 12 additions & 13 deletions tests/test_auth/test_token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ def test_token_renewal_with_skip_initial(self):
mock_provider.request_token.side_effect = [
SimpleToken(
"value",
(datetime.now(timezone.utc).timestamp() * 1000) + 50,
(datetime.now(timezone.utc).timestamp() * 1000) + 1000,
(datetime.now(timezone.utc).timestamp() * 1000),
{"oid": "test"},
),
SimpleToken(
"value",
(datetime.now(timezone.utc).timestamp() * 1000) + 150,
(datetime.now(timezone.utc).timestamp() * 1000) + 1500,
(datetime.now(timezone.utc).timestamp() * 1000),
{"oid": "test"},
),
Expand All @@ -194,12 +194,12 @@ def on_next(token):
mock_listener.on_next = on_next

retry_policy = RetryPolicy(3, 10)
config = TokenManagerConfig(1, 0, 1000, retry_policy)
config = TokenManagerConfig(0.5, 0, 1000, retry_policy)
mgr = TokenManager(mock_provider, config)
mgr.start(mock_listener, skip_initial=True)
# Should be less than a 0.1, or it will be flacky due to
# additional token renewal.
sleep(0.1)
assert len(tokens) == 0

sleep(0.6)

assert len(tokens) > 0

Expand All @@ -210,19 +210,19 @@ async def test_async_token_renewal_with_skip_initial(self):
mock_provider.request_token.side_effect = [
SimpleToken(
"value",
(datetime.now(timezone.utc).timestamp() * 1000) + 100,
(datetime.now(timezone.utc).timestamp() * 1000) + 1000,
(datetime.now(timezone.utc).timestamp() * 1000),
{"oid": "test"},
),
SimpleToken(
"value",
(datetime.now(timezone.utc).timestamp() * 1000) + 120,
(datetime.now(timezone.utc).timestamp() * 1000) + 1200,
(datetime.now(timezone.utc).timestamp() * 1000),
{"oid": "test"},
),
SimpleToken(
"value",
(datetime.now(timezone.utc).timestamp() * 1000) + 140,
(datetime.now(timezone.utc).timestamp() * 1000) + 1400,
(datetime.now(timezone.utc).timestamp() * 1000),
{"oid": "test"},
),
Expand All @@ -236,13 +236,12 @@ async def on_next(token):
mock_listener.on_next = on_next

retry_policy = RetryPolicy(3, 10)
config = TokenManagerConfig(1, 0, 1000, retry_policy)
config = TokenManagerConfig(0.5, 0, 1000, retry_policy)
mgr = TokenManager(mock_provider, config)
await mgr.start_async(mock_listener, skip_initial=True)
# Should be less than a 0.1, or it will be flacky
# due to additional token renewal.
await asyncio.sleep(0.2)
assert len(tokens) == 0

await asyncio.sleep(0.6)
assert len(tokens) > 0

def test_success_token_renewal_with_retry(self):
Expand Down
16 changes: 8 additions & 8 deletions tests/test_background.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def callback(arg1: str, arg2: int):
],
)
def test_run_recurring(self, interval, timeout, call_count):
execute_counter = 0
execute_counter = []
one = "arg1"
two = 9999

Expand All @@ -48,18 +48,18 @@ def callback(arg1: str, arg2: int):
nonlocal one
nonlocal two

execute_counter += 1
execute_counter.append(1)

assert arg1 == one
assert arg2 == two

scheduler = BackgroundScheduler()
scheduler.run_recurring(interval, callback, one, two)
assert execute_counter == 0
assert len(execute_counter) == 0

sleep(timeout)

assert execute_counter == call_count
assert len(execute_counter) == call_count

@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand All @@ -71,7 +71,7 @@ def callback(arg1: str, arg2: int):
],
)
async def test_run_recurring_async(self, interval, timeout, call_count):
execute_counter = 0
execute_counter = []
one = "arg1"
two = 9999

Expand All @@ -80,15 +80,15 @@ async def callback(arg1: str, arg2: int):
nonlocal one
nonlocal two

execute_counter += 1
execute_counter.append(1)

assert arg1 == one
assert arg2 == two

scheduler = BackgroundScheduler()
await scheduler.run_recurring_async(interval, callback, one, two)
assert execute_counter == 0
assert len(execute_counter) == 0

await asyncio.sleep(timeout)

assert execute_counter == call_count
assert len(execute_counter) == call_count
Loading
Loading