Skip to content

Commit 540171a

Browse files
fix connection pool for SSL (#129)
When an Async Redis SSLConnection (ACRE for example) is used, the method to check modules was failing. The module check has to be done with a sync connection because it is a decorator around certain `SearchIndex` methods.
1 parent ce0637e commit 540171a

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

redisvl/redis/connection.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import os
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Dict, List, Optional, Type
33

4-
from redis import ConnectionPool, Redis
4+
from redis import Redis
55
from redis.asyncio import Redis as AsyncRedis
6+
from redis.asyncio import SSLConnection as ASSLConnection
7+
from redis.connection import (
8+
AbstractConnection,
9+
Connection,
10+
ConnectionPool,
11+
SSLConnection,
12+
)
613

714
from redisvl.redis.constants import REDIS_REQUIRED_MODULES
815
from redisvl.redis.utils import convert_bytes
@@ -130,8 +137,18 @@ def validate_async_redis_modules(
130137
Raises:
131138
ValueError: If required Redis modules are not installed.
132139
"""
140+
# pick the right connection class
141+
connection_class: Type[AbstractConnection] = (
142+
SSLConnection
143+
if client.connection_pool.connection_class == ASSLConnection
144+
else Connection
145+
)
146+
# set up a temp sync client
133147
temp_client = Redis(
134-
connection_pool=ConnectionPool(**client.connection_pool.connection_kwargs)
148+
connection_pool=ConnectionPool(
149+
connection_class=connection_class,
150+
**client.connection_pool.connection_kwargs,
151+
)
135152
)
136153
RedisConnectionFactory.validate_redis_modules(
137154
temp_client, redis_required_modules

tests/integration/test_connection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,8 @@ def test_unknown_redis():
5151

5252
def test_required_modules(client):
5353
RedisConnectionFactory.validate_redis_modules(client)
54-
RedisConnectionFactory.validate_async_redis_modules(client)
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_async_required_modules(async_client):
58+
RedisConnectionFactory.validate_async_redis_modules(async_client)

0 commit comments

Comments
 (0)