diff --git a/src/connectors/openrouter.py b/src/connectors/openrouter.py index 75a6f011..d1e6764c 100644 --- a/src/connectors/openrouter.py +++ b/src/connectors/openrouter.py @@ -7,12 +7,7 @@ import httpx from src.connectors.openai import OpenAIConnector -from src.core.common.exceptions import ( - AuthenticationError, - BackendError, - ConfigurationError, - ServiceUnavailableError, -) +from src.core.common.exceptions import AuthenticationError, ConfigurationError from src.core.config.app_config import AppConfig from src.core.domain.responses import ResponseEnvelope, StreamingResponseEnvelope from src.core.interfaces.configuration_interface import IAppIdentityConfig @@ -40,6 +35,7 @@ def __init__( self.headers_provider: Callable[[Any, str], dict[str, str]] | None = None self.key_name: str | None = None self.api_keys: list[str] = [] + self._health_check_enabled = False def _build_openrouter_header_context(self) -> dict[str, str]: """Create a minimal context dictionary for header providers expecting config.""" @@ -156,19 +152,87 @@ def _try_provider_call(*args: Any) -> dict[str, str] | None: code="missing_credentials", ) + @staticmethod + def _normalize_payload_value(value: Any) -> Any: + """Normalize payload values to plain Python types.""" + + if hasattr(value, "model_dump") and callable(value.model_dump): + return value.model_dump() # type: ignore[no-any-return] + if isinstance(value, Mapping): + return { + key: OpenRouterBackend._normalize_payload_value(val) + for key, val in value.items() + } + if isinstance(value, (list, tuple, set)): + return [OpenRouterBackend._normalize_payload_value(item) for item in value] + return value + + def _collect_openrouter_payload_fields(self, request: Any) -> dict[str, Any]: + """Gather OpenRouter-specific parameters from the domain request.""" + + field_map: dict[str, Any] = { + "top_k": getattr(request, "top_k", None), + "repetition_penalty": getattr(request, "repetition_penalty", None), + "top_logprobs": getattr(request, "top_logprobs", None), + "min_p": getattr(request, "min_p", None), + "top_a": getattr(request, "top_a", None), + "reasoning_effort": getattr(request, "reasoning_effort", None), + "prediction": getattr(request, "prediction", None), + "transforms": getattr(request, "transforms", None), + "models": getattr(request, "models", None), + "route": getattr(request, "route", None), + "provider": getattr(request, "provider", None), + "response_format": getattr(request, "response_format", None), + } + + normalized: dict[str, Any] = {} + for key, value in field_map.items(): + if value is None: + continue + normalized[key] = self._normalize_payload_value(value) + return normalized + def get_headers(self, identity: IAppIdentityConfig | None = None) -> dict[str, str]: - if not self.headers_provider or not self.api_key: + if not self.api_key: raise AuthenticationError( - message="OpenRouter headers provider or API key not set.", + message="OpenRouter API key not configured.", code="missing_credentials", ) - headers = self._resolve_headers_from_provider() + + headers: dict[str, str] = {} + if self.headers_provider is not None: + headers.update(self._resolve_headers_from_provider()) + + def _ensure_header(key: str, value: str) -> None: + for existing_key in headers.keys(): + if existing_key.lower() == key.lower(): + return + headers[key] = value + + def _override_header(key: str, value: str) -> None: + normalized_key = key + for existing_key in list(headers.keys()): + if existing_key.lower() == key.lower(): + normalized_key = existing_key + headers.pop(existing_key) + break + headers[normalized_key] = value + + _ensure_header("Authorization", f"Bearer {self.api_key}") + _ensure_header("Content-Type", "application/json") + + context = self._build_openrouter_header_context() + _ensure_header("HTTP-Referer", context["app_site_url"]) + _ensure_header("X-Title", context["app_x_title"]) + if identity is not None: try: identity_headers = identity.get_resolved_headers(None) identity_headers = dict(identity_headers) if identity_headers: - headers.update(identity_headers) + for key, value in identity_headers.items(): + if isinstance(key, str) and isinstance(value, str): + _override_header(key, value) except (AttributeError, TypeError, ValueError) as exc: logger.error( "Failed to resolve identity headers in get_headers()", @@ -187,8 +251,12 @@ def get_headers(self, identity: IAppIdentityConfig | None = None) -> dict[str, s message="Unexpected error resolving identity configuration", details={"unexpected_error": str(exc)}, ) from exc + logger.info( - f"OpenRouter headers: Authorization: Bearer {self.api_key[:20]}..., HTTP-Referer: {headers.get('HTTP-Referer', 'NOT_SET')}, X-Title: {headers.get('X-Title', 'NOT_SET')}" + "OpenRouter headers prepared: Authorization prefix=%s, HTTP-Referer=%s, X-Title=%s", + headers.get("Authorization", "")[:20], + headers.get("HTTP-Referer", "NOT_SET"), + headers.get("X-Title", "NOT_SET"), ) return ensure_loop_guard_header(headers) @@ -242,10 +310,6 @@ async def chat_completions( # type: ignore[override] project: str | None = None, **kwargs: Any, ) -> ResponseEnvelope | StreamingResponseEnvelope: - # request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest) - # (the frontend controller converts from frontend-specific format to domain format) - # Backends should ONLY convert FROM domain TO backend-specific format - # Type assertion: we know from architectural design that request_data is ChatRequest-like from typing import cast from src.core.domain.chat import CanonicalChatRequest, ChatRequest @@ -255,153 +319,71 @@ async def chat_completions( # type: ignore[override] f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. " "Backend connectors should only receive domain-format requests." ) - # Cast to CanonicalChatRequest for mypy compatibility with translation service signature + domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data) - # Allow tests and callers to provide per-call OpenRouter settings via kwargs - headers_provider = kwargs.pop("openrouter_headers_provider", None) - key_name = kwargs.pop("key_name", None) - api_key = kwargs.pop("api_key", None) - api_base_url = kwargs.pop("openrouter_api_base_url", None) + headers_provider_override = kwargs.pop("openrouter_headers_provider", None) + key_name_override = kwargs.pop("key_name", None) + api_key_override = kwargs.pop("api_key", None) + explicit_api_base = kwargs.pop("openrouter_api_base_url", None) + api_base_kwarg = kwargs.pop("api_base_url", None) - original_headers_provider = self.headers_provider - original_key_name = self.key_name - original_api_key = self.api_key - original_api_base_url = self.api_base_url + if headers_provider_override is not None and not callable( + headers_provider_override + ): + raise TypeError("openrouter_headers_provider must be callable if provided") + + original_state = ( + self.headers_provider, + self.key_name, + self.api_key, + self.api_base_url, + ) try: - if headers_provider is not None: + if headers_provider_override is not None: self.headers_provider = cast( - Callable[[Any, str], dict[str, str]], headers_provider + Callable[[Any, str], dict[str, str]], headers_provider_override ) - if key_name is not None: - self.key_name = cast(str, key_name) - if api_key is not None: - self.api_key = cast(str, api_key) - if api_base_url: - self.api_base_url = cast(str, api_base_url) - - # Compute explicit headers for this call and ensure the exact - # Authorization header and URL used by tests are passed to the - # parent's streaming/non-streaming implementation. - headers_override: dict[str, str] | None = None - if self.headers_provider: - try: - headers_override = dict(self._resolve_headers_from_provider()) - except AuthenticationError: - headers_override = None - except Exception as exc: - logger.error( - "Unexpected error resolving headers from provider in chat_completions()", - exc_info=True, - ) - raise BackendError( - message="Failed to resolve headers from provider", - backend_name="openrouter", - details={"provider_error": str(exc)}, - ) from exc - - if headers_override is None: - headers_override = {} - - if self.api_key: - headers_override.setdefault("Authorization", f"Bearer {self.api_key}") - - if identity is not None: - try: - identity_headers = identity.get_resolved_headers(None) - if identity_headers: - headers_override.update(identity_headers) - except (AttributeError, TypeError, ValueError) as exc: - logger.error( - "Failed to resolve identity headers in chat_completions()", - exc_info=True, - ) - raise ConfigurationError( - message="Failed to resolve identity configuration", - details={"identity_error": str(exc)}, - ) from exc - except Exception as exc: - logger.error( - "Unexpected error resolving identity headers in chat_completions()", - exc_info=True, - ) - raise ConfigurationError( - message="Unexpected error resolving identity configuration", - details={"unexpected_error": str(exc)}, - ) from exc + if key_name_override is not None: + self.key_name = cast(str, key_name_override) + if api_key_override is not None: + self.api_key = cast(str, api_key_override) - if not headers_override: - headers_override = None + api_base_override = explicit_api_base or api_base_kwarg + if api_base_override: + self.api_base_url = cast(str, api_base_override) - # Determine the exact URL to call so tests that mock it see the - # same value. The parent expects `openai_url` kwarg for URL - # override; for OpenRouter we set it to our `api_base_url`. call_kwargs = dict(kwargs) - call_kwargs["headers_override"] = headers_override - call_kwargs["openai_url"] = self.api_base_url - - # Translate to a base payload using the shared hook so that - # processed_messages, effective_model and extra_body are applied - # consistently (and tests can patch _prepare_payload). - payload = await self._prepare_payload( - domain_request, processed_messages, effective_model + call_kwargs.setdefault("openai_url", self.api_base_url) + + return await super().chat_completions( + request_data=domain_request, + processed_messages=processed_messages, + effective_model=effective_model, + identity=identity, + **call_kwargs, ) - - # Add OpenRouter-specific parameters to the payload - if domain_request.top_k is not None: - payload["top_k"] = domain_request.top_k - if domain_request.seed is not None: - payload["seed"] = domain_request.seed - if domain_request.reasoning_effort is not None: - payload["reasoning_effort"] = domain_request.reasoning_effort - - # Add frequency_penalty and presence_penalty if specified - if domain_request.frequency_penalty is not None: - payload["frequency_penalty"] = domain_request.frequency_penalty - if domain_request.presence_penalty is not None: - payload["presence_penalty"] = domain_request.presence_penalty - - # Handle extra_body from the request (takes precedence) - if hasattr(domain_request, "extra_body") and domain_request.extra_body: - for key, value in domain_request.extra_body.items(): - payload[key] = value - - # Handle reasoning config - if hasattr(domain_request, "reasoning") and domain_request.reasoning: - payload["reasoning"] = domain_request.reasoning - - # Manually call the appropriate handler from the parent class - api_base = call_kwargs.get("openai_url") or self.api_base_url - url = f"{api_base.rstrip('/')}/chat/completions" - - if domain_request.stream: - stream_handle = await self._handle_streaming_response( - url, - payload, - headers_override, - domain_request.session_id or "", - "openai", - ) - return StreamingResponseEnvelope( - content=stream_handle.iterator, - media_type="text/event-stream", - headers={}, - cancel_callback=stream_handle.cancel_callback, - ) - else: - return await self._handle_non_streaming_response( - url, payload, headers_override, domain_request.session_id or "" - ) - except ServiceUnavailableError: - raise - except BackendError: - raise finally: - self.headers_provider = original_headers_provider - self.key_name = original_key_name - self.api_key = original_api_key - self.api_base_url = original_api_base_url + self.headers_provider = original_state[0] + self.key_name = original_state[1] + self.api_key = original_state[2] + self.api_base_url = original_state[3] + + async def _prepare_payload( + self, + request_data: "CanonicalChatRequest", + processed_messages: list[Any], + effective_model: str, + ) -> dict[str, Any]: + payload = await super()._prepare_payload( + request_data, processed_messages, effective_model + ) + + for key, value in self._collect_openrouter_payload_fields(request_data).items(): + payload.setdefault(key, value) + + return payload backend_registry.register_backend("openrouter", OpenRouterBackend) diff --git a/src/core/domain/chat.py b/src/core/domain/chat.py index de949c58..947564cb 100644 --- a/src/core/domain/chat.py +++ b/src/core/domain/chat.py @@ -1,213 +1,223 @@ -from collections.abc import Sequence -from typing import Any, TypeVar - -from pydantic import Field, field_validator - -from src.core.domain.base import ValueObject -from src.core.interfaces.model_bases import DomainModel - -# Define a type variable for generic methods -T = TypeVar("T", bound=DomainModel) - - -# For multimodal content parts -class MessageContentPartText(DomainModel): - """Represents a text content part in a multimodal message.""" - - type: str = "text" - text: str - - -class ImageURL(DomainModel): - """Specifies the URL and optional detail for an image in a multimodal message.""" - - # Should be a data URI (e.g., "data:image/jpeg;base64,...") or public URL - url: str - detail: str | None = Field(None, examples=["auto", "low", "high"]) - - -class MessageContentPartImage(DomainModel): - """Represents an image content part in a multimodal message.""" - - type: str = "image_url" - image_url: ImageURL - - -# Extend with other multimodal types as needed (e.g., audio, video file, documents) -# For now, text and image are common starting points. -MessageContentPart = MessageContentPartText | MessageContentPartImage -"""Type alias for possible content parts in a multimodal message.""" - - -class FunctionCall(DomainModel): - """Represents a function call within a tool call.""" - - name: str - arguments: str - - -class ToolCall(DomainModel): - """Represents a tool call in a chat completion response.""" - - id: str - type: str = "function" - function: FunctionCall - - -class FunctionDefinition(DomainModel): - """Represents a function definition for tool calling.""" - - name: str - description: str | None = None - parameters: dict[str, Any] | None = None - - -class ToolDefinition(DomainModel): - """Represents a tool definition in a chat completion request.""" - - type: str = "function" - function: FunctionDefinition - - @field_validator("function", mode="before") - @classmethod - def ensure_function_is_dict(cls, v: Any) -> dict[str, Any] | FunctionDefinition: - # Accept either a FunctionDefinition or a ToolDefinition/FunctionDefinition instance - # and normalize to a dict for ChatRequest validation - if isinstance(v, FunctionDefinition): - return v.model_dump() - # If v is already a dict, return it as is - if isinstance(v, dict): - return v - # If v is something else, try to convert it to a dict - # This should handle cases where v is a dict-like object - try: - return dict(v) # type: ignore - except (TypeError, ValueError): - # If we can't convert to dict, raise a ValueError to let Pydantic handle the error properly - raise ValueError(f"Cannot convert {type(v)} to dict or FunctionDefinition") - - -class ChatMessage(DomainModel): - """ - A chat message in a conversation. - """ - - role: str - content: str | Sequence[MessageContentPart] | None = None - name: str | None = None - tool_calls: list[ToolCall] | None = None - tool_call_id: str | None = None - metadata: dict[str, Any] | None = None - - def to_dict(self) -> dict[str, Any]: - """Convert the message to a dictionary.""" - result: dict[str, Any] = {"role": self.role} - if self.content is not None: - result["content"] = self._serialize_content(self.content) - if self.name: - result["name"] = self.name - if self.tool_calls: - result["tool_calls"] = [tc.model_dump() for tc in self.tool_calls] - if self.tool_call_id: - result["tool_call_id"] = self.tool_call_id - return result - - @staticmethod - def _serialize_content( - content: str | Sequence[MessageContentPart] | None, - ) -> Any: - """Normalize message content so downstream callers receive plain data structures.""" - - if content is None: - return None - - if isinstance(content, str): - return content - - if isinstance(content, DomainModel): - return content.model_dump() - - if isinstance(content, Sequence): - serialized_parts: list[Any] = [] - for part in content: - if isinstance(part, DomainModel): - serialized_parts.append(part.model_dump()) - else: - serialized_parts.append(part) - return serialized_parts - - return content - - -class ChatRequest(ValueObject): - """ - A request for a chat completion. - """ - - model: str - messages: list[ChatMessage] - system_prompt: str | None = None # Add system_prompt field - temperature: float | None = None - top_p: float | None = None - top_k: int | None = None - n: int | None = None - stream: bool | None = None - stop: list[str] | str | None = None - max_tokens: int | None = None - presence_penalty: float | None = None - frequency_penalty: float | None = None - logit_bias: dict[str, float] | None = None - user: str | None = None - seed: int | None = None - tools: list[dict[str, Any]] | None = None - tool_choice: str | dict[str, Any] | None = None - session_id: str | None = None - agent: str | None = None # Add agent field - extra_body: dict[str, Any] | None = None - - # Reasoning parameters for o1, o3, o4-mini and other reasoning models - reasoning_effort: str | None = None - reasoning: dict[str, Any] | None = None - - # Gemini-specific reasoning parameters - thinking_budget: int | None = None - generation_config: dict[str, Any] | None = None - - @field_validator("messages") - @classmethod - def validate_messages(cls, v: list[Any]) -> list[ChatMessage]: - """Validate and convert messages.""" - if not v: - raise ValueError("At least one message is required") - return [m if isinstance(m, ChatMessage) else ChatMessage(**m) for m in v] - - @field_validator("tools", mode="before") - @classmethod - def validate_tools(cls, v: Any) -> list[dict[str, Any]] | None: - """Allow passing ToolDefinition instances or dicts for tools.""" - if v is None: - return None - result: list[dict[str, Any]] = [] - for item in v: - if isinstance(item, ToolDefinition): - result.append(item.model_dump()) - elif isinstance(item, dict): - result.append(item) - else: - # Attempt to coerce - try: - td = ToolDefinition(**item) - result.append(td.model_dump()) - except Exception as e: - from src.core.common.exceptions import ToolCallParsingError - - raise ToolCallParsingError( - message="Invalid tool definition provided", - details={"original_error": str(e), "invalid_item": str(item)}, - ) from e - return result - - +from collections.abc import Sequence +from typing import Any, TypeVar + +from pydantic import Field, field_validator + +from src.core.domain.base import ValueObject +from src.core.interfaces.model_bases import DomainModel + +# Define a type variable for generic methods +T = TypeVar("T", bound=DomainModel) + + +# For multimodal content parts +class MessageContentPartText(DomainModel): + """Represents a text content part in a multimodal message.""" + + type: str = "text" + text: str + + +class ImageURL(DomainModel): + """Specifies the URL and optional detail for an image in a multimodal message.""" + + # Should be a data URI (e.g., "data:image/jpeg;base64,...") or public URL + url: str + detail: str | None = Field(None, examples=["auto", "low", "high"]) + + +class MessageContentPartImage(DomainModel): + """Represents an image content part in a multimodal message.""" + + type: str = "image_url" + image_url: ImageURL + + +# Extend with other multimodal types as needed (e.g., audio, video file, documents) +# For now, text and image are common starting points. +MessageContentPart = MessageContentPartText | MessageContentPartImage +"""Type alias for possible content parts in a multimodal message.""" + + +class FunctionCall(DomainModel): + """Represents a function call within a tool call.""" + + name: str + arguments: str + + +class ToolCall(DomainModel): + """Represents a tool call in a chat completion response.""" + + id: str + type: str = "function" + function: FunctionCall + + +class FunctionDefinition(DomainModel): + """Represents a function definition for tool calling.""" + + name: str + description: str | None = None + parameters: dict[str, Any] | None = None + + +class ToolDefinition(DomainModel): + """Represents a tool definition in a chat completion request.""" + + type: str = "function" + function: FunctionDefinition + + @field_validator("function", mode="before") + @classmethod + def ensure_function_is_dict(cls, v: Any) -> dict[str, Any] | FunctionDefinition: + # Accept either a FunctionDefinition or a ToolDefinition/FunctionDefinition instance + # and normalize to a dict for ChatRequest validation + if isinstance(v, FunctionDefinition): + return v.model_dump() + # If v is already a dict, return it as is + if isinstance(v, dict): + return v + # If v is something else, try to convert it to a dict + # This should handle cases where v is a dict-like object + try: + return dict(v) # type: ignore + except (TypeError, ValueError): + # If we can't convert to dict, raise a ValueError to let Pydantic handle the error properly + raise ValueError(f"Cannot convert {type(v)} to dict or FunctionDefinition") + + +class ChatMessage(DomainModel): + """ + A chat message in a conversation. + """ + + role: str + content: str | Sequence[MessageContentPart] | None = None + name: str | None = None + tool_calls: list[ToolCall] | None = None + tool_call_id: str | None = None + metadata: dict[str, Any] | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert the message to a dictionary.""" + result: dict[str, Any] = {"role": self.role} + if self.content is not None: + result["content"] = self._serialize_content(self.content) + if self.name: + result["name"] = self.name + if self.tool_calls: + result["tool_calls"] = [tc.model_dump() for tc in self.tool_calls] + if self.tool_call_id: + result["tool_call_id"] = self.tool_call_id + return result + + @staticmethod + def _serialize_content( + content: str | Sequence[MessageContentPart] | None, + ) -> Any: + """Normalize message content so downstream callers receive plain data structures.""" + + if content is None: + return None + + if isinstance(content, str): + return content + + if isinstance(content, DomainModel): + return content.model_dump() + + if isinstance(content, Sequence): + serialized_parts: list[Any] = [] + for part in content: + if isinstance(part, DomainModel): + serialized_parts.append(part.model_dump()) + else: + serialized_parts.append(part) + return serialized_parts + + return content + + +class ChatRequest(ValueObject): + """ + A request for a chat completion. + """ + + model: str + messages: list[ChatMessage] + system_prompt: str | None = None # Add system_prompt field + temperature: float | None = None + top_p: float | None = None + top_k: int | None = None + n: int | None = None + stream: bool | None = None + stop: list[str] | str | None = None + max_tokens: int | None = None + presence_penalty: float | None = None + frequency_penalty: float | None = None + logit_bias: dict[str, float] | None = None + user: str | None = None + seed: int | None = None + repetition_penalty: float | None = None + top_logprobs: int | None = None + min_p: float | None = None + top_a: float | None = None + prediction: dict[str, Any] | None = None + transforms: list[str] | None = None + models: list[str] | None = None + route: str | None = None + provider: dict[str, Any] | None = None + response_format: dict[str, Any] | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + session_id: str | None = None + agent: str | None = None # Add agent field + extra_body: dict[str, Any] | None = None + + # Reasoning parameters for o1, o3, o4-mini and other reasoning models + reasoning_effort: str | None = None + reasoning: dict[str, Any] | None = None + + # Gemini-specific reasoning parameters + thinking_budget: int | None = None + generation_config: dict[str, Any] | None = None + + @field_validator("messages") + @classmethod + def validate_messages(cls, v: list[Any]) -> list[ChatMessage]: + """Validate and convert messages.""" + if not v: + raise ValueError("At least one message is required") + return [m if isinstance(m, ChatMessage) else ChatMessage(**m) for m in v] + + @field_validator("tools", mode="before") + @classmethod + def validate_tools(cls, v: Any) -> list[dict[str, Any]] | None: + """Allow passing ToolDefinition instances or dicts for tools.""" + if v is None: + return None + result: list[dict[str, Any]] = [] + for item in v: + if isinstance(item, ToolDefinition): + result.append(item.model_dump()) + elif isinstance(item, dict): + result.append(item) + else: + # Attempt to coerce + try: + td = ToolDefinition(**item) + result.append(td.model_dump()) + except Exception as e: + from src.core.common.exceptions import ToolCallParsingError + + raise ToolCallParsingError( + message="Invalid tool definition provided", + details={"original_error": str(e), "invalid_item": str(item)}, + ) from e + return result + + class ChatCompletionChoiceMessage(DomainModel): """Represents the message content within a chat completion choice.""" @@ -216,121 +226,121 @@ class ChatCompletionChoiceMessage(DomainModel): tool_calls: list[ToolCall] | None = None tool_call_id: str | None = None metadata: dict[str, Any] | None = None - - -class ChatCompletionChoice(DomainModel): - """Represents a single choice in a chat completion response.""" - - index: int - message: ChatCompletionChoiceMessage - finish_reason: str | None = None - - -# ChatUsage class is defined elsewhere in this file - - -class ChatResponse(ValueObject): - """ - A response from a chat completion. - """ - - id: str - created: int - model: str - choices: list[ChatCompletionChoice] - usage: dict[str, Any] | None = None - system_fingerprint: str | None = None - object: str = "chat.completion" - - -class StreamingChatResponse(ValueObject): - """ - A streaming chunk of a chat completion response. - """ - - content: str | None - model: str - finish_reason: str | None = None - tool_calls: list[dict[str, Any]] | None = None - delta: dict[str, Any] | None = None - system_fingerprint: str | None = None - done: bool | None = None - metadata: dict[str, Any] | None = None - - @classmethod - def from_legacy_chunk(cls, chunk: dict[str, Any]) -> "StreamingChatResponse": - """ - Create a StreamingChatResponse from a legacy chunk format. - - Args: - chunk: A legacy streaming chunk - - Returns: - A new StreamingChatResponse - """ - # Extract the response content and other fields from the chunk - content: str | None = None - if chunk.get("choices"): - choice: dict[str, Any] = chunk["choices"][0] - if "delta" in choice: - delta: dict[str, Any] = choice["delta"] - if "content" in delta: - content = delta["content"] - - # Might have tool calls in delta - tool_calls: list[dict[str, Any]] | None = delta.get("tool_calls") - - # The delta is the actual delta object - delta_obj: dict[str, Any] | None = delta - else: - # Simpler format - content = choice.get("text", "") - tool_calls = None - delta_obj = None - - # Extract finish reason if present - finish_reason: str | None = choice.get("finish_reason") - else: - # Anthropic format - if "content" in chunk: - if isinstance(chunk["content"], list): - content_parts: list[str] = [ - p["text"] for p in chunk["content"] if p.get("type") == "text" - ] - content = "".join(content_parts) - else: - content = chunk["content"] - - tool_calls = chunk.get("tool_calls") - delta_obj = None - finish_reason = chunk.get("stop_reason") - - # Extract model - model: str = chunk.get("model", "unknown") - - # Extract system fingerprint - system_fingerprint: str | None = chunk.get("system_fingerprint") - - return cls( - content=content, - model=model, - finish_reason=finish_reason, - tool_calls=tool_calls, - delta=delta_obj, - system_fingerprint=system_fingerprint, - ) - - -# ChatUsage class is defined elsewhere in this file - - -class CanonicalChatRequest(ChatRequest): - """ - A canonical chat request model that is used internally throughout the application. - """ - - -class CanonicalChatResponse(ChatResponse): - """ - A canonical chat response model that is used internally throughout the application. - """ + + +class ChatCompletionChoice(DomainModel): + """Represents a single choice in a chat completion response.""" + + index: int + message: ChatCompletionChoiceMessage + finish_reason: str | None = None + + +# ChatUsage class is defined elsewhere in this file + + +class ChatResponse(ValueObject): + """ + A response from a chat completion. + """ + + id: str + created: int + model: str + choices: list[ChatCompletionChoice] + usage: dict[str, Any] | None = None + system_fingerprint: str | None = None + object: str = "chat.completion" + + +class StreamingChatResponse(ValueObject): + """ + A streaming chunk of a chat completion response. + """ + + content: str | None + model: str + finish_reason: str | None = None + tool_calls: list[dict[str, Any]] | None = None + delta: dict[str, Any] | None = None + system_fingerprint: str | None = None + done: bool | None = None + metadata: dict[str, Any] | None = None + + @classmethod + def from_legacy_chunk(cls, chunk: dict[str, Any]) -> "StreamingChatResponse": + """ + Create a StreamingChatResponse from a legacy chunk format. + + Args: + chunk: A legacy streaming chunk + + Returns: + A new StreamingChatResponse + """ + # Extract the response content and other fields from the chunk + content: str | None = None + if chunk.get("choices"): + choice: dict[str, Any] = chunk["choices"][0] + if "delta" in choice: + delta: dict[str, Any] = choice["delta"] + if "content" in delta: + content = delta["content"] + + # Might have tool calls in delta + tool_calls: list[dict[str, Any]] | None = delta.get("tool_calls") + + # The delta is the actual delta object + delta_obj: dict[str, Any] | None = delta + else: + # Simpler format + content = choice.get("text", "") + tool_calls = None + delta_obj = None + + # Extract finish reason if present + finish_reason: str | None = choice.get("finish_reason") + else: + # Anthropic format + if "content" in chunk: + if isinstance(chunk["content"], list): + content_parts: list[str] = [ + p["text"] for p in chunk["content"] if p.get("type") == "text" + ] + content = "".join(content_parts) + else: + content = chunk["content"] + + tool_calls = chunk.get("tool_calls") + delta_obj = None + finish_reason = chunk.get("stop_reason") + + # Extract model + model: str = chunk.get("model", "unknown") + + # Extract system fingerprint + system_fingerprint: str | None = chunk.get("system_fingerprint") + + return cls( + content=content, + model=model, + finish_reason=finish_reason, + tool_calls=tool_calls, + delta=delta_obj, + system_fingerprint=system_fingerprint, + ) + + +# ChatUsage class is defined elsewhere in this file + + +class CanonicalChatRequest(ChatRequest): + """ + A canonical chat request model that is used internally throughout the application. + """ + + +class CanonicalChatResponse(ChatResponse): + """ + A canonical chat response model that is used internally throughout the application. + """ diff --git a/tests/unit/openrouter_connector_tests/test_payload_construction_and_headers.py b/tests/unit/openrouter_connector_tests/test_payload_construction_and_headers.py index 7828d3f1..d109dee6 100644 --- a/tests/unit/openrouter_connector_tests/test_payload_construction_and_headers.py +++ b/tests/unit/openrouter_connector_tests/test_payload_construction_and_headers.py @@ -216,3 +216,53 @@ async def test_openrouter_processed_messages_remain_pydantic( assert isinstance( processed_msgs_fixture[2].content[1], MessageContentPartImage ) # Specific type + + +@pytest.mark.asyncio +async def test_openrouter_specific_parameters_forwarded( + openrouter_backend: OpenRouterBackend, httpx_mock: HTTPXMock +): + request = ChatRequest( + model="openai/gpt-4o", + messages=[ChatMessage(role="user", content="Hello")], + stream=False, + top_k=42, + repetition_penalty=1.15, + top_logprobs=5, + min_p=0.25, + top_a=0.8, + prediction={"type": "content", "content": "Hello"}, + transforms=["safe-mode"], + models=["openai/gpt-4o", "perplexity/llama-3"], + route="fallback", + provider={"allow_fallbacks": True, "order": ["openai/gpt-4o"]}, + response_format={"type": "json_object"}, + ) + + httpx_mock.add_response(status_code=200, json={"id": "ok"}) + + await openrouter_backend.chat_completions( + request_data=request, + processed_messages=request.messages, + effective_model=request.model, + openrouter_api_base_url=TEST_OPENROUTER_API_BASE_URL, + openrouter_headers_provider=mock_get_openrouter_headers, + key_name="test_key", + api_key="FAKE_KEY", + ) + + sent_request = httpx_mock.get_request() + assert sent_request is not None + payload = json.loads(sent_request.content) + + assert payload["top_k"] == 42 + assert payload["repetition_penalty"] == 1.15 + assert payload["top_logprobs"] == 5 + assert payload["min_p"] == 0.25 + assert payload["top_a"] == 0.8 + assert payload["prediction"] == {"type": "content", "content": "Hello"} + assert payload["transforms"] == ["safe-mode"] + assert payload["models"] == ["openai/gpt-4o", "perplexity/llama-3"] + assert payload["route"] == "fallback" + assert payload["provider"] == {"allow_fallbacks": True, "order": ["openai/gpt-4o"]} + assert payload["response_format"] == {"type": "json_object"}