Skip to content

Commit f052355

Browse files
committed
feat: add LLMWithGateway for enterprise OAuth support
Add LLMWithGateway subclass to support enterprise API gateways with OAuth 2.0 authentication (e.g., APIGEE, Azure API Management). Key features: - OAuth 2.0 token fetch and automatic refresh - Thread-safe token caching with TTL - Custom header injection for gateway-specific requirements - Template variable replacement for flexible configuration - Fully generic implementation (no vendor lock-in) Implementation approach: - Separate LLMWithGateway class (addresses PR #963 feedback from neubig) - Focused feature set for OAuth + custom headers (no over-engineering) - Comprehensive test coverage This replaces the previous approach of modifying the main LLM class, keeping the codebase cleaner and more maintainable. Example usage (APIGEE + Tachyon): ```python llm = LLMWithGateway( model="gemini-1.5-flash", base_url=os.environ["TACHYON_API_URL"], gateway_auth_url=os.environ["APIGEE_TOKEN_URL"], gateway_auth_headers={ "X-Client-Id": os.environ["APIGEE_CLIENT_ID"], "X-Client-Secret": os.environ["APIGEE_CLIENT_SECRET"], }, gateway_auth_body={"grant_type": "client_credentials"}, custom_headers={"X-Tachyon-Key": os.environ["TACHYON_API_KEY"]}, ) ``` Files added: - openhands-sdk/openhands/sdk/llm/llm_with_gateway.py (new class) - tests/sdk/llm/test_llm_with_gateway.py (comprehensive tests)
1 parent b9860ce commit f052355

File tree

3 files changed

+739
-0
lines changed

3 files changed

+739
-0
lines changed

openhands-sdk/openhands/sdk/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from openhands.sdk.llm.llm import LLM
22
from openhands.sdk.llm.llm_registry import LLMRegistry, RegistryEvent
33
from openhands.sdk.llm.llm_response import LLMResponse
4+
from openhands.sdk.llm.llm_with_gateway import LLMWithGateway
45
from openhands.sdk.llm.message import (
56
ImageContent,
67
Message,
@@ -23,6 +24,7 @@
2324
__all__ = [
2425
"LLMResponse",
2526
"LLM",
27+
"LLMWithGateway",
2628
"LLMRegistry",
2729
"RouterLLM",
2830
"RegistryEvent",
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
"""LLM subclass with enterprise gateway support.
2+
3+
This module provides LLMWithGateway, which extends the base LLM class to support
4+
OAuth 2.0 authentication flows and custom headers for enterprise API gateways
5+
like APIGEE + Tachyon.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import threading
11+
import time
12+
from typing import Any
13+
14+
import httpx
15+
from pydantic import Field, PrivateAttr
16+
17+
from openhands.sdk.llm.llm import LLM
18+
from openhands.sdk.logger import get_logger
19+
20+
logger = get_logger(__name__)
21+
22+
__all__ = ["LLMWithGateway"]
23+
24+
25+
class LLMWithGateway(LLM):
26+
"""LLM subclass with enterprise gateway support.
27+
28+
Supports OAuth 2.0 token exchange with configurable headers and bodies.
29+
Designed for enterprise gateways like APIGEE + Tachyon that require:
30+
1. Initial OAuth call to get a bearer token
31+
2. Bearer token included in subsequent LLM API calls
32+
3. Custom headers for routing/authentication
33+
34+
Example usage for APIGEE + Tachyon:
35+
llm = LLMWithGateway(
36+
model="gemini-1.5-flash",
37+
base_url="https://apigee.example.com/.../tachyon/v1",
38+
gateway_auth_url="https://apigee.example.com/oauth/token",
39+
gateway_auth_headers={
40+
"X-Client-Id": os.environ["APIGEE_KEY"],
41+
"X-Client-Secret": os.environ["APIGEE_SECRET"],
42+
},
43+
gateway_auth_body={"grant_type": "client_credentials"},
44+
custom_headers={"X-Tachyon-Key": os.environ["TACHYON_KEY"]},
45+
)
46+
"""
47+
48+
# OAuth configuration
49+
gateway_auth_url: str | None = Field(
50+
default=None,
51+
description="Identity provider URL to fetch gateway tokens (OAuth endpoint).",
52+
)
53+
gateway_auth_method: str = Field(
54+
default="POST",
55+
description="HTTP method for identity provider requests.",
56+
)
57+
gateway_auth_headers: dict[str, str] | None = Field(
58+
default=None,
59+
description="Headers to include when calling the identity provider.",
60+
)
61+
gateway_auth_body: dict[str, Any] | None = Field(
62+
default=None,
63+
description="JSON body to include when calling the identity provider.",
64+
)
65+
gateway_auth_token_path: str = Field(
66+
default="access_token",
67+
description="Dot-notation path to the token in the OAuth response (e.g., 'access_token' or 'data.token').",
68+
)
69+
gateway_auth_token_ttl: int | None = Field(
70+
default=None,
71+
description="Token TTL in seconds. If not set, defaults to 300s (5 minutes).",
72+
)
73+
74+
# Token header configuration
75+
gateway_token_header: str = Field(
76+
default="Authorization",
77+
description="Header name for the gateway token (defaults to 'Authorization').",
78+
)
79+
gateway_token_prefix: str = Field(
80+
default="Bearer ",
81+
description="Prefix prepended to the token (e.g., 'Bearer ').",
82+
)
83+
84+
# Custom headers for all requests
85+
custom_headers: dict[str, str] | None = Field(
86+
default=None,
87+
description="Custom headers to include with every LLM request.",
88+
)
89+
90+
# Private fields for token management
91+
_gateway_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
92+
_gateway_token: str | None = PrivateAttr(default=None)
93+
_gateway_token_expiry: float | None = PrivateAttr(default=None)
94+
95+
def model_post_init(self, __context: Any) -> None:
96+
"""Initialize private fields after model validation."""
97+
super().model_post_init(__context)
98+
self._gateway_lock = threading.Lock()
99+
self._gateway_token = None
100+
self._gateway_token_expiry = None
101+
102+
def _completion(
103+
self, messages: list[dict], **kwargs
104+
) -> Any: # Returns ModelResponse
105+
"""Override to inject gateway authentication before calling LiteLLM."""
106+
self._prepare_gateway_call(kwargs)
107+
return super()._completion(messages, **kwargs)
108+
109+
def _responses(
110+
self, messages: list[dict], **kwargs
111+
) -> Any: # Returns ResponsesAPIResponse
112+
"""Override to inject gateway authentication before calling LiteLLM."""
113+
self._prepare_gateway_call(kwargs)
114+
return super()._responses(messages, **kwargs)
115+
116+
def _prepare_gateway_call(self, call_kwargs: dict[str, Any]) -> None:
117+
"""Augment LiteLLM kwargs with gateway headers and token.
118+
119+
This method:
120+
1. Fetches/refreshes OAuth token if needed
121+
2. Adds token to headers
122+
3. Adds custom headers
123+
4. Performs basic template variable replacement
124+
"""
125+
if not self.gateway_auth_url and not self.custom_headers:
126+
return
127+
128+
# Start with existing headers
129+
headers: dict[str, str] = {}
130+
existing_headers = call_kwargs.get("extra_headers")
131+
if isinstance(existing_headers, dict):
132+
headers.update(existing_headers)
133+
134+
# Add custom headers (with template replacement)
135+
if self.custom_headers:
136+
rendered_headers = self._render_templates(self.custom_headers)
137+
if isinstance(rendered_headers, dict):
138+
headers.update(rendered_headers)
139+
140+
# Add gateway token if OAuth is configured
141+
if self.gateway_auth_url:
142+
token_headers = self._get_gateway_token_headers()
143+
if token_headers:
144+
headers.update(token_headers)
145+
146+
# Set headers on the call
147+
if headers:
148+
call_kwargs["extra_headers"] = headers
149+
150+
def _get_gateway_token_headers(self) -> dict[str, str]:
151+
"""Get headers containing the gateway token."""
152+
token = self._ensure_gateway_token()
153+
if not token:
154+
return {}
155+
156+
header_name = self.gateway_token_header
157+
prefix = self.gateway_token_prefix
158+
value = f"{prefix}{token}" if prefix else token
159+
return {header_name: value}
160+
161+
def _ensure_gateway_token(self) -> str | None:
162+
"""Ensure we have a valid gateway token, refreshing if needed.
163+
164+
Returns:
165+
Valid gateway token, or None if gateway auth is not configured.
166+
"""
167+
if not self.gateway_auth_url:
168+
return None
169+
170+
# Fast path: check if current token is still valid (with 5s buffer)
171+
now = time.time()
172+
if (
173+
self._gateway_token
174+
and self._gateway_token_expiry
175+
and now < self._gateway_token_expiry - 5
176+
):
177+
return self._gateway_token
178+
179+
# Slow path: acquire lock and refresh token
180+
with self._gateway_lock:
181+
# Double-check after acquiring lock
182+
if (
183+
self._gateway_token
184+
and self._gateway_token_expiry
185+
and time.time() < self._gateway_token_expiry - 5
186+
):
187+
return self._gateway_token
188+
189+
# Refresh token
190+
return self._refresh_gateway_token()
191+
192+
def _refresh_gateway_token(self) -> str:
193+
"""Fetch a new gateway token from the identity provider.
194+
195+
This method is called while holding _gateway_lock.
196+
197+
Returns:
198+
Fresh gateway token.
199+
200+
Raises:
201+
Exception: If token fetch fails.
202+
"""
203+
method = self.gateway_auth_method.upper()
204+
headers = self._render_templates(self.gateway_auth_headers or {})
205+
body = self._render_templates(self.gateway_auth_body or {})
206+
207+
logger.debug(
208+
f"Fetching gateway token from {self.gateway_auth_url} "
209+
f"(method={method})"
210+
)
211+
212+
try:
213+
response = httpx.request(
214+
method,
215+
self.gateway_auth_url,
216+
headers=headers if isinstance(headers, dict) else None,
217+
json=body if isinstance(body, dict) else None,
218+
timeout=self.timeout or 30,
219+
)
220+
response.raise_for_status()
221+
except Exception as exc:
222+
logger.error(f"Gateway auth request failed: {exc}")
223+
raise
224+
225+
try:
226+
payload = response.json()
227+
except Exception as exc:
228+
logger.error(f"Failed to parse gateway auth response JSON: {exc}")
229+
raise
230+
231+
# Extract token from response
232+
token_path = self.gateway_auth_token_path
233+
token_value = self._extract_from_path(payload, token_path)
234+
if not isinstance(token_value, str) or not token_value.strip():
235+
raise ValueError(
236+
f"Gateway auth response did not contain token at path "
237+
f'"{token_path}". Response: {payload}'
238+
)
239+
240+
# Determine TTL
241+
ttl_seconds = float(self.gateway_auth_token_ttl or 300)
242+
243+
# Update cache
244+
self._gateway_token = token_value.strip()
245+
self._gateway_token_expiry = time.time() + max(ttl_seconds, 1.0)
246+
247+
logger.info(
248+
f"Gateway token refreshed successfully (expires in {ttl_seconds}s)"
249+
)
250+
return self._gateway_token
251+
252+
def _render_templates(self, value: Any) -> Any:
253+
"""Replace template variables in strings with actual values.
254+
255+
Supports:
256+
- {{llm_model}} -> self.model
257+
- {{llm_base_url}} -> self.base_url
258+
- {{llm_api_key}} -> self.api_key (if set)
259+
260+
Args:
261+
value: String, dict, list, or other value to render.
262+
263+
Returns:
264+
Value with templates replaced.
265+
"""
266+
if isinstance(value, str):
267+
replacements: dict[str, str] = {
268+
"{{llm_model}}": self.model,
269+
"{{llm_base_url}}": self.base_url or "",
270+
}
271+
if self.api_key:
272+
replacements["{{llm_api_key}}"] = self.api_key.get_secret_value()
273+
274+
result = value
275+
for placeholder, actual in replacements.items():
276+
result = result.replace(placeholder, actual)
277+
return result
278+
279+
if isinstance(value, dict):
280+
return {k: self._render_templates(v) for k, v in value.items()}
281+
282+
if isinstance(value, list):
283+
return [self._render_templates(v) for v in value]
284+
285+
return value
286+
287+
@staticmethod
288+
def _extract_from_path(payload: Any, path: str) -> Any:
289+
"""Extract a value from nested dict/list using dot notation.
290+
291+
Examples:
292+
_extract_from_path({"a": {"b": "value"}}, "a.b") -> "value"
293+
_extract_from_path({"data": [{"token": "x"}]}, "data.0.token") -> "x"
294+
295+
Args:
296+
payload: Dict or list to traverse.
297+
path: Dot-separated path (e.g., "data.token" or "items.0.value").
298+
299+
Returns:
300+
Value at the specified path.
301+
302+
Raises:
303+
ValueError: If path cannot be traversed.
304+
"""
305+
current = payload
306+
if not path:
307+
return current
308+
309+
for part in path.split("."):
310+
if isinstance(current, dict):
311+
current = current.get(part)
312+
if current is None:
313+
raise ValueError(
314+
f'Key "{part}" not found in response while traversing path "{path}".'
315+
)
316+
elif isinstance(current, list):
317+
try:
318+
index = int(part)
319+
except (ValueError, TypeError):
320+
raise ValueError(
321+
f'Invalid list index "{part}" while traversing response path "{path}".'
322+
) from None
323+
try:
324+
current = current[index]
325+
except (IndexError, TypeError):
326+
raise ValueError(
327+
f'Index {index} out of range while traversing response path "{path}".'
328+
) from None
329+
else:
330+
raise ValueError(
331+
f'Cannot traverse path "{path}"; segment "{part}" not found or not accessible.'
332+
)
333+
334+
return current

0 commit comments

Comments
 (0)