From ad9752c76ece51b7fa08aead1d8b2e614e570978 Mon Sep 17 00:00:00 2001 From: Alona King Date: Thu, 30 Oct 2025 13:14:10 -0400 Subject: [PATCH 01/10] feat: add LLMWithGateway for enterprise OAuth support Add LLMWithGateway class that extends LLM with enterprise gateway support: - OAuth 2.0 token fetching and automatic refresh with caching - Configurable token paths and TTL for various OAuth response formats - Custom header injection for routing and additional authentication - Template variable support ({{llm_model}}, {{llm_base_url}}, etc.) - Thread-safe token management - Works with both completion() and responses() APIs The class maintains full API compatibility with the base LLM class while transparently handling gateway authentication flows behind the scenes. Includes comprehensive test coverage (25 tests) covering: - OAuth token lifecycle (fetch, cache, refresh) - Header injection and custom headers - Template replacement - Nested path extraction from OAuth responses - Error handling and edge cases --- openhands-sdk/openhands/sdk/llm/__init__.py | 2 + .../openhands/sdk/llm/llm_with_gateway.py | 335 ++++++++++++++++++ tests/sdk/llm/test_llm_with_gateway.py | 314 ++++++++++++++++ 3 files changed, 651 insertions(+) create mode 100644 openhands-sdk/openhands/sdk/llm/llm_with_gateway.py create mode 100644 tests/sdk/llm/test_llm_with_gateway.py diff --git a/openhands-sdk/openhands/sdk/llm/__init__.py b/openhands-sdk/openhands/sdk/llm/__init__.py index fabed357d1..a5284acf39 100644 --- a/openhands-sdk/openhands/sdk/llm/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/__init__.py @@ -1,6 +1,7 @@ from openhands.sdk.llm.llm import LLM from openhands.sdk.llm.llm_registry import LLMRegistry, RegistryEvent from openhands.sdk.llm.llm_response import LLMResponse +from openhands.sdk.llm.llm_with_gateway import LLMWithGateway from openhands.sdk.llm.message import ( ImageContent, Message, @@ -23,6 +24,7 @@ __all__ = [ "LLMResponse", "LLM", + "LLMWithGateway", "LLMRegistry", "RouterLLM", "RegistryEvent", diff --git a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py new file mode 100644 index 0000000000..4944ccd585 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py @@ -0,0 +1,335 @@ +"""LLM subclass with enterprise gateway support. + +This module provides LLMWithGateway, which extends the base LLM class to support +OAuth 2.0 authentication flows and custom headers for enterprise API gateways. +""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +import httpx +from pydantic import Field, PrivateAttr + +from openhands.sdk.llm.llm import LLM +from openhands.sdk.logger import get_logger + + +logger = get_logger(__name__) + +__all__ = ["LLMWithGateway"] + + +class LLMWithGateway(LLM): + """LLM subclass with enterprise gateway support. + + Supports OAuth 2.0 token exchange with configurable headers and bodies. + Designed for enterprise API gateways that require: + 1. Initial OAuth call to get a bearer token + 2. Bearer token included in subsequent LLM API calls + 3. Custom headers for routing/authentication + + Example usage: + llm = LLMWithGateway( + model="gpt-4", + base_url="https://gateway.company.com/llm/v1", + gateway_auth_url="https://gateway.company.com/oauth/token", + gateway_auth_headers={ + "X-Client-Id": os.environ["GATEWAY_CLIENT_ID"], + "X-Client-Secret": os.environ["GATEWAY_CLIENT_SECRET"], + }, + gateway_auth_body={"grant_type": "client_credentials"}, + custom_headers={"X-Gateway-Key": os.environ["GATEWAY_API_KEY"]}, + ) + """ + + # OAuth configuration + gateway_auth_url: str | None = Field( + default=None, + description="Identity provider URL to fetch gateway tokens (OAuth endpoint).", + ) + gateway_auth_method: str = Field( + default="POST", + description="HTTP method for identity provider requests.", + ) + gateway_auth_headers: dict[str, str] | None = Field( + default=None, + description="Headers to include when calling the identity provider.", + ) + gateway_auth_body: dict[str, Any] | None = Field( + default=None, + description="JSON body to include when calling the identity provider.", + ) + gateway_auth_token_path: str = Field( + default="access_token", + description=( + "Dot-notation path to the token in the OAuth response " + "(e.g., 'access_token' or 'data.token')." + ), + ) + gateway_auth_token_ttl: int | None = Field( + default=None, + description="Token TTL in seconds. If not set, defaults to 300s (5 minutes).", + ) + + # Token header configuration + gateway_token_header: str = Field( + default="Authorization", + description="Header name for the gateway token (defaults to 'Authorization').", + ) + gateway_token_prefix: str = Field( + default="Bearer ", + description="Prefix prepended to the token (e.g., 'Bearer ').", + ) + + # Custom headers for all requests + custom_headers: dict[str, str] | None = Field( + default=None, + description="Custom headers to include with every LLM request.", + ) + + # Private fields for token management + _gateway_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + _gateway_token: str | None = PrivateAttr(default=None) + _gateway_token_expiry: float | None = PrivateAttr(default=None) + + def model_post_init(self, __context: Any) -> None: + """Initialize private fields after model validation.""" + super().model_post_init(__context) + self._gateway_lock = threading.Lock() + self._gateway_token = None + self._gateway_token_expiry = None + + def completion(self, *args, **kwargs): + """Override to inject gateway authentication before calling LiteLLM.""" + self._prepare_gateway_call(kwargs) + return super().completion(*args, **kwargs) + + def responses(self, *args, **kwargs): + """Override to inject gateway authentication before calling LiteLLM.""" + self._prepare_gateway_call(kwargs) + return super().responses(*args, **kwargs) + + def _prepare_gateway_call(self, call_kwargs: dict[str, Any]) -> None: + """Augment LiteLLM kwargs with gateway headers and token. + + This method: + 1. Fetches/refreshes OAuth token if needed + 2. Adds token to headers + 3. Adds custom headers + 4. Performs basic template variable replacement + """ + if not self.gateway_auth_url and not self.custom_headers: + return + + # Start with existing headers + headers: dict[str, str] = {} + existing_headers = call_kwargs.get("extra_headers") + if isinstance(existing_headers, dict): + headers.update(existing_headers) + + # Add custom headers (with template replacement) + if self.custom_headers: + rendered_headers = self._render_templates(self.custom_headers) + if isinstance(rendered_headers, dict): + headers.update(rendered_headers) + + # Add gateway token if OAuth is configured + if self.gateway_auth_url: + token_headers = self._get_gateway_token_headers() + if token_headers: + headers.update(token_headers) + + # Set headers on the call + if headers: + call_kwargs["extra_headers"] = headers + + def _get_gateway_token_headers(self) -> dict[str, str]: + """Get headers containing the gateway token.""" + token = self._ensure_gateway_token() + if not token: + return {} + + header_name = self.gateway_token_header + prefix = self.gateway_token_prefix + value = f"{prefix}{token}" if prefix else token + return {header_name: value} + + def _ensure_gateway_token(self) -> str | None: + """Ensure we have a valid gateway token, refreshing if needed. + + Returns: + Valid gateway token, or None if gateway auth is not configured. + """ + if not self.gateway_auth_url: + return None + + # Fast path: check if current token is still valid (with 5s buffer) + now = time.time() + if ( + self._gateway_token + and self._gateway_token_expiry + and now < self._gateway_token_expiry - 5 + ): + return self._gateway_token + + # Slow path: acquire lock and refresh token + with self._gateway_lock: + # Double-check after acquiring lock + if ( + self._gateway_token + and self._gateway_token_expiry + and time.time() < self._gateway_token_expiry - 5 + ): + return self._gateway_token + + # Refresh token + return self._refresh_gateway_token() + + def _refresh_gateway_token(self) -> str: + """Fetch a new gateway token from the identity provider. + + This method is called while holding _gateway_lock. + + Returns: + Fresh gateway token. + + Raises: + Exception: If token fetch fails. + """ + assert self.gateway_auth_url is not None, "gateway_auth_url must be set" + method = self.gateway_auth_method.upper() + headers = self._render_templates(self.gateway_auth_headers or {}) + body = self._render_templates(self.gateway_auth_body or {}) + + logger.debug( + f"Fetching gateway token from {self.gateway_auth_url} (method={method})" + ) + + try: + response = httpx.request( + method, + self.gateway_auth_url, + headers=headers if isinstance(headers, dict) else None, + json=body if isinstance(body, dict) else None, + timeout=self.timeout or 30, + ) + response.raise_for_status() + except Exception as exc: + logger.error(f"Gateway auth request failed: {exc}") + raise + + try: + payload = response.json() + except Exception as exc: + logger.error(f"Failed to parse gateway auth response JSON: {exc}") + raise + + # Extract token from response + token_path = self.gateway_auth_token_path + token_value = self._extract_from_path(payload, token_path) + if not isinstance(token_value, str) or not token_value.strip(): + raise ValueError( + f"Gateway auth response did not contain token at path " + f'"{token_path}". Response: {payload}' + ) + + # Determine TTL + ttl_seconds = float(self.gateway_auth_token_ttl or 300) + + # Update cache + self._gateway_token = token_value.strip() + self._gateway_token_expiry = time.time() + max(ttl_seconds, 1.0) + + logger.info(f"Gateway token refreshed successfully (expires in {ttl_seconds}s)") + return self._gateway_token + + def _render_templates(self, value: Any) -> Any: + """Replace template variables in strings with actual values. + + Supports: + - {{llm_model}} -> self.model + - {{llm_base_url}} -> self.base_url + - {{llm_api_key}} -> self.api_key (if set) + + Args: + value: String, dict, list, or other value to render. + + Returns: + Value with templates replaced. + """ + if isinstance(value, str): + replacements: dict[str, str] = { + "{{llm_model}}": self.model, + "{{llm_base_url}}": self.base_url or "", + } + if self.api_key: + replacements["{{llm_api_key}}"] = self.api_key.get_secret_value() + + result = value + for placeholder, actual in replacements.items(): + result = result.replace(placeholder, actual) + return result + + if isinstance(value, dict): + return {k: self._render_templates(v) for k, v in value.items()} + + if isinstance(value, list): + return [self._render_templates(v) for v in value] + + return value + + @staticmethod + def _extract_from_path(payload: Any, path: str) -> Any: + """Extract a value from nested dict/list using dot notation. + + Examples: + _extract_from_path({"a": {"b": "value"}}, "a.b") -> "value" + _extract_from_path({"data": [{"token": "x"}]}, "data.0.token") -> "x" + + Args: + payload: Dict or list to traverse. + path: Dot-separated path (e.g., "data.token" or "items.0.value"). + + Returns: + Value at the specified path. + + Raises: + ValueError: If path cannot be traversed. + """ + current = payload + if not path: + return current + + for part in path.split("."): + if isinstance(current, dict): + current = current.get(part) + if current is None: + raise ValueError( + f'Key "{part}" not found in response ' + f'while traversing path "{path}".' + ) + elif isinstance(current, list): + try: + index = int(part) + except (ValueError, TypeError): + raise ValueError( + f'Invalid list index "{part}" ' + f'while traversing response path "{path}".' + ) from None + try: + current = current[index] + except (IndexError, TypeError): + raise ValueError( + f"Index {index} out of range " + f'while traversing response path "{path}".' + ) from None + else: + raise ValueError( + f'Cannot traverse path "{path}"; ' + f'segment "{part}" not found or not accessible.' + ) + + return current diff --git a/tests/sdk/llm/test_llm_with_gateway.py b/tests/sdk/llm/test_llm_with_gateway.py new file mode 100644 index 0000000000..72a8c26ea9 --- /dev/null +++ b/tests/sdk/llm/test_llm_with_gateway.py @@ -0,0 +1,314 @@ +"""Tests for LLMWithGateway enterprise gateway support.""" + +import time +from typing import Any +from unittest.mock import Mock, patch + +import pytest +from pydantic import SecretStr + +from openhands.sdk.llm import LLMWithGateway, Message, TextContent +from tests.conftest import create_mock_litellm_response + + +@pytest.fixture +def mock_gateway_auth_response(): + """Mock OAuth response from gateway.""" + return { + "access_token": "test-gateway-token-12345", + "token_type": "Bearer", + "expires_in": 3600, + } + + +@pytest.fixture +def gateway_llm(): + """Create LLMWithGateway instance for testing.""" + return LLMWithGateway( + model="gemini-1.5-flash", + api_key=SecretStr("test-api-key"), + base_url="https://gateway.example.com/v1", + gateway_auth_url="https://gateway.example.com/oauth/token", + gateway_auth_headers={ + "X-Client-Id": "test-client-id", + "X-Client-Secret": "test-client-secret", + }, + gateway_auth_body={"grant_type": "client_credentials"}, + gateway_auth_token_ttl=3600, + custom_headers={"X-Custom-Key": "test-custom-value"}, + usage_id="gateway-test-llm", + num_retries=0, # Disable retries for testing + ) + + +class TestLLMWithGatewayInit: + """Test LLMWithGateway initialization.""" + + def test_init_with_gateway_config(self, gateway_llm): + """Test initialization with gateway configuration.""" + assert gateway_llm.gateway_auth_url == "https://gateway.example.com/oauth/token" + assert gateway_llm.gateway_auth_method == "POST" + assert gateway_llm.gateway_auth_headers == { + "X-Client-Id": "test-client-id", + "X-Client-Secret": "test-client-secret", + } + assert gateway_llm.gateway_auth_body == {"grant_type": "client_credentials"} + assert gateway_llm.gateway_auth_token_path == "access_token" + assert gateway_llm.gateway_auth_token_ttl == 3600 + assert gateway_llm.gateway_token_header == "Authorization" + assert gateway_llm.gateway_token_prefix == "Bearer " + assert gateway_llm.custom_headers == {"X-Custom-Key": "test-custom-value"} + + def test_init_without_gateway_config(self): + """Test initialization without gateway configuration (regular LLM).""" + llm = LLMWithGateway( + model="gpt-4", + api_key=SecretStr("test-key"), + usage_id="regular-llm", + ) + assert llm.gateway_auth_url is None + assert llm.custom_headers is None + + +class TestGatewayTokenFetch: + """Test OAuth token fetching and caching.""" + + @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") + def test_fetch_token_success( + self, mock_request, gateway_llm, mock_gateway_auth_response + ): + """Test successful token fetch from gateway.""" + mock_response = Mock() + mock_response.json.return_value = mock_gateway_auth_response + mock_response.raise_for_status = Mock() + mock_request.return_value = mock_response + + token = gateway_llm._ensure_gateway_token() + + assert token == "test-gateway-token-12345" + assert gateway_llm._gateway_token == "test-gateway-token-12345" + assert gateway_llm._gateway_token_expiry is not None + assert gateway_llm._gateway_token_expiry > time.time() + + # Verify request was made correctly + mock_request.assert_called_once_with( + "POST", + "https://gateway.example.com/oauth/token", + headers={ + "X-Client-Id": "test-client-id", + "X-Client-Secret": "test-client-secret", + }, + json={"grant_type": "client_credentials"}, + timeout=30, + ) + + @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") + def test_token_caching(self, mock_request, gateway_llm, mock_gateway_auth_response): + """Test that tokens are cached and not re-fetched unnecessarily.""" + mock_response = Mock() + mock_response.json.return_value = mock_gateway_auth_response + mock_response.raise_for_status = Mock() + mock_request.return_value = mock_response + + # First call should fetch token + token1 = gateway_llm._ensure_gateway_token() + assert mock_request.call_count == 1 + + # Second call should use cached token + token2 = gateway_llm._ensure_gateway_token() + assert mock_request.call_count == 1 # Still only 1 call + assert token1 == token2 + + @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") + def test_token_refresh_when_expired( + self, mock_request, gateway_llm, mock_gateway_auth_response + ): + """Test that token is refreshed when expired.""" + mock_response = Mock() + mock_response.json.return_value = mock_gateway_auth_response + mock_response.raise_for_status = Mock() + mock_request.return_value = mock_response + + # Fetch initial token + gateway_llm._ensure_gateway_token() + assert mock_request.call_count == 1 + + # Manually expire the token + gateway_llm._gateway_token_expiry = time.time() - 10 + + # Next call should refresh + gateway_llm._ensure_gateway_token() + assert mock_request.call_count == 2 + + @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") + def test_token_fetch_missing_token(self, mock_request, gateway_llm): + """Test handling of response without token field.""" + mock_response = Mock() + mock_response.json.return_value = {"error": "invalid_client"} + mock_response.raise_for_status = Mock() + mock_request.return_value = mock_response + + with pytest.raises(ValueError, match="not found in response"): + gateway_llm._ensure_gateway_token() + + +class TestTemplateReplacement: + """Test template variable replacement.""" + + def test_render_templates_string(self, gateway_llm): + """Test template replacement in strings.""" + template = "Model: {{llm_model}}, URL: {{llm_base_url}}" + result = gateway_llm._render_templates(template) + assert result == "Model: gemini-1.5-flash, URL: https://gateway.example.com/v1" + + def test_render_templates_dict(self, gateway_llm): + """Test template replacement in dictionaries.""" + template = { + "model": "{{llm_model}}", + "endpoint": "{{llm_base_url}}/chat", + "nested": {"key": "{{llm_model}}"}, + } + result = gateway_llm._render_templates(template) + assert result["model"] == "gemini-1.5-flash" + assert result["endpoint"] == "https://gateway.example.com/v1/chat" + assert result["nested"]["key"] == "gemini-1.5-flash" + + def test_render_templates_list(self, gateway_llm): + """Test template replacement in lists.""" + template = ["{{llm_model}}", {"url": "{{llm_base_url}}"}] + result = gateway_llm._render_templates(template) + assert result[0] == "gemini-1.5-flash" + assert result[1]["url"] == "https://gateway.example.com/v1" + + def test_render_templates_with_api_key(self, gateway_llm): + """Test template replacement includes API key.""" + template = "Key: {{llm_api_key}}" + result = gateway_llm._render_templates(template) + assert result == "Key: test-api-key" + + def test_render_templates_no_base_url(self): + """Test template replacement when base_url is not set.""" + llm = LLMWithGateway( + model="gpt-4", + api_key=SecretStr("key"), + usage_id="test", + ) + template = "URL: {{llm_base_url}}" + result = llm._render_templates(template) + assert result == "URL: " + + +class TestPathExtraction: + """Test nested path extraction from OAuth responses.""" + + def test_extract_simple_path(self): + """Test extraction from simple path.""" + payload = {"access_token": "token123"} + result = LLMWithGateway._extract_from_path(payload, "access_token") + assert result == "token123" + + def test_extract_nested_path(self): + """Test extraction from nested path.""" + payload = {"data": {"auth": {"token": "token456"}}} + result = LLMWithGateway._extract_from_path(payload, "data.auth.token") + assert result == "token456" + + def test_extract_from_array(self): + """Test extraction from array.""" + payload = {"tokens": [{"value": "token1"}, {"value": "token2"}]} + result = LLMWithGateway._extract_from_path(payload, "tokens.0.value") + assert result == "token1" + + def test_extract_empty_path(self): + """Test extraction with empty path returns root.""" + payload = {"key": "value"} + result = LLMWithGateway._extract_from_path(payload, "") + assert result == payload + + def test_extract_missing_key(self): + """Test extraction fails for missing key.""" + payload = {"other": "value"} + with pytest.raises(ValueError, match="not found"): + LLMWithGateway._extract_from_path(payload, "missing") + + def test_extract_invalid_array_index(self): + """Test extraction fails for invalid array index.""" + payload = {"items": ["a", "b"]} + with pytest.raises(ValueError, match="Invalid list index"): + LLMWithGateway._extract_from_path(payload, "items.invalid") + + def test_extract_array_index_out_of_range(self): + """Test extraction fails for out of range index.""" + payload = {"items": ["a", "b"]} + with pytest.raises(ValueError, match="out of range"): + LLMWithGateway._extract_from_path(payload, "items.5") + + +class TestGatewayIntegration: + """Integration tests for gateway functionality.""" + + @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") + @patch("openhands.sdk.llm.llm.litellm_completion") + def test_full_gateway_flow( + self, mock_completion, mock_request, mock_gateway_auth_response + ): + """Test complete flow: OAuth -> LLM request with headers.""" + # Setup gateway LLM + llm = LLMWithGateway( + model="gpt-4", + base_url="https://gateway.example.com/llm/v1", + gateway_auth_url="https://gateway.example.com/oauth/token", + gateway_auth_headers={ + "X-Client-Id": "client123", + "X-Client-Secret": "secret456", + }, + gateway_auth_body={"grant_type": "client_credentials"}, + custom_headers={"X-Gateway-Key": "gateway789"}, + usage_id="integration-test", + num_retries=0, + ) + + # Mock OAuth response + mock_oauth_response = Mock() + mock_oauth_response.json.return_value = mock_gateway_auth_response + mock_oauth_response.raise_for_status = Mock() + mock_request.return_value = mock_oauth_response + + # Mock LLM completion + mock_completion.return_value = create_mock_litellm_response( + content="Hello from gateway!" + ) + + # Make completion request + messages = [Message(role="user", content=[TextContent(text="Hello")])] + response = llm.completion(messages) + + # Verify OAuth was called + assert mock_request.call_count == 1 + oauth_call = mock_request.call_args + assert oauth_call[0][0] == "POST" + assert oauth_call[0][1] == "https://gateway.example.com/oauth/token" + + # Verify LLM completion was called with correct headers + assert mock_completion.call_count == 1 + completion_kwargs = mock_completion.call_args[1] + headers = completion_kwargs["extra_headers"] + assert headers["Authorization"] == "Bearer test-gateway-token-12345" + assert headers["X-Gateway-Key"] == "gateway789" + + # Verify response + assert isinstance(response.message.content[0], TextContent) + assert response.message.content[0].text == "Hello from gateway!" + + def test_gateway_disabled_when_no_config(self): + """Test that gateway logic is skipped when not configured.""" + llm = LLMWithGateway( + model="gpt-4", + api_key=SecretStr("key"), + usage_id="no-gateway-test", + ) + + # Should not fail, just act like regular LLM + kwargs: dict[str, Any] = {} + llm._prepare_gateway_call(kwargs) + assert "extra_headers" not in kwargs From 4cfd5a4ab318af74f5b89b5160b6b90307e8e492 Mon Sep 17 00:00:00 2001 From: Alona King Date: Thu, 30 Oct 2025 14:24:16 -0400 Subject: [PATCH 02/10] feat: improve token TTL handling and add extended thinking support - Auto-detect token expiry from OAuth expires_in field when available - Fall back to 300s default when expires_in not provided - Allow explicit TTL override via gateway_auth_token_ttl - Fix method override to use _transport_call instead of completion - Add extended thinking header merge test - Add 3 new TTL handling tests (expires_in, fallback, override) --- .../openhands/sdk/llm/llm_with_gateway.py | 43 +++++- tests/sdk/llm/test_llm_with_gateway.py | 127 ++++++++++++++++++ 2 files changed, 163 insertions(+), 7 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py index 4944ccd585..8b425b6b43 100644 --- a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py +++ b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py @@ -11,6 +11,7 @@ from typing import Any import httpx +from litellm.types.utils import ModelResponse from pydantic import Field, PrivateAttr from openhands.sdk.llm.llm import LLM @@ -71,7 +72,10 @@ class LLMWithGateway(LLM): ) gateway_auth_token_ttl: int | None = Field( default=None, - description="Token TTL in seconds. If not set, defaults to 300s (5 minutes).", + description=( + "Token TTL in seconds. If not set, uses `expires_in` from the OAuth" + " response when available, falling back to 300s (5 minutes)." + ), ) # Token header configuration @@ -102,16 +106,18 @@ def model_post_init(self, __context: Any) -> None: self._gateway_token = None self._gateway_token_expiry = None - def completion(self, *args, **kwargs): - """Override to inject gateway authentication before calling LiteLLM.""" - self._prepare_gateway_call(kwargs) - return super().completion(*args, **kwargs) - def responses(self, *args, **kwargs): """Override to inject gateway authentication before calling LiteLLM.""" self._prepare_gateway_call(kwargs) return super().responses(*args, **kwargs) + def _transport_call( + self, *, messages: list[dict[str, Any]], **kwargs + ) -> ModelResponse: + """Inject gateway headers just before delegating to LiteLLM.""" + self._prepare_gateway_call(kwargs) + return super()._transport_call(messages=messages, **kwargs) + def _prepare_gateway_call(self, call_kwargs: dict[str, Any]) -> None: """Augment LiteLLM kwargs with gateway headers and token. @@ -237,7 +243,30 @@ def _refresh_gateway_token(self) -> str: ) # Determine TTL - ttl_seconds = float(self.gateway_auth_token_ttl or 300) + ttl_seconds: float | None = None + if self.gateway_auth_token_ttl is not None: + try: + ttl_seconds = float(self.gateway_auth_token_ttl) + except (TypeError, ValueError): # pragma: no cover - defensive + logger.warning( + "Configured gateway_auth_token_ttl is not numeric; falling back" + ) + ttl_seconds = None + else: + expires_in = None + if isinstance(payload, dict): + expires_in = payload.get("expires_in") + if expires_in is not None: + try: + ttl_seconds = float(expires_in) + except (TypeError, ValueError): + logger.warning( + "Invalid expires_in value %r from gateway; using default", + expires_in, + ) + + if ttl_seconds is None or ttl_seconds <= 0: + ttl_seconds = 300.0 # Update cache self._gateway_token = token_value.strip() diff --git a/tests/sdk/llm/test_llm_with_gateway.py b/tests/sdk/llm/test_llm_with_gateway.py index 72a8c26ea9..f0f4a3879c 100644 --- a/tests/sdk/llm/test_llm_with_gateway.py +++ b/tests/sdk/llm/test_llm_with_gateway.py @@ -151,6 +151,93 @@ def test_token_fetch_missing_token(self, mock_request, gateway_llm): with pytest.raises(ValueError, match="not found in response"): gateway_llm._ensure_gateway_token() + @patch("openhands.sdk.llm.llm_with_gateway.time.time") + @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") + def test_token_ttl_uses_expires_in_by_default( + self, mock_request, mock_time + ) -> None: + """Token expiry should honor expires_in when TTL override not configured.""" + mock_time.return_value = 1_000.0 + + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "token-expires-in", + "expires_in": 120, + } + mock_response.raise_for_status = Mock() + mock_request.return_value = mock_response + + llm = LLMWithGateway( + model="gpt-4", + api_key=SecretStr("key"), + base_url="https://gateway.example.com/llm/v1", + gateway_auth_url="https://gateway.example.com/oauth/token", + gateway_auth_headers={"X-Client-Id": "client"}, + gateway_auth_body={"grant_type": "client_credentials"}, + usage_id="ttl-expires-in-test", + num_retries=0, + ) + + token = llm._ensure_gateway_token() + + assert token == "token-expires-in" + assert llm._gateway_token_expiry == pytest.approx(1_120.0, abs=0.1) + + @patch("openhands.sdk.llm.llm_with_gateway.time.time") + @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") + def test_token_ttl_falls_back_to_default(self, mock_request, mock_time) -> None: + """Missing expires_in should fall back to default TTL.""" + mock_time.return_value = 2_000.0 + + mock_response = Mock() + mock_response.json.return_value = {"access_token": "token-default"} + mock_response.raise_for_status = Mock() + mock_request.return_value = mock_response + + llm = LLMWithGateway( + model="gpt-4", + api_key=SecretStr("key"), + base_url="https://gateway.example.com/llm/v1", + gateway_auth_url="https://gateway.example.com/oauth/token", + usage_id="ttl-default-test", + num_retries=0, + ) + + llm._ensure_gateway_token() + + assert llm._gateway_token_expiry == pytest.approx(2_300.0, abs=0.1) + + @patch("openhands.sdk.llm.llm_with_gateway.time.time") + @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") + def test_token_ttl_prefers_configured_override( + self, mock_request, mock_time + ) -> None: + """Configured TTL should override expires_in from response.""" + mock_time.return_value = 3_000.0 + + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "token-override", + "expires_in": 3_600, + } + mock_response.raise_for_status = Mock() + mock_request.return_value = mock_response + + llm = LLMWithGateway( + model="gpt-4", + api_key=SecretStr("key"), + base_url="https://gateway.example.com/llm/v1", + gateway_auth_url="https://gateway.example.com/oauth/token", + gateway_auth_body={"grant_type": "client_credentials"}, + gateway_auth_token_ttl=45, + usage_id="ttl-override-test", + num_retries=0, + ) + + llm._ensure_gateway_token() + + assert llm._gateway_token_expiry == pytest.approx(3_045.0, abs=0.1) + class TestTemplateReplacement: """Test template variable replacement.""" @@ -300,6 +387,46 @@ def test_full_gateway_flow( assert isinstance(response.message.content[0], TextContent) assert response.message.content[0].text == "Hello from gateway!" + @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") + @patch("openhands.sdk.llm.llm.litellm_completion") + def test_gateway_headers_merge_with_extended_thinking( + self, mock_completion, mock_request, mock_gateway_auth_response + ): + """Gateway headers should merge with extended thinking defaults.""" + mock_oauth_response = Mock() + mock_oauth_response.json.return_value = mock_gateway_auth_response + mock_oauth_response.raise_for_status = Mock() + mock_request.return_value = mock_oauth_response + + mock_completion.return_value = create_mock_litellm_response( + content="extended thinking response" + ) + + llm = LLMWithGateway( + model="claude-sonnet-4-5-latest", + api_key=SecretStr("test-api-key"), + base_url="https://gateway.example.com/llm/v1", + gateway_auth_url="https://gateway.example.com/oauth/token", + gateway_auth_headers={ + "X-Client-Id": "client123", + "X-Client-Secret": "secret456", + }, + gateway_auth_body={"grant_type": "client_credentials"}, + custom_headers={"X-Gateway-Key": "gateway789"}, + extended_thinking_budget=512, + usage_id="extended-thinking-test", + num_retries=0, + ) + + messages = [Message(role="user", content=[TextContent(text="Hello")])] + llm.completion(messages) + + completion_kwargs = mock_completion.call_args[1] + headers = completion_kwargs["extra_headers"] + assert headers["Authorization"] == "Bearer test-gateway-token-12345" + assert headers["X-Gateway-Key"] == "gateway789" + assert headers["anthropic-beta"] == "interleaved-thinking-2025-05-14" + def test_gateway_disabled_when_no_config(self): """Test that gateway logic is skipped when not configured.""" llm = LLMWithGateway( From 28eb5a1b604efc24d496d51313d8f12080a38d9b Mon Sep 17 00:00:00 2001 From: Alona King Date: Thu, 30 Oct 2025 18:59:24 -0400 Subject: [PATCH 03/10] feat: pass custom_llm_provider to litellm calls Add custom_llm_provider parameter to both litellm_completion and litellm_responses calls to support custom provider configurations. --- openhands-sdk/openhands/sdk/llm/llm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 1972a2005f..a09cc71c63 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -598,6 +598,7 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: else None, api_base=self.base_url, api_version=self.api_version, + custom_llm_provider=self.custom_llm_provider, timeout=self.timeout, drop_params=self.drop_params, seed=self.seed, @@ -666,6 +667,7 @@ def _transport_call( api_key=self.api_key.get_secret_value() if self.api_key else None, base_url=self.base_url, api_version=self.api_version, + custom_llm_provider=self.custom_llm_provider, timeout=self.timeout, drop_params=self.drop_params, seed=self.seed, From ec34c9359090f0fdb9d40bdf90aa73acc0a4bc94 Mon Sep 17 00:00:00 2001 From: Alona King Date: Fri, 31 Oct 2025 10:52:03 -0400 Subject: [PATCH 04/10] Add ssl_verify support for gateway LLM calls --- openhands-sdk/openhands/sdk/llm/llm.py | 9 +++++++++ openhands-sdk/openhands/sdk/llm/llm_with_gateway.py | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index a09cc71c63..9d8b01fc1f 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -162,6 +162,14 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): ) ollama_base_url: str | None = Field(default=None) + ssl_verify: bool | str | None = Field( + default=None, + description=( + "TLS verification forwarded to LiteLLM; " + "set to False when corporate proxies break certificate chains." + ), + ) + drop_params: bool = Field(default=True) modify_params: bool = Field( default=True, @@ -669,6 +677,7 @@ def _transport_call( api_version=self.api_version, custom_llm_provider=self.custom_llm_provider, timeout=self.timeout, + ssl_verify=self.ssl_verify, drop_params=self.drop_params, seed=self.seed, messages=messages, diff --git a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py index 8b425b6b43..0285c16f19 100644 --- a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py +++ b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py @@ -214,6 +214,10 @@ def _refresh_gateway_token(self) -> str: f"Fetching gateway token from {self.gateway_auth_url} (method={method})" ) + request_kwargs: dict[str, Any] = {} + if self.ssl_verify is not None: + request_kwargs["verify"] = self.ssl_verify + try: response = httpx.request( method, @@ -221,6 +225,7 @@ def _refresh_gateway_token(self) -> str: headers=headers if isinstance(headers, dict) else None, json=body if isinstance(body, dict) else None, timeout=self.timeout or 30, + **request_kwargs, ) response.raise_for_status() except Exception as exc: From 1029d4f64d857a9014d6bcbf4f63eeb3998403a6 Mon Sep 17 00:00:00 2001 From: Alona King Date: Fri, 31 Oct 2025 13:44:40 -0400 Subject: [PATCH 05/10] more tls fixes --- openhands-sdk/openhands/sdk/llm/llm.py | 144 ++++++++++++++----------- 1 file changed, 80 insertions(+), 64 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 9d8b01fc1f..8c2207764b 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -591,34 +591,35 @@ def responses( def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: final_kwargs = {**call_kwargs, **retry_kwargs} with self._litellm_modify_params_ctx(self.modify_params): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - typed_input: ResponseInputParam | str = ( - cast(ResponseInputParam, input_items) if input_items else "" - ) - ret = litellm_responses( - model=self.model, - input=typed_input, - instructions=instructions, - tools=resp_tools, - api_key=self.api_key.get_secret_value() - if self.api_key - else None, - api_base=self.base_url, - api_version=self.api_version, - custom_llm_provider=self.custom_llm_provider, - timeout=self.timeout, - drop_params=self.drop_params, - seed=self.seed, - **final_kwargs, - ) - assert isinstance(ret, ResponsesAPIResponse), ( - f"Expected ResponsesAPIResponse, got {type(ret)}" - ) - # telemetry (latency, cost). Token usage mapping we handle after. - assert self._telemetry is not None - self._telemetry.on_response(ret) - return ret + with self._litellm_ssl_verify_ctx(): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + typed_input: ResponseInputParam | str = ( + cast(ResponseInputParam, input_items) if input_items else "" + ) + ret = litellm_responses( + model=self.model, + input=typed_input, + instructions=instructions, + tools=resp_tools, + api_key=self.api_key.get_secret_value() + if self.api_key + else None, + api_base=self.base_url, + api_version=self.api_version, + custom_llm_provider=self.custom_llm_provider, + timeout=self.timeout, + drop_params=self.drop_params, + seed=self.seed, + **final_kwargs, + ) + assert isinstance(ret, ResponsesAPIResponse), ( + f"Expected ResponsesAPIResponse, got {type(ret)}" + ) + # telemetry (latency, cost). Token usage mapping we handle after. + assert self._telemetry is not None + self._telemetry.on_response(ret) + return ret try: resp: ResponsesAPIResponse = _one_attempt() @@ -651,42 +652,45 @@ def _transport_call( ) -> ModelResponse: # litellm.modify_params is GLOBAL; guard it for thread-safety with self._litellm_modify_params_ctx(self.modify_params): - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", category=DeprecationWarning, module="httpx.*" - ) - warnings.filterwarnings( - "ignore", - message=r".*content=.*upload.*", - category=DeprecationWarning, - ) - warnings.filterwarnings( - "ignore", - message=r"There is no current event loop", - category=DeprecationWarning, - ) - warnings.filterwarnings( - "ignore", - category=UserWarning, - ) - # Some providers need renames handled in _normalize_call_kwargs. - ret = litellm_completion( - model=self.model, - api_key=self.api_key.get_secret_value() if self.api_key else None, - base_url=self.base_url, - api_version=self.api_version, - custom_llm_provider=self.custom_llm_provider, - timeout=self.timeout, - ssl_verify=self.ssl_verify, - drop_params=self.drop_params, - seed=self.seed, - messages=messages, - **kwargs, - ) - assert isinstance(ret, ModelResponse), ( - f"Expected ModelResponse, got {type(ret)}" - ) - return ret + with self._litellm_ssl_verify_ctx(): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=DeprecationWarning, module="httpx.*" + ) + warnings.filterwarnings( + "ignore", + message=r".*content=.*upload.*", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"There is no current event loop", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + category=UserWarning, + ) + # Some providers need renames handled in _normalize_call_kwargs. + ret = litellm_completion( + model=self.model, + api_key=self.api_key.get_secret_value() + if self.api_key + else None, + base_url=self.base_url, + api_version=self.api_version, + custom_llm_provider=self.custom_llm_provider, + timeout=self.timeout, + ssl_verify=self.ssl_verify, + drop_params=self.drop_params, + seed=self.seed, + messages=messages, + **kwargs, + ) + assert isinstance(ret, ModelResponse), ( + f"Expected ModelResponse, got {type(ret)}" + ) + return ret @contextmanager def _litellm_modify_params_ctx(self, flag: bool): @@ -697,6 +701,18 @@ def _litellm_modify_params_ctx(self, flag: bool): finally: litellm.modify_params = old + @contextmanager + def _litellm_ssl_verify_ctx(self): + if self.ssl_verify is None: + yield + return + old = getattr(litellm, "ssl_verify", None) + try: + litellm.ssl_verify = self.ssl_verify + yield + finally: + litellm.ssl_verify = old + # ========================================================================= # Capabilities, formatting, and info # ========================================================================= From 10fdfbcb77ed9dcd39363cf6e17f66db074a7b4e Mon Sep 17 00:00:00 2001 From: Alona King Date: Fri, 31 Oct 2025 14:55:45 -0400 Subject: [PATCH 06/10] simply PR --- .../openhands/sdk/llm/llm_with_gateway.py | 285 +----------- tests/sdk/llm/test_llm_with_gateway.py | 434 +++--------------- 2 files changed, 63 insertions(+), 656 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py index 0285c16f19..d6b63912f7 100644 --- a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py +++ b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py @@ -1,24 +1,17 @@ """LLM subclass with enterprise gateway support. This module provides LLMWithGateway, which extends the base LLM class to support -OAuth 2.0 authentication flows and custom headers for enterprise API gateways. +custom headers for enterprise API gateways. """ from __future__ import annotations -import threading -import time from typing import Any -import httpx from litellm.types.utils import ModelResponse -from pydantic import Field, PrivateAttr +from pydantic import Field from openhands.sdk.llm.llm import LLM -from openhands.sdk.logger import get_logger - - -logger = get_logger(__name__) __all__ = ["LLMWithGateway"] @@ -26,88 +19,17 @@ class LLMWithGateway(LLM): """LLM subclass with enterprise gateway support. - Supports OAuth 2.0 token exchange with configurable headers and bodies. - Designed for enterprise API gateways that require: - 1. Initial OAuth call to get a bearer token - 2. Bearer token included in subsequent LLM API calls - 3. Custom headers for routing/authentication - - Example usage: - llm = LLMWithGateway( - model="gpt-4", - base_url="https://gateway.company.com/llm/v1", - gateway_auth_url="https://gateway.company.com/oauth/token", - gateway_auth_headers={ - "X-Client-Id": os.environ["GATEWAY_CLIENT_ID"], - "X-Client-Secret": os.environ["GATEWAY_CLIENT_SECRET"], - }, - gateway_auth_body={"grant_type": "client_credentials"}, - custom_headers={"X-Gateway-Key": os.environ["GATEWAY_API_KEY"]}, - ) + Supports adding custom headers on each request with optional template + rendering against LLM attributes. """ - # OAuth configuration - gateway_auth_url: str | None = Field( - default=None, - description="Identity provider URL to fetch gateway tokens (OAuth endpoint).", - ) - gateway_auth_method: str = Field( - default="POST", - description="HTTP method for identity provider requests.", - ) - gateway_auth_headers: dict[str, str] | None = Field( - default=None, - description="Headers to include when calling the identity provider.", - ) - gateway_auth_body: dict[str, Any] | None = Field( - default=None, - description="JSON body to include when calling the identity provider.", - ) - gateway_auth_token_path: str = Field( - default="access_token", - description=( - "Dot-notation path to the token in the OAuth response " - "(e.g., 'access_token' or 'data.token')." - ), - ) - gateway_auth_token_ttl: int | None = Field( - default=None, - description=( - "Token TTL in seconds. If not set, uses `expires_in` from the OAuth" - " response when available, falling back to 300s (5 minutes)." - ), - ) - - # Token header configuration - gateway_token_header: str = Field( - default="Authorization", - description="Header name for the gateway token (defaults to 'Authorization').", - ) - gateway_token_prefix: str = Field( - default="Bearer ", - description="Prefix prepended to the token (e.g., 'Bearer ').", - ) - - # Custom headers for all requests custom_headers: dict[str, str] | None = Field( default=None, description="Custom headers to include with every LLM request.", ) - # Private fields for token management - _gateway_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) - _gateway_token: str | None = PrivateAttr(default=None) - _gateway_token_expiry: float | None = PrivateAttr(default=None) - - def model_post_init(self, __context: Any) -> None: - """Initialize private fields after model validation.""" - super().model_post_init(__context) - self._gateway_lock = threading.Lock() - self._gateway_token = None - self._gateway_token_expiry = None - def responses(self, *args, **kwargs): - """Override to inject gateway authentication before calling LiteLLM.""" + """Override to inject gateway headers before calling LiteLLM.""" self._prepare_gateway_call(kwargs) return super().responses(*args, **kwargs) @@ -119,15 +41,13 @@ def _transport_call( return super()._transport_call(messages=messages, **kwargs) def _prepare_gateway_call(self, call_kwargs: dict[str, Any]) -> None: - """Augment LiteLLM kwargs with gateway headers and token. + """Augment LiteLLM kwargs with gateway headers. This method: - 1. Fetches/refreshes OAuth token if needed - 2. Adds token to headers - 3. Adds custom headers - 4. Performs basic template variable replacement + 1. Adds custom headers + 2. Performs basic template variable replacement """ - if not self.gateway_auth_url and not self.custom_headers: + if not self.custom_headers: return # Start with existing headers @@ -142,144 +62,10 @@ def _prepare_gateway_call(self, call_kwargs: dict[str, Any]) -> None: if isinstance(rendered_headers, dict): headers.update(rendered_headers) - # Add gateway token if OAuth is configured - if self.gateway_auth_url: - token_headers = self._get_gateway_token_headers() - if token_headers: - headers.update(token_headers) - # Set headers on the call if headers: call_kwargs["extra_headers"] = headers - def _get_gateway_token_headers(self) -> dict[str, str]: - """Get headers containing the gateway token.""" - token = self._ensure_gateway_token() - if not token: - return {} - - header_name = self.gateway_token_header - prefix = self.gateway_token_prefix - value = f"{prefix}{token}" if prefix else token - return {header_name: value} - - def _ensure_gateway_token(self) -> str | None: - """Ensure we have a valid gateway token, refreshing if needed. - - Returns: - Valid gateway token, or None if gateway auth is not configured. - """ - if not self.gateway_auth_url: - return None - - # Fast path: check if current token is still valid (with 5s buffer) - now = time.time() - if ( - self._gateway_token - and self._gateway_token_expiry - and now < self._gateway_token_expiry - 5 - ): - return self._gateway_token - - # Slow path: acquire lock and refresh token - with self._gateway_lock: - # Double-check after acquiring lock - if ( - self._gateway_token - and self._gateway_token_expiry - and time.time() < self._gateway_token_expiry - 5 - ): - return self._gateway_token - - # Refresh token - return self._refresh_gateway_token() - - def _refresh_gateway_token(self) -> str: - """Fetch a new gateway token from the identity provider. - - This method is called while holding _gateway_lock. - - Returns: - Fresh gateway token. - - Raises: - Exception: If token fetch fails. - """ - assert self.gateway_auth_url is not None, "gateway_auth_url must be set" - method = self.gateway_auth_method.upper() - headers = self._render_templates(self.gateway_auth_headers or {}) - body = self._render_templates(self.gateway_auth_body or {}) - - logger.debug( - f"Fetching gateway token from {self.gateway_auth_url} (method={method})" - ) - - request_kwargs: dict[str, Any] = {} - if self.ssl_verify is not None: - request_kwargs["verify"] = self.ssl_verify - - try: - response = httpx.request( - method, - self.gateway_auth_url, - headers=headers if isinstance(headers, dict) else None, - json=body if isinstance(body, dict) else None, - timeout=self.timeout or 30, - **request_kwargs, - ) - response.raise_for_status() - except Exception as exc: - logger.error(f"Gateway auth request failed: {exc}") - raise - - try: - payload = response.json() - except Exception as exc: - logger.error(f"Failed to parse gateway auth response JSON: {exc}") - raise - - # Extract token from response - token_path = self.gateway_auth_token_path - token_value = self._extract_from_path(payload, token_path) - if not isinstance(token_value, str) or not token_value.strip(): - raise ValueError( - f"Gateway auth response did not contain token at path " - f'"{token_path}". Response: {payload}' - ) - - # Determine TTL - ttl_seconds: float | None = None - if self.gateway_auth_token_ttl is not None: - try: - ttl_seconds = float(self.gateway_auth_token_ttl) - except (TypeError, ValueError): # pragma: no cover - defensive - logger.warning( - "Configured gateway_auth_token_ttl is not numeric; falling back" - ) - ttl_seconds = None - else: - expires_in = None - if isinstance(payload, dict): - expires_in = payload.get("expires_in") - if expires_in is not None: - try: - ttl_seconds = float(expires_in) - except (TypeError, ValueError): - logger.warning( - "Invalid expires_in value %r from gateway; using default", - expires_in, - ) - - if ttl_seconds is None or ttl_seconds <= 0: - ttl_seconds = 300.0 - - # Update cache - self._gateway_token = token_value.strip() - self._gateway_token_expiry = time.time() + max(ttl_seconds, 1.0) - - logger.info(f"Gateway token refreshed successfully (expires in {ttl_seconds}s)") - return self._gateway_token - def _render_templates(self, value: Any) -> Any: """Replace template variables in strings with actual values. @@ -314,56 +100,3 @@ def _render_templates(self, value: Any) -> Any: return [self._render_templates(v) for v in value] return value - - @staticmethod - def _extract_from_path(payload: Any, path: str) -> Any: - """Extract a value from nested dict/list using dot notation. - - Examples: - _extract_from_path({"a": {"b": "value"}}, "a.b") -> "value" - _extract_from_path({"data": [{"token": "x"}]}, "data.0.token") -> "x" - - Args: - payload: Dict or list to traverse. - path: Dot-separated path (e.g., "data.token" or "items.0.value"). - - Returns: - Value at the specified path. - - Raises: - ValueError: If path cannot be traversed. - """ - current = payload - if not path: - return current - - for part in path.split("."): - if isinstance(current, dict): - current = current.get(part) - if current is None: - raise ValueError( - f'Key "{part}" not found in response ' - f'while traversing path "{path}".' - ) - elif isinstance(current, list): - try: - index = int(part) - except (ValueError, TypeError): - raise ValueError( - f'Invalid list index "{part}" ' - f'while traversing response path "{path}".' - ) from None - try: - current = current[index] - except (IndexError, TypeError): - raise ValueError( - f"Index {index} out of range " - f'while traversing response path "{path}".' - ) from None - else: - raise ValueError( - f'Cannot traverse path "{path}"; ' - f'segment "{part}" not found or not accessible.' - ) - - return current diff --git a/tests/sdk/llm/test_llm_with_gateway.py b/tests/sdk/llm/test_llm_with_gateway.py index f0f4a3879c..9f239b8aaf 100644 --- a/tests/sdk/llm/test_llm_with_gateway.py +++ b/tests/sdk/llm/test_llm_with_gateway.py @@ -1,280 +1,110 @@ -"""Tests for LLMWithGateway enterprise gateway support.""" +"""Tests for LLMWithGateway custom header support.""" -import time -from typing import Any -from unittest.mock import Mock, patch +from __future__ import annotations + +from unittest.mock import patch -import pytest from pydantic import SecretStr from openhands.sdk.llm import LLMWithGateway, Message, TextContent from tests.conftest import create_mock_litellm_response -@pytest.fixture -def mock_gateway_auth_response(): - """Mock OAuth response from gateway.""" - return { - "access_token": "test-gateway-token-12345", - "token_type": "Bearer", - "expires_in": 3600, - } - - -@pytest.fixture -def gateway_llm(): - """Create LLMWithGateway instance for testing.""" +def create_llm(custom_headers: dict[str, str] | None = None) -> LLMWithGateway: + """Helper to build an LLMWithGateway for tests.""" return LLMWithGateway( model="gemini-1.5-flash", api_key=SecretStr("test-api-key"), base_url="https://gateway.example.com/v1", - gateway_auth_url="https://gateway.example.com/oauth/token", - gateway_auth_headers={ - "X-Client-Id": "test-client-id", - "X-Client-Secret": "test-client-secret", - }, - gateway_auth_body={"grant_type": "client_credentials"}, - gateway_auth_token_ttl=3600, - custom_headers={"X-Custom-Key": "test-custom-value"}, + custom_headers=custom_headers, usage_id="gateway-test-llm", - num_retries=0, # Disable retries for testing + num_retries=0, ) -class TestLLMWithGatewayInit: - """Test LLMWithGateway initialization.""" +class TestInitialization: + """Basic initialization behaviour.""" - def test_init_with_gateway_config(self, gateway_llm): - """Test initialization with gateway configuration.""" - assert gateway_llm.gateway_auth_url == "https://gateway.example.com/oauth/token" - assert gateway_llm.gateway_auth_method == "POST" - assert gateway_llm.gateway_auth_headers == { - "X-Client-Id": "test-client-id", - "X-Client-Secret": "test-client-secret", - } - assert gateway_llm.gateway_auth_body == {"grant_type": "client_credentials"} - assert gateway_llm.gateway_auth_token_path == "access_token" - assert gateway_llm.gateway_auth_token_ttl == 3600 - assert gateway_llm.gateway_token_header == "Authorization" - assert gateway_llm.gateway_token_prefix == "Bearer " - assert gateway_llm.custom_headers == {"X-Custom-Key": "test-custom-value"} - - def test_init_without_gateway_config(self): - """Test initialization without gateway configuration (regular LLM).""" - llm = LLMWithGateway( - model="gpt-4", - api_key=SecretStr("test-key"), - usage_id="regular-llm", - ) - assert llm.gateway_auth_url is None + def test_defaults(self) -> None: + llm = create_llm() assert llm.custom_headers is None + def test_custom_headers_configuration(self) -> None: + headers = {"X-Custom-Key": "value"} + llm = create_llm(custom_headers=headers) + assert llm.custom_headers == headers -class TestGatewayTokenFetch: - """Test OAuth token fetching and caching.""" - - @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") - def test_fetch_token_success( - self, mock_request, gateway_llm, mock_gateway_auth_response - ): - """Test successful token fetch from gateway.""" - mock_response = Mock() - mock_response.json.return_value = mock_gateway_auth_response - mock_response.raise_for_status = Mock() - mock_request.return_value = mock_response - - token = gateway_llm._ensure_gateway_token() - - assert token == "test-gateway-token-12345" - assert gateway_llm._gateway_token == "test-gateway-token-12345" - assert gateway_llm._gateway_token_expiry is not None - assert gateway_llm._gateway_token_expiry > time.time() - - # Verify request was made correctly - mock_request.assert_called_once_with( - "POST", - "https://gateway.example.com/oauth/token", - headers={ - "X-Client-Id": "test-client-id", - "X-Client-Secret": "test-client-secret", - }, - json={"grant_type": "client_credentials"}, - timeout=30, - ) - - @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") - def test_token_caching(self, mock_request, gateway_llm, mock_gateway_auth_response): - """Test that tokens are cached and not re-fetched unnecessarily.""" - mock_response = Mock() - mock_response.json.return_value = mock_gateway_auth_response - mock_response.raise_for_status = Mock() - mock_request.return_value = mock_response - - # First call should fetch token - token1 = gateway_llm._ensure_gateway_token() - assert mock_request.call_count == 1 - - # Second call should use cached token - token2 = gateway_llm._ensure_gateway_token() - assert mock_request.call_count == 1 # Still only 1 call - assert token1 == token2 - - @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") - def test_token_refresh_when_expired( - self, mock_request, gateway_llm, mock_gateway_auth_response - ): - """Test that token is refreshed when expired.""" - mock_response = Mock() - mock_response.json.return_value = mock_gateway_auth_response - mock_response.raise_for_status = Mock() - mock_request.return_value = mock_response - - # Fetch initial token - gateway_llm._ensure_gateway_token() - assert mock_request.call_count == 1 - - # Manually expire the token - gateway_llm._gateway_token_expiry = time.time() - 10 - - # Next call should refresh - gateway_llm._ensure_gateway_token() - assert mock_request.call_count == 2 - - @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") - def test_token_fetch_missing_token(self, mock_request, gateway_llm): - """Test handling of response without token field.""" - mock_response = Mock() - mock_response.json.return_value = {"error": "invalid_client"} - mock_response.raise_for_status = Mock() - mock_request.return_value = mock_response - - with pytest.raises(ValueError, match="not found in response"): - gateway_llm._ensure_gateway_token() - - @patch("openhands.sdk.llm.llm_with_gateway.time.time") - @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") - def test_token_ttl_uses_expires_in_by_default( - self, mock_request, mock_time - ) -> None: - """Token expiry should honor expires_in when TTL override not configured.""" - mock_time.return_value = 1_000.0 - mock_response = Mock() - mock_response.json.return_value = { - "access_token": "token-expires-in", - "expires_in": 120, - } - mock_response.raise_for_status = Mock() - mock_request.return_value = mock_response - - llm = LLMWithGateway( - model="gpt-4", - api_key=SecretStr("key"), - base_url="https://gateway.example.com/llm/v1", - gateway_auth_url="https://gateway.example.com/oauth/token", - gateway_auth_headers={"X-Client-Id": "client"}, - gateway_auth_body={"grant_type": "client_credentials"}, - usage_id="ttl-expires-in-test", - num_retries=0, - ) +class TestHeaderInjection: + """Ensure custom headers are merged into completion calls.""" - token = llm._ensure_gateway_token() - - assert token == "token-expires-in" - assert llm._gateway_token_expiry == pytest.approx(1_120.0, abs=0.1) - - @patch("openhands.sdk.llm.llm_with_gateway.time.time") - @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") - def test_token_ttl_falls_back_to_default(self, mock_request, mock_time) -> None: - """Missing expires_in should fall back to default TTL.""" - mock_time.return_value = 2_000.0 - - mock_response = Mock() - mock_response.json.return_value = {"access_token": "token-default"} - mock_response.raise_for_status = Mock() - mock_request.return_value = mock_response - - llm = LLMWithGateway( - model="gpt-4", - api_key=SecretStr("key"), - base_url="https://gateway.example.com/llm/v1", - gateway_auth_url="https://gateway.example.com/oauth/token", - usage_id="ttl-default-test", - num_retries=0, + @patch("openhands.sdk.llm.llm.litellm_completion") + def test_headers_passed_to_litellm(self, mock_completion) -> None: + llm = create_llm(custom_headers={"X-Test": "value"}) + mock_completion.return_value = create_mock_litellm_response( + content="Hello!" ) - llm._ensure_gateway_token() - - assert llm._gateway_token_expiry == pytest.approx(2_300.0, abs=0.1) + messages = [Message(role="user", content=[TextContent(text="Hi")])] + response = llm.completion(messages) - @patch("openhands.sdk.llm.llm_with_gateway.time.time") - @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") - def test_token_ttl_prefers_configured_override( - self, mock_request, mock_time - ) -> None: - """Configured TTL should override expires_in from response.""" - mock_time.return_value = 3_000.0 + mock_completion.assert_called_once() + headers = mock_completion.call_args.kwargs["extra_headers"] + assert headers["X-Test"] == "value" - mock_response = Mock() - mock_response.json.return_value = { - "access_token": "token-override", - "expires_in": 3_600, - } - mock_response.raise_for_status = Mock() - mock_request.return_value = mock_response + # Ensure we still surface the underlying content. + assert response.message.content[0].text == "Hello!" - llm = LLMWithGateway( - model="gpt-4", - api_key=SecretStr("key"), - base_url="https://gateway.example.com/llm/v1", - gateway_auth_url="https://gateway.example.com/oauth/token", - gateway_auth_body={"grant_type": "client_credentials"}, - gateway_auth_token_ttl=45, - usage_id="ttl-override-test", - num_retries=0, + @patch("openhands.sdk.llm.llm.litellm_completion") + def test_headers_merge_existing_extra_headers(self, mock_completion) -> None: + llm = create_llm(custom_headers={"X-Test": "value"}) + mock_completion.return_value = create_mock_litellm_response( + content="Merged!" ) - llm._ensure_gateway_token() + messages = [Message(role="user", content=[TextContent(text="Hi")])] + llm.completion(messages, extra_headers={"Existing": "1"}) - assert llm._gateway_token_expiry == pytest.approx(3_045.0, abs=0.1) + headers = mock_completion.call_args.kwargs["extra_headers"] + assert headers["X-Test"] == "value" + assert headers["Existing"] == "1" class TestTemplateReplacement: - """Test template variable replacement.""" + """Template replacement should render against LLM fields.""" - def test_render_templates_string(self, gateway_llm): - """Test template replacement in strings.""" + def test_render_templates_string(self) -> None: + llm = create_llm() template = "Model: {{llm_model}}, URL: {{llm_base_url}}" - result = gateway_llm._render_templates(template) + result = llm._render_templates(template) assert result == "Model: gemini-1.5-flash, URL: https://gateway.example.com/v1" - def test_render_templates_dict(self, gateway_llm): - """Test template replacement in dictionaries.""" + def test_render_templates_dict(self) -> None: + llm = create_llm() template = { "model": "{{llm_model}}", "endpoint": "{{llm_base_url}}/chat", "nested": {"key": "{{llm_model}}"}, } - result = gateway_llm._render_templates(template) + result = llm._render_templates(template) assert result["model"] == "gemini-1.5-flash" assert result["endpoint"] == "https://gateway.example.com/v1/chat" assert result["nested"]["key"] == "gemini-1.5-flash" - def test_render_templates_list(self, gateway_llm): - """Test template replacement in lists.""" + def test_render_templates_list(self) -> None: + llm = create_llm() template = ["{{llm_model}}", {"url": "{{llm_base_url}}"}] - result = gateway_llm._render_templates(template) + result = llm._render_templates(template) assert result[0] == "gemini-1.5-flash" assert result[1]["url"] == "https://gateway.example.com/v1" - def test_render_templates_with_api_key(self, gateway_llm): - """Test template replacement includes API key.""" + def test_render_templates_with_api_key(self) -> None: + llm = create_llm() template = "Key: {{llm_api_key}}" - result = gateway_llm._render_templates(template) + result = llm._render_templates(template) assert result == "Key: test-api-key" - def test_render_templates_no_base_url(self): - """Test template replacement when base_url is not set.""" + def test_render_templates_no_base_url(self) -> None: llm = LLMWithGateway( model="gpt-4", api_key=SecretStr("key"), @@ -283,159 +113,3 @@ def test_render_templates_no_base_url(self): template = "URL: {{llm_base_url}}" result = llm._render_templates(template) assert result == "URL: " - - -class TestPathExtraction: - """Test nested path extraction from OAuth responses.""" - - def test_extract_simple_path(self): - """Test extraction from simple path.""" - payload = {"access_token": "token123"} - result = LLMWithGateway._extract_from_path(payload, "access_token") - assert result == "token123" - - def test_extract_nested_path(self): - """Test extraction from nested path.""" - payload = {"data": {"auth": {"token": "token456"}}} - result = LLMWithGateway._extract_from_path(payload, "data.auth.token") - assert result == "token456" - - def test_extract_from_array(self): - """Test extraction from array.""" - payload = {"tokens": [{"value": "token1"}, {"value": "token2"}]} - result = LLMWithGateway._extract_from_path(payload, "tokens.0.value") - assert result == "token1" - - def test_extract_empty_path(self): - """Test extraction with empty path returns root.""" - payload = {"key": "value"} - result = LLMWithGateway._extract_from_path(payload, "") - assert result == payload - - def test_extract_missing_key(self): - """Test extraction fails for missing key.""" - payload = {"other": "value"} - with pytest.raises(ValueError, match="not found"): - LLMWithGateway._extract_from_path(payload, "missing") - - def test_extract_invalid_array_index(self): - """Test extraction fails for invalid array index.""" - payload = {"items": ["a", "b"]} - with pytest.raises(ValueError, match="Invalid list index"): - LLMWithGateway._extract_from_path(payload, "items.invalid") - - def test_extract_array_index_out_of_range(self): - """Test extraction fails for out of range index.""" - payload = {"items": ["a", "b"]} - with pytest.raises(ValueError, match="out of range"): - LLMWithGateway._extract_from_path(payload, "items.5") - - -class TestGatewayIntegration: - """Integration tests for gateway functionality.""" - - @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") - @patch("openhands.sdk.llm.llm.litellm_completion") - def test_full_gateway_flow( - self, mock_completion, mock_request, mock_gateway_auth_response - ): - """Test complete flow: OAuth -> LLM request with headers.""" - # Setup gateway LLM - llm = LLMWithGateway( - model="gpt-4", - base_url="https://gateway.example.com/llm/v1", - gateway_auth_url="https://gateway.example.com/oauth/token", - gateway_auth_headers={ - "X-Client-Id": "client123", - "X-Client-Secret": "secret456", - }, - gateway_auth_body={"grant_type": "client_credentials"}, - custom_headers={"X-Gateway-Key": "gateway789"}, - usage_id="integration-test", - num_retries=0, - ) - - # Mock OAuth response - mock_oauth_response = Mock() - mock_oauth_response.json.return_value = mock_gateway_auth_response - mock_oauth_response.raise_for_status = Mock() - mock_request.return_value = mock_oauth_response - - # Mock LLM completion - mock_completion.return_value = create_mock_litellm_response( - content="Hello from gateway!" - ) - - # Make completion request - messages = [Message(role="user", content=[TextContent(text="Hello")])] - response = llm.completion(messages) - - # Verify OAuth was called - assert mock_request.call_count == 1 - oauth_call = mock_request.call_args - assert oauth_call[0][0] == "POST" - assert oauth_call[0][1] == "https://gateway.example.com/oauth/token" - - # Verify LLM completion was called with correct headers - assert mock_completion.call_count == 1 - completion_kwargs = mock_completion.call_args[1] - headers = completion_kwargs["extra_headers"] - assert headers["Authorization"] == "Bearer test-gateway-token-12345" - assert headers["X-Gateway-Key"] == "gateway789" - - # Verify response - assert isinstance(response.message.content[0], TextContent) - assert response.message.content[0].text == "Hello from gateway!" - - @patch("openhands.sdk.llm.llm_with_gateway.httpx.request") - @patch("openhands.sdk.llm.llm.litellm_completion") - def test_gateway_headers_merge_with_extended_thinking( - self, mock_completion, mock_request, mock_gateway_auth_response - ): - """Gateway headers should merge with extended thinking defaults.""" - mock_oauth_response = Mock() - mock_oauth_response.json.return_value = mock_gateway_auth_response - mock_oauth_response.raise_for_status = Mock() - mock_request.return_value = mock_oauth_response - - mock_completion.return_value = create_mock_litellm_response( - content="extended thinking response" - ) - - llm = LLMWithGateway( - model="claude-sonnet-4-5-latest", - api_key=SecretStr("test-api-key"), - base_url="https://gateway.example.com/llm/v1", - gateway_auth_url="https://gateway.example.com/oauth/token", - gateway_auth_headers={ - "X-Client-Id": "client123", - "X-Client-Secret": "secret456", - }, - gateway_auth_body={"grant_type": "client_credentials"}, - custom_headers={"X-Gateway-Key": "gateway789"}, - extended_thinking_budget=512, - usage_id="extended-thinking-test", - num_retries=0, - ) - - messages = [Message(role="user", content=[TextContent(text="Hello")])] - llm.completion(messages) - - completion_kwargs = mock_completion.call_args[1] - headers = completion_kwargs["extra_headers"] - assert headers["Authorization"] == "Bearer test-gateway-token-12345" - assert headers["X-Gateway-Key"] == "gateway789" - assert headers["anthropic-beta"] == "interleaved-thinking-2025-05-14" - - def test_gateway_disabled_when_no_config(self): - """Test that gateway logic is skipped when not configured.""" - llm = LLMWithGateway( - model="gpt-4", - api_key=SecretStr("key"), - usage_id="no-gateway-test", - ) - - # Should not fail, just act like regular LLM - kwargs: dict[str, Any] = {} - llm._prepare_gateway_call(kwargs) - assert "extra_headers" not in kwargs From 00e9ae8fe225e5af65340d476bf95ae69da855ec Mon Sep 17 00:00:00 2001 From: Alona King Date: Fri, 31 Oct 2025 15:10:35 -0400 Subject: [PATCH 07/10] fix pre-commit --- openhands-sdk/openhands/sdk/llm/llm_with_gateway.py | 1 + tests/sdk/llm/test_llm_with_gateway.py | 12 +++++------- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py index d6b63912f7..c03c10ea97 100644 --- a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py +++ b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py @@ -13,6 +13,7 @@ from openhands.sdk.llm.llm import LLM + __all__ = ["LLMWithGateway"] diff --git a/tests/sdk/llm/test_llm_with_gateway.py b/tests/sdk/llm/test_llm_with_gateway.py index 9f239b8aaf..1265fd3914 100644 --- a/tests/sdk/llm/test_llm_with_gateway.py +++ b/tests/sdk/llm/test_llm_with_gateway.py @@ -41,9 +41,7 @@ class TestHeaderInjection: @patch("openhands.sdk.llm.llm.litellm_completion") def test_headers_passed_to_litellm(self, mock_completion) -> None: llm = create_llm(custom_headers={"X-Test": "value"}) - mock_completion.return_value = create_mock_litellm_response( - content="Hello!" - ) + mock_completion.return_value = create_mock_litellm_response(content="Hello!") messages = [Message(role="user", content=[TextContent(text="Hi")])] response = llm.completion(messages) @@ -53,14 +51,14 @@ def test_headers_passed_to_litellm(self, mock_completion) -> None: assert headers["X-Test"] == "value" # Ensure we still surface the underlying content. - assert response.message.content[0].text == "Hello!" + content = response.message.content[0] + assert isinstance(content, TextContent) + assert content.text == "Hello!" @patch("openhands.sdk.llm.llm.litellm_completion") def test_headers_merge_existing_extra_headers(self, mock_completion) -> None: llm = create_llm(custom_headers={"X-Test": "value"}) - mock_completion.return_value = create_mock_litellm_response( - content="Merged!" - ) + mock_completion.return_value = create_mock_litellm_response(content="Merged!") messages = [Message(role="user", content=[TextContent(text="Hi")])] llm.completion(messages, extra_headers={"Existing": "1"}) From c52bb3e11ad9cd369e41eb02e4d512e2596625bd Mon Sep 17 00:00:00 2001 From: Alona King Date: Fri, 31 Oct 2025 15:21:32 -0400 Subject: [PATCH 08/10] nit --- openhands-sdk/openhands/sdk/llm/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 8c2207764b..84ce4e3201 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -616,7 +616,7 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: assert isinstance(ret, ResponsesAPIResponse), ( f"Expected ResponsesAPIResponse, got {type(ret)}" ) - # telemetry (latency, cost). Token usage mapping we handle after. + # telemetry (latency, cost). Token usage handled after. assert self._telemetry is not None self._telemetry.on_response(ret) return ret From 4b9b99dc4a53525020b6b816f8d1b05b015a2ec8 Mon Sep 17 00:00:00 2001 From: Alona King Date: Sat, 1 Nov 2025 08:39:06 -0400 Subject: [PATCH 09/10] hardening gateway header/tls handling --- openhands-sdk/openhands/sdk/llm/llm.py | 198 ++++++++++-------- .../openhands/sdk/llm/llm_with_gateway.py | 103 +++++---- tests/sdk/llm/test_llm_with_gateway.py | 42 ++++ 3 files changed, 213 insertions(+), 130 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 84ce4e3201..ded2e62639 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -454,15 +454,19 @@ def completion( has_tools_flag = bool(cc_tools) and use_native_fc # Behavior-preserving: delegate to select_chat_options call_kwargs = select_chat_options(self, kwargs, has_tools=has_tools_flag) + call_kwargs = self._prepare_request_kwargs(call_kwargs) # 4) optional request logging context (kept small) assert self._telemetry is not None log_ctx = None if self._telemetry.log_enabled: + sanitized_kwargs = { + k: v for k, v in call_kwargs.items() if k != "extra_headers" + } log_ctx = { "messages": formatted_messages[:], # already simple dicts "tools": tools, - "kwargs": {k: v for k, v in call_kwargs.items()}, + "kwargs": sanitized_kwargs, "context_window": self.max_input_tokens or 0, } if tools and not use_native_fc: @@ -481,7 +485,7 @@ def completion( def _one_attempt(**retry_kwargs) -> ModelResponse: assert self._telemetry is not None # Merge retry-modified kwargs (like temperature) with call_kwargs - final_kwargs = {**call_kwargs, **retry_kwargs} + final_kwargs = self._prepare_request_kwargs({**call_kwargs, **retry_kwargs}) resp = self._transport_call(messages=formatted_messages, **final_kwargs) raw_resp: ModelResponse | None = None if use_mock_tools: @@ -565,16 +569,20 @@ def responses( call_kwargs = select_responses_options( self, kwargs, include=include, store=store ) + call_kwargs = self._prepare_request_kwargs(call_kwargs) # Optional request logging assert self._telemetry is not None log_ctx = None if self._telemetry.log_enabled: + sanitized_kwargs = { + k: v for k, v in call_kwargs.items() if k != "extra_headers" + } log_ctx = { "llm_path": "responses", "input": input_items[:], "tools": tools, - "kwargs": {k: v for k, v in call_kwargs.items()}, + "kwargs": sanitized_kwargs, "context_window": self.max_input_tokens or 0, } self._telemetry.on_request(log_ctx=log_ctx) @@ -589,37 +597,37 @@ def responses( retry_listener=self.retry_listener, ) def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: - final_kwargs = {**call_kwargs, **retry_kwargs} + final_kwargs = self._prepare_request_kwargs({**call_kwargs, **retry_kwargs}) with self._litellm_modify_params_ctx(self.modify_params): - with self._litellm_ssl_verify_ctx(): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - typed_input: ResponseInputParam | str = ( - cast(ResponseInputParam, input_items) if input_items else "" - ) - ret = litellm_responses( - model=self.model, - input=typed_input, - instructions=instructions, - tools=resp_tools, - api_key=self.api_key.get_secret_value() - if self.api_key - else None, - api_base=self.base_url, - api_version=self.api_version, - custom_llm_provider=self.custom_llm_provider, - timeout=self.timeout, - drop_params=self.drop_params, - seed=self.seed, - **final_kwargs, - ) - assert isinstance(ret, ResponsesAPIResponse), ( - f"Expected ResponsesAPIResponse, got {type(ret)}" - ) - # telemetry (latency, cost). Token usage handled after. - assert self._telemetry is not None - self._telemetry.on_response(ret) - return ret + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + typed_input: ResponseInputParam | str = ( + cast(ResponseInputParam, input_items) if input_items else "" + ) + ret = litellm_responses( + model=self.model, + input=typed_input, + instructions=instructions, + tools=resp_tools, + api_key=self.api_key.get_secret_value() + if self.api_key + else None, + api_base=self.base_url, + api_version=self.api_version, + custom_llm_provider=self.custom_llm_provider, + timeout=self.timeout, + ssl_verify=self.ssl_verify, + drop_params=self.drop_params, + seed=self.seed, + **final_kwargs, + ) + assert isinstance(ret, ResponsesAPIResponse), ( + f"Expected ResponsesAPIResponse, got {type(ret)}" + ) + # telemetry (latency, cost). Token usage handled after. + assert self._telemetry is not None + self._telemetry.on_response(ret) + return ret try: resp: ResponsesAPIResponse = _one_attempt() @@ -647,50 +655,52 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: # ========================================================================= # Transport + helpers # ========================================================================= + def _prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any]: + """Hook for subclasses to adjust final LiteLLM kwargs.""" + + return call_kwargs + def _transport_call( self, *, messages: list[dict[str, Any]], **kwargs ) -> ModelResponse: # litellm.modify_params is GLOBAL; guard it for thread-safety with self._litellm_modify_params_ctx(self.modify_params): - with self._litellm_ssl_verify_ctx(): - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", category=DeprecationWarning, module="httpx.*" - ) - warnings.filterwarnings( - "ignore", - message=r".*content=.*upload.*", - category=DeprecationWarning, - ) - warnings.filterwarnings( - "ignore", - message=r"There is no current event loop", - category=DeprecationWarning, - ) - warnings.filterwarnings( - "ignore", - category=UserWarning, - ) - # Some providers need renames handled in _normalize_call_kwargs. - ret = litellm_completion( - model=self.model, - api_key=self.api_key.get_secret_value() - if self.api_key - else None, - base_url=self.base_url, - api_version=self.api_version, - custom_llm_provider=self.custom_llm_provider, - timeout=self.timeout, - ssl_verify=self.ssl_verify, - drop_params=self.drop_params, - seed=self.seed, - messages=messages, - **kwargs, - ) - assert isinstance(ret, ModelResponse), ( - f"Expected ModelResponse, got {type(ret)}" - ) - return ret + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=DeprecationWarning, module="httpx.*" + ) + warnings.filterwarnings( + "ignore", + message=r".*content=.*upload.*", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"There is no current event loop", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + category=UserWarning, + ) + # Some providers need renames handled in _normalize_call_kwargs. + ret = litellm_completion( + model=self.model, + api_key=self.api_key.get_secret_value() if self.api_key else None, + base_url=self.base_url, + api_version=self.api_version, + custom_llm_provider=self.custom_llm_provider, + timeout=self.timeout, + ssl_verify=self.ssl_verify, + drop_params=self.drop_params, + seed=self.seed, + messages=messages, + **kwargs, + ) + assert isinstance(ret, ModelResponse), ( + f"Expected ModelResponse, got {type(ret)}" + ) + return ret @contextmanager def _litellm_modify_params_ctx(self, flag: bool): @@ -701,18 +711,6 @@ def _litellm_modify_params_ctx(self, flag: bool): finally: litellm.modify_params = old - @contextmanager - def _litellm_ssl_verify_ctx(self): - if self.ssl_verify is None: - yield - return - old = getattr(litellm, "ssl_verify", None) - try: - litellm.ssl_verify = self.ssl_verify - yield - finally: - litellm.ssl_verify = old - # ========================================================================= # Capabilities, formatting, and info # ========================================================================= @@ -955,6 +953,7 @@ def load_from_json(cls, json_path: str) -> LLM: @classmethod def load_from_env(cls, prefix: str = "LLM_") -> LLM: TRUTHY = {"true", "1", "yes", "on"} + FALSY = {"false", "0", "no", "off"} def _unwrap_type(t: Any) -> Any: origin = get_origin(t) @@ -963,20 +962,33 @@ def _unwrap_type(t: Any) -> Any: args = [a for a in get_args(t) if a is not type(None)] return args[0] if args else t - def _cast_value(raw: str, t: Any) -> Any: - t = _unwrap_type(t) + def _cast_value(field_name: str, raw: str, annotation: Any) -> Any: + stripped = raw.strip() + lowered = stripped.lower() + if field_name == "ssl_verify": + if lowered in TRUTHY: + return True + if lowered in FALSY: + return False + return stripped + + t = _unwrap_type(annotation) if t is SecretStr: - return SecretStr(raw) + return SecretStr(stripped) if t is bool: - return raw.lower() in TRUTHY + if lowered in TRUTHY: + return True + if lowered in FALSY: + return False + return None if t is int: try: - return int(raw) + return int(stripped) except ValueError: return None if t is float: try: - return float(raw) + return float(stripped) except ValueError: return None origin = get_origin(t) @@ -984,10 +996,10 @@ def _cast_value(raw: str, t: Any) -> Any: isinstance(t, type) and issubclass(t, BaseModel) ): try: - return json.loads(raw) + return json.loads(stripped) except Exception: pass - return raw + return stripped data: dict[str, Any] = {} fields: dict[str, Any] = { @@ -1002,7 +1014,7 @@ def _cast_value(raw: str, t: Any) -> Any: field_name = key[len(prefix) :].lower() if field_name not in fields: continue - v = _cast_value(value, fields[field_name]) + v = _cast_value(field_name, value, fields[field_name]) if v is not None: data[field_name] = v return cls(**data) diff --git a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py index c03c10ea97..d6ad61ff55 100644 --- a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py +++ b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py @@ -6,22 +6,28 @@ from __future__ import annotations +from collections.abc import Mapping from typing import Any -from litellm.types.utils import ModelResponse from pydantic import Field from openhands.sdk.llm.llm import LLM +from openhands.sdk.logger import get_logger __all__ = ["LLMWithGateway"] +logger = get_logger(__name__) + + class LLMWithGateway(LLM): """LLM subclass with enterprise gateway support. Supports adding custom headers on each request with optional template - rendering against LLM attributes. + rendering against LLM attributes. If you include ``{{llm_api_key}}`` in a + header value, the decrypted API key is sent to the gateway—treat the + gateway as a trusted recipient and avoid logging those headers. """ custom_headers: dict[str, str] | None = Field( @@ -29,43 +35,66 @@ class LLMWithGateway(LLM): description="Custom headers to include with every LLM request.", ) - def responses(self, *args, **kwargs): - """Override to inject gateway headers before calling LiteLLM.""" - self._prepare_gateway_call(kwargs) - return super().responses(*args, **kwargs) - - def _transport_call( - self, *, messages: list[dict[str, Any]], **kwargs - ) -> ModelResponse: - """Inject gateway headers just before delegating to LiteLLM.""" - self._prepare_gateway_call(kwargs) - return super()._transport_call(messages=messages, **kwargs) - - def _prepare_gateway_call(self, call_kwargs: dict[str, Any]) -> None: - """Augment LiteLLM kwargs with gateway headers. + def _prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any]: + prepared = dict(super()._prepare_request_kwargs(call_kwargs)) - This method: - 1. Adds custom headers - 2. Performs basic template variable replacement - """ if not self.custom_headers: - return - - # Start with existing headers - headers: dict[str, str] = {} - existing_headers = call_kwargs.get("extra_headers") - if isinstance(existing_headers, dict): - headers.update(existing_headers) - - # Add custom headers (with template replacement) - if self.custom_headers: - rendered_headers = self._render_templates(self.custom_headers) - if isinstance(rendered_headers, dict): - headers.update(rendered_headers) - - # Set headers on the call - if headers: - call_kwargs["extra_headers"] = headers + return prepared + + rendered = self._render_templates(self.custom_headers) + if not isinstance(rendered, dict): + return prepared + + existing = prepared.get("extra_headers") + base_headers: dict[str, Any] + if isinstance(existing, Mapping): + base_headers = dict(existing) + elif existing is None: + base_headers = {} + else: + base_headers = {} + + merged, collisions = self._merge_headers(base_headers, rendered) + for header, old_val, new_val in collisions: + logger.warning( + "LLMWithGateway overriding header %s (existing=%r, new=%r)", + header, + old_val, + new_val, + ) + + if merged: + prepared["extra_headers"] = merged + + return prepared + + @staticmethod + def _merge_headers( + existing: dict[str, Any], new_headers: dict[str, Any] + ) -> tuple[dict[str, Any], list[tuple[str, Any, Any]]]: + """Merge header dictionaries case-insensitively. + + Returns the merged headers and a list of collisions where an existing + header was replaced with a different value. + """ + + merged = dict(existing) + lower_map = {k.lower(): k for k in merged} + collisions: list[tuple[str, Any, Any]] = [] + + for key, value in new_headers.items(): + lower = key.lower() + if lower in lower_map: + canonical = lower_map[lower] + old_value = merged[canonical] + if old_value != value: + collisions.append((canonical, old_value, value)) + merged[canonical] = value + else: + merged[key] = value + lower_map[lower] = key + + return merged, collisions def _render_templates(self, value: Any) -> Any: """Replace template variables in strings with actual values. diff --git a/tests/sdk/llm/test_llm_with_gateway.py b/tests/sdk/llm/test_llm_with_gateway.py index 1265fd3914..5da176af89 100644 --- a/tests/sdk/llm/test_llm_with_gateway.py +++ b/tests/sdk/llm/test_llm_with_gateway.py @@ -4,6 +4,9 @@ from unittest.mock import patch +from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse +from openai.types.responses.response_output_message import ResponseOutputMessage +from openai.types.responses.response_output_text import ResponseOutputText from pydantic import SecretStr from openhands.sdk.llm import LLMWithGateway, Message, TextContent @@ -22,6 +25,35 @@ def create_llm(custom_headers: dict[str, str] | None = None) -> LLMWithGateway: ) +def make_responses_api_response(text: str) -> ResponsesAPIResponse: + """Construct a minimal ResponsesAPIResponse for testing.""" + + message = ResponseOutputMessage.model_construct( + id="msg", + type="message", + role="assistant", + status="completed", + content=[ # type: ignore[arg-type] + ResponseOutputText(type="output_text", text=text, annotations=[]) + ], + ) + + usage = ResponseAPIUsage(input_tokens=1, output_tokens=1, total_tokens=2) + + return ResponsesAPIResponse( + id="resp", + created_at=0, + output=[message], # type: ignore[arg-type] + parallel_tool_calls=False, + tool_choice="auto", + top_p=None, + tools=[], + usage=usage, + instructions=None, + status="completed", + ) + + class TestInitialization: """Basic initialization behaviour.""" @@ -67,6 +99,16 @@ def test_headers_merge_existing_extra_headers(self, mock_completion) -> None: assert headers["X-Test"] == "value" assert headers["Existing"] == "1" + @patch("openhands.sdk.llm.llm.litellm_responses") + def test_responses_headers_passed_to_litellm(self, mock_responses) -> None: + llm = create_llm(custom_headers={"X-Test": "value"}) + mock_responses.return_value = make_responses_api_response("ok") + + llm.responses([Message(role="user", content=[TextContent(text="Hi")])]) + + headers = mock_responses.call_args.kwargs["extra_headers"] + assert headers["X-Test"] == "value" + class TestTemplateReplacement: """Template replacement should render against LLM fields.""" From aa373aa1b4a53cefed40ea1264fa1b84c9ff491b Mon Sep 17 00:00:00 2001 From: Alona King Date: Mon, 3 Nov 2025 10:30:46 -0500 Subject: [PATCH 10/10] remove rendering --- .../openhands/sdk/llm/llm_with_gateway.py | 47 ++----------------- tests/sdk/llm/test_llm_with_gateway.py | 45 ------------------ 2 files changed, 3 insertions(+), 89 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py index d6ad61ff55..bdbde418a3 100644 --- a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py +++ b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py @@ -24,10 +24,8 @@ class LLMWithGateway(LLM): """LLM subclass with enterprise gateway support. - Supports adding custom headers on each request with optional template - rendering against LLM attributes. If you include ``{{llm_api_key}}`` in a - header value, the decrypted API key is sent to the gateway—treat the - gateway as a trusted recipient and avoid logging those headers. + Supports adding static custom headers on each request. Take care not to include + raw secrets in headers unless the gateway is trusted and headers are never logged. """ custom_headers: dict[str, str] | None = Field( @@ -41,10 +39,6 @@ def _prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any] if not self.custom_headers: return prepared - rendered = self._render_templates(self.custom_headers) - if not isinstance(rendered, dict): - return prepared - existing = prepared.get("extra_headers") base_headers: dict[str, Any] if isinstance(existing, Mapping): @@ -54,7 +48,7 @@ def _prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any] else: base_headers = {} - merged, collisions = self._merge_headers(base_headers, rendered) + merged, collisions = self._merge_headers(base_headers, self.custom_headers) for header, old_val, new_val in collisions: logger.warning( "LLMWithGateway overriding header %s (existing=%r, new=%r)", @@ -95,38 +89,3 @@ def _merge_headers( lower_map[lower] = key return merged, collisions - - def _render_templates(self, value: Any) -> Any: - """Replace template variables in strings with actual values. - - Supports: - - {{llm_model}} -> self.model - - {{llm_base_url}} -> self.base_url - - {{llm_api_key}} -> self.api_key (if set) - - Args: - value: String, dict, list, or other value to render. - - Returns: - Value with templates replaced. - """ - if isinstance(value, str): - replacements: dict[str, str] = { - "{{llm_model}}": self.model, - "{{llm_base_url}}": self.base_url or "", - } - if self.api_key: - replacements["{{llm_api_key}}"] = self.api_key.get_secret_value() - - result = value - for placeholder, actual in replacements.items(): - result = result.replace(placeholder, actual) - return result - - if isinstance(value, dict): - return {k: self._render_templates(v) for k, v in value.items()} - - if isinstance(value, list): - return [self._render_templates(v) for v in value] - - return value diff --git a/tests/sdk/llm/test_llm_with_gateway.py b/tests/sdk/llm/test_llm_with_gateway.py index 5da176af89..94449c5f02 100644 --- a/tests/sdk/llm/test_llm_with_gateway.py +++ b/tests/sdk/llm/test_llm_with_gateway.py @@ -108,48 +108,3 @@ def test_responses_headers_passed_to_litellm(self, mock_responses) -> None: headers = mock_responses.call_args.kwargs["extra_headers"] assert headers["X-Test"] == "value" - - -class TestTemplateReplacement: - """Template replacement should render against LLM fields.""" - - def test_render_templates_string(self) -> None: - llm = create_llm() - template = "Model: {{llm_model}}, URL: {{llm_base_url}}" - result = llm._render_templates(template) - assert result == "Model: gemini-1.5-flash, URL: https://gateway.example.com/v1" - - def test_render_templates_dict(self) -> None: - llm = create_llm() - template = { - "model": "{{llm_model}}", - "endpoint": "{{llm_base_url}}/chat", - "nested": {"key": "{{llm_model}}"}, - } - result = llm._render_templates(template) - assert result["model"] == "gemini-1.5-flash" - assert result["endpoint"] == "https://gateway.example.com/v1/chat" - assert result["nested"]["key"] == "gemini-1.5-flash" - - def test_render_templates_list(self) -> None: - llm = create_llm() - template = ["{{llm_model}}", {"url": "{{llm_base_url}}"}] - result = llm._render_templates(template) - assert result[0] == "gemini-1.5-flash" - assert result[1]["url"] == "https://gateway.example.com/v1" - - def test_render_templates_with_api_key(self) -> None: - llm = create_llm() - template = "Key: {{llm_api_key}}" - result = llm._render_templates(template) - assert result == "Key: test-api-key" - - def test_render_templates_no_base_url(self) -> None: - llm = LLMWithGateway( - model="gpt-4", - api_key=SecretStr("key"), - usage_id="test", - ) - template = "URL: {{llm_base_url}}" - result = llm._render_templates(template) - assert result == "URL: "