Skip to content

Commit 23dab1a

Browse files
committed
feat: implement EntraID auth for Azure with Service Principal, Managed Identity and Default Credential support
1 parent 51b91d4 commit 23dab1a

File tree

6 files changed

+539
-59
lines changed

6 files changed

+539
-59
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies = [
2626
"dotenv>=0.9.9",
2727
"numpy>=2.2.4",
2828
"click>=8.0.0",
29+
"redis-entraid>=1.0.0",
2930
]
3031

3132
[project.scripts]

src/common/config.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55

66
load_dotenv()
77

8+
# Default values for Entra ID authentication
9+
DEFAULT_TOKEN_EXPIRATION_REFRESH_RATIO = 0.9
10+
DEFAULT_LOWER_REFRESH_BOUND_MILLIS = 30000 # 30 seconds
11+
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_MS = 10000 # 10 seconds
12+
DEFAULT_RETRY_MAX_ATTEMPTS = 3
13+
DEFAULT_RETRY_DELAY_MS = 100
14+
815
REDIS_CFG = {
916
"host": os.getenv("REDIS_HOST", "127.0.0.1"),
1017
"port": int(os.getenv("REDIS_PORT", 6379)),
@@ -20,6 +27,55 @@
2027
"db": int(os.getenv("REDIS_DB", 0)),
2128
}
2229

30+
# Entra ID Authentication Configuration
31+
ENTRAID_CFG = {
32+
# Authentication flow selection
33+
"auth_flow": os.getenv(
34+
"REDIS_ENTRAID_AUTH_FLOW", None
35+
), # service_principal, managed_identity, default_credential
36+
# Service Principal Authentication
37+
"client_id": os.getenv("REDIS_ENTRAID_CLIENT_ID", None),
38+
"client_secret": os.getenv("REDIS_ENTRAID_CLIENT_SECRET", None),
39+
"tenant_id": os.getenv("REDIS_ENTRAID_TENANT_ID", None),
40+
# Managed Identity Authentication
41+
"identity_type": os.getenv(
42+
"REDIS_ENTRAID_IDENTITY_TYPE", "system_assigned"
43+
), # system_assigned, user_assigned
44+
"user_assigned_identity_client_id": os.getenv(
45+
"REDIS_ENTRAID_USER_ASSIGNED_CLIENT_ID", None
46+
),
47+
# Default Azure Credential Authentication
48+
"scopes": os.getenv("REDIS_ENTRAID_SCOPES", "https://redis.azure.com/.default"),
49+
# Token lifecycle configuration
50+
"token_expiration_refresh_ratio": float(
51+
os.getenv(
52+
"REDIS_ENTRAID_TOKEN_EXPIRATION_REFRESH_RATIO",
53+
DEFAULT_TOKEN_EXPIRATION_REFRESH_RATIO,
54+
)
55+
),
56+
"lower_refresh_bound_millis": int(
57+
os.getenv(
58+
"REDIS_ENTRAID_LOWER_REFRESH_BOUND_MILLIS",
59+
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
60+
)
61+
),
62+
"token_request_execution_timeout_ms": int(
63+
os.getenv(
64+
"REDIS_ENTRAID_TOKEN_REQUEST_EXECUTION_TIMEOUT_MS",
65+
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_MS,
66+
)
67+
),
68+
# Retry configuration
69+
"retry_max_attempts": int(
70+
os.getenv("REDIS_ENTRAID_RETRY_MAX_ATTEMPTS", DEFAULT_RETRY_MAX_ATTEMPTS)
71+
),
72+
"retry_delay_ms": int(
73+
os.getenv("REDIS_ENTRAID_RETRY_DELAY_MS", DEFAULT_RETRY_DELAY_MS)
74+
),
75+
# Resource configuration
76+
"resource": os.getenv("REDIS_ENTRAID_RESOURCE", "https://redis.azure.com/"),
77+
}
78+
2379

2480
def parse_redis_uri(uri: str) -> dict:
2581
"""Parse a Redis URI and return connection parameters."""
@@ -99,3 +155,77 @@ def set_redis_config_from_cli(config: dict):
99155
else:
100156
# Convert other values to strings
101157
REDIS_CFG[key] = str(value) if value is not None else None
158+
159+
160+
def set_entraid_config_from_cli(config: dict):
161+
"""Update Entra ID configuration from CLI parameters."""
162+
for key, value in config.items():
163+
if value is not None:
164+
if key in ["token_expiration_refresh_ratio"]:
165+
# Keep float values as floats
166+
ENTRAID_CFG[key] = float(value)
167+
elif key in [
168+
"lower_refresh_bound_millis",
169+
"token_request_execution_timeout_ms",
170+
"retry_max_attempts",
171+
"retry_delay_ms",
172+
]:
173+
# Keep integer values as integers
174+
ENTRAID_CFG[key] = int(value)
175+
else:
176+
# Convert other values to strings
177+
ENTRAID_CFG[key] = str(value)
178+
179+
180+
def is_entraid_auth_enabled() -> bool:
181+
"""Check if Entra ID authentication is enabled."""
182+
return ENTRAID_CFG["auth_flow"] is not None
183+
184+
185+
def get_entraid_auth_flow() -> str:
186+
"""Get the configured Entra ID authentication flow."""
187+
return ENTRAID_CFG["auth_flow"]
188+
189+
190+
def validate_entraid_config() -> tuple[bool, str]:
191+
"""Validate Entra ID configuration based on the selected auth flow.
192+
193+
Returns:
194+
tuple: (is_valid, error_message)
195+
"""
196+
auth_flow = ENTRAID_CFG["auth_flow"]
197+
198+
if not auth_flow:
199+
return True, "" # No Entra ID auth configured, which is valid
200+
201+
if auth_flow == "service_principal":
202+
required_fields = ["client_id", "client_secret", "tenant_id"]
203+
missing_fields = [field for field in required_fields if not ENTRAID_CFG[field]]
204+
if missing_fields:
205+
return (
206+
False,
207+
f"Service principal authentication requires: {', '.join(missing_fields)}",
208+
)
209+
210+
elif auth_flow == "managed_identity":
211+
identity_type = ENTRAID_CFG["identity_type"]
212+
if (
213+
identity_type == "user_assigned"
214+
and not ENTRAID_CFG["user_assigned_identity_client_id"]
215+
):
216+
return (
217+
False,
218+
"User-assigned managed identity requires user_assigned_identity_client_id",
219+
)
220+
221+
elif auth_flow == "default_credential":
222+
# Default credential doesn't require specific configuration
223+
pass
224+
225+
else:
226+
return (
227+
False,
228+
f"Invalid auth_flow: {auth_flow}. Must be one of: service_principal, managed_identity, default_credential",
229+
)
230+
231+
return True, ""

src/common/connection.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from redis import Redis
66
from redis.cluster import RedisCluster
77

8-
from src.common.config import REDIS_CFG
8+
from src.common.config import REDIS_CFG, is_entraid_auth_enabled
9+
from src.common.entraid_auth import (
10+
create_credential_provider,
11+
EntraIDAuthenticationError,
12+
)
913
from src.version import __version__
1014

1115
_logger = logging.getLogger(__name__)
@@ -18,6 +22,17 @@ class RedisConnectionManager:
1822
def get_connection(cls, decode_responses=True) -> Redis:
1923
if cls._instance is None:
2024
try:
25+
# Create Entra ID credential provider if configured
26+
credential_provider = None
27+
if is_entraid_auth_enabled():
28+
try:
29+
credential_provider = create_credential_provider()
30+
except EntraIDAuthenticationError as e:
31+
_logger.error(
32+
"Failed to create Entra ID credential provider: %s", e
33+
)
34+
raise
35+
2136
if REDIS_CFG["cluster_mode"]:
2237
redis_class: Type[Union[Redis, RedisCluster]] = (
2338
redis.cluster.RedisCluster
@@ -37,6 +52,12 @@ def get_connection(cls, decode_responses=True) -> Redis:
3752
"lib_name": f"redis-py(mcp-server_v{__version__})",
3853
"max_connections_per_node": 10,
3954
}
55+
56+
# Add credential provider if available
57+
if credential_provider:
58+
connection_params["credential_provider"] = credential_provider
59+
# Note: Azure Redis Enterprise with EntraID uses plain text connections
60+
# SSL setting is controlled by REDIS_SSL environment variable
4061
else:
4162
redis_class: Type[Union[Redis, RedisCluster]] = redis.Redis
4263
connection_params = {
@@ -56,6 +77,12 @@ def get_connection(cls, decode_responses=True) -> Redis:
5677
"max_connections": 10,
5778
}
5879

80+
# Add credential provider if available
81+
if credential_provider:
82+
connection_params["credential_provider"] = credential_provider
83+
# Note: Azure Redis Enterprise with EntraID uses plain text connections
84+
# SSL setting is controlled by REDIS_SSL environment variable
85+
5986
cls._instance = redis_class(**connection_params)
6087

6188
except redis.exceptions.ConnectionError:

src/common/entraid_auth.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""
2+
Entra ID authentication provider factory for Redis MCP Server.
3+
4+
This module provides factory methods to create credential providers for different
5+
Azure authentication flows based on configuration.
6+
"""
7+
8+
import logging
9+
10+
from src.common.config import (
11+
ENTRAID_CFG,
12+
is_entraid_auth_enabled,
13+
validate_entraid_config,
14+
)
15+
16+
_logger = logging.getLogger(__name__)
17+
18+
# Reduce Azure SDK logging verbosity
19+
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)
20+
logging.getLogger("azure.identity").setLevel(logging.WARNING)
21+
logging.getLogger("redis.auth.token_manager").setLevel(logging.WARNING)
22+
23+
# Import redis-entraid components only when needed
24+
try:
25+
from redis_entraid.cred_provider import (
26+
create_from_default_azure_credential,
27+
create_from_managed_identity,
28+
create_from_service_principal,
29+
ManagedIdentityType,
30+
TokenManagerConfig,
31+
RetryPolicy,
32+
)
33+
34+
ENTRAID_AVAILABLE = True
35+
except ImportError:
36+
_logger.warning(
37+
"redis-entraid package not available. Entra ID authentication will be disabled."
38+
)
39+
ENTRAID_AVAILABLE = False
40+
41+
42+
class EntraIDAuthenticationError(Exception):
43+
"""Exception raised for Entra ID authentication configuration errors."""
44+
45+
pass
46+
47+
48+
def create_credential_provider():
49+
"""
50+
Create an Entra ID credential provider based on the current configuration.
51+
52+
Returns:
53+
Credential provider instance or None if Entra ID auth is not configured.
54+
55+
Raises:
56+
EntraIDAuthenticationError: If configuration is invalid or required packages are missing.
57+
"""
58+
if not is_entraid_auth_enabled():
59+
return None
60+
61+
if not ENTRAID_AVAILABLE:
62+
raise EntraIDAuthenticationError(
63+
"redis-entraid package is required for Entra ID authentication. "
64+
"Install it with: pip install redis-entraid"
65+
)
66+
67+
# Validate configuration
68+
is_valid, error_message = validate_entraid_config()
69+
if not is_valid:
70+
raise EntraIDAuthenticationError(
71+
f"Invalid Entra ID configuration: {error_message}"
72+
)
73+
74+
auth_flow = ENTRAID_CFG["auth_flow"]
75+
76+
try:
77+
# Create token manager configuration
78+
token_manager_config = _create_token_manager_config()
79+
80+
if auth_flow == "service_principal":
81+
return _create_service_principal_provider(token_manager_config)
82+
elif auth_flow == "managed_identity":
83+
return _create_managed_identity_provider(token_manager_config)
84+
elif auth_flow == "default_credential":
85+
return _create_default_credential_provider(token_manager_config)
86+
else:
87+
raise EntraIDAuthenticationError(
88+
f"Unsupported authentication flow: {auth_flow}"
89+
)
90+
91+
except Exception as e:
92+
_logger.error("Failed to create Entra ID credential provider: %s", e)
93+
raise EntraIDAuthenticationError(f"Failed to create credential provider: {e}")
94+
95+
96+
def _create_token_manager_config():
97+
"""Create TokenManagerConfig from current configuration."""
98+
retry_policy = RetryPolicy(
99+
max_attempts=ENTRAID_CFG["retry_max_attempts"],
100+
delay_in_ms=ENTRAID_CFG["retry_delay_ms"],
101+
)
102+
103+
return TokenManagerConfig(
104+
expiration_refresh_ratio=ENTRAID_CFG["token_expiration_refresh_ratio"],
105+
lower_refresh_bound_millis=ENTRAID_CFG["lower_refresh_bound_millis"],
106+
token_request_execution_timeout_in_ms=ENTRAID_CFG[
107+
"token_request_execution_timeout_ms"
108+
],
109+
retry_policy=retry_policy,
110+
)
111+
112+
113+
def _create_service_principal_provider(token_manager_config):
114+
"""Create service principal credential provider."""
115+
116+
return create_from_service_principal(
117+
client_id=ENTRAID_CFG["client_id"],
118+
client_credential=ENTRAID_CFG["client_secret"],
119+
tenant_id=ENTRAID_CFG["tenant_id"],
120+
token_manager_config=token_manager_config,
121+
)
122+
123+
124+
def _create_managed_identity_provider(token_manager_config):
125+
"""Create managed identity credential provider."""
126+
identity_type_str = ENTRAID_CFG["identity_type"]
127+
128+
if identity_type_str == "system_assigned":
129+
identity_type = ManagedIdentityType.SYSTEM_ASSIGNED
130+
131+
return create_from_managed_identity(
132+
identity_type=identity_type,
133+
resource=ENTRAID_CFG["resource"],
134+
token_manager_config=token_manager_config,
135+
)
136+
137+
elif identity_type_str == "user_assigned":
138+
identity_type = ManagedIdentityType.USER_ASSIGNED
139+
140+
return create_from_managed_identity(
141+
identity_type=identity_type,
142+
resource=ENTRAID_CFG["resource"],
143+
client_id=ENTRAID_CFG["user_assigned_identity_client_id"],
144+
token_manager_config=token_manager_config,
145+
)
146+
147+
else:
148+
raise EntraIDAuthenticationError(f"Invalid identity type: {identity_type_str}")
149+
150+
151+
def _create_default_credential_provider(token_manager_config):
152+
"""Create default Azure credential provider."""
153+
154+
# Parse scopes from configuration
155+
scopes_str = ENTRAID_CFG["scopes"]
156+
scopes = tuple(scope.strip() for scope in scopes_str.split(","))
157+
158+
return create_from_default_azure_credential(
159+
scopes=scopes, token_manager_config=token_manager_config
160+
)

0 commit comments

Comments
 (0)