Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 132 additions & 150 deletions src/connectors/openrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
import httpx

from src.connectors.openai import OpenAIConnector
from src.core.common.exceptions import (
AuthenticationError,
BackendError,
ConfigurationError,
ServiceUnavailableError,
)
from src.core.common.exceptions import AuthenticationError, ConfigurationError
from src.core.config.app_config import AppConfig
from src.core.domain.responses import ResponseEnvelope, StreamingResponseEnvelope
from src.core.interfaces.configuration_interface import IAppIdentityConfig
Expand Down Expand Up @@ -40,6 +35,7 @@ def __init__(
self.headers_provider: Callable[[Any, str], dict[str, str]] | None = None
self.key_name: str | None = None
self.api_keys: list[str] = []
self._health_check_enabled = False

def _build_openrouter_header_context(self) -> dict[str, str]:
"""Create a minimal context dictionary for header providers expecting config."""
Expand Down Expand Up @@ -156,19 +152,87 @@ def _try_provider_call(*args: Any) -> dict[str, str] | None:
code="missing_credentials",
)

@staticmethod
def _normalize_payload_value(value: Any) -> Any:
"""Normalize payload values to plain Python types."""

if hasattr(value, "model_dump") and callable(value.model_dump):
return value.model_dump() # type: ignore[no-any-return]
if isinstance(value, Mapping):
return {
key: OpenRouterBackend._normalize_payload_value(val)
for key, val in value.items()
}
if isinstance(value, (list, tuple, set)):
return [OpenRouterBackend._normalize_payload_value(item) for item in value]
return value

def _collect_openrouter_payload_fields(self, request: Any) -> dict[str, Any]:
"""Gather OpenRouter-specific parameters from the domain request."""

field_map: dict[str, Any] = {
"top_k": getattr(request, "top_k", None),
"repetition_penalty": getattr(request, "repetition_penalty", None),
"top_logprobs": getattr(request, "top_logprobs", None),
"min_p": getattr(request, "min_p", None),
"top_a": getattr(request, "top_a", None),
"reasoning_effort": getattr(request, "reasoning_effort", None),
"prediction": getattr(request, "prediction", None),
"transforms": getattr(request, "transforms", None),
"models": getattr(request, "models", None),
"route": getattr(request, "route", None),
"provider": getattr(request, "provider", None),
"response_format": getattr(request, "response_format", None),
}

normalized: dict[str, Any] = {}
for key, value in field_map.items():
if value is None:
continue
normalized[key] = self._normalize_payload_value(value)
return normalized

def get_headers(self, identity: IAppIdentityConfig | None = None) -> dict[str, str]:
if not self.headers_provider or not self.api_key:
if not self.api_key:
raise AuthenticationError(
message="OpenRouter headers provider or API key not set.",
message="OpenRouter API key not configured.",
code="missing_credentials",
)
headers = self._resolve_headers_from_provider()

headers: dict[str, str] = {}
if self.headers_provider is not None:
headers.update(self._resolve_headers_from_provider())

def _ensure_header(key: str, value: str) -> None:
for existing_key in headers.keys():
if existing_key.lower() == key.lower():
return
headers[key] = value

def _override_header(key: str, value: str) -> None:
normalized_key = key
for existing_key in list(headers.keys()):
if existing_key.lower() == key.lower():
normalized_key = existing_key
headers.pop(existing_key)
break
headers[normalized_key] = value

_ensure_header("Authorization", f"Bearer {self.api_key}")
_ensure_header("Content-Type", "application/json")

context = self._build_openrouter_header_context()
_ensure_header("HTTP-Referer", context["app_site_url"])
_ensure_header("X-Title", context["app_x_title"])

if identity is not None:
try:
identity_headers = identity.get_resolved_headers(None)
identity_headers = dict(identity_headers)
if identity_headers:
headers.update(identity_headers)
for key, value in identity_headers.items():
if isinstance(key, str) and isinstance(value, str):
_override_header(key, value)
except (AttributeError, TypeError, ValueError) as exc:
logger.error(
"Failed to resolve identity headers in get_headers()",
Expand All @@ -187,8 +251,12 @@ def get_headers(self, identity: IAppIdentityConfig | None = None) -> dict[str, s
message="Unexpected error resolving identity configuration",
details={"unexpected_error": str(exc)},
) from exc

logger.info(
f"OpenRouter headers: Authorization: Bearer {self.api_key[:20]}..., HTTP-Referer: {headers.get('HTTP-Referer', 'NOT_SET')}, X-Title: {headers.get('X-Title', 'NOT_SET')}"
"OpenRouter headers prepared: Authorization prefix=%s, HTTP-Referer=%s, X-Title=%s",
headers.get("Authorization", "")[:20],
headers.get("HTTP-Referer", "NOT_SET"),
headers.get("X-Title", "NOT_SET"),
)
return ensure_loop_guard_header(headers)

Expand Down Expand Up @@ -242,10 +310,6 @@ async def chat_completions( # type: ignore[override]
project: str | None = None,
**kwargs: Any,
) -> ResponseEnvelope | StreamingResponseEnvelope:
# request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest)
# (the frontend controller converts from frontend-specific format to domain format)
# Backends should ONLY convert FROM domain TO backend-specific format
# Type assertion: we know from architectural design that request_data is ChatRequest-like
from typing import cast

from src.core.domain.chat import CanonicalChatRequest, ChatRequest
Expand All @@ -255,153 +319,71 @@ async def chat_completions( # type: ignore[override]
f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. "
"Backend connectors should only receive domain-format requests."
)
# Cast to CanonicalChatRequest for mypy compatibility with translation service signature

domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data)

# Allow tests and callers to provide per-call OpenRouter settings via kwargs
headers_provider = kwargs.pop("openrouter_headers_provider", None)
key_name = kwargs.pop("key_name", None)
api_key = kwargs.pop("api_key", None)
api_base_url = kwargs.pop("openrouter_api_base_url", None)
headers_provider_override = kwargs.pop("openrouter_headers_provider", None)
key_name_override = kwargs.pop("key_name", None)
api_key_override = kwargs.pop("api_key", None)
explicit_api_base = kwargs.pop("openrouter_api_base_url", None)
api_base_kwarg = kwargs.pop("api_base_url", None)

original_headers_provider = self.headers_provider
original_key_name = self.key_name
original_api_key = self.api_key
original_api_base_url = self.api_base_url
if headers_provider_override is not None and not callable(
headers_provider_override
):
raise TypeError("openrouter_headers_provider must be callable if provided")

original_state = (
self.headers_provider,
self.key_name,
self.api_key,
self.api_base_url,
)

try:
if headers_provider is not None:
if headers_provider_override is not None:
self.headers_provider = cast(
Callable[[Any, str], dict[str, str]], headers_provider
Callable[[Any, str], dict[str, str]], headers_provider_override
)
if key_name is not None:
self.key_name = cast(str, key_name)
if api_key is not None:
self.api_key = cast(str, api_key)
if api_base_url:
self.api_base_url = cast(str, api_base_url)

# Compute explicit headers for this call and ensure the exact
# Authorization header and URL used by tests are passed to the
# parent's streaming/non-streaming implementation.
headers_override: dict[str, str] | None = None
if self.headers_provider:
try:
headers_override = dict(self._resolve_headers_from_provider())
except AuthenticationError:
headers_override = None
except Exception as exc:
logger.error(
"Unexpected error resolving headers from provider in chat_completions()",
exc_info=True,
)
raise BackendError(
message="Failed to resolve headers from provider",
backend_name="openrouter",
details={"provider_error": str(exc)},
) from exc

if headers_override is None:
headers_override = {}

if self.api_key:
headers_override.setdefault("Authorization", f"Bearer {self.api_key}")

if identity is not None:
try:
identity_headers = identity.get_resolved_headers(None)
if identity_headers:
headers_override.update(identity_headers)
except (AttributeError, TypeError, ValueError) as exc:
logger.error(
"Failed to resolve identity headers in chat_completions()",
exc_info=True,
)
raise ConfigurationError(
message="Failed to resolve identity configuration",
details={"identity_error": str(exc)},
) from exc
except Exception as exc:
logger.error(
"Unexpected error resolving identity headers in chat_completions()",
exc_info=True,
)
raise ConfigurationError(
message="Unexpected error resolving identity configuration",
details={"unexpected_error": str(exc)},
) from exc
if key_name_override is not None:
self.key_name = cast(str, key_name_override)
if api_key_override is not None:
self.api_key = cast(str, api_key_override)

if not headers_override:
headers_override = None
api_base_override = explicit_api_base or api_base_kwarg
if api_base_override:
self.api_base_url = cast(str, api_base_override)

# Determine the exact URL to call so tests that mock it see the
# same value. The parent expects `openai_url` kwarg for URL
# override; for OpenRouter we set it to our `api_base_url`.
call_kwargs = dict(kwargs)
call_kwargs["headers_override"] = headers_override
call_kwargs["openai_url"] = self.api_base_url

# Translate to a base payload using the shared hook so that
# processed_messages, effective_model and extra_body are applied
# consistently (and tests can patch _prepare_payload).
payload = await self._prepare_payload(
domain_request, processed_messages, effective_model
call_kwargs.setdefault("openai_url", self.api_base_url)

return await super().chat_completions(
request_data=domain_request,
processed_messages=processed_messages,
effective_model=effective_model,
identity=identity,
**call_kwargs,
)

# Add OpenRouter-specific parameters to the payload
if domain_request.top_k is not None:
payload["top_k"] = domain_request.top_k
if domain_request.seed is not None:
payload["seed"] = domain_request.seed
if domain_request.reasoning_effort is not None:
payload["reasoning_effort"] = domain_request.reasoning_effort

# Add frequency_penalty and presence_penalty if specified
if domain_request.frequency_penalty is not None:
payload["frequency_penalty"] = domain_request.frequency_penalty
if domain_request.presence_penalty is not None:
payload["presence_penalty"] = domain_request.presence_penalty

# Handle extra_body from the request (takes precedence)
if hasattr(domain_request, "extra_body") and domain_request.extra_body:
for key, value in domain_request.extra_body.items():
payload[key] = value

# Handle reasoning config
if hasattr(domain_request, "reasoning") and domain_request.reasoning:
payload["reasoning"] = domain_request.reasoning

# Manually call the appropriate handler from the parent class
api_base = call_kwargs.get("openai_url") or self.api_base_url
url = f"{api_base.rstrip('/')}/chat/completions"

if domain_request.stream:
stream_handle = await self._handle_streaming_response(
url,
payload,
headers_override,
domain_request.session_id or "",
"openai",
)
return StreamingResponseEnvelope(
content=stream_handle.iterator,
media_type="text/event-stream",
headers={},
cancel_callback=stream_handle.cancel_callback,
)
else:
return await self._handle_non_streaming_response(
url, payload, headers_override, domain_request.session_id or ""
)
except ServiceUnavailableError:
raise
except BackendError:
raise
finally:
self.headers_provider = original_headers_provider
self.key_name = original_key_name
self.api_key = original_api_key
self.api_base_url = original_api_base_url
self.headers_provider = original_state[0]
self.key_name = original_state[1]
self.api_key = original_state[2]
self.api_base_url = original_state[3]

async def _prepare_payload(
self,
request_data: "CanonicalChatRequest",
processed_messages: list[Any],
effective_model: str,
) -> dict[str, Any]:
payload = await super()._prepare_payload(
request_data, processed_messages, effective_model
)

for key, value in self._collect_openrouter_payload_fields(request_data).items():
payload.setdefault(key, value)

return payload


backend_registry.register_backend("openrouter", OpenRouterBackend)
Loading
Loading