diff --git a/src/connectors/anthropic.py b/src/connectors/anthropic.py index c9551c26..e8e64015 100644 --- a/src/connectors/anthropic.py +++ b/src/connectors/anthropic.py @@ -327,6 +327,16 @@ def _prepare_anthropic_payload( logger.warning( "AnthropicBackend does not support the 'logit_bias' parameter." ) + if request_data.repetition_penalty is not None and logger.isEnabledFor( + logging.WARNING + ): + logger.warning( + "AnthropicBackend does not support the 'repetition_penalty' parameter." + ) + if request_data.min_p is not None and logger.isEnabledFor(logging.WARNING): + logger.warning( + "AnthropicBackend does not support the 'min_p' parameter." + ) # Include tools and tool_choice when provided (tests set these fields) if request_data.tools is not None: diff --git a/src/connectors/gemini.py b/src/connectors/gemini.py index 35e29d17..4025cfa9 100644 --- a/src/connectors/gemini.py +++ b/src/connectors/gemini.py @@ -1,785 +1,793 @@ -from __future__ import annotations - -import asyncio -import contextlib -import json -import logging -import uuid -from collections.abc import AsyncGenerator, Callable -from typing import Any, cast - -import httpx -from fastapi import HTTPException - -from src.connectors.base import LLMBackend -from src.core.common.exceptions import ( - AuthenticationError, - BackendError, - ServiceUnavailableError, -) -from src.core.config.app_config import AppConfig # Added -from src.core.domain.chat import ( - ChatRequest, - MessageContentPartImage, - MessageContentPartText, -) -from src.core.domain.responses import ( - ResponseEnvelope, - StreamingResponseEnvelope, - StreamingResponseHandle, -) -from src.core.interfaces.configuration_interface import IAppIdentityConfig -from src.core.interfaces.model_bases import DomainModel, InternalDTO -from src.core.interfaces.response_processor_interface import ProcessedResponse -from src.core.security.loop_prevention import ensure_loop_guard_header -from src.core.services.backend_registry import backend_registry -from src.core.services.translation_service import TranslationService - -# Legacy ChatCompletionRequest removed from connector signatures; use domain ChatRequest - -# API key redaction and command filtering are now handled by middleware -# from src.security import APIKeyRedactor, ProxyCommandFilter - -logger = logging.getLogger(__name__) - - -class GeminiBackend(LLMBackend): - """LLMBackend implementation for Google's Gemini API.""" - - backend_type: str = "gemini" - - def __init__( - self, - client: httpx.AsyncClient, - config: AppConfig, - translation_service: TranslationService, - ) -> None: - self.client = client - self.config = config # Stored config - self.translation_service = translation_service - self.available_models: list[str] = [] - self.api_keys: list[str] = [] - - async def initialize(self, **kwargs: Any) -> None: - """Store configuration for lazy initialization.""" - self.gemini_api_base_url = kwargs.get("gemini_api_base_url") - self.key_name = kwargs.get("key_name") - self.api_key = kwargs.get("api_key") - - if not self.gemini_api_base_url or not self.key_name or not self.api_key: - raise ValueError( - "gemini_api_base_url, key_name, and api_key are required for GeminiBackend" - ) - - # Don't make HTTP calls during initialization - # Models will be fetched on first use - - async def _ensure_models_loaded(self) -> None: - """Fetch models if not already cached.""" - if ( - not self.available_models - and hasattr(self, "api_key") - and self.gemini_api_base_url - and self.key_name - and self.api_key - ): - try: - data = await self.list_models( - gemini_api_base_url=self.gemini_api_base_url, - key_name=self.key_name, - api_key=self.api_key, - ) - self.available_models = [ - m.get("name") for m in data.get("models", []) if m.get("name") - ] - except Exception as e: - logger.warning("Failed to fetch Gemini models: %s", e, exc_info=True) - # Return empty list on failure, don't crash - self.available_models = [] - - def get_available_models(self) -> list[str]: - """Return cached Gemini model names. For immediate use, prefer async version.""" - return list(self.available_models) - - async def get_available_models_async(self) -> list[str]: - """Return Gemini model names, fetching them if not cached.""" - await self._ensure_models_loaded() - return list(self.available_models) - - # Translation is now handled by TranslationService - - def _convert_part_for_gemini( - self, part: MessageContentPartText | MessageContentPartImage - ) -> dict[str, Any]: - """Convert a MessageContentPart into Gemini API format.""" - if isinstance(part, MessageContentPartText): - # Text content is already processed by middleware - return {"text": part.text} - if isinstance(part, MessageContentPartImage): - url = part.image_url.url - # Data URL -> inlineData - if url.startswith("data:"): - try: - header, b64_data = url.split(",", 1) - mime = header.split(";")[0][5:] - except Exception: - mime = "application/octet-stream" - b64_data = "" - return {"inlineData": {"mimeType": mime, "data": b64_data}} - # Otherwise treat as remote file URI - return { - "fileData": {"mimeType": "application/octet-stream", "fileUri": url} - } - data = part.model_dump(exclude_unset=True) - if data.get("type") == "text" and "text" in data: - # Text content is already processed by middleware - data.pop("type", None) - return data - - def _prepare_gemini_contents( - self, processed_messages: list[Any] - ) -> list[dict[str, Any]]: - payload_contents = [] - for msg in processed_messages: - # Handle both object and dict formats for backward compatibility - if isinstance(msg, dict): - role = msg.get("role") - # For dict format, check if it's already in Gemini format (has "parts") - # or in generic format (has "content") - if "parts" in msg: - # Already in Gemini format, use directly - payload_contents.append({"role": role, "parts": msg["parts"]}) - continue - else: - content = msg.get("content") - else: - role = getattr(msg, "role", None) - content = getattr(msg, "content", None) - - if role == "system": - # Gemini API does not support system role - continue - - if isinstance(content, str): - # If this is a tool or function role, represent it as functionResponse for Gemini - if role in ["tool", "function"]: - # Try to parse JSON payload; otherwise wrap string - try: - input_obj = json.loads(content) - except Exception: - input_obj = {"output": content} - parts: list[dict[str, Any]] = [ - { - "functionResponse": { - "name": ( - getattr(msg, "name", "tool") or "tool" - if not isinstance(msg, dict) - else msg.get("name", "tool") - ), - "response": input_obj, - } - } - ] - else: - # Content is already processed by middleware - parts = [{"text": content}] - elif content is not None: - parts = [self._convert_part_for_gemini(part) for part in content] - else: - # Skip messages with no content - continue - - # Map roles to 'user' or 'model' as required by Gemini API - if role == "user": - gemini_role = "user" - elif role in ["tool", "function"]: - # Tool/function results are treated as coming from the user side in Gemini - gemini_role = "user" - else: # e.g., assistant - gemini_role = "model" - - payload_contents.append({"role": gemini_role, "parts": parts}) - return payload_contents - - @staticmethod - def _coerce_stream_chunk(raw_chunk: Any) -> dict[str, Any] | None: - if isinstance(raw_chunk, dict): - return raw_chunk - - if isinstance(raw_chunk, bytes | bytearray): - raw_chunk = raw_chunk.decode("utf-8", errors="ignore") - - if not isinstance(raw_chunk, str): - return None - - stripped_chunk = raw_chunk.strip() - if not stripped_chunk: - return None - - data_segments: list[str] = [] - for line in stripped_chunk.splitlines(): - line = line.strip() - if not line or line.startswith(":"): - continue - if line.startswith("data:"): - data_value = line[5:].strip() - if not data_value: - continue - if data_value == "[DONE]": - return None - data_segments.append(data_value) - else: - data_segments.append(line) - - for segment in data_segments or [stripped_chunk]: - try: - parsed = json.loads(segment) - except json.JSONDecodeError: - continue - - if isinstance(parsed, dict): - return parsed - - if isinstance(parsed, str): - stripped_parsed = parsed.strip() - if stripped_parsed: - return { - "candidates": [ - { - "content": {"parts": [{"text": stripped_parsed}]}, - } - ] - } - - # If parsed value is not usable, continue searching remaining segments - continue - - # Fallback to treating the content as plain text - return { - "candidates": [ - { - "content": {"parts": [{"text": stripped_chunk}]}, - } - ] - } - - async def _handle_gemini_streaming_response( - self, - base_url: str, - payload: dict[str, Any], - headers: dict[str, str], - effective_model: str, - ) -> StreamingResponseHandle: - request_headers = ensure_loop_guard_header(headers) - request_id = request_headers.get("x-goog-request-id") or uuid.uuid4().hex - request_headers.setdefault("x-goog-request-id", request_id) - - url = f"{base_url}:streamGenerateContent" - try: - request = self.client.build_request( - "POST", url, json=payload, headers=request_headers - ) - response = await self.client.send(request, stream=True) - except httpx.RequestError as e: - logger.error("Request error connecting to Gemini: %s", e, exc_info=True) - raise ServiceUnavailableError(message=f"Could not connect to Gemini ({e})") - except (AttributeError, TypeError): - request = self.client.build_request( - "POST", url, json=payload, headers=request_headers - ) - try: - response = await self.client.send(request, stream=True) - except httpx.RequestError as e: - logger.error("Request error connecting to Gemini: %s", e, exc_info=True) - raise ServiceUnavailableError( - message=f"Could not connect to Gemini ({e})" - ) - - if response.status_code >= 400: - try: - if hasattr(response, "aread"): - body_bytes = await response.aread() # type: ignore[no-untyped-call] - else: - body_bytes = b"" - body_text = body_bytes.decode("utf-8", errors="ignore") - except Exception: - body_text = "" - finally: - if hasattr(response, "aclose"): - await response.aclose() - logger.error( - "HTTP error during Gemini stream: %s - %s", - response.status_code, - body_text, - ) - raise BackendError( - message=f"Gemini stream error: {response.status_code} - {body_text}", - code="gemini_error", - status_code=response.status_code, - ) - - # Prefer response-provided request identifiers when available - response_request_id = response.headers.get("x-goog-request-id") - if response_request_id: - request_id = response_request_id - - cancel_lock = asyncio.Lock() - cancel_state = {"called": False} - - async def cancel_stream() -> None: - async with cancel_lock: - if cancel_state["called"]: - return - cancel_state["called"] = True - - cancel_url = f"{base_url}:cancel" - cancel_headers = ensure_loop_guard_header(dict(request_headers)) - payload_body = {"requestId": request_id} - - try: - cancel_response = await self.client.post( - cancel_url, - json=payload_body, - headers=cancel_headers, - ) - except Exception as exc: - logger.debug( - "Gemini cancel request failed - url=%s request_id=%s error=%s", - cancel_url, - request_id, - exc, - exc_info=True, - ) - else: - with contextlib.suppress(Exception): - await cancel_response.aclose() - - with contextlib.suppress(Exception): - await response.aclose() - - async def stream_generator() -> AsyncGenerator[ProcessedResponse, None]: - processed_stream = response.aiter_text() - - try: - async for raw_chunk in processed_stream: - parsed_chunk = self._coerce_stream_chunk(raw_chunk) - if parsed_chunk is None: - continue - - yield ProcessedResponse( - content=self.translation_service.to_domain_stream_chunk( - parsed_chunk, source_format="gemini" - ) - ) - - done_chunk = { - "candidates": [ - { - "content": {"parts": []}, - "finishReason": "STOP", - } - ] - } - yield ProcessedResponse( - content=self.translation_service.to_domain_stream_chunk( - done_chunk, source_format="gemini" - ) - ) - except httpx.RequestError as stream_error: - logger.error( - "Request error while streaming from Gemini: %s", - stream_error, - exc_info=True, - ) - raise ServiceUnavailableError( - message=f"Gemini streaming connection error ({stream_error})" - ) from stream_error - finally: - with contextlib.suppress(Exception): - await response.aclose() - - try: - response_headers = dict(response.headers) - except Exception: - response_headers = {} - - return StreamingResponseHandle( - iterator=stream_generator(), - cancel_callback=cancel_stream, - headers=response_headers, - ) - - async def chat_completions( # type: ignore[override] - self, - request_data: DomainModel | InternalDTO | dict[str, Any], - processed_messages: list[Any], - effective_model: str, - identity: IAppIdentityConfig | None = None, - openrouter_api_base_url: str | None = None, - openrouter_headers_provider: Callable[[Any, str], dict[str, str]] | None = None, - key_name: str | None = None, - api_key: str | None = None, - project: str | None = None, - agent: str | None = None, - gemini_api_base_url: str | None = None, - **kwargs: Any, - ) -> ResponseEnvelope | StreamingResponseEnvelope: - # Resolve base configuration - base_api_url, headers = await self._resolve_gemini_api_config( - gemini_api_base_url, - openrouter_api_base_url, - api_key, - openrouter_headers_provider=openrouter_headers_provider, - key_name=key_name, - **kwargs, - ) - if identity: - headers.update(identity.get_resolved_headers(None)) - - # request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest) - # (the frontend controller converts from frontend-specific format to domain format) - # Backends should ONLY convert FROM domain TO backend-specific format - # Type assertion: we know from architectural design that request_data is ChatRequest-like - from typing import cast - - from src.core.domain.chat import CanonicalChatRequest, ChatRequest - - if not isinstance(request_data, ChatRequest): - raise TypeError( - f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. " - "Backend connectors should only receive domain-format requests." - ) - # Cast to CanonicalChatRequest for mypy compatibility with translation service signature - domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data) - - # Translate CanonicalChatRequest to Gemini request using the translation service - payload = self.translation_service.from_domain_request( - domain_request, target_format="gemini" - ) - - # Apply generation config including temperature clamping - self._apply_generation_config(payload, domain_request) - - # Apply contents and extra_body - payload["contents"] = self._prepare_gemini_contents(processed_messages) - if domain_request.extra_body: - # Merge extra_body with payload, but be careful with generationConfig. - # We support both legacy placement under 'generation_config' and - # the external 'generationConfig' key that Gemini expects. - # Normalize: prefer explicit generation_config on ChatRequest, then - # merge any 'generationConfig' present in extra_body on top. - extra_body_copy = dict(domain_request.extra_body) - - # If caller placed generation_config on ChatRequest it was already - # merged by _apply_generation_config into payload['generationConfig']. - # Now merge any generationConfig from extra_body on top of what we - # already have (extra body should be able to override specific keys). - # Accept either CamelCase 'generationConfig' (as used in tests and - # by external callers) or legacy snake_case 'generation_config' - extra_gen_cfg = extra_body_copy.pop("generationConfig", None) - if extra_gen_cfg is None: - extra_gen_cfg = extra_body_copy.pop("generation_config", None) - if extra_gen_cfg: - # merge by creating a new dict so we don't retain old references - existing = payload.get("generationConfig", {}) - merged = dict(existing) - - # Handle nested structures like thinkingConfig - for key, value in extra_gen_cfg.items(): - if ( - key == "thinkingConfig" - and isinstance(value, dict) - and "thinkingConfig" in merged - and isinstance(merged["thinkingConfig"], dict) - ): - # Deep merge thinkingConfig - merged["thinkingConfig"].update(value) - elif key == "maxOutputTokens" and "maxOutputTokens" not in merged: - # Add maxOutputTokens if not present - merged["maxOutputTokens"] = value - else: - # Regular update for other keys - merged[key] = value - - # Ensure extra_body overrides win for temperature specifically - if "temperature" in extra_gen_cfg: - merged["temperature"] = extra_gen_cfg["temperature"] - payload["generationConfig"] = merged - - # Finally update payload with remaining extra body fields - if extra_body_copy: - payload.update(extra_body_copy) - # Remove generation_config (legacy key) if present; we've migrated it - # into 'generationConfig' in _apply_generation_config. - payload.pop("generation_config", None) - # Debug output - logger.debug("Final payload: %s", payload) - - # Normalize model id and construct URL - model_name = self._normalize_model_name(effective_model) - logger.debug("Constructing Gemini API URL with model_name: %s", model_name) - model_url = f"{base_api_url}/v1beta/models/{model_name}" - - # Streaming vs non-streaming - if domain_request.stream: - stream_handle = await self._handle_gemini_streaming_response( - model_url, payload, headers, effective_model - ) - return StreamingResponseEnvelope( - content=stream_handle.iterator, - media_type="text/event-stream", - headers=stream_handle.headers or {}, - cancel_callback=stream_handle.cancel_callback, - ) - - return await self._handle_gemini_non_streaming_response( - model_url, payload, headers, effective_model - ) - - def _build_openrouter_header_context(self) -> dict[str, str]: - referer = "http://localhost:8000" - title = "InterceptorProxy" - - identity = getattr(self.config, "identity", None) - if identity is not None: - referer = ( - getattr(getattr(identity, "url", None), "default_value", referer) - or referer - ) - title = ( - getattr(getattr(identity, "title", None), "default_value", title) - or title - ) - - return {"app_site_url": referer, "app_x_title": title} - - async def _resolve_gemini_api_config( - self, - gemini_api_base_url: str | None, - openrouter_api_base_url: str | None, - api_key: str | None, - *, - openrouter_headers_provider: Callable[[Any, str], dict[str, str]] | None = None, - key_name: str | None = None, - **kwargs: Any, - ) -> tuple[str, dict[str, str]]: - # Prefer explicit params, then kwargs, then instance attributes set during initialize - base = ( - gemini_api_base_url - or openrouter_api_base_url - or kwargs.get("gemini_api_base_url") - or getattr(self, "gemini_api_base_url", None) - ) - key = api_key or kwargs.get("api_key") or getattr(self, "api_key", None) - if not base or not key: - raise HTTPException( - status_code=500, - detail="Gemini API base URL and API key must be provided.", - ) - normalized_base = base.rstrip("/") - - # Only use OpenRouter mode if the chosen base is actually OpenRouter - # OpenRouter mode should only be enabled when the resolved base URL is different - # from the default Gemini API base URL, indicating we're actually routing to OpenRouter - gemini_default_base = "https://generativelanguage.googleapis.com" - using_openrouter = ( - openrouter_api_base_url is not None - and normalized_base != gemini_default_base.rstrip("/") - ) - - headers: dict[str, str] - if using_openrouter: - headers = {} - provided_headers: dict[str, str] | None = None - - if openrouter_headers_provider is not None: - errors: list[Exception] = [] - - if key_name is not None: - try: - candidate = openrouter_headers_provider(key_name, key) - except (AttributeError, TypeError) as exc: - errors.append(exc) - else: - if candidate: - provided_headers = dict(candidate) - - if provided_headers is None: - context = self._build_openrouter_header_context() - try: - candidate = openrouter_headers_provider(context, key) - except Exception as exc: # pragma: no cover - defensive guard - if errors and logger.isEnabledFor(logging.DEBUG): - logger.debug( - "OpenRouter headers provider rejected key_name input: %s", - errors[-1], - exc_info=True, - ) - raise AuthenticationError( - message="OpenRouter headers provider failed to produce headers.", - code="missing_credentials", - ) from exc - else: - provided_headers = dict(candidate) - - if provided_headers is None: - context = self._build_openrouter_header_context() - provided_headers = { - "Authorization": f"Bearer {key}", - "Content-Type": "application/json", - "HTTP-Referer": context["app_site_url"], - "X-Title": context["app_x_title"], - } - - headers.update(provided_headers) - context = self._build_openrouter_header_context() - headers.setdefault("Authorization", f"Bearer {key}") - headers.setdefault("Content-Type", "application/json") - headers.setdefault("HTTP-Referer", context["app_site_url"]) - headers.setdefault("X-Title", context["app_x_title"]) - else: - key_name_to_use = ( - key_name - or kwargs.get("key_name") - or getattr(self, "key_name", None) - or "x-goog-api-key" - ) - headers = {key_name_to_use: key} - - return normalized_base, ensure_loop_guard_header(headers) - - def _apply_generation_config( - self, payload: dict[str, Any], request_data: ChatRequest - ) -> None: - # Initialize generationConfig - generation_config = payload.setdefault("generationConfig", {}) - - # thinking budget - if getattr(request_data, "thinking_budget", None): - thinking_config = generation_config.setdefault("thinkingConfig", {}) - thinking_config["thinkingBudget"] = request_data.thinking_budget # type: ignore[index] - - # top_k - if getattr(request_data, "top_k", None) is not None: - generation_config["topK"] = request_data.top_k - - # reasoning_effort - if getattr(request_data, "reasoning_effort", None) is not None: - thinking_config = generation_config.setdefault("thinkingConfig", {}) - thinking_config["reasoning_effort"] = request_data.reasoning_effort - - # generation config blob - merge with existing config - if getattr(request_data, "generation_config", None): - # Deep merge the generation_config into generationConfig - for key, value in request_data.generation_config.items(): # type: ignore[union-attr] - generation_config[key] = value - - # temperature clamped to [0,1] - temperature = getattr(request_data, "temperature", None) - if temperature is not None: - # Clamp temperature to [0,1] range for Gemini - if float(temperature) > 1.0: - logger.warning( - f"Temperature {temperature} > 1.0 for Gemini, clamping to 1.0" - ) - temperature = 1.0 - generation_config["temperature"] = float(temperature) - - # top_p - if request_data.top_p is not None: - generation_config["topP"] = request_data.top_p - - # stop sequences - if request_data.stop: - generation_config["stopSequences"] = request_data.stop - - # Unsupported parameters - if request_data.seed is not None and logger.isEnabledFor(logging.WARNING): - logger.warning("GeminiBackend does not support the 'seed' parameter.") - if request_data.presence_penalty is not None and logger.isEnabledFor( - logging.WARNING - ): - logger.warning( - "GeminiBackend does not support the 'presence_penalty' parameter." - ) - if request_data.frequency_penalty is not None and logger.isEnabledFor( - logging.WARNING - ): - logger.warning( - "GeminiBackend does not support the 'frequency_penalty' parameter." - ) - if request_data.logit_bias is not None and logger.isEnabledFor(logging.WARNING): - logger.warning("GeminiBackend does not support the 'logit_bias' parameter.") - if request_data.user is not None and logger.isEnabledFor(logging.WARNING): - logger.warning("GeminiBackend does not support the 'user' parameter.") - - def _normalize_model_name(self, effective_model: str) -> str: - model_name = effective_model - if model_name.startswith("gemini:"): - model_name = model_name.split(":", 1)[1] - if model_name.startswith("models/"): - model_name = model_name.split("/", 1)[1] - if model_name.startswith("gemini/"): - model_name = model_name.split("/", 1)[1] - if "/" in model_name: - logger.debug( - "Detected provider prefix in model name '%s'. Using last path segment as Gemini model id.", - model_name, - ) - model_name = model_name.rsplit("/", 1)[-1] - return model_name - - async def _handle_gemini_non_streaming_response( - self, base_url: str, payload: dict, headers: dict, effective_model: str - ) -> ResponseEnvelope: - headers = ensure_loop_guard_header(headers) - url = f"{base_url}:generateContent" - try: - response = await self.client.post(url, json=payload, headers=headers) - if response.status_code >= 400: - try: - error_detail = response.json() - except Exception: - error_detail = response.text - raise BackendError( - message=str(error_detail), - code="gemini_error", - status_code=response.status_code, - ) - data = response.json() - logger.debug("Gemini response headers: %s", dict(response.headers)) - return ResponseEnvelope( - content=self.translation_service.to_domain_response( - data, source_format="gemini" - ), - headers=dict(response.headers), - status_code=response.status_code, - ) - except httpx.RequestError as e: - logger.error("Request error connecting to Gemini: %s", e, exc_info=True) - raise ServiceUnavailableError(message=f"Could not connect to Gemini ({e})") - - async def list_models( - self, *, gemini_api_base_url: str, key_name: str, api_key: str - ) -> dict[str, Any]: - headers = ensure_loop_guard_header({key_name: api_key}) - url = f"{gemini_api_base_url.rstrip('/')}/v1beta/models" - try: - response = await self.client.get(url, headers=headers) - if response.status_code >= 400: - try: - error_detail = response.json() - except Exception: - error_detail = response.text - raise BackendError( - message=str(error_detail), - code="gemini_error", - status_code=response.status_code, - ) - return cast(dict[str, Any], response.json()) - except httpx.RequestError as e: - logger.error("Request error connecting to Gemini: %s", e, exc_info=True) - raise ServiceUnavailableError(message=f"Could not connect to Gemini ({e})") - - -backend_registry.register_backend("gemini", GeminiBackend) +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +import uuid +from collections.abc import AsyncGenerator, Callable +from typing import Any, cast + +import httpx +from fastapi import HTTPException + +from src.connectors.base import LLMBackend +from src.core.common.exceptions import ( + AuthenticationError, + BackendError, + ServiceUnavailableError, +) +from src.core.config.app_config import AppConfig # Added +from src.core.domain.chat import ( + ChatRequest, + MessageContentPartImage, + MessageContentPartText, +) +from src.core.domain.responses import ( + ResponseEnvelope, + StreamingResponseEnvelope, + StreamingResponseHandle, +) +from src.core.interfaces.configuration_interface import IAppIdentityConfig +from src.core.interfaces.model_bases import DomainModel, InternalDTO +from src.core.interfaces.response_processor_interface import ProcessedResponse +from src.core.security.loop_prevention import ensure_loop_guard_header +from src.core.services.backend_registry import backend_registry +from src.core.services.translation_service import TranslationService + +# Legacy ChatCompletionRequest removed from connector signatures; use domain ChatRequest + +# API key redaction and command filtering are now handled by middleware +# from src.security import APIKeyRedactor, ProxyCommandFilter + +logger = logging.getLogger(__name__) + + +class GeminiBackend(LLMBackend): + """LLMBackend implementation for Google's Gemini API.""" + + backend_type: str = "gemini" + + def __init__( + self, + client: httpx.AsyncClient, + config: AppConfig, + translation_service: TranslationService, + ) -> None: + self.client = client + self.config = config # Stored config + self.translation_service = translation_service + self.available_models: list[str] = [] + self.api_keys: list[str] = [] + + async def initialize(self, **kwargs: Any) -> None: + """Store configuration for lazy initialization.""" + self.gemini_api_base_url = kwargs.get("gemini_api_base_url") + self.key_name = kwargs.get("key_name") + self.api_key = kwargs.get("api_key") + + if not self.gemini_api_base_url or not self.key_name or not self.api_key: + raise ValueError( + "gemini_api_base_url, key_name, and api_key are required for GeminiBackend" + ) + + # Don't make HTTP calls during initialization + # Models will be fetched on first use + + async def _ensure_models_loaded(self) -> None: + """Fetch models if not already cached.""" + if ( + not self.available_models + and hasattr(self, "api_key") + and self.gemini_api_base_url + and self.key_name + and self.api_key + ): + try: + data = await self.list_models( + gemini_api_base_url=self.gemini_api_base_url, + key_name=self.key_name, + api_key=self.api_key, + ) + self.available_models = [ + m.get("name") for m in data.get("models", []) if m.get("name") + ] + except Exception as e: + logger.warning("Failed to fetch Gemini models: %s", e, exc_info=True) + # Return empty list on failure, don't crash + self.available_models = [] + + def get_available_models(self) -> list[str]: + """Return cached Gemini model names. For immediate use, prefer async version.""" + return list(self.available_models) + + async def get_available_models_async(self) -> list[str]: + """Return Gemini model names, fetching them if not cached.""" + await self._ensure_models_loaded() + return list(self.available_models) + + # Translation is now handled by TranslationService + + def _convert_part_for_gemini( + self, part: MessageContentPartText | MessageContentPartImage + ) -> dict[str, Any]: + """Convert a MessageContentPart into Gemini API format.""" + if isinstance(part, MessageContentPartText): + # Text content is already processed by middleware + return {"text": part.text} + if isinstance(part, MessageContentPartImage): + url = part.image_url.url + # Data URL -> inlineData + if url.startswith("data:"): + try: + header, b64_data = url.split(",", 1) + mime = header.split(";")[0][5:] + except Exception: + mime = "application/octet-stream" + b64_data = "" + return {"inlineData": {"mimeType": mime, "data": b64_data}} + # Otherwise treat as remote file URI + return { + "fileData": {"mimeType": "application/octet-stream", "fileUri": url} + } + data = part.model_dump(exclude_unset=True) + if data.get("type") == "text" and "text" in data: + # Text content is already processed by middleware + data.pop("type", None) + return data + + def _prepare_gemini_contents( + self, processed_messages: list[Any] + ) -> list[dict[str, Any]]: + payload_contents = [] + for msg in processed_messages: + # Handle both object and dict formats for backward compatibility + if isinstance(msg, dict): + role = msg.get("role") + # For dict format, check if it's already in Gemini format (has "parts") + # or in generic format (has "content") + if "parts" in msg: + # Already in Gemini format, use directly + payload_contents.append({"role": role, "parts": msg["parts"]}) + continue + else: + content = msg.get("content") + else: + role = getattr(msg, "role", None) + content = getattr(msg, "content", None) + + if role == "system": + # Gemini API does not support system role + continue + + if isinstance(content, str): + # If this is a tool or function role, represent it as functionResponse for Gemini + if role in ["tool", "function"]: + # Try to parse JSON payload; otherwise wrap string + try: + input_obj = json.loads(content) + except Exception: + input_obj = {"output": content} + parts: list[dict[str, Any]] = [ + { + "functionResponse": { + "name": ( + getattr(msg, "name", "tool") or "tool" + if not isinstance(msg, dict) + else msg.get("name", "tool") + ), + "response": input_obj, + } + } + ] + else: + # Content is already processed by middleware + parts = [{"text": content}] + elif content is not None: + parts = [self._convert_part_for_gemini(part) for part in content] + else: + # Skip messages with no content + continue + + # Map roles to 'user' or 'model' as required by Gemini API + if role == "user": + gemini_role = "user" + elif role in ["tool", "function"]: + # Tool/function results are treated as coming from the user side in Gemini + gemini_role = "user" + else: # e.g., assistant + gemini_role = "model" + + payload_contents.append({"role": gemini_role, "parts": parts}) + return payload_contents + + @staticmethod + def _coerce_stream_chunk(raw_chunk: Any) -> dict[str, Any] | None: + if isinstance(raw_chunk, dict): + return raw_chunk + + if isinstance(raw_chunk, bytes | bytearray): + raw_chunk = raw_chunk.decode("utf-8", errors="ignore") + + if not isinstance(raw_chunk, str): + return None + + stripped_chunk = raw_chunk.strip() + if not stripped_chunk: + return None + + data_segments: list[str] = [] + for line in stripped_chunk.splitlines(): + line = line.strip() + if not line or line.startswith(":"): + continue + if line.startswith("data:"): + data_value = line[5:].strip() + if not data_value: + continue + if data_value == "[DONE]": + return None + data_segments.append(data_value) + else: + data_segments.append(line) + + for segment in data_segments or [stripped_chunk]: + try: + parsed = json.loads(segment) + except json.JSONDecodeError: + continue + + if isinstance(parsed, dict): + return parsed + + if isinstance(parsed, str): + stripped_parsed = parsed.strip() + if stripped_parsed: + return { + "candidates": [ + { + "content": {"parts": [{"text": stripped_parsed}]}, + } + ] + } + + # If parsed value is not usable, continue searching remaining segments + continue + + # Fallback to treating the content as plain text + return { + "candidates": [ + { + "content": {"parts": [{"text": stripped_chunk}]}, + } + ] + } + + async def _handle_gemini_streaming_response( + self, + base_url: str, + payload: dict[str, Any], + headers: dict[str, str], + effective_model: str, + ) -> StreamingResponseHandle: + request_headers = ensure_loop_guard_header(headers) + request_id = request_headers.get("x-goog-request-id") or uuid.uuid4().hex + request_headers.setdefault("x-goog-request-id", request_id) + + url = f"{base_url}:streamGenerateContent" + try: + request = self.client.build_request( + "POST", url, json=payload, headers=request_headers + ) + response = await self.client.send(request, stream=True) + except httpx.RequestError as e: + logger.error("Request error connecting to Gemini: %s", e, exc_info=True) + raise ServiceUnavailableError(message=f"Could not connect to Gemini ({e})") + except (AttributeError, TypeError): + request = self.client.build_request( + "POST", url, json=payload, headers=request_headers + ) + try: + response = await self.client.send(request, stream=True) + except httpx.RequestError as e: + logger.error("Request error connecting to Gemini: %s", e, exc_info=True) + raise ServiceUnavailableError( + message=f"Could not connect to Gemini ({e})" + ) + + if response.status_code >= 400: + try: + if hasattr(response, "aread"): + body_bytes = await response.aread() # type: ignore[no-untyped-call] + else: + body_bytes = b"" + body_text = body_bytes.decode("utf-8", errors="ignore") + except Exception: + body_text = "" + finally: + if hasattr(response, "aclose"): + await response.aclose() + logger.error( + "HTTP error during Gemini stream: %s - %s", + response.status_code, + body_text, + ) + raise BackendError( + message=f"Gemini stream error: {response.status_code} - {body_text}", + code="gemini_error", + status_code=response.status_code, + ) + + # Prefer response-provided request identifiers when available + response_request_id = response.headers.get("x-goog-request-id") + if response_request_id: + request_id = response_request_id + + cancel_lock = asyncio.Lock() + cancel_state = {"called": False} + + async def cancel_stream() -> None: + async with cancel_lock: + if cancel_state["called"]: + return + cancel_state["called"] = True + + cancel_url = f"{base_url}:cancel" + cancel_headers = ensure_loop_guard_header(dict(request_headers)) + payload_body = {"requestId": request_id} + + try: + cancel_response = await self.client.post( + cancel_url, + json=payload_body, + headers=cancel_headers, + ) + except Exception as exc: + logger.debug( + "Gemini cancel request failed - url=%s request_id=%s error=%s", + cancel_url, + request_id, + exc, + exc_info=True, + ) + else: + with contextlib.suppress(Exception): + await cancel_response.aclose() + + with contextlib.suppress(Exception): + await response.aclose() + + async def stream_generator() -> AsyncGenerator[ProcessedResponse, None]: + processed_stream = response.aiter_text() + + try: + async for raw_chunk in processed_stream: + parsed_chunk = self._coerce_stream_chunk(raw_chunk) + if parsed_chunk is None: + continue + + yield ProcessedResponse( + content=self.translation_service.to_domain_stream_chunk( + parsed_chunk, source_format="gemini" + ) + ) + + done_chunk = { + "candidates": [ + { + "content": {"parts": []}, + "finishReason": "STOP", + } + ] + } + yield ProcessedResponse( + content=self.translation_service.to_domain_stream_chunk( + done_chunk, source_format="gemini" + ) + ) + except httpx.RequestError as stream_error: + logger.error( + "Request error while streaming from Gemini: %s", + stream_error, + exc_info=True, + ) + raise ServiceUnavailableError( + message=f"Gemini streaming connection error ({stream_error})" + ) from stream_error + finally: + with contextlib.suppress(Exception): + await response.aclose() + + try: + response_headers = dict(response.headers) + except Exception: + response_headers = {} + + return StreamingResponseHandle( + iterator=stream_generator(), + cancel_callback=cancel_stream, + headers=response_headers, + ) + + async def chat_completions( # type: ignore[override] + self, + request_data: DomainModel | InternalDTO | dict[str, Any], + processed_messages: list[Any], + effective_model: str, + identity: IAppIdentityConfig | None = None, + openrouter_api_base_url: str | None = None, + openrouter_headers_provider: Callable[[Any, str], dict[str, str]] | None = None, + key_name: str | None = None, + api_key: str | None = None, + project: str | None = None, + agent: str | None = None, + gemini_api_base_url: str | None = None, + **kwargs: Any, + ) -> ResponseEnvelope | StreamingResponseEnvelope: + # Resolve base configuration + base_api_url, headers = await self._resolve_gemini_api_config( + gemini_api_base_url, + openrouter_api_base_url, + api_key, + openrouter_headers_provider=openrouter_headers_provider, + key_name=key_name, + **kwargs, + ) + if identity: + headers.update(identity.get_resolved_headers(None)) + + # request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest) + # (the frontend controller converts from frontend-specific format to domain format) + # Backends should ONLY convert FROM domain TO backend-specific format + # Type assertion: we know from architectural design that request_data is ChatRequest-like + from typing import cast + + from src.core.domain.chat import CanonicalChatRequest, ChatRequest + + if not isinstance(request_data, ChatRequest): + raise TypeError( + f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. " + "Backend connectors should only receive domain-format requests." + ) + # Cast to CanonicalChatRequest for mypy compatibility with translation service signature + domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data) + + # Translate CanonicalChatRequest to Gemini request using the translation service + payload = self.translation_service.from_domain_request( + domain_request, target_format="gemini" + ) + + # Apply generation config including temperature clamping + self._apply_generation_config(payload, domain_request) + + # Apply contents and extra_body + payload["contents"] = self._prepare_gemini_contents(processed_messages) + if domain_request.extra_body: + # Merge extra_body with payload, but be careful with generationConfig. + # We support both legacy placement under 'generation_config' and + # the external 'generationConfig' key that Gemini expects. + # Normalize: prefer explicit generation_config on ChatRequest, then + # merge any 'generationConfig' present in extra_body on top. + extra_body_copy = dict(domain_request.extra_body) + + # If caller placed generation_config on ChatRequest it was already + # merged by _apply_generation_config into payload['generationConfig']. + # Now merge any generationConfig from extra_body on top of what we + # already have (extra body should be able to override specific keys). + # Accept either CamelCase 'generationConfig' (as used in tests and + # by external callers) or legacy snake_case 'generation_config' + extra_gen_cfg = extra_body_copy.pop("generationConfig", None) + if extra_gen_cfg is None: + extra_gen_cfg = extra_body_copy.pop("generation_config", None) + if extra_gen_cfg: + # merge by creating a new dict so we don't retain old references + existing = payload.get("generationConfig", {}) + merged = dict(existing) + + # Handle nested structures like thinkingConfig + for key, value in extra_gen_cfg.items(): + if ( + key == "thinkingConfig" + and isinstance(value, dict) + and "thinkingConfig" in merged + and isinstance(merged["thinkingConfig"], dict) + ): + # Deep merge thinkingConfig + merged["thinkingConfig"].update(value) + elif key == "maxOutputTokens" and "maxOutputTokens" not in merged: + # Add maxOutputTokens if not present + merged["maxOutputTokens"] = value + else: + # Regular update for other keys + merged[key] = value + + # Ensure extra_body overrides win for temperature specifically + if "temperature" in extra_gen_cfg: + merged["temperature"] = extra_gen_cfg["temperature"] + payload["generationConfig"] = merged + + # Finally update payload with remaining extra body fields + if extra_body_copy: + payload.update(extra_body_copy) + # Remove generation_config (legacy key) if present; we've migrated it + # into 'generationConfig' in _apply_generation_config. + payload.pop("generation_config", None) + # Debug output + logger.debug("Final payload: %s", payload) + + # Normalize model id and construct URL + model_name = self._normalize_model_name(effective_model) + logger.debug("Constructing Gemini API URL with model_name: %s", model_name) + model_url = f"{base_api_url}/v1beta/models/{model_name}" + + # Streaming vs non-streaming + if domain_request.stream: + stream_handle = await self._handle_gemini_streaming_response( + model_url, payload, headers, effective_model + ) + return StreamingResponseEnvelope( + content=stream_handle.iterator, + media_type="text/event-stream", + headers=stream_handle.headers or {}, + cancel_callback=stream_handle.cancel_callback, + ) + + return await self._handle_gemini_non_streaming_response( + model_url, payload, headers, effective_model + ) + + def _build_openrouter_header_context(self) -> dict[str, str]: + referer = "http://localhost:8000" + title = "InterceptorProxy" + + identity = getattr(self.config, "identity", None) + if identity is not None: + referer = ( + getattr(getattr(identity, "url", None), "default_value", referer) + or referer + ) + title = ( + getattr(getattr(identity, "title", None), "default_value", title) + or title + ) + + return {"app_site_url": referer, "app_x_title": title} + + async def _resolve_gemini_api_config( + self, + gemini_api_base_url: str | None, + openrouter_api_base_url: str | None, + api_key: str | None, + *, + openrouter_headers_provider: Callable[[Any, str], dict[str, str]] | None = None, + key_name: str | None = None, + **kwargs: Any, + ) -> tuple[str, dict[str, str]]: + # Prefer explicit params, then kwargs, then instance attributes set during initialize + base = ( + gemini_api_base_url + or openrouter_api_base_url + or kwargs.get("gemini_api_base_url") + or getattr(self, "gemini_api_base_url", None) + ) + key = api_key or kwargs.get("api_key") or getattr(self, "api_key", None) + if not base or not key: + raise HTTPException( + status_code=500, + detail="Gemini API base URL and API key must be provided.", + ) + normalized_base = base.rstrip("/") + + # Only use OpenRouter mode if the chosen base is actually OpenRouter + # OpenRouter mode should only be enabled when the resolved base URL is different + # from the default Gemini API base URL, indicating we're actually routing to OpenRouter + gemini_default_base = "https://generativelanguage.googleapis.com" + using_openrouter = ( + openrouter_api_base_url is not None + and normalized_base != gemini_default_base.rstrip("/") + ) + + headers: dict[str, str] + if using_openrouter: + headers = {} + provided_headers: dict[str, str] | None = None + + if openrouter_headers_provider is not None: + errors: list[Exception] = [] + + if key_name is not None: + try: + candidate = openrouter_headers_provider(key_name, key) + except (AttributeError, TypeError) as exc: + errors.append(exc) + else: + if candidate: + provided_headers = dict(candidate) + + if provided_headers is None: + context = self._build_openrouter_header_context() + try: + candidate = openrouter_headers_provider(context, key) + except Exception as exc: # pragma: no cover - defensive guard + if errors and logger.isEnabledFor(logging.DEBUG): + logger.debug( + "OpenRouter headers provider rejected key_name input: %s", + errors[-1], + exc_info=True, + ) + raise AuthenticationError( + message="OpenRouter headers provider failed to produce headers.", + code="missing_credentials", + ) from exc + else: + provided_headers = dict(candidate) + + if provided_headers is None: + context = self._build_openrouter_header_context() + provided_headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + "HTTP-Referer": context["app_site_url"], + "X-Title": context["app_x_title"], + } + + headers.update(provided_headers) + context = self._build_openrouter_header_context() + headers.setdefault("Authorization", f"Bearer {key}") + headers.setdefault("Content-Type", "application/json") + headers.setdefault("HTTP-Referer", context["app_site_url"]) + headers.setdefault("X-Title", context["app_x_title"]) + else: + key_name_to_use = ( + key_name + or kwargs.get("key_name") + or getattr(self, "key_name", None) + or "x-goog-api-key" + ) + headers = {key_name_to_use: key} + + return normalized_base, ensure_loop_guard_header(headers) + + def _apply_generation_config( + self, payload: dict[str, Any], request_data: ChatRequest + ) -> None: + # Initialize generationConfig + generation_config = payload.setdefault("generationConfig", {}) + + # thinking budget + if getattr(request_data, "thinking_budget", None): + thinking_config = generation_config.setdefault("thinkingConfig", {}) + thinking_config["thinkingBudget"] = request_data.thinking_budget # type: ignore[index] + + # top_k + if getattr(request_data, "top_k", None) is not None: + generation_config["topK"] = request_data.top_k + + # reasoning_effort + if getattr(request_data, "reasoning_effort", None) is not None: + thinking_config = generation_config.setdefault("thinkingConfig", {}) + thinking_config["reasoning_effort"] = request_data.reasoning_effort + + # generation config blob - merge with existing config + if getattr(request_data, "generation_config", None): + # Deep merge the generation_config into generationConfig + for key, value in request_data.generation_config.items(): # type: ignore[union-attr] + generation_config[key] = value + + # temperature clamped to [0,1] + temperature = getattr(request_data, "temperature", None) + if temperature is not None: + # Clamp temperature to [0,1] range for Gemini + if float(temperature) > 1.0: + logger.warning( + f"Temperature {temperature} > 1.0 for Gemini, clamping to 1.0" + ) + temperature = 1.0 + generation_config["temperature"] = float(temperature) + + # top_p + if request_data.top_p is not None: + generation_config["topP"] = request_data.top_p + + # stop sequences + if request_data.stop: + generation_config["stopSequences"] = request_data.stop + + # Unsupported parameters + if request_data.seed is not None and logger.isEnabledFor(logging.WARNING): + logger.warning("GeminiBackend does not support the 'seed' parameter.") + if request_data.presence_penalty is not None and logger.isEnabledFor( + logging.WARNING + ): + logger.warning( + "GeminiBackend does not support the 'presence_penalty' parameter." + ) + if request_data.frequency_penalty is not None and logger.isEnabledFor( + logging.WARNING + ): + logger.warning( + "GeminiBackend does not support the 'frequency_penalty' parameter." + ) + if request_data.logit_bias is not None and logger.isEnabledFor(logging.WARNING): + logger.warning("GeminiBackend does not support the 'logit_bias' parameter.") + if request_data.user is not None and logger.isEnabledFor(logging.WARNING): + logger.warning("GeminiBackend does not support the 'user' parameter.") + if request_data.repetition_penalty is not None and logger.isEnabledFor( + logging.WARNING + ): + logger.warning( + "GeminiBackend does not support the 'repetition_penalty' parameter." + ) + if request_data.min_p is not None and logger.isEnabledFor(logging.WARNING): + logger.warning("GeminiBackend does not support the 'min_p' parameter.") + + def _normalize_model_name(self, effective_model: str) -> str: + model_name = effective_model + if model_name.startswith("gemini:"): + model_name = model_name.split(":", 1)[1] + if model_name.startswith("models/"): + model_name = model_name.split("/", 1)[1] + if model_name.startswith("gemini/"): + model_name = model_name.split("/", 1)[1] + if "/" in model_name: + logger.debug( + "Detected provider prefix in model name '%s'. Using last path segment as Gemini model id.", + model_name, + ) + model_name = model_name.rsplit("/", 1)[-1] + return model_name + + async def _handle_gemini_non_streaming_response( + self, base_url: str, payload: dict, headers: dict, effective_model: str + ) -> ResponseEnvelope: + headers = ensure_loop_guard_header(headers) + url = f"{base_url}:generateContent" + try: + response = await self.client.post(url, json=payload, headers=headers) + if response.status_code >= 400: + try: + error_detail = response.json() + except Exception: + error_detail = response.text + raise BackendError( + message=str(error_detail), + code="gemini_error", + status_code=response.status_code, + ) + data = response.json() + logger.debug("Gemini response headers: %s", dict(response.headers)) + return ResponseEnvelope( + content=self.translation_service.to_domain_response( + data, source_format="gemini" + ), + headers=dict(response.headers), + status_code=response.status_code, + ) + except httpx.RequestError as e: + logger.error("Request error connecting to Gemini: %s", e, exc_info=True) + raise ServiceUnavailableError(message=f"Could not connect to Gemini ({e})") + + async def list_models( + self, *, gemini_api_base_url: str, key_name: str, api_key: str + ) -> dict[str, Any]: + headers = ensure_loop_guard_header({key_name: api_key}) + url = f"{gemini_api_base_url.rstrip('/')}/v1beta/models" + try: + response = await self.client.get(url, headers=headers) + if response.status_code >= 400: + try: + error_detail = response.json() + except Exception: + error_detail = response.text + raise BackendError( + message=str(error_detail), + code="gemini_error", + status_code=response.status_code, + ) + return cast(dict[str, Any], response.json()) + except httpx.RequestError as e: + logger.error("Request error connecting to Gemini: %s", e, exc_info=True) + raise ServiceUnavailableError(message=f"Could not connect to Gemini ({e})") + + +backend_registry.register_backend("gemini", GeminiBackend) diff --git a/src/connectors/gemini_cloud_project.py b/src/connectors/gemini_cloud_project.py index b39f2bb8..a806965d 100644 --- a/src/connectors/gemini_cloud_project.py +++ b/src/connectors/gemini_cloud_project.py @@ -1420,6 +1420,18 @@ def _build_generation_config(self, request_data: Any) -> dict[str, Any]: if top_k is not None: with contextlib.suppress(Exception): cfg["topK"] = int(top_k) + if getattr(request_data, "repetition_penalty", None) is not None and logger.isEnabledFor( + logging.WARNING + ): + logger.warning( + "GeminiCloudProjectConnector does not support the 'repetition_penalty' parameter." + ) + if getattr(request_data, "min_p", None) is not None and logger.isEnabledFor( + logging.WARNING + ): + logger.warning( + "GeminiCloudProjectConnector does not support the 'min_p' parameter." + ) return cfg async def _ensure_project_onboarded(self, auth_session) -> str: diff --git a/src/connectors/gemini_oauth_base.py b/src/connectors/gemini_oauth_base.py index 7e9f7647..02c554a2 100644 --- a/src/connectors/gemini_oauth_base.py +++ b/src/connectors/gemini_oauth_base.py @@ -2740,6 +2740,17 @@ def _build_generation_config(self, request_data: Any) -> dict[str, Any]: if "top_k" in cfg: cfg["topK"] = cfg.pop("top_k") + if getattr(request_data, "repetition_penalty", None) is not None and logger.isEnabledFor( + logging.WARNING + ): + logger.warning( + "GeminiOAuthBase does not support the 'repetition_penalty' parameter." + ) + if getattr(request_data, "min_p", None) is not None and logger.isEnabledFor( + logging.WARNING + ): + logger.warning("GeminiOAuthBase does not support the 'min_p' parameter.") + return cfg def _convert_from_code_assist_format( diff --git a/src/connectors/openai.py b/src/connectors/openai.py index 1f51bdba..7d70b2e7 100644 --- a/src/connectors/openai.py +++ b/src/connectors/openai.py @@ -1,1026 +1,1049 @@ -from __future__ import annotations - -import asyncio -import contextlib -import inspect -import json -import logging - -logger = logging.getLogger(__name__) - -from collections.abc import ( - AsyncGenerator, - Mapping, -) -from json import JSONDecodeError -from typing import Any - -import httpx -from fastapi import HTTPException - -from src.core.common.exceptions import ( - AuthenticationError, - ServiceResolutionError, - ServiceUnavailableError, -) -from src.core.config.app_config import AppConfig -from src.core.domain.chat import CanonicalChatRequest -from src.core.domain.responses import ( - ResponseEnvelope, - StreamingResponseEnvelope, - StreamingResponseHandle, -) -from src.core.interfaces.configuration_interface import IAppIdentityConfig -from src.core.interfaces.model_bases import DomainModel, InternalDTO -from src.core.interfaces.response_processor_interface import ( - IResponseProcessor, - ProcessedResponse, -) -from src.core.security.loop_prevention import ensure_loop_guard_header -from src.core.services.backend_registry import backend_registry -from src.core.services.translation_service import TranslationService - -from .base import LLMBackend - -# Legacy ChatCompletionRequest removed from connector signatures; use domain ChatRequest - - -class OpenAIConnector(LLMBackend): - """Minimal OpenAI-compatible connector used by OpenRouterBackend in tests. - - It supports an optional `headers_override` kwarg and treats streaming - responses that expose `aiter_bytes()` as streamable even if returned by - test doubles. - """ - - backend_type: str = "openai" - - def __init__( - self, - client: httpx.AsyncClient, - config: AppConfig, - translation_service: TranslationService | None = None, - response_processor: IResponseProcessor | None = None, - ) -> None: - super().__init__(config, response_processor) - self.client = client - # Allow callers/tests to omit TranslationService; resolve through DI for consistency - self.translation_service = ( - translation_service - if translation_service is not None - else self._resolve_translation_service() - ) - self.config = config # Stored config - self.available_models: list[str] = [] - self.api_key: str | None = None - self._api_base_url: str = "https://api.openai.com/v1" - - # Health check attributes - self._health_checked: bool = False - import os - - disable_health_checks_env = os.getenv( - "DISABLE_HEALTH_CHECKS", "false" - ).lower() in ("true", "1", "yes") - - disable_health_checks_config = bool( - getattr(self.config, "disable_health_checks", False) - ) - - # Enable health checks only when neither config nor env disable them - self._health_check_enabled = not ( - disable_health_checks_env or disable_health_checks_config - ) - - @property - def api_base_url(self) -> str: - """Return the API base URL.""" - return self._api_base_url - - @api_base_url.setter - def api_base_url(self, value: str) -> None: - """Set the API base URL.""" - self._api_base_url = value - - @staticmethod - def _resolve_translation_service() -> TranslationService: - """Resolve TranslationService from the DI container.""" - - from src.core.di.services import get_or_build_service_provider - - provider = get_or_build_service_provider() - service = provider.get_service(TranslationService) - if service is None: - raise ServiceResolutionError( - "TranslationService is not registered in the service provider", - service_name="TranslationService", - ) - return service - - def get_headers(self, identity: IAppIdentityConfig | None = None) -> dict[str, str]: - """Return request headers including API key and optional request identity.""" - - headers: dict[str, str] = {} - - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - - if identity is not None: - try: - identity_headers = identity.get_resolved_headers(None) - except Exception: - identity_headers = {} - else: - identity_headers = dict(identity_headers) - if identity_headers: - headers.update(identity_headers) - - return ensure_loop_guard_header(headers) - - async def initialize(self, **kwargs: Any) -> None: - self.api_key = kwargs.get("api_key") - logger.info( - "OpenAIConnector initialize called. api_key_provided=%s", - "yes" if self.api_key else "no", - ) - if "api_base_url" in kwargs: - self.api_base_url = kwargs["api_base_url"] - - # Proceed to fetch models only when we have credentials; failures are non-fatal - if not self.api_key: - logger.debug( - "Skipping OpenAI model listing during init; no API key configured" - ) - else: - try: - headers = self.get_headers() - response = await self.client.get( - f"{self.api_base_url}/models", headers=headers - ) - data = self._decode_json_payload(response) - if isinstance(data, dict): - self.available_models = [ - model["id"] - for model in data.get("data", []) - if isinstance(model, Mapping) and "id" in model - ] - else: - logger.debug( - "Unexpected models payload type from OpenAI: %s", - type(data).__name__, - ) - self.available_models = [] - except Exception as e: - logger.warning("Failed to fetch models: %s", e, exc_info=True) - # Log the error but don't fail initialization - - async def _perform_health_check(self) -> bool: - """Perform a health check by testing API connectivity. - - This method tests actual API connectivity by making a simple request to verify - the API key works and the service is accessible. - - Returns: - bool: True if health check passes, False otherwise - """ - try: - # Test API connectivity with a simple models endpoint request - if not self.api_key: - logger.warning("Health check failed - no API key available") - return False - - headers = self.get_headers() - if not headers.get("Authorization"): - logger.warning("Health check failed - no authorization header") - return False - - url = f"{self.api_base_url}/models" - response = await self.client.get(url, headers=headers) - - if response.status_code == 200: - logger.info("Health check passed - API connectivity verified") - self._health_checked = True - return True - else: - logger.warning( - f"Health check failed - API returned status {response.status_code}" - ) - return False - - except Exception as e: - logger.error("Health check failed - unexpected error: %s", e, exc_info=True) - return False - - async def _ensure_healthy(self) -> None: - """Ensure the backend is healthy before use. - - This method performs health checks on first use, similar to how - models are loaded lazily in the parent class. - """ - if not self._health_check_enabled: - # Health check is disabled, skip - return - - if not hasattr(self, "_health_checked") or not self._health_checked: - logger.info( - f"Performing first-use health check for {self.backend_type} backend" - ) - - healthy = await self._perform_health_check() - if not healthy: - logger.warning( - "Health check did not pass; continuing with lazy verification on first request" - ) - else: - logger.info("Health check passed - backend is ready for use") - - self._health_checked = True - - def enable_health_check(self) -> None: - """Enable health check functionality for this connector instance.""" - self._health_check_enabled = True - self._health_checked = False # Reset so it will check on next use - logger.info(f"Health check enabled for {self.backend_type} backend") - - def disable_health_check(self) -> None: - """Disable health check functionality for this connector instance.""" - self._health_check_enabled = False - logger.info(f"Health check disabled for {self.backend_type} backend") - - _XSSI_PREFIXES = ( - ")]}',\n", - ")]}',", - ")]}'", - "while(1);", - "while (1);", - ) - - def _decode_json_payload(self, response: httpx.Response) -> Any: - """Safely decode JSON payloads that may include XSSI guards or trailing data.""" - try: - return response.json() - except JSONDecodeError: - text = response.text or "" - sanitized = self._strip_xssi_prefix(text) - if sanitized != text: - try: - return json.loads(sanitized) - except JSONDecodeError: - pass - candidate = self._extract_first_json_value(sanitized) - if candidate: - try: - return json.loads(candidate) - except JSONDecodeError: - logger.debug( - "Failed to decode sanitized JSON payload; candidate snippet=%s", - candidate[:200], - ) - logger.warning( - "Unable to decode JSON payload from OpenAI response (status=%s, preview=%r)", - getattr(response, "status_code", "unknown"), - (sanitized or text)[:200], - ) - return None - - def _strip_xssi_prefix(self, payload: str) -> str: - stripped = payload.lstrip() - for prefix in self._XSSI_PREFIXES: - if stripped.startswith(prefix): - return stripped[len(prefix) :] - return stripped - - def _extract_first_json_value(self, payload: str) -> str | None: - candidate = payload.strip() - if not candidate: - return None - - opening = candidate[0] - if opening not in ("{", "["): - # Attempt to locate the first JSON object within the payload - for idx, ch in enumerate(candidate): - if ch in ("{", "["): - candidate = candidate[idx:] - opening = candidate[0] - break - else: - return None - - stack = [] - in_string = False - escape = False - for idx, ch in enumerate(candidate): - if in_string: - if escape: - escape = False - elif ch == "\\": - escape = True - elif ch == '"': - in_string = False - continue - - if ch == '"': - in_string = True - continue - - if ch in ("{", "["): - stack.append("]" if ch == "[" else "}") - elif ch in ("}", "]"): - if not stack or stack.pop() != ch: - return None - if not stack: - return candidate[: idx + 1] - - return None - - async def chat_completions( - self, - request_data: DomainModel | InternalDTO | dict[str, Any], - processed_messages: list[Any], - effective_model: str, - identity: IAppIdentityConfig | None = None, - **kwargs: Any, - ) -> ResponseEnvelope | StreamingResponseEnvelope: - # Perform health check if enabled (for subclasses that support it) - await self._ensure_healthy() - - # request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest) - # (the frontend controller converts from frontend-specific format to domain format) - # Backends should ONLY convert FROM domain TO backend-specific format - # Type assertion: we know from architectural design that request_data is ChatRequest-like - from typing import cast - - from src.core.domain.chat import CanonicalChatRequest, ChatRequest - - if not isinstance(request_data, ChatRequest): - raise TypeError( - f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. " - "Backend connectors should only receive domain-format requests." - ) - # Cast to CanonicalChatRequest for mypy compatibility with _prepare_payload signature - domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data) - - # Prepare the payload using a helper so subclasses and tests can - # override or patch payload construction logic easily. - payload = await self._prepare_payload( - domain_request, processed_messages, effective_model - ) - headers_override = kwargs.pop("headers_override", None) - headers: dict[str, str] | None = None - - base_headers: dict[str, str] | None - try: - base_headers = self.get_headers(identity=identity) - except Exception: - base_headers = None - - if headers_override is not None: - # Avoid mutating the caller-provided mapping while preserving any - # Authorization header we compute from the configured API key. - headers = dict(headers_override) - if base_headers: - merged_headers = dict(base_headers) - merged_headers.update(headers) - headers = merged_headers - else: - headers = base_headers - - api_base = kwargs.get("openai_url") or self.api_base_url - url = f"{api_base.rstrip('/')}/chat/completions" - - if domain_request.stream: - # Return a domain-level streaming envelope (raw bytes iterator) - try: - stream_handle = await self._handle_streaming_response( - url, - payload, - headers, - domain_request.session_id or "", - "openai", - ) - except AuthenticationError as e: - raise HTTPException(status_code=401, detail=str(e)) - return StreamingResponseEnvelope( - content=stream_handle.iterator, - media_type="text/event-stream", - headers={}, - cancel_callback=stream_handle.cancel_callback, - ) - else: - # Return a domain ResponseEnvelope for non-streaming - return await self._handle_non_streaming_response( - url, payload, headers, domain_request.session_id or "" - ) - - async def _prepare_payload( - self, - request_data: CanonicalChatRequest, - processed_messages: list[Any], - effective_model: str, - ) -> dict[str, Any]: - """ - Default payload preparation for OpenAI-compatible backends. - - Subclasses or tests may patch/override this method to customize the - final payload sent to the provider. - """ - # request_data is expected to be a CanonicalChatRequest already - # (the caller creates it via TranslationService.to_domain_request). - payload = self.translation_service.from_domain_request(request_data, "openai") - if inspect.isawaitable(payload): - payload = await payload - - # Prefer processed_messages (these are the canonical, post-processed - # messages ready to send). Convert them to plain dicts to ensure JSON - # serializability without mutating the original Pydantic models. - if processed_messages: - try: - normalized_messages: list[dict[str, Any]] = [] - - def _get_value(message: Any, key: str) -> Any: - if isinstance(message, Mapping): - return message.get(key) - return getattr(message, key, None) - - def _normalize_content(value: Any) -> Any: - if isinstance(value, list | tuple): - normalized_parts: list[Any] = [] - for part in value: - if hasattr(part, "model_dump") and callable( - part.model_dump - ): - normalized_parts.append( - part.model_dump(exclude_none=True) - ) - elif isinstance(part, Mapping): - normalized_parts.append(dict(part)) - else: - normalized_parts.append(part) - return normalized_parts - return value - - for message in processed_messages: - if hasattr(message, "model_dump") and callable(message.model_dump): - dumped = message.model_dump(exclude_none=False) - if isinstance(dumped, dict): - normalized_messages.append(dumped) - continue - - msg: dict[str, Any] - if isinstance(message, Mapping): - msg = dict(message) - else: - msg = {} - - role = _get_value(message, "role") or msg.get("role") or "user" - msg["role"] = role - - content = _get_value(message, "content") - if content is None and "content" in msg: - content = msg["content"] - msg["content"] = _normalize_content(content) - - name = _get_value(message, "name") - if name is not None: - msg["name"] = name - - tool_calls = _get_value(message, "tool_calls") - if tool_calls is None and isinstance(message, Mapping): - tool_calls = msg.get("tool_calls") - if tool_calls: - normalized_tool_calls: list[Any] = [] - for tool_call in tool_calls: - if hasattr(tool_call, "model_dump") and callable( - tool_call.model_dump - ): - normalized_tool_calls.append( - tool_call.model_dump(exclude_none=True) - ) - elif isinstance(tool_call, Mapping): - normalized_tool_calls.append(dict(tool_call)) - else: - normalized_tool_calls.append(tool_call) - msg["tool_calls"] = normalized_tool_calls - - tool_call_id = _get_value(message, "tool_call_id") - if tool_call_id is not None: - msg["tool_call_id"] = tool_call_id - - normalized_messages.append(msg) - - payload["messages"] = normalized_messages - except (KeyError, TypeError, AttributeError): - # Fallback - leave whatever the converter produced - pass - - # The caller may supply an "effective_model" which should override - # the model value coming from the domain request. Many tests expect - # the provider payload to use the effective_model. - if effective_model: - logger.info( - f"OpenAI DEBUG: Overriding model in payload from '{payload.get('model')}' to '{effective_model}'" - ) - payload["model"] = effective_model - - # Allow request.extra_body to override or augment the final payload. - extra = getattr(request_data, "extra_body", None) - if isinstance(extra, dict): - payload.update(extra) - - return payload # type: ignore[no-any-return] - - async def _handle_non_streaming_response( - self, - url: str, - payload: dict[str, Any], - headers: dict[str, str] | None, - session_id: str, - ) -> ResponseEnvelope: - if not headers or not headers.get("Authorization"): - raise AuthenticationError(message="No auth credentials found") - - guarded_headers = ensure_loop_guard_header(headers) - - try: - response = await self.client.post( - url, json=payload, headers=guarded_headers - ) - except httpx.RequestError as e: - raise ServiceUnavailableError(message=f"Could not connect to backend ({e})") - - if int(response.status_code) >= 400: - # For backwards compatibility with existing error handlers, still use HTTPException here. - # This will be replaced in a future update with domain exceptions. - try: - err = response.json() - except Exception: - err = response.text - raise HTTPException(status_code=response.status_code, detail=err) - - domain_response = self.translation_service.to_domain_response( - response.json(), "openai" - ) - # Some tests use mocks that set response.headers to AsyncMock or - # other non-dict types; defensively coerce to a dict and fall back - # to an empty dict on error so tests don't raise during header - # extraction. - try: - response_headers = dict(response.headers) - except Exception: - try: - response_headers = dict(getattr(response, "headers", {}) or {}) - except Exception: - response_headers = {} - - return ResponseEnvelope( - content=domain_response.model_dump(), - status_code=response.status_code, - headers=response_headers, - usage=domain_response.usage, - ) - - async def _handle_streaming_response( - self, - url: str, - payload: dict[str, Any], - headers: dict[str, str] | None, - session_id: str, - stream_format: str, - ) -> StreamingResponseHandle: - """Return a streaming handle with iterator and cancellation callback.""" - - if not headers or not headers.get("Authorization"): - raise AuthenticationError(message="No auth credentials found") - - guarded_headers = ensure_loop_guard_header(headers) - - request = self.client.build_request( - "POST", url, json=payload, headers=guarded_headers - ) - try: - response = await self.client.send(request, stream=True) - except httpx.RequestError as exc: # Normalize network failures - raise ServiceUnavailableError( - message=f"Could not connect to backend ({exc})" - ) from exc - - status_code = ( - int(response.status_code) if hasattr(response, "status_code") else 200 - ) - if status_code >= 400: - # For backwards compatibility with existing error handlers, still use HTTPException here. - # This will be replaced in a future update with domain exceptions. - body: str = "" - try: - body_bytes = await response.aread() - except Exception: - fallback: str = str(getattr(response, "text", "")) - body = fallback() if callable(fallback) else fallback - else: - try: - body = body_bytes.decode("utf-8") - except Exception: - fallback_text: str = str(getattr(response, "text", "")) - body = fallback_text() if callable(fallback_text) else fallback_text - finally: - with contextlib.suppress(Exception): - await response.aclose() - - if not isinstance(body, str): - body = str(body) - logger.warning( - "Backend %s returned HTTP %s with body: %s", url, status_code, body - ) - raise HTTPException( - status_code=status_code, - detail={ - "message": body, - "type": ( - "openrouter_error" if "openrouter" in url else "openai_error" - ), - "code": status_code, - }, - ) - - loop = asyncio.get_running_loop() - response_id_future: asyncio.Future[str] = loop.create_future() - cancel_lock = asyncio.Lock() - cancel_state = {"called": False} - supports_protocol_cancel = stream_format in {"responses", "openai-responses"} - cancel_headers = dict(guarded_headers) - cancel_headers.setdefault("Content-Type", "application/json") - cancel_base_url = url.rstrip("/") - - async def cancel_stream() -> None: - async with cancel_lock: - if cancel_state["called"]: - return - cancel_state["called"] = True - - if supports_protocol_cancel: - target_id: str | None = None - if response_id_future.done(): - target_id = response_id_future.result() - else: - try: - target_id = await asyncio.wait_for(response_id_future, 0.5) - except asyncio.TimeoutError: - target_id = None - - if target_id: - await self._send_openai_responses_cancel( - base_url=cancel_base_url, - headers=cancel_headers, - response_id=target_id, - session_id=session_id, - ) - - with contextlib.suppress(Exception): - await response.aclose() - - async def gen() -> AsyncGenerator[ProcessedResponse, None]: - async def text_generator() -> AsyncGenerator[dict[Any, Any] | Any, None]: - async def iter_sse_messages() -> AsyncGenerator[str, None]: - buffer = "" - separator = "\n\n" - alt_separator = "\r\n\r\n" - try: - async for chunk_bytes in response.aiter_bytes(): - chunk_text = ( - chunk_bytes.decode("utf-8", errors="replace") - if isinstance(chunk_bytes, bytes | bytearray) - else str(chunk_bytes) - ) - buffer += chunk_text - while True: - if alt_separator in buffer: - event, buffer = buffer.split(alt_separator, 1) - separator_used = alt_separator - elif separator in buffer: - event, buffer = buffer.split(separator, 1) - separator_used = separator - else: - break - if event: - yield event + separator_used - if buffer: - yield buffer - buffer = "" - except httpx.RequestError as exc: - if buffer: - yield buffer - buffer = "" - raise ServiceUnavailableError( - message=f"Streaming connection interrupted ({exc})" - ) from exc - - try: - if stream_format in {"openai", "responses", "openai-responses"}: - async for message in iter_sse_messages(): - domain_chunk = ( - self.translation_service.to_domain_stream_chunk( - message, stream_format - ) - ) - if ( - isinstance(domain_chunk, dict) - and domain_chunk.get("error") - and logger.isEnabledFor(logging.DEBUG) - ): - try: - logger.debug( - "Streaming chunk translation returned error=%s raw=%s", - domain_chunk.get("error"), - ( - message[:500] - if isinstance(message, str) - else str(message) - ), - ) - except Exception: - logger.debug( - "Streaming chunk translation returned error but raw chunk not serializable" - ) - yield domain_chunk - else: - async for chunk in response.aiter_text(): - domain_chunk = ( - self.translation_service.to_domain_stream_chunk( - chunk, stream_format - ) - ) - if ( - isinstance(domain_chunk, dict) - and domain_chunk.get("error") - and logger.isEnabledFor(logging.DEBUG) - ): - try: - logger.debug( - "Streaming chunk translation returned error=%s raw=%s", - domain_chunk.get("error"), - chunk[:500], - ) - except Exception: - logger.debug( - "Streaming chunk translation returned error but raw chunk not serializable" - ) - yield domain_chunk - except httpx.RequestError as exc: - raise ServiceUnavailableError( - message=f"Streaming connection interrupted ({exc})" - ) from exc - - pending_error: Exception | None = None - try: - async for chunk in text_generator(): - if ( - supports_protocol_cancel - and isinstance(chunk, dict) - and not response_id_future.done() - ): - chunk_id = chunk.get("id") - if isinstance(chunk_id, str) and chunk_id: - response_id_future.set_result(chunk_id) - yield ProcessedResponse(content=chunk) - except ServiceUnavailableError as exc: - pending_error = exc - except httpx.HTTPError as exc: - raise ServiceUnavailableError( - message=f"Streaming connection interrupted ({exc})" - ) from exc - finally: - with contextlib.suppress(Exception): - await response.aclose() - if pending_error: - raise pending_error - - try: - response_headers = dict(response.headers) - except Exception: - response_headers = {} - - return StreamingResponseHandle( - iterator=gen(), - cancel_callback=cancel_stream, - headers=response_headers, - ) - - async def _send_openai_responses_cancel( - self, - base_url: str, - headers: Mapping[str, str], - response_id: str, - session_id: str, - ) -> None: - cancel_url = f"{base_url}/{response_id}/cancel" - try: - request = self.client.build_request("POST", cancel_url, headers=headers) - except Exception as exc: - logger.debug( - "Failed to build cancellation request - session_id=%s, url=%s, error=%s", - session_id, - cancel_url, - exc, - exc_info=True, - ) - return - - try: - cancel_response = await self.client.send(request, stream=False) - except Exception as exc: - logger.warning( - "Failed to send cancellation request - session_id=%s, url=%s, error=%s", - session_id, - cancel_url, - exc, - exc_info=True, - ) - return - - with contextlib.suppress(Exception): - await cancel_response.aclose() - - async def responses( - self, - request_data: DomainModel | InternalDTO | dict[str, Any], - processed_messages: list[Any], - effective_model: str, - identity: IAppIdentityConfig | None = None, - **kwargs: Any, - ) -> ResponseEnvelope | StreamingResponseEnvelope: - """Handle OpenAI Responses API calls. - - This method handles requests to the /v1/responses endpoint, which provides - structured output generation with JSON schema validation. - """ - # Perform health check if enabled - await self._ensure_healthy() - - # Convert to domain request first - # Note: The responses() method can be called directly with dicts (e.g., from tests), - # unlike chat_completions() which only goes through the frontend->backend flow - domain_request = self.translation_service.to_domain_request( - request_data, "responses" - ) - - # Prepare the payload for Responses API - payload = self.translation_service.from_domain_to_responses_request( - domain_request - ) - - # Override model if effective_model is provided - if effective_model: - payload["model"] = effective_model - - # Update messages with processed_messages if available - if processed_messages: - try: - normalized_messages: list[dict[str, Any]] = [] - for m in processed_messages: - # If the message is a pydantic model, use model_dump - if hasattr(m, "model_dump") and callable(m.model_dump): - dumped = m.model_dump(exclude_none=False) - if isinstance(dumped, dict): - normalized_messages.append(dumped) - continue - - # Fallback: build a minimal dict - msg: dict[str, Any] = {"role": getattr(m, "role", "user")} - content = getattr(m, "content", None) - msg["content"] = content - - # Add other message fields if present - name = getattr(m, "name", None) - if name: - msg["name"] = name - tool_calls = getattr(m, "tool_calls", None) - if tool_calls: - msg["tool_calls"] = tool_calls - tool_call_id = getattr(m, "tool_call_id", None) - if tool_call_id: - msg["tool_call_id"] = tool_call_id - normalized_messages.append(msg) - - payload["messages"] = normalized_messages - except (KeyError, TypeError, AttributeError): - # Fallback - leave whatever the converter produced - pass - - headers_override = kwargs.pop("headers_override", None) - resolved_headers: dict[str, str] | None = None - - if headers_override is not None: - resolved_headers = dict(headers_override) - - base_headers: dict[str, str] | None - try: - base_headers = self.get_headers(identity=identity) - except Exception: - base_headers = None - - headers: dict[str, str] | None = None - if base_headers is not None: - merged_headers = dict(base_headers) - if resolved_headers: - merged_headers.update(resolved_headers) - headers = merged_headers - else: - headers = resolved_headers - - api_base = kwargs.get("openai_url") or self.api_base_url - url = f"{api_base.rstrip('/')}/responses" - - guarded_headers = ensure_loop_guard_header(headers) - - if domain_request.stream: - # Return a domain-level streaming envelope - try: - stream_handle = await self._handle_streaming_response( - url, - payload, - guarded_headers, - domain_request.session_id or "", - "openai-responses", - ) - except AuthenticationError as e: - raise HTTPException(status_code=401, detail=str(e)) - return StreamingResponseEnvelope( - content=stream_handle.iterator, - media_type="text/event-stream", - headers={}, - cancel_callback=stream_handle.cancel_callback, - ) - else: - # Return a domain ResponseEnvelope for non-streaming - return await self._handle_responses_non_streaming_response( - url, payload, guarded_headers, domain_request.session_id or "" - ) - - async def _handle_responses_non_streaming_response( - self, - url: str, - payload: dict[str, Any], - headers: dict[str, str] | None, - session_id: str, - ) -> ResponseEnvelope: - """Handle non-streaming Responses API responses with proper format conversion.""" - if not headers or not headers.get("Authorization"): - raise AuthenticationError(message="No auth credentials found") - - guarded_headers = ensure_loop_guard_header(headers) - - try: - response = await self.client.post( - url, json=payload, headers=guarded_headers - ) - except httpx.RequestError as e: - raise ServiceUnavailableError(message=f"Could not connect to backend ({e})") - - if int(response.status_code) >= 400: - try: - err = response.json() - except Exception: - err = response.text - raise HTTPException(status_code=response.status_code, detail=err) - - # For Responses API, we need to handle the response differently - # The response should already be in Responses API format from OpenAI - response_data = response.json() - - # Convert to domain response first, then back to ensure consistency - # We'll treat the Responses API response as a special case of OpenAI response - domain_response = self.translation_service.to_domain_response( - response_data, "openai-responses" - ) - - # Convert back to Responses API format for the final response - responses_content = self.translation_service.from_domain_to_responses_response( - domain_response - ) - - try: - response_headers = dict(response.headers) - except Exception: - try: - response_headers = dict(getattr(response, "headers", {}) or {}) - except Exception: - response_headers = {} - - return ResponseEnvelope( - content=responses_content, - status_code=response.status_code, - headers=response_headers, - usage=domain_response.usage, - ) - - async def list_models(self, api_base_url: str | None = None) -> dict[str, Any]: - headers = self.get_headers() - base = api_base_url or self.api_base_url - logger.info(f"OpenAIConnector list_models - base URL: {base}") - response = await self.client.get(f"{base.rstrip('/')}/models", headers=headers) - response.raise_for_status() - result = response.json() - return result # type: ignore[no-any-return] # type: ignore[no-any-return] - - -backend_registry.register_backend("openai", OpenAIConnector) +from __future__ import annotations + +import asyncio +import contextlib +import inspect +import json +import logging + +logger = logging.getLogger(__name__) + +from collections.abc import ( + AsyncGenerator, + Mapping, +) +from json import JSONDecodeError +from typing import Any + +import httpx +from fastapi import HTTPException + +from src.core.common.exceptions import ( + AuthenticationError, + ServiceResolutionError, + ServiceUnavailableError, +) +from src.core.config.app_config import AppConfig +from src.core.domain.chat import CanonicalChatRequest +from src.core.domain.responses import ( + ResponseEnvelope, + StreamingResponseEnvelope, + StreamingResponseHandle, +) +from src.core.interfaces.configuration_interface import IAppIdentityConfig +from src.core.interfaces.model_bases import DomainModel, InternalDTO +from src.core.interfaces.response_processor_interface import ( + IResponseProcessor, + ProcessedResponse, +) +from src.core.security.loop_prevention import ensure_loop_guard_header +from src.core.services.backend_registry import backend_registry +from src.core.services.translation_service import TranslationService + +from .base import LLMBackend + +# Legacy ChatCompletionRequest removed from connector signatures; use domain ChatRequest + + +class OpenAIConnector(LLMBackend): + """Minimal OpenAI-compatible connector used by OpenRouterBackend in tests. + + It supports an optional `headers_override` kwarg and treats streaming + responses that expose `aiter_bytes()` as streamable even if returned by + test doubles. + """ + + backend_type: str = "openai" + SUPPORTED_CUSTOM_PARAMETERS: frozenset[str] = frozenset() + + def __init__( + self, + client: httpx.AsyncClient, + config: AppConfig, + translation_service: TranslationService | None = None, + response_processor: IResponseProcessor | None = None, + ) -> None: + super().__init__(config, response_processor) + self.client = client + # Allow callers/tests to omit TranslationService; resolve through DI for consistency + self.translation_service = ( + translation_service + if translation_service is not None + else self._resolve_translation_service() + ) + self.config = config # Stored config + self.available_models: list[str] = [] + self.api_key: str | None = None + self._api_base_url: str = "https://api.openai.com/v1" + + # Health check attributes + self._health_checked: bool = False + import os + + disable_health_checks_env = os.getenv( + "DISABLE_HEALTH_CHECKS", "false" + ).lower() in ("true", "1", "yes") + + disable_health_checks_config = bool( + getattr(self.config, "disable_health_checks", False) + ) + + # Enable health checks only when neither config nor env disable them + self._health_check_enabled = not ( + disable_health_checks_env or disable_health_checks_config + ) + + @property + def api_base_url(self) -> str: + """Return the API base URL.""" + return self._api_base_url + + @api_base_url.setter + def api_base_url(self, value: str) -> None: + """Set the API base URL.""" + self._api_base_url = value + + @staticmethod + def _resolve_translation_service() -> TranslationService: + """Resolve TranslationService from the DI container.""" + + from src.core.di.services import get_or_build_service_provider + + provider = get_or_build_service_provider() + service = provider.get_service(TranslationService) + if service is None: + raise ServiceResolutionError( + "TranslationService is not registered in the service provider", + service_name="TranslationService", + ) + return service + + def get_headers(self, identity: IAppIdentityConfig | None = None) -> dict[str, str]: + """Return request headers including API key and optional request identity.""" + + headers: dict[str, str] = {} + + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + if identity is not None: + try: + identity_headers = identity.get_resolved_headers(None) + except Exception: + identity_headers = {} + else: + identity_headers = dict(identity_headers) + if identity_headers: + headers.update(identity_headers) + + return ensure_loop_guard_header(headers) + + async def initialize(self, **kwargs: Any) -> None: + self.api_key = kwargs.get("api_key") + logger.info( + "OpenAIConnector initialize called. api_key_provided=%s", + "yes" if self.api_key else "no", + ) + if "api_base_url" in kwargs: + self.api_base_url = kwargs["api_base_url"] + + # Proceed to fetch models only when we have credentials; failures are non-fatal + if not self.api_key: + logger.debug( + "Skipping OpenAI model listing during init; no API key configured" + ) + else: + try: + headers = self.get_headers() + response = await self.client.get( + f"{self.api_base_url}/models", headers=headers + ) + data = self._decode_json_payload(response) + if isinstance(data, dict): + self.available_models = [ + model["id"] + for model in data.get("data", []) + if isinstance(model, Mapping) and "id" in model + ] + else: + logger.debug( + "Unexpected models payload type from OpenAI: %s", + type(data).__name__, + ) + self.available_models = [] + except Exception as e: + logger.warning("Failed to fetch models: %s", e, exc_info=True) + # Log the error but don't fail initialization + + async def _perform_health_check(self) -> bool: + """Perform a health check by testing API connectivity. + + This method tests actual API connectivity by making a simple request to verify + the API key works and the service is accessible. + + Returns: + bool: True if health check passes, False otherwise + """ + try: + # Test API connectivity with a simple models endpoint request + if not self.api_key: + logger.warning("Health check failed - no API key available") + return False + + headers = self.get_headers() + if not headers.get("Authorization"): + logger.warning("Health check failed - no authorization header") + return False + + url = f"{self.api_base_url}/models" + response = await self.client.get(url, headers=headers) + + if response.status_code == 200: + logger.info("Health check passed - API connectivity verified") + self._health_checked = True + return True + else: + logger.warning( + f"Health check failed - API returned status {response.status_code}" + ) + return False + + except Exception as e: + logger.error("Health check failed - unexpected error: %s", e, exc_info=True) + return False + + async def _ensure_healthy(self) -> None: + """Ensure the backend is healthy before use. + + This method performs health checks on first use, similar to how + models are loaded lazily in the parent class. + """ + if not self._health_check_enabled: + # Health check is disabled, skip + return + + if not hasattr(self, "_health_checked") or not self._health_checked: + logger.info( + f"Performing first-use health check for {self.backend_type} backend" + ) + + healthy = await self._perform_health_check() + if not healthy: + logger.warning( + "Health check did not pass; continuing with lazy verification on first request" + ) + else: + logger.info("Health check passed - backend is ready for use") + + self._health_checked = True + + def enable_health_check(self) -> None: + """Enable health check functionality for this connector instance.""" + self._health_check_enabled = True + self._health_checked = False # Reset so it will check on next use + logger.info(f"Health check enabled for {self.backend_type} backend") + + def disable_health_check(self) -> None: + """Disable health check functionality for this connector instance.""" + self._health_check_enabled = False + logger.info(f"Health check disabled for {self.backend_type} backend") + + _XSSI_PREFIXES = ( + ")]}',\n", + ")]}',", + ")]}'", + "while(1);", + "while (1);", + ) + + def _decode_json_payload(self, response: httpx.Response) -> Any: + """Safely decode JSON payloads that may include XSSI guards or trailing data.""" + try: + return response.json() + except JSONDecodeError: + text = response.text or "" + sanitized = self._strip_xssi_prefix(text) + if sanitized != text: + try: + return json.loads(sanitized) + except JSONDecodeError: + pass + candidate = self._extract_first_json_value(sanitized) + if candidate: + try: + return json.loads(candidate) + except JSONDecodeError: + logger.debug( + "Failed to decode sanitized JSON payload; candidate snippet=%s", + candidate[:200], + ) + logger.warning( + "Unable to decode JSON payload from OpenAI response (status=%s, preview=%r)", + getattr(response, "status_code", "unknown"), + (sanitized or text)[:200], + ) + return None + + def _strip_xssi_prefix(self, payload: str) -> str: + stripped = payload.lstrip() + for prefix in self._XSSI_PREFIXES: + if stripped.startswith(prefix): + return stripped[len(prefix) :] + return stripped + + def _extract_first_json_value(self, payload: str) -> str | None: + candidate = payload.strip() + if not candidate: + return None + + opening = candidate[0] + if opening not in ("{", "["): + # Attempt to locate the first JSON object within the payload + for idx, ch in enumerate(candidate): + if ch in ("{", "["): + candidate = candidate[idx:] + opening = candidate[0] + break + else: + return None + + stack = [] + in_string = False + escape = False + for idx, ch in enumerate(candidate): + if in_string: + if escape: + escape = False + elif ch == "\\": + escape = True + elif ch == '"': + in_string = False + continue + + if ch == '"': + in_string = True + continue + + if ch in ("{", "["): + stack.append("]" if ch == "[" else "}") + elif ch in ("}", "]"): + if not stack or stack.pop() != ch: + return None + if not stack: + return candidate[: idx + 1] + + return None + + async def chat_completions( + self, + request_data: DomainModel | InternalDTO | dict[str, Any], + processed_messages: list[Any], + effective_model: str, + identity: IAppIdentityConfig | None = None, + **kwargs: Any, + ) -> ResponseEnvelope | StreamingResponseEnvelope: + # Perform health check if enabled (for subclasses that support it) + await self._ensure_healthy() + + # request_data is expected to be a domain ChatRequest (or subclass like CanonicalChatRequest) + # (the frontend controller converts from frontend-specific format to domain format) + # Backends should ONLY convert FROM domain TO backend-specific format + # Type assertion: we know from architectural design that request_data is ChatRequest-like + from typing import cast + + from src.core.domain.chat import CanonicalChatRequest, ChatRequest + + if not isinstance(request_data, ChatRequest): + raise TypeError( + f"Expected ChatRequest or CanonicalChatRequest, got {type(request_data).__name__}. " + "Backend connectors should only receive domain-format requests." + ) + # Cast to CanonicalChatRequest for mypy compatibility with _prepare_payload signature + domain_request: CanonicalChatRequest = cast(CanonicalChatRequest, request_data) + + # Prepare the payload using a helper so subclasses and tests can + # override or patch payload construction logic easily. + payload = await self._prepare_payload( + domain_request, processed_messages, effective_model + ) + headers_override = kwargs.pop("headers_override", None) + headers: dict[str, str] | None = None + + base_headers: dict[str, str] | None + try: + base_headers = self.get_headers(identity=identity) + except Exception: + base_headers = None + + if headers_override is not None: + # Avoid mutating the caller-provided mapping while preserving any + # Authorization header we compute from the configured API key. + headers = dict(headers_override) + if base_headers: + merged_headers = dict(base_headers) + merged_headers.update(headers) + headers = merged_headers + else: + headers = base_headers + + api_base = kwargs.get("openai_url") or self.api_base_url + url = f"{api_base.rstrip('/')}/chat/completions" + + if domain_request.stream: + # Return a domain-level streaming envelope (raw bytes iterator) + try: + stream_handle = await self._handle_streaming_response( + url, + payload, + headers, + domain_request.session_id or "", + "openai", + ) + except AuthenticationError as e: + raise HTTPException(status_code=401, detail=str(e)) + return StreamingResponseEnvelope( + content=stream_handle.iterator, + media_type="text/event-stream", + headers={}, + cancel_callback=stream_handle.cancel_callback, + ) + else: + # Return a domain ResponseEnvelope for non-streaming + return await self._handle_non_streaming_response( + url, payload, headers, domain_request.session_id or "" + ) + + async def _prepare_payload( + self, + request_data: CanonicalChatRequest, + processed_messages: list[Any], + effective_model: str, + ) -> dict[str, Any]: + """ + Default payload preparation for OpenAI-compatible backends. + + Subclasses or tests may patch/override this method to customize the + final payload sent to the provider. + """ + # request_data is expected to be a CanonicalChatRequest already + # (the caller creates it via TranslationService.to_domain_request). + payload = self.translation_service.from_domain_request(request_data, "openai") + if inspect.isawaitable(payload): + payload = await payload + + # Prefer processed_messages (these are the canonical, post-processed + # messages ready to send). Convert them to plain dicts to ensure JSON + # serializability without mutating the original Pydantic models. + if processed_messages: + try: + normalized_messages: list[dict[str, Any]] = [] + + def _get_value(message: Any, key: str) -> Any: + if isinstance(message, Mapping): + return message.get(key) + return getattr(message, key, None) + + def _normalize_content(value: Any) -> Any: + if isinstance(value, list | tuple): + normalized_parts: list[Any] = [] + for part in value: + if hasattr(part, "model_dump") and callable( + part.model_dump + ): + normalized_parts.append( + part.model_dump(exclude_none=True) + ) + elif isinstance(part, Mapping): + normalized_parts.append(dict(part)) + else: + normalized_parts.append(part) + return normalized_parts + return value + + for message in processed_messages: + if hasattr(message, "model_dump") and callable(message.model_dump): + dumped = message.model_dump(exclude_none=False) + if isinstance(dumped, dict): + normalized_messages.append(dumped) + continue + + msg: dict[str, Any] + if isinstance(message, Mapping): + msg = dict(message) + else: + msg = {} + + role = _get_value(message, "role") or msg.get("role") or "user" + msg["role"] = role + + content = _get_value(message, "content") + if content is None and "content" in msg: + content = msg["content"] + msg["content"] = _normalize_content(content) + + name = _get_value(message, "name") + if name is not None: + msg["name"] = name + + tool_calls = _get_value(message, "tool_calls") + if tool_calls is None and isinstance(message, Mapping): + tool_calls = msg.get("tool_calls") + if tool_calls: + normalized_tool_calls: list[Any] = [] + for tool_call in tool_calls: + if hasattr(tool_call, "model_dump") and callable( + tool_call.model_dump + ): + normalized_tool_calls.append( + tool_call.model_dump(exclude_none=True) + ) + elif isinstance(tool_call, Mapping): + normalized_tool_calls.append(dict(tool_call)) + else: + normalized_tool_calls.append(tool_call) + msg["tool_calls"] = normalized_tool_calls + + tool_call_id = _get_value(message, "tool_call_id") + if tool_call_id is not None: + msg["tool_call_id"] = tool_call_id + + normalized_messages.append(msg) + + payload["messages"] = normalized_messages + except (KeyError, TypeError, AttributeError): + # Fallback - leave whatever the converter produced + pass + + # The caller may supply an "effective_model" which should override + # the model value coming from the domain request. Many tests expect + # the provider payload to use the effective_model. + if effective_model: + logger.info( + f"OpenAI DEBUG: Overriding model in payload from '{payload.get('model')}' to '{effective_model}'" + ) + payload["model"] = effective_model + + # Allow request.extra_body to override or augment the final payload. + extra = getattr(request_data, "extra_body", None) + if isinstance(extra, dict): + payload.update(extra) + + self._filter_unsupported_parameters(payload) + + return payload # type: ignore[no-any-return] + + def _filter_unsupported_parameters(self, payload: dict[str, Any]) -> None: + unsupported_parameters = [] + for param_name in ("repetition_penalty", "min_p"): + if param_name in payload and param_name not in self.SUPPORTED_CUSTOM_PARAMETERS: + unsupported_parameters.append(param_name) + + if not unsupported_parameters: + return + + backend_name = self.backend_type or self.__class__.__name__ + for param_name in unsupported_parameters: + value = payload.pop(param_name, None) + if value is not None and logger.isEnabledFor(logging.WARNING): + logger.warning( + "%s backend does not support the '%s' parameter; ignoring value %r", + backend_name, + param_name, + value, + ) + + async def _handle_non_streaming_response( + self, + url: str, + payload: dict[str, Any], + headers: dict[str, str] | None, + session_id: str, + ) -> ResponseEnvelope: + if not headers or not headers.get("Authorization"): + raise AuthenticationError(message="No auth credentials found") + + guarded_headers = ensure_loop_guard_header(headers) + + try: + response = await self.client.post( + url, json=payload, headers=guarded_headers + ) + except httpx.RequestError as e: + raise ServiceUnavailableError(message=f"Could not connect to backend ({e})") + + if int(response.status_code) >= 400: + # For backwards compatibility with existing error handlers, still use HTTPException here. + # This will be replaced in a future update with domain exceptions. + try: + err = response.json() + except Exception: + err = response.text + raise HTTPException(status_code=response.status_code, detail=err) + + domain_response = self.translation_service.to_domain_response( + response.json(), "openai" + ) + # Some tests use mocks that set response.headers to AsyncMock or + # other non-dict types; defensively coerce to a dict and fall back + # to an empty dict on error so tests don't raise during header + # extraction. + try: + response_headers = dict(response.headers) + except Exception: + try: + response_headers = dict(getattr(response, "headers", {}) or {}) + except Exception: + response_headers = {} + + return ResponseEnvelope( + content=domain_response.model_dump(), + status_code=response.status_code, + headers=response_headers, + usage=domain_response.usage, + ) + + async def _handle_streaming_response( + self, + url: str, + payload: dict[str, Any], + headers: dict[str, str] | None, + session_id: str, + stream_format: str, + ) -> StreamingResponseHandle: + """Return a streaming handle with iterator and cancellation callback.""" + + if not headers or not headers.get("Authorization"): + raise AuthenticationError(message="No auth credentials found") + + guarded_headers = ensure_loop_guard_header(headers) + + request = self.client.build_request( + "POST", url, json=payload, headers=guarded_headers + ) + try: + response = await self.client.send(request, stream=True) + except httpx.RequestError as exc: # Normalize network failures + raise ServiceUnavailableError( + message=f"Could not connect to backend ({exc})" + ) from exc + + status_code = ( + int(response.status_code) if hasattr(response, "status_code") else 200 + ) + if status_code >= 400: + # For backwards compatibility with existing error handlers, still use HTTPException here. + # This will be replaced in a future update with domain exceptions. + body: str = "" + try: + body_bytes = await response.aread() + except Exception: + fallback: str = str(getattr(response, "text", "")) + body = fallback() if callable(fallback) else fallback + else: + try: + body = body_bytes.decode("utf-8") + except Exception: + fallback_text: str = str(getattr(response, "text", "")) + body = fallback_text() if callable(fallback_text) else fallback_text + finally: + with contextlib.suppress(Exception): + await response.aclose() + + if not isinstance(body, str): + body = str(body) + logger.warning( + "Backend %s returned HTTP %s with body: %s", url, status_code, body + ) + raise HTTPException( + status_code=status_code, + detail={ + "message": body, + "type": ( + "openrouter_error" if "openrouter" in url else "openai_error" + ), + "code": status_code, + }, + ) + + loop = asyncio.get_running_loop() + response_id_future: asyncio.Future[str] = loop.create_future() + cancel_lock = asyncio.Lock() + cancel_state = {"called": False} + supports_protocol_cancel = stream_format in {"responses", "openai-responses"} + cancel_headers = dict(guarded_headers) + cancel_headers.setdefault("Content-Type", "application/json") + cancel_base_url = url.rstrip("/") + + async def cancel_stream() -> None: + async with cancel_lock: + if cancel_state["called"]: + return + cancel_state["called"] = True + + if supports_protocol_cancel: + target_id: str | None = None + if response_id_future.done(): + target_id = response_id_future.result() + else: + try: + target_id = await asyncio.wait_for(response_id_future, 0.5) + except asyncio.TimeoutError: + target_id = None + + if target_id: + await self._send_openai_responses_cancel( + base_url=cancel_base_url, + headers=cancel_headers, + response_id=target_id, + session_id=session_id, + ) + + with contextlib.suppress(Exception): + await response.aclose() + + async def gen() -> AsyncGenerator[ProcessedResponse, None]: + async def text_generator() -> AsyncGenerator[dict[Any, Any] | Any, None]: + async def iter_sse_messages() -> AsyncGenerator[str, None]: + buffer = "" + separator = "\n\n" + alt_separator = "\r\n\r\n" + try: + async for chunk_bytes in response.aiter_bytes(): + chunk_text = ( + chunk_bytes.decode("utf-8", errors="replace") + if isinstance(chunk_bytes, bytes | bytearray) + else str(chunk_bytes) + ) + buffer += chunk_text + while True: + if alt_separator in buffer: + event, buffer = buffer.split(alt_separator, 1) + separator_used = alt_separator + elif separator in buffer: + event, buffer = buffer.split(separator, 1) + separator_used = separator + else: + break + if event: + yield event + separator_used + if buffer: + yield buffer + buffer = "" + except httpx.RequestError as exc: + if buffer: + yield buffer + buffer = "" + raise ServiceUnavailableError( + message=f"Streaming connection interrupted ({exc})" + ) from exc + + try: + if stream_format in {"openai", "responses", "openai-responses"}: + async for message in iter_sse_messages(): + domain_chunk = ( + self.translation_service.to_domain_stream_chunk( + message, stream_format + ) + ) + if ( + isinstance(domain_chunk, dict) + and domain_chunk.get("error") + and logger.isEnabledFor(logging.DEBUG) + ): + try: + logger.debug( + "Streaming chunk translation returned error=%s raw=%s", + domain_chunk.get("error"), + ( + message[:500] + if isinstance(message, str) + else str(message) + ), + ) + except Exception: + logger.debug( + "Streaming chunk translation returned error but raw chunk not serializable" + ) + yield domain_chunk + else: + async for chunk in response.aiter_text(): + domain_chunk = ( + self.translation_service.to_domain_stream_chunk( + chunk, stream_format + ) + ) + if ( + isinstance(domain_chunk, dict) + and domain_chunk.get("error") + and logger.isEnabledFor(logging.DEBUG) + ): + try: + logger.debug( + "Streaming chunk translation returned error=%s raw=%s", + domain_chunk.get("error"), + chunk[:500], + ) + except Exception: + logger.debug( + "Streaming chunk translation returned error but raw chunk not serializable" + ) + yield domain_chunk + except httpx.RequestError as exc: + raise ServiceUnavailableError( + message=f"Streaming connection interrupted ({exc})" + ) from exc + + pending_error: Exception | None = None + try: + async for chunk in text_generator(): + if ( + supports_protocol_cancel + and isinstance(chunk, dict) + and not response_id_future.done() + ): + chunk_id = chunk.get("id") + if isinstance(chunk_id, str) and chunk_id: + response_id_future.set_result(chunk_id) + yield ProcessedResponse(content=chunk) + except ServiceUnavailableError as exc: + pending_error = exc + except httpx.HTTPError as exc: + raise ServiceUnavailableError( + message=f"Streaming connection interrupted ({exc})" + ) from exc + finally: + with contextlib.suppress(Exception): + await response.aclose() + if pending_error: + raise pending_error + + try: + response_headers = dict(response.headers) + except Exception: + response_headers = {} + + return StreamingResponseHandle( + iterator=gen(), + cancel_callback=cancel_stream, + headers=response_headers, + ) + + async def _send_openai_responses_cancel( + self, + base_url: str, + headers: Mapping[str, str], + response_id: str, + session_id: str, + ) -> None: + cancel_url = f"{base_url}/{response_id}/cancel" + try: + request = self.client.build_request("POST", cancel_url, headers=headers) + except Exception as exc: + logger.debug( + "Failed to build cancellation request - session_id=%s, url=%s, error=%s", + session_id, + cancel_url, + exc, + exc_info=True, + ) + return + + try: + cancel_response = await self.client.send(request, stream=False) + except Exception as exc: + logger.warning( + "Failed to send cancellation request - session_id=%s, url=%s, error=%s", + session_id, + cancel_url, + exc, + exc_info=True, + ) + return + + with contextlib.suppress(Exception): + await cancel_response.aclose() + + async def responses( + self, + request_data: DomainModel | InternalDTO | dict[str, Any], + processed_messages: list[Any], + effective_model: str, + identity: IAppIdentityConfig | None = None, + **kwargs: Any, + ) -> ResponseEnvelope | StreamingResponseEnvelope: + """Handle OpenAI Responses API calls. + + This method handles requests to the /v1/responses endpoint, which provides + structured output generation with JSON schema validation. + """ + # Perform health check if enabled + await self._ensure_healthy() + + # Convert to domain request first + # Note: The responses() method can be called directly with dicts (e.g., from tests), + # unlike chat_completions() which only goes through the frontend->backend flow + domain_request = self.translation_service.to_domain_request( + request_data, "responses" + ) + + # Prepare the payload for Responses API + payload = self.translation_service.from_domain_to_responses_request( + domain_request + ) + + # Override model if effective_model is provided + if effective_model: + payload["model"] = effective_model + + # Update messages with processed_messages if available + if processed_messages: + try: + normalized_messages: list[dict[str, Any]] = [] + for m in processed_messages: + # If the message is a pydantic model, use model_dump + if hasattr(m, "model_dump") and callable(m.model_dump): + dumped = m.model_dump(exclude_none=False) + if isinstance(dumped, dict): + normalized_messages.append(dumped) + continue + + # Fallback: build a minimal dict + msg: dict[str, Any] = {"role": getattr(m, "role", "user")} + content = getattr(m, "content", None) + msg["content"] = content + + # Add other message fields if present + name = getattr(m, "name", None) + if name: + msg["name"] = name + tool_calls = getattr(m, "tool_calls", None) + if tool_calls: + msg["tool_calls"] = tool_calls + tool_call_id = getattr(m, "tool_call_id", None) + if tool_call_id: + msg["tool_call_id"] = tool_call_id + normalized_messages.append(msg) + + payload["messages"] = normalized_messages + except (KeyError, TypeError, AttributeError): + # Fallback - leave whatever the converter produced + pass + + headers_override = kwargs.pop("headers_override", None) + resolved_headers: dict[str, str] | None = None + + if headers_override is not None: + resolved_headers = dict(headers_override) + + base_headers: dict[str, str] | None + try: + base_headers = self.get_headers(identity=identity) + except Exception: + base_headers = None + + headers: dict[str, str] | None = None + if base_headers is not None: + merged_headers = dict(base_headers) + if resolved_headers: + merged_headers.update(resolved_headers) + headers = merged_headers + else: + headers = resolved_headers + + api_base = kwargs.get("openai_url") or self.api_base_url + url = f"{api_base.rstrip('/')}/responses" + + guarded_headers = ensure_loop_guard_header(headers) + + if domain_request.stream: + # Return a domain-level streaming envelope + try: + stream_handle = await self._handle_streaming_response( + url, + payload, + guarded_headers, + domain_request.session_id or "", + "openai-responses", + ) + except AuthenticationError as e: + raise HTTPException(status_code=401, detail=str(e)) + return StreamingResponseEnvelope( + content=stream_handle.iterator, + media_type="text/event-stream", + headers={}, + cancel_callback=stream_handle.cancel_callback, + ) + else: + # Return a domain ResponseEnvelope for non-streaming + return await self._handle_responses_non_streaming_response( + url, payload, guarded_headers, domain_request.session_id or "" + ) + + async def _handle_responses_non_streaming_response( + self, + url: str, + payload: dict[str, Any], + headers: dict[str, str] | None, + session_id: str, + ) -> ResponseEnvelope: + """Handle non-streaming Responses API responses with proper format conversion.""" + if not headers or not headers.get("Authorization"): + raise AuthenticationError(message="No auth credentials found") + + guarded_headers = ensure_loop_guard_header(headers) + + try: + response = await self.client.post( + url, json=payload, headers=guarded_headers + ) + except httpx.RequestError as e: + raise ServiceUnavailableError(message=f"Could not connect to backend ({e})") + + if int(response.status_code) >= 400: + try: + err = response.json() + except Exception: + err = response.text + raise HTTPException(status_code=response.status_code, detail=err) + + # For Responses API, we need to handle the response differently + # The response should already be in Responses API format from OpenAI + response_data = response.json() + + # Convert to domain response first, then back to ensure consistency + # We'll treat the Responses API response as a special case of OpenAI response + domain_response = self.translation_service.to_domain_response( + response_data, "openai-responses" + ) + + # Convert back to Responses API format for the final response + responses_content = self.translation_service.from_domain_to_responses_response( + domain_response + ) + + try: + response_headers = dict(response.headers) + except Exception: + try: + response_headers = dict(getattr(response, "headers", {}) or {}) + except Exception: + response_headers = {} + + return ResponseEnvelope( + content=responses_content, + status_code=response.status_code, + headers=response_headers, + usage=domain_response.usage, + ) + + async def list_models(self, api_base_url: str | None = None) -> dict[str, Any]: + headers = self.get_headers() + base = api_base_url or self.api_base_url + logger.info(f"OpenAIConnector list_models - base URL: {base}") + response = await self.client.get(f"{base.rstrip('/')}/models", headers=headers) + response.raise_for_status() + result = response.json() + return result # type: ignore[no-any-return] # type: ignore[no-any-return] + + +backend_registry.register_backend("openai", OpenAIConnector) diff --git a/src/connectors/openrouter.py b/src/connectors/openrouter.py index 75a6f011..44ca7a6b 100644 --- a/src/connectors/openrouter.py +++ b/src/connectors/openrouter.py @@ -28,6 +28,9 @@ class OpenRouterBackend(OpenAIConnector): """LLMBackend implementation for OpenRouter.ai.""" backend_type: str = "openrouter" + SUPPORTED_CUSTOM_PARAMETERS: frozenset[str] = frozenset( + {"repetition_penalty", "min_p"} + ) def __init__( self, diff --git a/src/core/domain/chat.py b/src/core/domain/chat.py index de949c58..57bee67b 100644 --- a/src/core/domain/chat.py +++ b/src/core/domain/chat.py @@ -1,213 +1,215 @@ -from collections.abc import Sequence -from typing import Any, TypeVar - -from pydantic import Field, field_validator - -from src.core.domain.base import ValueObject -from src.core.interfaces.model_bases import DomainModel - -# Define a type variable for generic methods -T = TypeVar("T", bound=DomainModel) - - -# For multimodal content parts -class MessageContentPartText(DomainModel): - """Represents a text content part in a multimodal message.""" - - type: str = "text" - text: str - - -class ImageURL(DomainModel): - """Specifies the URL and optional detail for an image in a multimodal message.""" - - # Should be a data URI (e.g., "data:image/jpeg;base64,...") or public URL - url: str - detail: str | None = Field(None, examples=["auto", "low", "high"]) - - -class MessageContentPartImage(DomainModel): - """Represents an image content part in a multimodal message.""" - - type: str = "image_url" - image_url: ImageURL - - -# Extend with other multimodal types as needed (e.g., audio, video file, documents) -# For now, text and image are common starting points. -MessageContentPart = MessageContentPartText | MessageContentPartImage -"""Type alias for possible content parts in a multimodal message.""" - - -class FunctionCall(DomainModel): - """Represents a function call within a tool call.""" - - name: str - arguments: str - - -class ToolCall(DomainModel): - """Represents a tool call in a chat completion response.""" - - id: str - type: str = "function" - function: FunctionCall - - -class FunctionDefinition(DomainModel): - """Represents a function definition for tool calling.""" - - name: str - description: str | None = None - parameters: dict[str, Any] | None = None - - -class ToolDefinition(DomainModel): - """Represents a tool definition in a chat completion request.""" - - type: str = "function" - function: FunctionDefinition - - @field_validator("function", mode="before") - @classmethod - def ensure_function_is_dict(cls, v: Any) -> dict[str, Any] | FunctionDefinition: - # Accept either a FunctionDefinition or a ToolDefinition/FunctionDefinition instance - # and normalize to a dict for ChatRequest validation - if isinstance(v, FunctionDefinition): - return v.model_dump() - # If v is already a dict, return it as is - if isinstance(v, dict): - return v - # If v is something else, try to convert it to a dict - # This should handle cases where v is a dict-like object - try: - return dict(v) # type: ignore - except (TypeError, ValueError): - # If we can't convert to dict, raise a ValueError to let Pydantic handle the error properly - raise ValueError(f"Cannot convert {type(v)} to dict or FunctionDefinition") - - -class ChatMessage(DomainModel): - """ - A chat message in a conversation. - """ - - role: str - content: str | Sequence[MessageContentPart] | None = None - name: str | None = None - tool_calls: list[ToolCall] | None = None - tool_call_id: str | None = None - metadata: dict[str, Any] | None = None - - def to_dict(self) -> dict[str, Any]: - """Convert the message to a dictionary.""" - result: dict[str, Any] = {"role": self.role} - if self.content is not None: - result["content"] = self._serialize_content(self.content) - if self.name: - result["name"] = self.name - if self.tool_calls: - result["tool_calls"] = [tc.model_dump() for tc in self.tool_calls] - if self.tool_call_id: - result["tool_call_id"] = self.tool_call_id - return result - - @staticmethod - def _serialize_content( - content: str | Sequence[MessageContentPart] | None, - ) -> Any: - """Normalize message content so downstream callers receive plain data structures.""" - - if content is None: - return None - - if isinstance(content, str): - return content - - if isinstance(content, DomainModel): - return content.model_dump() - - if isinstance(content, Sequence): - serialized_parts: list[Any] = [] - for part in content: - if isinstance(part, DomainModel): - serialized_parts.append(part.model_dump()) - else: - serialized_parts.append(part) - return serialized_parts - - return content - - -class ChatRequest(ValueObject): - """ - A request for a chat completion. - """ - - model: str - messages: list[ChatMessage] - system_prompt: str | None = None # Add system_prompt field - temperature: float | None = None - top_p: float | None = None - top_k: int | None = None - n: int | None = None - stream: bool | None = None - stop: list[str] | str | None = None - max_tokens: int | None = None - presence_penalty: float | None = None - frequency_penalty: float | None = None - logit_bias: dict[str, float] | None = None - user: str | None = None - seed: int | None = None - tools: list[dict[str, Any]] | None = None - tool_choice: str | dict[str, Any] | None = None - session_id: str | None = None - agent: str | None = None # Add agent field - extra_body: dict[str, Any] | None = None - - # Reasoning parameters for o1, o3, o4-mini and other reasoning models - reasoning_effort: str | None = None - reasoning: dict[str, Any] | None = None - - # Gemini-specific reasoning parameters - thinking_budget: int | None = None - generation_config: dict[str, Any] | None = None - - @field_validator("messages") - @classmethod - def validate_messages(cls, v: list[Any]) -> list[ChatMessage]: - """Validate and convert messages.""" - if not v: - raise ValueError("At least one message is required") - return [m if isinstance(m, ChatMessage) else ChatMessage(**m) for m in v] - - @field_validator("tools", mode="before") - @classmethod - def validate_tools(cls, v: Any) -> list[dict[str, Any]] | None: - """Allow passing ToolDefinition instances or dicts for tools.""" - if v is None: - return None - result: list[dict[str, Any]] = [] - for item in v: - if isinstance(item, ToolDefinition): - result.append(item.model_dump()) - elif isinstance(item, dict): - result.append(item) - else: - # Attempt to coerce - try: - td = ToolDefinition(**item) - result.append(td.model_dump()) - except Exception as e: - from src.core.common.exceptions import ToolCallParsingError - - raise ToolCallParsingError( - message="Invalid tool definition provided", - details={"original_error": str(e), "invalid_item": str(item)}, - ) from e - return result - - +from collections.abc import Sequence +from typing import Any, TypeVar + +from pydantic import Field, field_validator + +from src.core.domain.base import ValueObject +from src.core.interfaces.model_bases import DomainModel + +# Define a type variable for generic methods +T = TypeVar("T", bound=DomainModel) + + +# For multimodal content parts +class MessageContentPartText(DomainModel): + """Represents a text content part in a multimodal message.""" + + type: str = "text" + text: str + + +class ImageURL(DomainModel): + """Specifies the URL and optional detail for an image in a multimodal message.""" + + # Should be a data URI (e.g., "data:image/jpeg;base64,...") or public URL + url: str + detail: str | None = Field(None, examples=["auto", "low", "high"]) + + +class MessageContentPartImage(DomainModel): + """Represents an image content part in a multimodal message.""" + + type: str = "image_url" + image_url: ImageURL + + +# Extend with other multimodal types as needed (e.g., audio, video file, documents) +# For now, text and image are common starting points. +MessageContentPart = MessageContentPartText | MessageContentPartImage +"""Type alias for possible content parts in a multimodal message.""" + + +class FunctionCall(DomainModel): + """Represents a function call within a tool call.""" + + name: str + arguments: str + + +class ToolCall(DomainModel): + """Represents a tool call in a chat completion response.""" + + id: str + type: str = "function" + function: FunctionCall + + +class FunctionDefinition(DomainModel): + """Represents a function definition for tool calling.""" + + name: str + description: str | None = None + parameters: dict[str, Any] | None = None + + +class ToolDefinition(DomainModel): + """Represents a tool definition in a chat completion request.""" + + type: str = "function" + function: FunctionDefinition + + @field_validator("function", mode="before") + @classmethod + def ensure_function_is_dict(cls, v: Any) -> dict[str, Any] | FunctionDefinition: + # Accept either a FunctionDefinition or a ToolDefinition/FunctionDefinition instance + # and normalize to a dict for ChatRequest validation + if isinstance(v, FunctionDefinition): + return v.model_dump() + # If v is already a dict, return it as is + if isinstance(v, dict): + return v + # If v is something else, try to convert it to a dict + # This should handle cases where v is a dict-like object + try: + return dict(v) # type: ignore + except (TypeError, ValueError): + # If we can't convert to dict, raise a ValueError to let Pydantic handle the error properly + raise ValueError(f"Cannot convert {type(v)} to dict or FunctionDefinition") + + +class ChatMessage(DomainModel): + """ + A chat message in a conversation. + """ + + role: str + content: str | Sequence[MessageContentPart] | None = None + name: str | None = None + tool_calls: list[ToolCall] | None = None + tool_call_id: str | None = None + metadata: dict[str, Any] | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert the message to a dictionary.""" + result: dict[str, Any] = {"role": self.role} + if self.content is not None: + result["content"] = self._serialize_content(self.content) + if self.name: + result["name"] = self.name + if self.tool_calls: + result["tool_calls"] = [tc.model_dump() for tc in self.tool_calls] + if self.tool_call_id: + result["tool_call_id"] = self.tool_call_id + return result + + @staticmethod + def _serialize_content( + content: str | Sequence[MessageContentPart] | None, + ) -> Any: + """Normalize message content so downstream callers receive plain data structures.""" + + if content is None: + return None + + if isinstance(content, str): + return content + + if isinstance(content, DomainModel): + return content.model_dump() + + if isinstance(content, Sequence): + serialized_parts: list[Any] = [] + for part in content: + if isinstance(part, DomainModel): + serialized_parts.append(part.model_dump()) + else: + serialized_parts.append(part) + return serialized_parts + + return content + + +class ChatRequest(ValueObject): + """ + A request for a chat completion. + """ + + model: str + messages: list[ChatMessage] + system_prompt: str | None = None # Add system_prompt field + temperature: float | None = None + top_p: float | None = None + top_k: int | None = None + repetition_penalty: float | None = None + min_p: float | None = None + n: int | None = None + stream: bool | None = None + stop: list[str] | str | None = None + max_tokens: int | None = None + presence_penalty: float | None = None + frequency_penalty: float | None = None + logit_bias: dict[str, float] | None = None + user: str | None = None + seed: int | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + session_id: str | None = None + agent: str | None = None # Add agent field + extra_body: dict[str, Any] | None = None + + # Reasoning parameters for o1, o3, o4-mini and other reasoning models + reasoning_effort: str | None = None + reasoning: dict[str, Any] | None = None + + # Gemini-specific reasoning parameters + thinking_budget: int | None = None + generation_config: dict[str, Any] | None = None + + @field_validator("messages") + @classmethod + def validate_messages(cls, v: list[Any]) -> list[ChatMessage]: + """Validate and convert messages.""" + if not v: + raise ValueError("At least one message is required") + return [m if isinstance(m, ChatMessage) else ChatMessage(**m) for m in v] + + @field_validator("tools", mode="before") + @classmethod + def validate_tools(cls, v: Any) -> list[dict[str, Any]] | None: + """Allow passing ToolDefinition instances or dicts for tools.""" + if v is None: + return None + result: list[dict[str, Any]] = [] + for item in v: + if isinstance(item, ToolDefinition): + result.append(item.model_dump()) + elif isinstance(item, dict): + result.append(item) + else: + # Attempt to coerce + try: + td = ToolDefinition(**item) + result.append(td.model_dump()) + except Exception as e: + from src.core.common.exceptions import ToolCallParsingError + + raise ToolCallParsingError( + message="Invalid tool definition provided", + details={"original_error": str(e), "invalid_item": str(item)}, + ) from e + return result + + class ChatCompletionChoiceMessage(DomainModel): """Represents the message content within a chat completion choice.""" @@ -216,121 +218,121 @@ class ChatCompletionChoiceMessage(DomainModel): tool_calls: list[ToolCall] | None = None tool_call_id: str | None = None metadata: dict[str, Any] | None = None - - -class ChatCompletionChoice(DomainModel): - """Represents a single choice in a chat completion response.""" - - index: int - message: ChatCompletionChoiceMessage - finish_reason: str | None = None - - -# ChatUsage class is defined elsewhere in this file - - -class ChatResponse(ValueObject): - """ - A response from a chat completion. - """ - - id: str - created: int - model: str - choices: list[ChatCompletionChoice] - usage: dict[str, Any] | None = None - system_fingerprint: str | None = None - object: str = "chat.completion" - - -class StreamingChatResponse(ValueObject): - """ - A streaming chunk of a chat completion response. - """ - - content: str | None - model: str - finish_reason: str | None = None - tool_calls: list[dict[str, Any]] | None = None - delta: dict[str, Any] | None = None - system_fingerprint: str | None = None - done: bool | None = None - metadata: dict[str, Any] | None = None - - @classmethod - def from_legacy_chunk(cls, chunk: dict[str, Any]) -> "StreamingChatResponse": - """ - Create a StreamingChatResponse from a legacy chunk format. - - Args: - chunk: A legacy streaming chunk - - Returns: - A new StreamingChatResponse - """ - # Extract the response content and other fields from the chunk - content: str | None = None - if chunk.get("choices"): - choice: dict[str, Any] = chunk["choices"][0] - if "delta" in choice: - delta: dict[str, Any] = choice["delta"] - if "content" in delta: - content = delta["content"] - - # Might have tool calls in delta - tool_calls: list[dict[str, Any]] | None = delta.get("tool_calls") - - # The delta is the actual delta object - delta_obj: dict[str, Any] | None = delta - else: - # Simpler format - content = choice.get("text", "") - tool_calls = None - delta_obj = None - - # Extract finish reason if present - finish_reason: str | None = choice.get("finish_reason") - else: - # Anthropic format - if "content" in chunk: - if isinstance(chunk["content"], list): - content_parts: list[str] = [ - p["text"] for p in chunk["content"] if p.get("type") == "text" - ] - content = "".join(content_parts) - else: - content = chunk["content"] - - tool_calls = chunk.get("tool_calls") - delta_obj = None - finish_reason = chunk.get("stop_reason") - - # Extract model - model: str = chunk.get("model", "unknown") - - # Extract system fingerprint - system_fingerprint: str | None = chunk.get("system_fingerprint") - - return cls( - content=content, - model=model, - finish_reason=finish_reason, - tool_calls=tool_calls, - delta=delta_obj, - system_fingerprint=system_fingerprint, - ) - - -# ChatUsage class is defined elsewhere in this file - - -class CanonicalChatRequest(ChatRequest): - """ - A canonical chat request model that is used internally throughout the application. - """ - - -class CanonicalChatResponse(ChatResponse): - """ - A canonical chat response model that is used internally throughout the application. - """ + + +class ChatCompletionChoice(DomainModel): + """Represents a single choice in a chat completion response.""" + + index: int + message: ChatCompletionChoiceMessage + finish_reason: str | None = None + + +# ChatUsage class is defined elsewhere in this file + + +class ChatResponse(ValueObject): + """ + A response from a chat completion. + """ + + id: str + created: int + model: str + choices: list[ChatCompletionChoice] + usage: dict[str, Any] | None = None + system_fingerprint: str | None = None + object: str = "chat.completion" + + +class StreamingChatResponse(ValueObject): + """ + A streaming chunk of a chat completion response. + """ + + content: str | None + model: str + finish_reason: str | None = None + tool_calls: list[dict[str, Any]] | None = None + delta: dict[str, Any] | None = None + system_fingerprint: str | None = None + done: bool | None = None + metadata: dict[str, Any] | None = None + + @classmethod + def from_legacy_chunk(cls, chunk: dict[str, Any]) -> "StreamingChatResponse": + """ + Create a StreamingChatResponse from a legacy chunk format. + + Args: + chunk: A legacy streaming chunk + + Returns: + A new StreamingChatResponse + """ + # Extract the response content and other fields from the chunk + content: str | None = None + if chunk.get("choices"): + choice: dict[str, Any] = chunk["choices"][0] + if "delta" in choice: + delta: dict[str, Any] = choice["delta"] + if "content" in delta: + content = delta["content"] + + # Might have tool calls in delta + tool_calls: list[dict[str, Any]] | None = delta.get("tool_calls") + + # The delta is the actual delta object + delta_obj: dict[str, Any] | None = delta + else: + # Simpler format + content = choice.get("text", "") + tool_calls = None + delta_obj = None + + # Extract finish reason if present + finish_reason: str | None = choice.get("finish_reason") + else: + # Anthropic format + if "content" in chunk: + if isinstance(chunk["content"], list): + content_parts: list[str] = [ + p["text"] for p in chunk["content"] if p.get("type") == "text" + ] + content = "".join(content_parts) + else: + content = chunk["content"] + + tool_calls = chunk.get("tool_calls") + delta_obj = None + finish_reason = chunk.get("stop_reason") + + # Extract model + model: str = chunk.get("model", "unknown") + + # Extract system fingerprint + system_fingerprint: str | None = chunk.get("system_fingerprint") + + return cls( + content=content, + model=model, + finish_reason=finish_reason, + tool_calls=tool_calls, + delta=delta_obj, + system_fingerprint=system_fingerprint, + ) + + +# ChatUsage class is defined elsewhere in this file + + +class CanonicalChatRequest(ChatRequest): + """ + A canonical chat request model that is used internally throughout the application. + """ + + +class CanonicalChatResponse(ChatResponse): + """ + A canonical chat response model that is used internally throughout the application. + """ diff --git a/src/core/domain/gemini_translation.py b/src/core/domain/gemini_translation.py index c19ea2de..60024658 100644 --- a/src/core/domain/gemini_translation.py +++ b/src/core/domain/gemini_translation.py @@ -222,6 +222,8 @@ def gemini_request_to_canonical_request( temperature = generation_config.get("temperature") top_p = generation_config.get("topP") top_k = generation_config.get("topK") + repetition_penalty = generation_config.get("repetitionPenalty") + min_p = generation_config.get("minP") max_tokens = generation_config.get("maxOutputTokens") stop = generation_config.get("stopSequences") @@ -298,6 +300,8 @@ def gemini_request_to_canonical_request( tools=tools, # type: ignore tool_choice=tool_choice, reasoning_effort=reasoning_effort, + repetition_penalty=repetition_penalty, + min_p=min_p, ) diff --git a/src/core/domain/translation.py b/src/core/domain/translation.py index 76dd03b6..9855d7af 100644 --- a/src/core/domain/translation.py +++ b/src/core/domain/translation.py @@ -1,3598 +1,3618 @@ -from __future__ import annotations - -import json -import logging -import mimetypes -import os -from typing import Any, cast - -from src.core.app.constants.logging_constants import TRACE_LEVEL - -_MAX_SANITIZE_DEPTH = 100 - -from src.core.domain.base_translator import BaseTranslator -from src.core.domain.chat import ( - CanonicalChatRequest, - CanonicalChatResponse, - ChatCompletionChoice, - ChatCompletionChoiceMessage, - ChatMessage, - ChatResponse, - FunctionCall, - ToolCall, -) -from src.core.services.tool_text_renderer import render_tool_call - -logger = logging.getLogger(__name__) - - -class Translation(BaseTranslator): - """ - A class for translating requests and responses between different API formats. - """ - - _codex_tool_call_index_base: dict[str, int] = {} - _codex_tool_call_item_index: dict[str, dict[str, int]] = {} - - @classmethod - def _reset_tool_call_state(cls, response_id: str | None) -> None: - if not response_id: - return - cls._codex_tool_call_index_base.pop(response_id, None) - cls._codex_tool_call_item_index.pop(response_id, None) - - @classmethod - def _assign_tool_call_index( - cls, - response_id: str | None, - output_index: Any, - item_id: str | None, - ) -> int: - if not response_id: - return 0 - - if not isinstance(output_index, int): - if item_id: - return cls._codex_tool_call_item_index.get(response_id, {}).get( - item_id, 0 - ) - return 0 - - base = cls._codex_tool_call_index_base.get(response_id) - if base is None or output_index < base: - cls._codex_tool_call_index_base[response_id] = output_index - base = output_index - - index = output_index - base - if index < 0: - index = 0 - - if item_id: - cls._codex_tool_call_item_index.setdefault(response_id, {})[item_id] = index - - return index - - @staticmethod - def validate_json_against_schema( - json_data: dict[str, Any], schema: dict[str, Any] - ) -> tuple[bool, str | None]: - """ - Validate JSON data against a JSON schema. - - Args: - json_data: The JSON data to validate - schema: The JSON schema to validate against - - Returns: - A tuple of (is_valid, error_message) - """ - try: - import jsonschema - - jsonschema.validate(json_data, schema) - return True, None - except ImportError: - # jsonschema not available, perform basic validation - return Translation._basic_schema_validation(json_data, schema) - except Exception as e: - # Check if this is a jsonschema error, even if the import failed - if "jsonschema" in str(e) and "ValidationError" in str(e): - return False, str(e) - # Fallback for other validation errors - return False, f"Schema validation error: {e!s}" - - @staticmethod - def _basic_schema_validation( - json_data: dict[str, Any], schema: dict[str, Any] - ) -> tuple[bool, str | None]: - """ - Perform basic JSON schema validation without jsonschema library. - - This is a fallback validation that checks basic schema requirements. - """ - try: - # Check type - schema_type = schema.get("type") - if schema_type == "object" and not isinstance(json_data, dict): - return False, f"Expected object, got {type(json_data).__name__}" - elif schema_type == "array" and not isinstance(json_data, list): - return False, f"Expected array, got {type(json_data).__name__}" - elif schema_type == "string" and not isinstance(json_data, str): - return False, f"Expected string, got {type(json_data).__name__}" - elif schema_type == "number" and not isinstance(json_data, int | float): - return False, f"Expected number, got {type(json_data).__name__}" - elif schema_type == "integer" and not isinstance(json_data, int): - return False, f"Expected integer, got {type(json_data).__name__}" - elif schema_type == "boolean" and not isinstance(json_data, bool): - return False, f"Expected boolean, got {type(json_data).__name__}" - - # Check required properties for objects - if schema_type == "object" and isinstance(json_data, dict): - required = schema.get("required", []) - for prop in required: - if prop not in json_data: - return False, f"Missing required property: {prop}" - - return True, None - except Exception as e: - return False, f"Basic validation error: {e!s}" - - @staticmethod - def _detect_image_mime_type(url: str) -> str: - """Detect the MIME type for an image URL or data URI.""" - if url.startswith("data:"): - header = url.split(",", 1)[0] - header = header.split(";", 1)[0] - if ":" in header: - candidate = header.split(":", 1)[1] - if candidate: - return candidate - return "image/jpeg" - - clean_url = url.split("?", 1)[0].split("#", 1)[0] - if "." in clean_url: - extension = clean_url.rsplit(".", 1)[-1].lower() - if extension: - mime_type = mimetypes.types_map.get(f".{extension}") - if mime_type and mime_type.startswith("image/"): - return mime_type - if extension == "jpg": - return "image/jpeg" - return "image/jpeg" - - @staticmethod - def _process_gemini_image_part(part: Any) -> dict[str, Any] | None: - """Convert a multimodal image part to Gemini format.""" - from src.core.domain.chat import MessageContentPartImage - - if not isinstance(part, MessageContentPartImage) or not part.image_url: - return None - - url_str = str(part.image_url.url or "").strip() - if not url_str: - return None - - # Inline data URIs are allowed - if url_str.startswith("data:"): - mime_type = Translation._detect_image_mime_type(url_str) - try: - _, base64_data = url_str.split(",", 1) - except ValueError: - base64_data = "" - return { - "inline_data": { - "mime_type": mime_type, - "data": base64_data, - } - } - - # For non-inline URIs, only allow http/https schemes. Reject file/ftp and local paths. - try: - from urllib.parse import urlparse - - scheme = (urlparse(url_str).scheme or "").lower() - except Exception: - scheme = "" - - allowed_schemes = {"http", "https"} - - if scheme not in allowed_schemes: - # Also treat Windows/local file paths (no scheme or drive-letter scheme) as invalid - return None - - mime_type = Translation._detect_image_mime_type(url_str) - return { - "file_data": { - "mime_type": mime_type, - "file_uri": url_str, - } - } - - @staticmethod - def _normalize_usage_metadata( - usage: dict[str, Any], source_format: str - ) -> dict[str, Any]: - """Normalize usage metadata from different API formats to a standard structure.""" - if source_format == "gemini": - return { - "prompt_tokens": usage.get("promptTokenCount", 0), - "completion_tokens": usage.get("candidatesTokenCount", 0), - "total_tokens": usage.get("totalTokenCount", 0), - } - elif source_format == "anthropic": - return { - "prompt_tokens": usage.get("input_tokens", 0), - "completion_tokens": usage.get("output_tokens", 0), - "total_tokens": usage.get("input_tokens", 0) - + usage.get("output_tokens", 0), - } - elif source_format in {"openai", "openai-responses"}: - prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens", 0)) - completion_tokens = usage.get( - "completion_tokens", usage.get("output_tokens", 0) - ) - total_tokens = usage.get("total_tokens") - if total_tokens is None: - total_tokens = prompt_tokens + completion_tokens - - return { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - } - else: - # Default normalization - return { - "prompt_tokens": usage.get("prompt_tokens", 0), - "completion_tokens": usage.get("completion_tokens", 0), - "total_tokens": usage.get("total_tokens", 0), - } - - @staticmethod - def _normalize_responses_input_to_messages( - input_payload: Any, - ) -> list[dict[str, Any]]: - """Coerce OpenAI Responses API input payloads into chat messages.""" - - def _normalize_message_entry(entry: Any) -> dict[str, Any] | None: - if entry is None: - return None - - if isinstance(entry, str): - return {"role": "user", "content": entry} - - if isinstance(entry, dict): - raw_role = entry.get("role") - if raw_role is None: - raw_role = "user" - role = str(raw_role) - message: dict[str, Any] = {"role": role} - - content = Translation._normalize_responses_content(entry.get("content")) - if content is not None: - if isinstance(content, list): - message["content_parts"] = content - message["content"] = content - else: - parts = [{"type": "text", "text": content}] - message["content_parts"] = parts - message["content"] = parts - - if "name" in entry and entry.get("name") is not None: - message["name"] = entry["name"] - - if "tool_calls" in entry and entry.get("tool_calls") is not None: - message["tool_calls"] = entry["tool_calls"] - - if "tool_call_id" in entry and entry.get("tool_call_id") is not None: - message["tool_call_id"] = entry["tool_call_id"] - - return message - - # Fallback: convert to string representation - return {"role": "user", "content": str(entry)} - - if input_payload is None: - return [] - - if isinstance(input_payload, str | bytes): - text_value = ( - input_payload.decode("utf-8", "ignore") - if isinstance(input_payload, bytes | bytearray) - else input_payload - ) - return [{"role": "user", "content": text_value}] - - if isinstance(input_payload, dict): - normalized = _normalize_message_entry(input_payload) - return [normalized] if normalized else [] - - if isinstance(input_payload, list | tuple): - messages: list[dict[str, Any]] = [] - for item in input_payload: - normalized = _normalize_message_entry(item) - if normalized: - messages.append(normalized) - return messages - - # Unknown type - coerce to a single user message - return [{"role": "user", "content": str(input_payload)}] - - @staticmethod - def _normalize_responses_content(content: Any) -> Any: - """Normalize Responses API content blocks into chat-compatible structures.""" - - def _coerce_text_value(value: Any) -> str: - if isinstance(value, str): - return value - if isinstance(value, bytes | bytearray): - return value.decode("utf-8", "ignore") - if isinstance(value, list): - segments: list[str] = [] - for segment in value: - if isinstance(segment, dict): - segments.append(_coerce_text_value(segment.get("text"))) - else: - segments.append(str(segment)) - return "".join(segments) - if isinstance(value, dict) and "text" in value: - return _coerce_text_value(value.get("text")) - return str(value) if value is not None else "" - - if content is None: - return None - - if isinstance(content, str | bytes | bytearray): - return _coerce_text_value(content) - - if isinstance(content, dict): - normalized_parts = Translation._normalize_responses_content_part(content) - if not normalized_parts: - return None - if len(normalized_parts) == 1 and normalized_parts[0].get("type") == "text": - return normalized_parts[0]["text"] - return normalized_parts - - if isinstance(content, list | tuple): - collected_parts: list[dict[str, Any]] = [] - for part in content: - if isinstance(part, dict): - collected_parts.extend( - Translation._normalize_responses_content_part(part) - ) - elif isinstance(part, str | bytes | bytearray): - collected_parts.append( - {"type": "text", "text": _coerce_text_value(part)} - ) - if not collected_parts: - return None - if len(collected_parts) == 1 and collected_parts[0].get("type") == "text": - return collected_parts[0]["text"] - return collected_parts - - return str(content) - - @staticmethod - def _normalize_responses_content_part(part: dict[str, Any]) -> list[dict[str, Any]]: - """Normalize a single Responses API content part.""" - - part_type = str(part.get("type") or "").lower() - normalized_parts: list[dict[str, Any]] = [] - - if part_type in {"text", "input_text", "output_text"}: - text_value = part.get("text") - if text_value is None: - text_value = part.get("value") - normalized_parts.append( - {"type": "text", "text": Translation._safe_string(text_value)} - ) - elif "image" in part_type: - image_payload = ( - part.get("image_url") - or part.get("imageUrl") - or part.get("image") - or part.get("image_data") - ) - if isinstance(image_payload, str): - image_payload = {"url": image_payload} - if isinstance(image_payload, dict) and image_payload.get("url"): - normalized_parts.append( - {"type": "image_url", "image_url": image_payload} - ) - elif part_type == "tool_call": - # Tool call parts are handled elsewhere in the pipeline; ignore here. - return [] - else: - # Preserve already-normalized structures (e.g., function calls) as-is - normalized_parts.append(part) - - return [p for p in normalized_parts if p] - - @staticmethod - def _safe_string(value: Any) -> str: - if value is None: - return "" - if isinstance(value, str): - return value - if isinstance(value, bytes | bytearray): - return value.decode("utf-8", "ignore") - return str(value) - - @staticmethod - def _map_gemini_finish_reason(finish_reason: str | None) -> str | None: - """Map Gemini finish reasons to canonical values.""" - if finish_reason is None: - return None - - normalized = str(finish_reason).lower() - mapping = { - "stop": "stop", - "max_tokens": "length", - "safety": "content_filter", - "tool_calls": "tool_calls", - } - return mapping.get(normalized, "stop") - - @staticmethod - def _normalize_stop_sequences(stop: Any) -> list[str] | None: - """Normalize stop sequences to a consistent format.""" - if stop is None: - return None - - if isinstance(stop, str): - return [stop] - - if isinstance(stop, list): - # Ensure all elements are strings - return [str(s) for s in stop] - - # Convert other types to string - return [str(stop)] - - @staticmethod - def _normalize_tool_arguments(args: Any) -> str: - """Normalize tool call arguments to a JSON string.""" - if args is None: - return "{}" - - if isinstance(args, str): - stripped = args.strip() - if not stripped: - return "{}" - - # First, try to load it as-is. It might be a valid JSON string. - try: - json.loads(stripped) - return stripped - except json.JSONDecodeError: - # If it fails, it might be a string using single quotes. - # We will try to fix it, but only if it doesn't create an invalid JSON. - pass - - try: - # Attempt to replace single quotes with double quotes for JSON compatibility. - # This is a common issue with LLM-generated JSON in string format. - # However, we must be careful not to corrupt strings that contain single quotes. - fixed_string = stripped.replace("'", '"') - json.loads(fixed_string) - return fixed_string - except (json.JSONDecodeError, TypeError): - # If replacement fails, it's likely not a simple quote issue. - # This can happen if the string contains legitimate single quotes. - # Return empty object instead of _raw format to maintain tool calling contract. - return "{}" - - if isinstance(args, dict): - try: - return json.dumps(args) - except TypeError: - # Handle dicts with non-serializable values - sanitized_dict = Translation._sanitize_dict_for_json(args) - return json.dumps(sanitized_dict) - - if isinstance(args, list | tuple): - try: - # PERFORMANCE OPTIMIZATION: Avoid unnecessary list copying - # Use args directly if it's already a list, only convert tuples - return json.dumps(args if isinstance(args, list) else list(args)) - except TypeError: - # Handle lists with non-serializable items - # PERFORMANCE OPTIMIZATION: Avoid unnecessary list copying - sanitized_list = Translation._sanitize_list_for_json( - args if isinstance(args, list) else list(args) - ) - return json.dumps(sanitized_list) - - # For primitive types that should be JSON serializable - if isinstance(args, int | float | bool): - return json.dumps(args) - - # For non-serializable objects, return empty object instead of _raw format - # This maintains the tool calling contract while preventing failures - return "{}" - - @staticmethod - def _is_json_serializable( - value: Any, - *, - max_depth: int, - _depth: int = 0, - _seen: set[int] | None = None, - ) -> bool: - """Best-effort check to determine if a value can be JSON-serialized.""" - - if _depth > max_depth: - return False - - if value is None or isinstance(value, str | int | float | bool): - return True - - if isinstance(value, list | tuple): - if _seen is None: - _seen = set() - obj_id = id(value) - if obj_id in _seen: - return False - _seen.add(obj_id) - try: - return all( - Translation._is_json_serializable( - item, - max_depth=max_depth, - _depth=_depth + 1, - _seen=_seen, - ) - for item in value - ) - finally: - _seen.remove(obj_id) - - if isinstance(value, dict): - if _seen is None: - _seen = set() - obj_id = id(value) - if obj_id in _seen: - return False - _seen.add(obj_id) - try: - for key, item in value.items(): - if key is not None and not isinstance( - key, str | int | float | bool - ): - return False - if not Translation._is_json_serializable( - item, - max_depth=max_depth, - _depth=_depth + 1, - _seen=_seen, - ): - return False - finally: - _seen.remove(obj_id) - return True - - return False - - @staticmethod - def _sanitize_dict_for_json( - data: dict[str, Any], - *, - max_depth: int = _MAX_SANITIZE_DEPTH, - _depth: int = 0, - _seen: set[int] | None = None, - ) -> dict[str, Any]: - """Sanitize a dictionary by removing or converting non-JSON-serializable values.""" - - if _depth > max_depth: - return {} - - if _seen is None: - _seen = set() - - obj_id = id(data) - if obj_id in _seen: - return {} - - _seen.add(obj_id) - try: - sanitized: dict[str, Any] = {} - sanitized_value: Any = None - for key, value in data.items(): - if key is not None and not isinstance(key, str | int | float | bool): - continue - - if Translation._is_json_serializable( - value, - max_depth=max_depth, - _depth=_depth + 1, - _seen=_seen, - ): - sanitized[key] = value - continue - - if isinstance(value, dict): - sanitized_value = Translation._sanitize_dict_for_json( - value, - max_depth=max_depth, - _depth=_depth + 1, - _seen=_seen, - ) - elif isinstance(value, list | tuple): - sanitized_value = Translation._sanitize_list_for_json( - value if isinstance(value, list) else list(value), - max_depth=max_depth, - _depth=_depth + 1, - _seen=_seen, - ) - elif isinstance(value, str | int | float | bool) or value is None: - sanitized_value = value - else: - continue - - sanitized[key] = sanitized_value - - return sanitized - finally: - _seen.remove(obj_id) - - @staticmethod - def _sanitize_list_for_json( - data: list[Any], - *, - max_depth: int = _MAX_SANITIZE_DEPTH, - _depth: int = 0, - _seen: set[int] | None = None, - ) -> list[Any]: - """Sanitize a list by removing or converting non-JSON-serializable items.""" - - if _depth > max_depth: - return [] - - if _seen is None: - _seen = set() - - obj_id = id(data) - if obj_id in _seen: - return [] - - _seen.add(obj_id) - try: - sanitized: list[Any] = [] - for item in data: - if Translation._is_json_serializable( - item, - max_depth=max_depth, - _depth=_depth + 1, - _seen=_seen, - ): - sanitized.append(item) - continue - - if isinstance(item, dict): - sanitized.append( - Translation._sanitize_dict_for_json( - item, - max_depth=max_depth, - _depth=_depth + 1, - _seen=_seen, - ) - ) - elif isinstance(item, list | tuple): - sanitized.append( - Translation._sanitize_list_for_json( - item if isinstance(item, list) else list(item), - max_depth=max_depth, - _depth=_depth + 1, - _seen=_seen, - ) - ) - elif isinstance(item, str | int | float | bool) or item is None: - sanitized.append(item) - else: - continue - - return sanitized - finally: - _seen.remove(obj_id) - - @staticmethod - def _process_gemini_function_call(function_call: dict[str, Any]) -> ToolCall: - """Process a Gemini function call part into a ToolCall.""" - import uuid - - name = function_call.get("name", "") - raw_args = function_call.get("args", function_call.get("arguments")) - normalized_args = Translation._normalize_tool_arguments(raw_args) - - return ToolCall( - id=f"call_{uuid.uuid4().hex[:12]}", - type="function", - function=FunctionCall(name=name, arguments=normalized_args), - ) - - @staticmethod - def gemini_to_domain_request(request: Any) -> CanonicalChatRequest: - """ - Translate a Gemini request to a CanonicalChatRequest. - """ - from src.core.domain.gemini_translation import ( - gemini_request_to_canonical_request, - ) - - return gemini_request_to_canonical_request(request) - - @staticmethod - def anthropic_to_domain_request(request: Any) -> CanonicalChatRequest: - """ - Translate an Anthropic request to a CanonicalChatRequest. - """ - # Use the helper method to safely access request parameters - system_prompt = Translation._get_request_param(request, "system") - raw_messages = Translation._get_request_param(request, "messages", []) - normalized_messages: list[Any] = [] - - if system_prompt: - normalized_messages.append({"role": "system", "content": system_prompt}) - - if raw_messages: - for message in raw_messages: - normalized_messages.append(message) - - stop_param = Translation._get_request_param(request, "stop") - stop_sequences = Translation._get_request_param(request, "stop_sequences") - normalized_stop = stop_param - if ( - normalized_stop is None or normalized_stop == [] or normalized_stop == "" - ) and stop_sequences not in (None, [], ""): - normalized_stop = stop_sequences - - return CanonicalChatRequest( - model=Translation._get_request_param(request, "model"), - messages=normalized_messages, - temperature=Translation._get_request_param(request, "temperature"), - top_p=Translation._get_request_param(request, "top_p"), - top_k=Translation._get_request_param(request, "top_k"), - n=Translation._get_request_param(request, "n"), - stream=Translation._get_request_param(request, "stream"), - stop=normalized_stop, - max_tokens=Translation._get_request_param(request, "max_tokens"), - presence_penalty=Translation._get_request_param( - request, "presence_penalty" - ), - frequency_penalty=Translation._get_request_param( - request, "frequency_penalty" - ), - logit_bias=Translation._get_request_param(request, "logit_bias"), - user=Translation._get_request_param(request, "user"), - reasoning_effort=Translation._get_request_param( - request, "reasoning_effort" - ), - seed=Translation._get_request_param(request, "seed"), - tools=Translation._get_request_param(request, "tools"), - tool_choice=Translation._get_request_param(request, "tool_choice"), - extra_body=Translation._get_request_param(request, "extra_body"), - ) - - @staticmethod - def anthropic_to_domain_response(response: Any) -> CanonicalChatResponse: - """ - Translate an Anthropic response to a CanonicalChatResponse. - """ - import time - - if not isinstance(response, dict): - # Handle non-dict responses - return CanonicalChatResponse( - id=f"chatcmpl-anthropic-{int(time.time())}", - object="chat.completion", - created=int(time.time()), - model="unknown", - choices=[ - ChatCompletionChoice( - index=0, - message=ChatCompletionChoiceMessage( - role="assistant", content=str(response) - ), - finish_reason="stop", - ) - ], - usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - ) - - # Extract choices - choices = [] - if "content" in response: - for idx, item in enumerate(response["content"]): - if item.get("type") == "text": - choice = ChatCompletionChoice( - index=idx, - message=ChatCompletionChoiceMessage( - role="assistant", content=item.get("text", "") - ), - finish_reason=response.get("stop_reason", "stop"), - ) - choices.append(choice) - - # Extract usage - usage = response.get("usage", {}) - normalized_usage = Translation._normalize_usage_metadata(usage, "anthropic") - - return CanonicalChatResponse( - id=response.get("id", f"chatcmpl-anthropic-{int(time.time())}"), - object="chat.completion", - created=int(time.time()), - model=response.get("model", "unknown"), - choices=choices, - usage=normalized_usage, - ) - - @staticmethod - def gemini_to_domain_response(response: Any) -> CanonicalChatResponse: - """ - Translate a Gemini response to a CanonicalChatResponse. - """ - import time - import uuid - - # Generate a unique ID for the response - response_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" - created = int(time.time()) - model = "gemini-pro" # Default model if not specified - - # Extract choices from candidates - choices = [] - if isinstance(response, dict) and "candidates" in response: - for idx, candidate in enumerate(response["candidates"]): - content = "" - tool_calls = None - - # Extract content from parts - if "content" in candidate and "parts" in candidate["content"]: - parts = candidate["content"]["parts"] - - # Extract text content - text_parts = [] - for part in parts: - if "text" in part: - text_parts.append(part["text"]) - elif "functionCall" in part: - # Handle function calls (tool calls) - if tool_calls is None: - tool_calls = [] - - function_call = part["functionCall"] - tool_calls.append( - Translation._process_gemini_function_call(function_call) - ) - - content = "".join(text_parts) - - # Map finish reason - finish_reason = Translation._map_gemini_finish_reason( - candidate.get("finishReason") - ) - - # Create choice - choice = ChatCompletionChoice( - index=idx, - message=ChatCompletionChoiceMessage( - role="assistant", - content=content, - tool_calls=tool_calls, - ), - finish_reason=finish_reason, - ) - choices.append(choice) - - # Extract usage metadata - usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} - if isinstance(response, dict) and "usageMetadata" in response: - usage_metadata = response["usageMetadata"] - usage = Translation._normalize_usage_metadata(usage_metadata, "gemini") - - # If no choices were extracted, create a default one - if not choices: - choices = [ - ChatCompletionChoice( - index=0, - message=ChatCompletionChoiceMessage(role="assistant", content=""), - finish_reason="stop", - ) - ] - - return CanonicalChatResponse( - id=response_id, - object="chat.completion", - created=created, - model=model, - choices=choices, - usage=usage, - ) - - @staticmethod - def gemini_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: - """ - Translate a Gemini streaming chunk to a canonical dictionary format. - - Args: - chunk: The Gemini streaming chunk. - - Returns: - A dictionary representing the canonical chunk format. - """ - import time - import uuid - - if not isinstance(chunk, dict): - return {"error": "Invalid chunk format: expected a dictionary"} - - response_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" - created = int(time.time()) - model = "gemini-pro" # Default model - - content_pieces: list[str] = [] - tool_calls: list[dict[str, Any]] = [] - finish_reason = None - - if "candidates" in chunk: - for candidate in chunk["candidates"]: - if "content" in candidate and "parts" in candidate["content"]: - for part in candidate["content"]["parts"]: - if "text" in part: - content_pieces.append(part["text"]) - elif "functionCall" in part: - try: - tool_calls.append( - Translation._process_gemini_function_call( - part["functionCall"] - ).model_dump() - ) - except Exception: - continue - if "finishReason" in candidate: - finish_reason = Translation._map_gemini_finish_reason( - candidate["finishReason"] - ) - - delta: dict[str, Any] = {"role": "assistant"} - if content_pieces: - delta["content"] = "".join(content_pieces) - if tool_calls: - delta["tool_calls"] = tool_calls - - return { - "id": response_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "delta": delta, - "finish_reason": finish_reason, - } - ], - } - - @staticmethod - def openai_to_domain_request(request: Any) -> CanonicalChatRequest: - """ - Translate an OpenAI request to a CanonicalChatRequest. - """ - if isinstance(request, dict): - model = request.get("model") - messages = request.get("messages", []) - top_k = request.get("top_k") - top_p = request.get("top_p") - temperature = request.get("temperature") - max_tokens = request.get("max_tokens") - stop = request.get("stop") - stream = request.get("stream", False) - tools = request.get("tools") - tool_choice = request.get("tool_choice") - seed = request.get("seed") - reasoning_effort = request.get("reasoning_effort") - reasoning_payload = request.get("reasoning") - else: - model = getattr(request, "model", None) - messages = getattr(request, "messages", []) - top_k = getattr(request, "top_k", None) - top_p = getattr(request, "top_p", None) - temperature = getattr(request, "temperature", None) - max_tokens = getattr(request, "max_tokens", None) - stop = getattr(request, "stop", None) - stream = getattr(request, "stream", False) - tools = getattr(request, "tools", None) - tool_choice = getattr(request, "tool_choice", None) - seed = getattr(request, "seed", None) - reasoning_effort = getattr(request, "reasoning_effort", None) - reasoning_payload = getattr(request, "reasoning", None) - - if reasoning_effort in ("", None) and isinstance(reasoning_payload, dict): - raw_effort = reasoning_payload.get("effort") - if isinstance(raw_effort, str) and raw_effort.strip(): - reasoning_effort = raw_effort - - normalized_reasoning: dict[str, Any] | None = None - if reasoning_payload: - if isinstance(reasoning_payload, dict): - normalized_reasoning = dict(reasoning_payload) - elif hasattr(reasoning_payload, "model_dump"): - normalized_reasoning = reasoning_payload.model_dump() # type: ignore[attr-defined] - - if not model: - raise ValueError("Model not found in request") - - # Convert messages to ChatMessage objects if they are dicts - chat_messages = [] - for msg in messages: - if isinstance(msg, dict): - chat_messages.append(ChatMessage(**msg)) - else: - chat_messages.append(msg) - - return CanonicalChatRequest( - model=model, - messages=chat_messages, - top_k=top_k, - top_p=top_p, - temperature=temperature, - max_tokens=max_tokens, - stop=stop, - stream=stream, - tools=tools, - tool_choice=tool_choice, - seed=seed, - reasoning_effort=reasoning_effort, - reasoning=normalized_reasoning, - ) - - @staticmethod - def openai_to_domain_response(response: Any) -> CanonicalChatResponse: - """ - Translate an OpenAI response to a CanonicalChatResponse. - """ - import time - - if not isinstance(response, dict): - return CanonicalChatResponse( - id=f"chatcmpl-openai-{int(time.time())}", - object="chat.completion", - created=int(time.time()), - model="unknown", - choices=[ - ChatCompletionChoice( - index=0, - message=ChatCompletionChoiceMessage( - role="assistant", content=str(response) - ), - finish_reason="stop", - ) - ], - usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - ) - - choices: list[ChatCompletionChoice] = [] - for idx, ch in enumerate(response.get("choices", [])): - msg = ch.get("message", {}) - role = msg.get("role", "assistant") - content = msg.get("content") - - # Preserve tool_calls if present - tool_calls = None - raw_tool_calls = msg.get("tool_calls") - if isinstance(raw_tool_calls, list): - # Validate each tool call in the list before including it - validated_tool_calls = [] - for tc in raw_tool_calls: - # Convert dict to ToolCall if necessary - if isinstance(tc, dict): - # Create a ToolCall object from the dict - # Assuming the dict has the necessary structure for ToolCall - # We'll need to import ToolCall if not already available - # For now, we'll use a simple approach - try: - # Create ToolCall from dict, assuming proper structure - tool_call_obj = ToolCall(**tc) - validated_tool_calls.append(tool_call_obj) - except (TypeError, ValueError): - # If conversion fails, skip this tool call - pass - elif isinstance(tc, ToolCall): - validated_tool_calls.append(tc) - else: - # Log or handle invalid tool call - # For now, we'll skip invalid ones - pass - tool_calls = validated_tool_calls if validated_tool_calls else None - - message_obj = ChatCompletionChoiceMessage( - role=role, content=content, tool_calls=tool_calls - ) - - choices.append( - ChatCompletionChoice( - index=idx, - message=message_obj, - finish_reason=ch.get("finish_reason"), - ) - ) - - usage = response.get("usage") or {} - normalized_usage = Translation._normalize_usage_metadata(usage, "openai") - - return CanonicalChatResponse( - id=response.get("id", "chatcmpl-openai-unk"), - object=response.get("object", "chat.completion"), - created=response.get("created", int(__import__("time").time())), - model=response.get("model", "unknown"), - choices=choices, - usage=normalized_usage, - ) - - @staticmethod - def responses_to_domain_response(response: Any) -> CanonicalChatResponse: - """Translate an OpenAI Responses API response to a canonical response.""" - import time - - if not isinstance(response, dict): - return Translation.openai_to_domain_response(response) - - # If the backend already returned OpenAI-style choices, reuse that logic. - if response.get("choices") and not response.get("output"): - return Translation.openai_to_domain_response(response) - - output_items = response.get("output") or [] - choices: list[ChatCompletionChoice] = [] - - for idx, item in enumerate(output_items): - if not isinstance(item, dict): - continue - - role = item.get("role", "assistant") - content_parts = item.get("content") - if not isinstance(content_parts, list): - content_parts = [] - - text_segments: list[str] = [] - tool_calls: list[ToolCall] = [] - - for part in content_parts: - if not isinstance(part, dict): - continue - - part_type = part.get("type") - if part_type in {"output_text", "text", "input_text"}: - text_value = part.get("text") or part.get("value") or "" - if text_value: - text_segments.append(str(text_value)) - elif part_type == "tool_call": - function_payload = ( - part.get("function") or part.get("function_call") or {} - ) - normalized_args = Translation._normalize_tool_arguments( - function_payload.get("arguments") - or function_payload.get("args") - or function_payload.get("arguments_json") - ) - tool_calls.append( - ToolCall( - id=part.get("id") or f"tool_call_{idx}_{len(tool_calls)}", - function=FunctionCall( - name=function_payload.get("name", ""), - arguments=normalized_args, - ), - ) - ) - - content_text = "\n".join( - segment for segment in text_segments if segment - ).strip() - - finish_reason = item.get("finish_reason") or item.get("status") - if finish_reason == "completed": - finish_reason = "stop" - elif finish_reason == "incomplete": - finish_reason = "length" - elif finish_reason in {"in_progress", "generating"}: - finish_reason = None - elif finish_reason is None and (content_text or tool_calls): - finish_reason = "stop" - - message = ChatCompletionChoiceMessage( - role=role, - content=content_text or None, - tool_calls=tool_calls or None, - ) - - choices.append( - ChatCompletionChoice( - index=idx, - message=message, - finish_reason=finish_reason, - ) - ) - - if not choices: - # Fallback to output_text aggregation used by the Responses API when - # the structured output array is empty. This happens when the - # backend only returns plain text without additional metadata. - output_text = response.get("output_text") - fallback_text_segments: list[str] = [] - if isinstance(output_text, list): - fallback_text_segments = [ - str(segment) for segment in output_text if segment - ] - elif isinstance(output_text, str) and output_text: - fallback_text_segments = [output_text] - - if fallback_text_segments: - aggregated_text = "".join(fallback_text_segments) - status = response.get("status") - fallback_finish_reason: str | None - if status == "completed": - fallback_finish_reason = "stop" - elif status == "incomplete": - fallback_finish_reason = "length" - elif status in {"in_progress", "generating"}: - fallback_finish_reason = None - else: - fallback_finish_reason = "stop" if aggregated_text else None - - message = ChatCompletionChoiceMessage( - role="assistant", - content=aggregated_text, - tool_calls=None, - ) - - choices.append( - ChatCompletionChoice( - index=0, - message=message, - finish_reason=fallback_finish_reason, - ) - ) - - if not choices: - # Fallback to OpenAI conversion to avoid returning an empty response - return Translation.openai_to_domain_response(response) - - usage = response.get("usage") or {} - normalized_usage = Translation._normalize_usage_metadata( - usage, "openai-responses" - ) - - return CanonicalChatResponse( - id=response.get("id", f"resp-{int(time.time())}"), - object=response.get("object", "response"), - created=response.get("created", int(time.time())), - model=response.get("model", "unknown"), - choices=choices, - usage=normalized_usage, - system_fingerprint=response.get("system_fingerprint"), - ) - - @staticmethod - def openai_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: - """ - Translate an OpenAI streaming chunk to a canonical dictionary format. - - Args: - chunk: The OpenAI streaming chunk. - - Returns: - A dictionary representing the canonical chunk format. - """ - import json - import time - import uuid - - if isinstance(chunk, bytes | bytearray): - try: - chunk = chunk.decode("utf-8") - except Exception: - return {"error": "Invalid chunk format: unable to decode bytes"} - - if isinstance(chunk, str): - stripped_chunk = chunk.strip() - - if not stripped_chunk: - return { - "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": "unknown", - "choices": [ - {"index": 0, "delta": {}, "finish_reason": None}, - ], - } - - if stripped_chunk.startswith(":"): - # Comment/heartbeat lines (e.g., ": ping") should be ignored by emitting - # an empty delta so downstream processors keep the stream alive. - return { - "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": "unknown", - "choices": [ - {"index": 0, "delta": {}, "finish_reason": None}, - ], - } - - if stripped_chunk.startswith("data:"): - stripped_chunk = stripped_chunk[5:].strip() - - if stripped_chunk == "[DONE]": - return { - "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": "unknown", - "choices": [ - {"index": 0, "delta": {}, "finish_reason": "stop"}, - ], - } - - try: - chunk = json.loads(stripped_chunk) - except json.JSONDecodeError as exc: - logger.warning( - "Responses stream chunk JSON decode failed: %s", - stripped_chunk[:300], - ) - return { - "error": "Invalid chunk format: expected JSON after 'data:' prefix", - "details": {"message": str(exc)}, - } - - if not isinstance(chunk, dict): - return {"error": "Invalid chunk format: expected a dictionary"} - - # Basic validation for essential keys - if "id" not in chunk or "choices" not in chunk: - if logger.isEnabledFor(TRACE_LEVEL): - try: - logger.log( - TRACE_LEVEL, - "OpenAI stream chunk missing id/choices: %s", - json.dumps(chunk)[:500], - ) - except Exception: - logger.log( - TRACE_LEVEL, - "OpenAI stream chunk missing id/choices (non-serializable)", - ) - return {"error": "Invalid chunk: missing 'id' or 'choices'"} - - # For simplicity, we'll return the chunk as a dictionary. - # In a more complex scenario, you might map this to a Pydantic model. - return dict(chunk) - - @staticmethod - def responses_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: - """Translate an OpenAI Responses streaming chunk to canonical format.""" - import json - import time - import uuid - - def _heartbeat_chunk(finish_reason: str | None = None) -> dict[str, Any]: - """Return a minimal chunk used for comments/heartbeats.""" - return { - "id": f"resp-{uuid.uuid4().hex[:16]}", - "object": "response.chunk", - "created": int(time.time()), - "model": "unknown", - "choices": [ - {"index": 0, "delta": {}, "finish_reason": finish_reason}, - ], - } - - def _extract_text(value: Any) -> str: - if isinstance(value, str): - return value - if isinstance(value, dict): - if "text" in value: - return _extract_text(value["text"]) - if "content" in value: - return _extract_text(value["content"]) - if "value" in value: - return _extract_text(value["value"]) - if isinstance(value, list): - parts = [_extract_text(v) for v in value] - return "".join(part for part in parts if part) - if value is None: - return "" - return str(value) - - if isinstance(chunk, bytes | bytearray): - try: - chunk = chunk.decode("utf-8") - except UnicodeDecodeError: - return { - "error": "Invalid chunk format: unable to decode bytes", - } - - event_type_from_sse: str | None = None - if isinstance(chunk, str): - stripped_chunk = chunk.strip() - - if not stripped_chunk: - return {"error": "Invalid chunk format: empty string"} - - if stripped_chunk.startswith(":"): - # Comment/heartbeat line (e.g., ": ping") - return _heartbeat_chunk() - - data_parts: list[str] = [] - has_data_prefix = False - for raw_line in chunk.splitlines(): - line = raw_line.strip() - if not line: - continue - if line.startswith(":"): - return _heartbeat_chunk() - if line.startswith("event:"): - event_type_from_sse = line[6:].strip() - continue - if line.startswith("data:"): - has_data_prefix = True - payload = line[5:].strip() - if payload.startswith("event:") and not event_type_from_sse: - event_type_from_sse = payload[6:].strip() - continue - data_parts.append(payload) - continue - data_parts.append(line) - - stripped_chunk = "\n".join(part for part in data_parts if part).strip() - - if not has_data_prefix: - return _heartbeat_chunk() - - if not stripped_chunk: - return _heartbeat_chunk() - - if stripped_chunk == "[DONE]": - return _heartbeat_chunk(finish_reason="stop") - - try: - chunk = json.loads(stripped_chunk) - except json.JSONDecodeError as exc: - logger.warning( - "Responses stream chunk JSON decode failed: %s", - stripped_chunk[:300], - ) - return { - "error": "Invalid chunk format: expected JSON after 'data:' prefix", - "details": {"message": str(exc)}, - } - - if not isinstance(chunk, dict): - return {"error": "Invalid chunk format: expected a dictionary"} - - response_payload = chunk.get("response") - if isinstance(response_payload, dict): - chunk_id = response_payload.get("id") - created = response_payload.get("created") - model = response_payload.get("model") - else: - chunk_id = None - created = None - model = None - - chunk_id = chunk_id or chunk.get("id") or f"resp-{uuid.uuid4().hex[:16]}" - created = created or chunk.get("created") or int(time.time()) - model = model or chunk.get("model") or "unknown" - object_type = chunk.get("object") or "response.chunk" - index = chunk.get("index", 0) - event_type = ( - (chunk.get("type") or event_type_from_sse or "").strip() - if chunk.get("type") or event_type_from_sse - else "" - ) - - if logger.isEnabledFor(TRACE_LEVEL): - try: - logger.log( - TRACE_LEVEL, - "Responses event type=%s payload=%s", - event_type or "", - json.dumps(chunk)[:400], - ) - except Exception: - logger.log( - TRACE_LEVEL, - "Responses event type=%s payload=", - event_type or "", - ) - - def _build_chunk( - delta: dict[str, Any] | None = None, - finish_reason: str | None = None, - ) -> dict[str, Any]: - return { - "id": chunk_id, - "object": object_type, - "created": created, - "model": model, - "choices": [ - { - "index": index, - "delta": delta or {}, - "finish_reason": finish_reason, - } - ], - } - - if event_type == "response.output_text.delta": - delta_payload = chunk.get("delta") - text = _extract_text(delta_payload) - if not text: - return _build_chunk() - delta_map: dict[str, Any] = {"content": text} - if isinstance(delta_payload, dict): - role = delta_payload.get("role") - if role: - delta_map["role"] = role - delta_map.setdefault("role", "assistant") - return _build_chunk(delta_map) - - if event_type == "response.reasoning_summary_text.delta": - summary_text = _extract_text(chunk.get("delta")) - return _build_chunk({"reasoning_summary": summary_text}) - - if event_type == "response.reasoning_text.delta": - reasoning_text = _extract_text(chunk.get("delta")) - return _build_chunk({"reasoning_content": reasoning_text}) - - if event_type == "response.function_call_arguments.delta": - call_id = chunk.get("item_id") or chunk.get("call_id") - name = chunk.get("name") or "" - delta_payload = chunk.get("delta") or {} - if isinstance(delta_payload, str): - arguments_fragment = delta_payload - else: - arguments_fragment = _extract_text(delta_payload) - if not isinstance(arguments_fragment, str): - arguments_fragment = json.dumps(delta_payload) - if arguments_fragment is None: - arguments_fragment = "" - tool_index = Translation._assign_tool_call_index( - chunk_id, chunk.get("output_index"), call_id - ) - function_payload: dict[str, Any] = { - "arguments": arguments_fragment, - } - if name: - function_payload["name"] = name - delta = { - "tool_calls": [ - { - "index": tool_index, - "id": call_id or "", - "type": "function", - "function": function_payload, - } - ] - } - return _build_chunk(delta) - - if event_type == "response.function_call_arguments.done": - call_id = chunk.get("item_id") or chunk.get("call_id") - name = chunk.get("name") or "" - arguments = chunk.get("arguments") - if isinstance(arguments, dict | list): - arguments = json.dumps(arguments) - elif arguments is None: - arguments = "{}" - else: - arguments = str(arguments) - tool_index = Translation._assign_tool_call_index( - chunk_id, chunk.get("output_index"), call_id - ) - tool_call_obj = ToolCall( - id=call_id or "", - type="function", - function=FunctionCall(name=name, arguments=arguments), - ) - tool_text = render_tool_call(tool_call_obj) - delta = { - "tool_calls": [ - { - "index": tool_index, - "id": call_id or "", - "type": "function", - "function": { - "name": name, - "arguments": arguments, - }, - } - ] - } - if tool_text: - delta["_tool_call_text"] = tool_text # type: ignore[assignment] - return _build_chunk(delta, "tool_calls") - - if event_type == "response.output_item.done": - item = chunk.get("item") or {} - item_type = item.get("type") - - if logger.isEnabledFor(TRACE_LEVEL): - try: - logger.log( - TRACE_LEVEL, - "Responses output_item.done item=%s", - json.dumps(item)[:400], - ) - except Exception: - logger.log( - TRACE_LEVEL, - "Responses output_item.done item=", - ) - - if item_type == "message": - text = _extract_text(item.get("content", [])) - message_delta: dict[str, Any] = {"content": text} if text else {} - role = item.get("role") - if role: - message_delta["role"] = role - return _build_chunk(message_delta or None) - - if item_type == "function_call": - arguments = item.get("arguments", "{}") - if not isinstance(arguments, str): - arguments = json.dumps(arguments) - call_id = ( - item.get("call_id") - or item.get("id") - or f"call_{uuid.uuid4().hex[:8]}" - ) - tool_index = Translation._assign_tool_call_index( - chunk_id, chunk.get("output_index"), call_id - ) - tool_call_obj = ToolCall( - id=call_id, - type="function", - function=FunctionCall( - name=item.get("name", ""), arguments=arguments - ), - ) - tool_text = render_tool_call(tool_call_obj) - delta = { - "tool_calls": [ - { - "id": call_id, - "index": tool_index, - "type": "function", - "function": { - "name": item.get("name", ""), - "arguments": arguments, - }, - } - ] - } - if tool_text: - delta["_tool_call_text"] = tool_text # type: ignore[assignment] - return _build_chunk(delta, "tool_calls") - - if item_type == "custom_tool_call": - input_payload = item.get("input", "") - if not isinstance(input_payload, str): - input_payload = json.dumps(input_payload) - call_id = ( - item.get("call_id") - or item.get("id") - or f"custom_{uuid.uuid4().hex[:8]}" - ) - tool_index = Translation._assign_tool_call_index( - chunk_id, chunk.get("output_index"), call_id - ) - tool_call_obj = ToolCall( - id=call_id, - type="function", - function=FunctionCall( - name=item.get("name", ""), arguments=input_payload - ), - ) - tool_text = render_tool_call(tool_call_obj) - delta = { - "tool_calls": [ - { - "id": call_id, - "index": tool_index, - "type": "function", - "function": { - "name": item.get("name", ""), - "arguments": input_payload or "{}", - }, - } - ] - } - if tool_text: - delta["_tool_call_text"] = tool_text # type: ignore[assignment] - - return _build_chunk(delta) - - if item_type == "local_shell_call": - action = item.get("action") or {} - arguments = action if isinstance(action, str) else json.dumps(action) - call_id = ( - item.get("call_id") - or item.get("id") - or f"shell_{uuid.uuid4().hex[:8]}" - ) - tool_index = Translation._assign_tool_call_index( - chunk_id, chunk.get("output_index"), call_id - ) - tool_call_obj = ToolCall( - id=call_id, - type="function", - function=FunctionCall(name="shell", arguments=arguments), - ) - tool_text = render_tool_call(tool_call_obj) - delta = { - "tool_calls": [ - { - "id": call_id, - "index": tool_index, - "type": "function", - "function": { - "name": "shell", - "arguments": arguments, - }, - } - ] - } - if tool_text: - delta["_tool_call_text"] = tool_text # type: ignore[assignment] - - return _build_chunk(delta) - - return _build_chunk() - - if event_type == "response.completed": - response_info = chunk.get("response") or {} - result = _build_chunk({}, "stop") - usage = response_info.get("usage") - if usage: - result["usage"] = usage - response_id = response_info.get("id") or chunk_id - if response_id: - result["response_id"] = response_id - Translation._reset_tool_call_state(response_id) - return result - - if event_type == "response.created": - response_info = chunk.get("response") or {} - response_id = response_info.get("id") or chunk_id - if response_id: - Translation._reset_tool_call_state(response_id) - created_delta: dict[str, Any] = {} - if response_id: - created_delta["response_id"] = response_id - created_delta["role"] = "assistant" - return _build_chunk(created_delta or None) - - if event_type == "response.failed": - response_info = chunk.get("response") or {} - error_payload = response_info.get("error") or chunk.get("error") or {} - Translation._reset_tool_call_state(response_info.get("id") or chunk_id) - return { - "error": "Responses stream reported failure", - "details": error_payload, - } - - if event_type in { - "response.output_text.done", - "response.output_item.added", - "response.custom_tool_call_input.done", - "response.custom_tool_call_input.delta", - "response.function_call_arguments.delta", - "response.in_progress", - "response.content_part.done", - }: - return _build_chunk() - - if "choices" in chunk: - choices = chunk.get("choices") or [] - if not isinstance(choices, list) or not choices: - return _build_chunk() - - primary_choice = choices[0] or {} - finish_reason = primary_choice.get("finish_reason") - raw_delta = primary_choice.get("delta") or {} - if isinstance(raw_delta, dict): - delta = cast(dict[str, Any], dict(raw_delta)) - else: - delta = {"content": cast(Any, str(raw_delta))} - - content_value = delta.get("content") - if isinstance(content_value, list): - text_parts: list[str] = [] - for part in content_value: - if not isinstance(part, dict): - continue - part_type = part.get("type") - if part_type in {"output_text", "text", "input_text"}: - text_value = part.get("text") or part.get("value") or "" - if text_value: - text_parts.append(str(text_value)) - delta["content"] = cast(Any, "".join(text_parts)) - elif isinstance(content_value, dict): - delta["content"] = cast(Any, json.dumps(content_value)) - elif content_value is None: - delta.pop("content", None) - else: - delta["content"] = cast(Any, str(content_value)) - - tool_calls = delta.get("tool_calls") - if isinstance(tool_calls, list): - normalized_tool_calls: list[dict[str, Any]] = [] - for tool_call in tool_calls: - if isinstance(tool_call, dict): - call_data = dict(tool_call) - else: - function = getattr(tool_call, "function", None) - call_data = { - "id": getattr(tool_call, "id", ""), - "type": getattr(tool_call, "type", "function"), - "function": { - "name": getattr(function, "name", ""), - "arguments": getattr(function, "arguments", "{}"), - }, - } - - function_payload = call_data.get("function") or {} - if isinstance(function_payload, dict): - arguments = function_payload.get("arguments") - if isinstance(arguments, dict | list): - function_payload["arguments"] = json.dumps(arguments) - elif arguments is None: - function_payload["arguments"] = "{}" - else: - function_payload["arguments"] = str(arguments) - - normalized_tool_calls.append(call_data) - - if normalized_tool_calls: - delta["tool_calls"] = normalized_tool_calls - else: - delta.pop("tool_calls", None) - - return _build_chunk(delta, finish_reason) - - # Default: emit an empty chunk to keep the stream progressing. - return _build_chunk() - - @staticmethod - def openrouter_to_domain_request(request: Any) -> CanonicalChatRequest: - """ - Translate an OpenRouter request to a CanonicalChatRequest. - """ - if isinstance(request, dict): - model = request.get("model") - messages = request.get("messages", []) - top_k = request.get("top_k") - top_p = request.get("top_p") - temperature = request.get("temperature") - max_tokens = request.get("max_tokens") - stop = request.get("stop") - seed = request.get("seed") - reasoning_effort = request.get("reasoning_effort") - extra_params = request.get("extra_params") - else: - model = getattr(request, "model", None) - messages = getattr(request, "messages", []) - top_k = getattr(request, "top_k", None) - top_p = getattr(request, "top_p", None) - temperature = getattr(request, "temperature", None) - max_tokens = getattr(request, "max_tokens", None) - stop = getattr(request, "stop", None) - seed = getattr(request, "seed", None) - reasoning_effort = getattr(request, "reasoning_effort", None) - extra_params = getattr(request, "extra_params", None) - - if not model: - raise ValueError("Model not found in request") - - # Convert messages to ChatMessage objects if they are dicts - chat_messages = [] - for msg in messages: - if isinstance(msg, dict): - chat_messages.append(ChatMessage(**msg)) - else: - chat_messages.append(msg) - - return CanonicalChatRequest( - model=model, - messages=chat_messages, - top_k=top_k, - top_p=top_p, - temperature=temperature, - max_tokens=max_tokens, - stop=stop, - seed=seed, - reasoning_effort=reasoning_effort, - stream=( - request.get("stream") - if isinstance(request, dict) - else getattr(request, "stream", None) - ), - extra_body=( - request.get("extra_body") - if isinstance(request, dict) - else getattr(request, "extra_body", None) - ) - or (extra_params if extra_params is not None else None), - tools=( - request.get("tools") - if isinstance(request, dict) - else getattr(request, "tools", None) - ), - tool_choice=( - request.get("tool_choice") - if isinstance(request, dict) - else getattr(request, "tool_choice", None) - ), - ) - - @staticmethod - def _validate_request_parameters(request: CanonicalChatRequest) -> None: - """Validate required parameters in a domain request.""" - if not request.model: - raise ValueError("Model is required") - - if not request.messages: - raise ValueError("Messages are required") - - # Validate message structure - for message in request.messages: - if not message.role: - raise ValueError("Message role is required") - - # Allow assistant messages that carry only tool_calls (no textual content) - if message.role != "system": - has_text = bool(message.content) - has_tool_calls = bool(getattr(message, "tool_calls", None)) - if not has_text and not ( - message.role == "assistant" and has_tool_calls - ): - raise ValueError(f"Content is required for {message.role} messages") - - # Validate tool parameters if present - if request.tools: - for tool in request.tools: - if isinstance(tool, dict): - if "function" not in tool: - raise ValueError("Tool must have a function") - if "name" not in tool.get("function", {}): - raise ValueError("Tool function must have a name") - - @staticmethod - def from_domain_to_gemini_request(request: CanonicalChatRequest) -> dict[str, Any]: - """ - Translate a CanonicalChatRequest to a Gemini request. - """ - - Translation._validate_request_parameters(request) - - config: dict[str, Any] = {} - if request.top_k is not None: - config["topK"] = request.top_k - if request.top_p is not None: - config["topP"] = request.top_p - if request.temperature is not None: - config["temperature"] = request.temperature - if request.max_tokens is not None: - config["maxOutputTokens"] = request.max_tokens - if request.stop: - config["stopSequences"] = Translation._normalize_stop_sequences( - request.stop - ) - - # Handle thinking budget overrides and reasoning effort mapping. - def _resolve_thinking_budget(reasoning_effort: str | None) -> int | None: - """Resolve thinking budget from CLI override or reasoning effort.""" - cli_value = os.environ.get("THINKING_BUDGET") - if cli_value is not None: - try: - return int(cli_value) - except ValueError: - return None - - if reasoning_effort is None: - return None - - effort_to_budget: dict[str, int] = { - "low": 512, - "medium": 2048, - "high": -1, - } - - return effort_to_budget.get(reasoning_effort.lower(), None) - - thinking_budget = _resolve_thinking_budget(request.reasoning_effort) - if thinking_budget is not None: - config["thinkingConfig"] = { - "thinkingBudget": thinking_budget, - "includeThoughts": True, - } - - # Process messages with proper handling of multimodal content and tool calls - contents: list[dict[str, Any]] = [] - # Track tool_call id -> function name to map tool responses - tool_name_by_id: dict[str, str] = {} - - # Group consecutive tool messages together to match Gemini's requirement - # that all functionResponse parts must be in a single user message - i = 0 - while i < len(request.messages): - message = request.messages[i] - - # Map assistant role to 'model' for Gemini compatibility; keep others as-is - if message.role == "assistant": - gemini_role = "model" - elif message.role == "tool": - # Gemini expects function responses from the "user" role - gemini_role = "user" - else: - gemini_role = message.role - msg_dict: dict[str, Any] = {"role": gemini_role} - parts: list[dict[str, Any]] = [] - - # Add assistant tool calls as functionCall parts - has_tool_calls = message.role == "assistant" and getattr( - message, "tool_calls", None - ) - if has_tool_calls: - try: - for tc in message.tool_calls or []: - tc_dict = tc if isinstance(tc, dict) else tc.model_dump() - fn = (tc_dict.get("function") or {}).get("name", "") - args_raw = (tc_dict.get("function") or {}).get("arguments", "") - # Remember mapping for subsequent tool responses - if "id" in tc_dict: - tool_name_by_id[tc_dict["id"]] = fn - # Parse arguments as JSON when possible - import json as _json - - try: - args_val = ( - _json.loads(args_raw) - if isinstance(args_raw, str) - else args_raw - ) - except Exception: - args_val = args_raw - parts.append({"functionCall": {"name": fn, "args": args_val}}) - except Exception: - # Best-effort; continue even if a tool call cannot be parsed - pass - - # Handle content which could be string, list of parts, or None - # IMPORTANT: Gemini API requires that if a message has functionCall parts, - # it should NOT have text content in the same message. This prevents - # "number of function response parts not equal to function call parts" errors. - if not has_tool_calls: - if isinstance(message.content, str): - # Simple text content - parts.append({"text": message.content}) - elif isinstance(message.content, list): - # Multimodal content (list of parts) - for part in message.content: - if hasattr(part, "type") and part.type == "image_url": - processed_image = Translation._process_gemini_image_part( - part - ) - if processed_image: - parts.append(processed_image) - elif hasattr(part, "type") and part.type == "text": - from src.core.domain.chat import MessageContentPartText - - # Handle text part - if isinstance(part, MessageContentPartText) and hasattr( - part, "text" - ): - parts.append({"text": part.text}) - else: - # Try best effort conversion - if hasattr(part, "model_dump"): - part_dict = part.model_dump() - if "text" in part_dict: - parts.append({"text": part_dict["text"]}) - - # Map tool role messages to functionResponse parts - # Group all consecutive tool messages into a single user message - if message.role == "tool": - # Collect all consecutive tool messages - tool_messages = [message] - j = i + 1 - while j < len(request.messages) and request.messages[j].role == "tool": - tool_messages.append(request.messages[j]) - j += 1 - - # Process all tool messages into functionResponse parts - for tool_msg in tool_messages: - # Try to map tool_call_id back to the function name - name = tool_name_by_id.get( - getattr(tool_msg, "tool_call_id", ""), "" - ) - resp_obj: dict[str, Any] - val = tool_msg.content - # Try to parse JSON result if provided - if isinstance(val, str): - import json as _json - - try: - resp_obj = _json.loads(val) - except Exception: - resp_obj = {"text": val} - elif isinstance(val, dict): - resp_obj = val - else: - resp_obj = {"text": str(val)} - - parts.append( - {"functionResponse": {"name": name, "response": resp_obj}} - ) - - # Skip the tool messages we just processed - i = j - 1 - - # Add parts to message - msg_dict["parts"] = parts # type: ignore - - # Only add non-empty messages - if parts: - contents.append(msg_dict) - - i += 1 - - result = {"contents": contents, "generationConfig": config} - - # Add tools if present - if request.tools: - # Gemini Code Assist only allows multiple tools when they are all - # search tools. For function calling, we must group ALL functions - # into a SINGLE tool entry with a combined function_declarations list. - function_declarations: list[dict[str, Any]] = [] - - for tool in request.tools: - # Accept dict-like or model-like entries - tool_dict: dict[str, Any] - if isinstance(tool, dict): - tool_dict = tool - else: - try: - tool_dict = tool.model_dump() # type: ignore[attr-defined] - except Exception: - tool_dict = {} - function = ( - tool_dict.get("function") if isinstance(tool_dict, dict) else None - ) - if not function: - # Skip non-function tools for now (unsupported/mixed types) - continue - - params = Translation._sanitize_gemini_parameters( - function.get("parameters", {}) - ) - function_declarations.append( - { - "name": function.get("name", ""), - "description": function.get("description", ""), - "parameters": params, - } - ) - - if function_declarations: - result["tools"] = [{"function_declarations": function_declarations}] - - # Handle tool_choice for Gemini - if request.tool_choice: - mode = "AUTO" # Default - allowed_functions = None - - if isinstance(request.tool_choice, str): - if request.tool_choice == "none": - mode = "NONE" - elif request.tool_choice == "auto": - mode = "AUTO" - elif request.tool_choice in ["any", "required"]: - mode = "ANY" - elif ( - isinstance(request.tool_choice, dict) - and request.tool_choice.get("type") == "function" - ): - function_spec = request.tool_choice.get("function", {}) - function_name = function_spec.get("name") - if function_name: - mode = "ANY" - allowed_functions = [function_name] - - fcc: dict[str, Any] = {"mode": mode} - if allowed_functions: - fcc["allowedFunctionNames"] = allowed_functions - result["toolConfig"] = {"functionCallingConfig": fcc} - - # Handle structured output for Responses API - if request.extra_body and "response_format" in request.extra_body: - response_format = request.extra_body["response_format"] - if response_format.get("type") == "json_schema": - json_schema = response_format.get("json_schema", {}) - schema = json_schema.get("schema", {}) - - # For Gemini, add JSON mode and schema constraint to generation config - generation_config = result["generationConfig"] - if isinstance(generation_config, dict): - generation_config["responseMimeType"] = "application/json" - generation_config["responseSchema"] = schema - - # Add schema name and description as additional context if available - schema_name = json_schema.get("name") - schema_description = json_schema.get("description") - if schema_name or schema_description: - # Add schema context to the last user message or create a system-like instruction - schema_context = "Generate a JSON response" - if schema_name: - schema_context += f" for '{schema_name}'" - if schema_description: - schema_context += f": {schema_description}" - schema_context += ( - ". The response must conform to the provided JSON schema." - ) - - # Add this as context to help the model understand the structured output requirement - if ( - contents - and isinstance(contents[-1], dict) - and contents[-1].get("role") == "user" - ): - # Append to the last user message - last_message = contents[-1] - if ( - isinstance(last_message, dict) - and last_message.get("parts") - and isinstance(last_message["parts"], list) - ): - last_message["parts"].append( - {"text": f"\n\n{schema_context}"} - ) - else: - # Add as a new user message - contents.append( - {"role": "user", "parts": [{"text": schema_context}]} - ) - - return result - - @staticmethod - def _sanitize_gemini_parameters(schema: dict[str, Any]) -> dict[str, Any]: - """Sanitize OpenAI tool JSON schema for Gemini Code Assist function_declarations. - - The Code Assist API rejects certain JSON Schema keywords (e.g., "$schema", - and sometimes draft-specific fields like "exclusiveMinimum"). This method - removes unsupported keywords while preserving the core shape (type, - properties, required, items, enum, etc.). - - Args: - schema: Original JSON schema dict from OpenAI tool definition - - Returns: - A sanitized schema dict suitable for Gemini Code Assist. - """ - if not isinstance(schema, dict): - return {} - - blacklist = { - "$schema", - "$id", - "$comment", - "exclusiveMinimum", - "exclusiveMaximum", - } - - def _clean(obj: Any) -> Any: - if isinstance(obj, dict): - cleaned: dict[str, Any] = {} - for k, v in obj.items(): - if k in blacklist: - continue - cleaned[k] = _clean(v) - return cleaned - if isinstance(obj, list): - return [_clean(x) for x in obj] - return obj - - cleaned = _clean(schema) - return cleaned if isinstance(cleaned, dict) else {} - - @staticmethod - def from_domain_to_openai_request(request: CanonicalChatRequest) -> dict[str, Any]: - """ - Translate a CanonicalChatRequest to an OpenAI request. - """ - messages_payload: list[dict[str, Any]] = [] - for message in request.messages: - if hasattr(message, "to_dict"): - message_dict = message.to_dict() - # Preserve explicit `content: None` semantics expected by the - # OpenAI API when tool_calls are present. - if "content" not in message_dict: - message_dict["content"] = None - else: - message_dict = { - "role": getattr(message, "role", "assistant"), - "content": getattr(message, "content", None), - } - tool_calls = getattr(message, "tool_calls", None) - if tool_calls is not None: - message_dict["tool_calls"] = tool_calls - tool_call_id = getattr(message, "tool_call_id", None) - if tool_call_id is not None: - message_dict["tool_call_id"] = tool_call_id - - messages_payload.append(message_dict) - - payload: dict[str, Any] = { - "model": request.model, - "messages": messages_payload, - } - - # Add all supported parameters - if request.top_p is not None: - payload["top_p"] = request.top_p - if request.temperature is not None: - payload["temperature"] = request.temperature - if request.max_tokens is not None: - payload["max_tokens"] = request.max_tokens - if request.stream is not None: - payload["stream"] = request.stream - if request.stop is not None: - payload["stop"] = Translation._normalize_stop_sequences(request.stop) - if request.seed is not None: - payload["seed"] = request.seed - if request.presence_penalty is not None: - payload["presence_penalty"] = request.presence_penalty - if request.frequency_penalty is not None: - payload["frequency_penalty"] = request.frequency_penalty - if request.user is not None: - payload["user"] = request.user - if request.tools is not None: - payload["tools"] = request.tools - if request.tool_choice is not None: - payload["tool_choice"] = request.tool_choice - - # Handle OpenAI reasoning configuration - reasoning_payload: dict[str, Any] | None = None - if request.reasoning is not None: - if isinstance(request.reasoning, dict): - reasoning_payload = dict(request.reasoning) - elif hasattr(request.reasoning, "model_dump"): - reasoning_payload = request.reasoning.model_dump() # type: ignore[attr-defined] - - effort_value = request.reasoning_effort - normalized_effort: str | None - if isinstance(effort_value, str): - normalized_effort = effort_value.strip() - else: - normalized_effort = str(effort_value) if effort_value is not None else None - - if normalized_effort: - if reasoning_payload is None: - reasoning_payload = {} - if "effort" not in reasoning_payload: - reasoning_payload["effort"] = effort_value - - if reasoning_payload: - payload["reasoning"] = reasoning_payload - - # Handle structured output for Responses API - if request.extra_body and "response_format" in request.extra_body: - response_format = request.extra_body["response_format"] - if response_format.get("type") == "json_schema": - # For OpenAI, we can pass the response_format directly - payload["response_format"] = response_format - - return payload - - @staticmethod - def anthropic_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: - """ - Translate an Anthropic streaming chunk to a canonical dictionary format. - - Args: - chunk: The Anthropic streaming chunk (can be SSE string or dict). - - Returns: - A dictionary representing the canonical chunk format. - """ - import json - import time - import uuid - - # Handle SSE-formatted strings - if isinstance(chunk, str): - # Parse SSE format - handle multi-line SSE events - chunk = chunk.strip() - - # Handle [DONE] marker - if "data: [DONE]" in chunk or chunk == "[DONE]": - return { - "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": "claude-3-opus-20240229", - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": None, - } - ], - } - - # Extract data line from multi-line SSE chunk - data_line = None - for line in chunk.split("\n"): - line = line.strip() - if line.startswith("data:"): - data_line = line[5:].strip() - break - - # If no data line found, check if entire chunk is just event/id lines - if data_line is None: - if chunk.startswith(("event:", "id:")) or not chunk: - return { - "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": "claude-3-opus-20240229", - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": None, - } - ], - } - # Try to use the whole chunk as JSON - data_line = chunk - - # Try to parse as JSON - try: - chunk = json.loads(data_line) - except json.JSONDecodeError: - return {"error": "Invalid chunk format: expected a dictionary"} - - if not isinstance(chunk, dict): - return {"error": "Invalid chunk format: expected a dictionary"} - - response_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" - created = int(time.time()) - model = "claude-3-opus-20240229" # Default model - - content = "" - finish_reason = None - role = None - - # Handle different Anthropic event types - event_type = chunk.get("type") - - if event_type == "message_start": - # Message start event - set role - role = "assistant" - elif event_type == "content_block_start": - # Content block start - no content yet - pass - elif event_type == "content_block_delta": - # Content delta - extract text - delta = chunk.get("delta", {}) - if delta.get("type") == "text_delta": - content = delta.get("text", "") - elif event_type == "content_block_stop": - # Content block stop - no action needed - pass - elif event_type == "message_delta": - # Message delta - check for finish reason - delta = chunk.get("delta", {}) - stop_reason = delta.get("stop_reason") - if stop_reason == "end_turn": - finish_reason = "stop" - elif stop_reason == "max_tokens": - finish_reason = "length" - elif stop_reason == "tool_use": - finish_reason = "tool_calls" - elif event_type == "message_stop": - # Message stop - mark as complete - finish_reason = "stop" - - # Build delta - output_delta: dict[str, Any] = {} - if role: - output_delta["role"] = role - if content: - output_delta["content"] = content - - return { - "id": response_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "delta": output_delta, - "finish_reason": finish_reason, - } - ], - } - - @staticmethod - def from_domain_to_anthropic_request( - request: CanonicalChatRequest, - ) -> dict[str, Any]: - """ - Translate a CanonicalChatRequest to an Anthropic request. - """ - # Process messages with proper handling of system messages and multimodal content - processed_messages = [] - system_message = None - - for message in request.messages: - if message.role == "system": - # Extract system message - system_message = message.content - continue - - # Process regular messages - msg_dict = {"role": message.role} - - # Handle content which could be string, list of parts, or None - if message.content is None: - # Skip empty content - continue - elif isinstance(message.content, str): - # Simple text content - msg_dict["content"] = message.content - elif isinstance(message.content, list): - # Multimodal content (list of parts) - content_parts = [] - for part in message.content: - from src.core.domain.chat import ( - MessageContentPartImage, - MessageContentPartText, - ) - - if isinstance(part, MessageContentPartImage): - # Handle image part - if part.image_url: - url_str = str(part.image_url.url) - # Only include data URLs; skip http/https URLs - if url_str.startswith("data:"): - content_parts.append( - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/jpeg", - "data": url_str.split(",", 1)[-1], - }, - } - ) - elif isinstance(part, MessageContentPartText): - # Handle text part - content_parts.append({"type": "text", "text": part.text}) - else: - # Try best effort conversion - if hasattr(part, "model_dump"): - part_dict = part.model_dump() - if "text" in part_dict: - content_parts.append( - {"type": "text", "text": part_dict["text"]} - ) - - if content_parts: - # Use type annotation to help mypy - msg_dict["content"] = content_parts # type: ignore - - # Handle tool calls if present - if message.tool_calls: - tool_calls = [] - for tool_call in message.tool_calls: - if hasattr(tool_call, "model_dump"): - tool_call_dict = tool_call.model_dump() - tool_calls.append(tool_call_dict) - elif isinstance(tool_call, dict): - tool_calls.append(tool_call) - else: - # Convert to dict if possible - try: - tool_call_dict = dict(tool_call) - tool_calls.append(tool_call_dict) - except (TypeError, ValueError): - # Skip if can't convert - continue - - if tool_calls: - # Use type annotation to help mypy - msg_dict["tool_calls"] = tool_calls # type: ignore - - # Handle tool call ID if present - if message.tool_call_id: - msg_dict["tool_call_id"] = message.tool_call_id - - # Handle name if present - if message.name: - msg_dict["name"] = message.name - - processed_messages.append(msg_dict) - - payload: dict[str, Any] = { - "model": request.model, - "messages": processed_messages, - "max_tokens": request.max_tokens or 1024, - "stream": request.stream, - } - - if system_message: - payload["system"] = system_message - if request.temperature is not None: - payload["temperature"] = request.temperature - if request.top_p is not None: - payload["top_p"] = request.top_p - if request.top_k is not None: - payload["top_k"] = request.top_k - - # Handle tools if present - if request.tools: - # Convert tools to Anthropic format - anthropic_tools = [] - for tool in request.tools: - if isinstance(tool, dict) and "function" in tool: - anthropic_tool = {"type": "function", "function": tool["function"]} - anthropic_tools.append(anthropic_tool) - elif not isinstance(tool, dict): - tool_dict = tool.model_dump() - if "function" in tool_dict: - anthropic_tool = { - "type": "function", - "function": tool_dict["function"], - } - anthropic_tools.append(anthropic_tool) - - if anthropic_tools: - payload["tools"] = anthropic_tools - - # Handle tool_choice if present - if request.tool_choice: - if isinstance(request.tool_choice, dict): - if request.tool_choice.get("type") == "function": - # Already in Anthropic format - payload["tool_choice"] = request.tool_choice - elif "function" in request.tool_choice: - # Convert from OpenAI format to Anthropic format - payload["tool_choice"] = { - "type": "function", - "function": request.tool_choice["function"], - } - elif request.tool_choice == "auto" or request.tool_choice == "none": - payload["tool_choice"] = request.tool_choice - - # Add stop sequences if present - if request.stop: - payload["stop_sequences"] = Translation._normalize_stop_sequences( - request.stop - ) - - # Add metadata if present in extra_body - if request.extra_body and isinstance(request.extra_body, dict): - metadata = request.extra_body.get("metadata") - if metadata: - payload["metadata"] = metadata - - # Handle structured output for Responses API - response_format = request.extra_body.get("response_format") - if response_format and response_format.get("type") == "json_schema": - json_schema = response_format.get("json_schema", {}) - schema = json_schema.get("schema", {}) - schema_name = json_schema.get("name") - schema_description = json_schema.get("description") - strict = json_schema.get("strict", True) - - # For Anthropic, add comprehensive JSON schema instruction to system message - import json - - schema_instruction = ( - "\n\nYou must respond with valid JSON that conforms to this schema" - ) - if schema_name: - schema_instruction += f" for '{schema_name}'" - if schema_description: - schema_instruction += f" ({schema_description})" - schema_instruction += f":\n\n{json.dumps(schema, indent=2)}" - - if strict: - schema_instruction += "\n\nIMPORTANT: The response must strictly adhere to this schema. Do not include any additional fields or deviate from the specified structure." - else: - schema_instruction += "\n\nNote: The response should generally follow this schema, but minor variations may be acceptable." - - schema_instruction += "\n\nRespond only with the JSON object, no additional text or formatting." - - if payload.get("system"): - if isinstance(payload["system"], str): - payload["system"] += schema_instruction - else: - # If not a string, we cannot append. Replace it. - payload["system"] = schema_instruction - else: - payload["system"] = ( - f"You are a helpful assistant.{schema_instruction}" - ) - - return payload - - @staticmethod - def code_assist_to_domain_request(request: Any) -> CanonicalChatRequest: - """ - Translate a Code Assist API request to a CanonicalChatRequest. - - The Code Assist API uses the same format as OpenAI for the core request, - but with additional project field and different endpoint. - """ - # Code Assist API request format is essentially the same as OpenAI - # but may include a "project" field - if isinstance(request, dict): - # Remove Code Assist specific fields and treat as OpenAI format - cleaned_request = {k: v for k, v in request.items() if k != "project"} - return Translation.openai_to_domain_request(cleaned_request) - else: - # Handle object format by extracting fields - return Translation.openai_to_domain_request(request) - - @staticmethod - def code_assist_to_domain_response(response: Any) -> CanonicalChatResponse: - """ - Translate a Code Assist API response to a CanonicalChatResponse. - - The Code Assist API wraps the response in a "response" object and uses - different structure than standard Gemini API. - """ - import time - - if not isinstance(response, dict): - # Handle non-dict responses - return CanonicalChatResponse( - id=f"chatcmpl-code-assist-{int(time.time())}", - object="chat.completion", - created=int(time.time()), - model="unknown", - choices=[ - ChatCompletionChoice( - index=0, - message=ChatCompletionChoiceMessage( - role="assistant", content=str(response) - ), - finish_reason="stop", - ) - ], - usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - ) - - # Extract from Code Assist response wrapper - response_wrapper = response.get("response", {}) - candidates = response_wrapper.get("candidates", []) - generated_text = "" - - if candidates and len(candidates) > 0: - candidate = candidates[0] - content = candidate.get("content") or {} - parts = content.get("parts", []) - - if parts and len(parts) > 0: - generated_text = parts[0].get("text", "") - - # Create canonical response - return CanonicalChatResponse( - id=f"chatcmpl-code-assist-{int(time.time())}", - object="chat.completion", - created=int(time.time()), - model=response.get("model", "code-assist-model"), - choices=[ - ChatCompletionChoice( - index=0, - message=ChatCompletionChoiceMessage( - role="assistant", content=generated_text - ), - finish_reason="stop", - ) - ], - usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - ) - - @staticmethod - def code_assist_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: - """ - Translate a Code Assist API streaming chunk to a canonical dictionary format. - - Code Assist API uses Server-Sent Events (SSE) format with "data: " prefix. - """ - import time - import uuid - - if chunk is None: - # Handle end of stream - return { - "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": "code-assist-model", - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "stop", - } - ], - } - - if not isinstance(chunk, dict): - return {"error": "Invalid chunk format: expected a dictionary"} - - response_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" - created = int(time.time()) - model = "code-assist-model" - - content = "" - finish_reason = None - tool_calls: list[dict[str, Any]] | None = None - - # Extract from Code Assist response wrapper - response_wrapper = chunk.get("response", {}) - candidates = response_wrapper.get("candidates", []) - - if candidates and len(candidates) > 0: - candidate = candidates[0] - content_obj = candidate.get("content") or {} - parts = content_obj.get("parts", []) - - if parts and len(parts) > 0: - # Collect text and function calls - text_parts: list[str] = [] - for part in parts: - if isinstance(part, dict) and "text" in part: - text_parts.append(part.get("text", "")) - elif isinstance(part, dict) and "functionCall" in part: - try: - if tool_calls is None: - tool_calls = [] - tool_calls.append( - Translation._process_gemini_function_call( - part["functionCall"] - ).model_dump() - ) - except Exception: - # Ignore malformed functionCall parts - continue - content = "".join(text_parts) - - if "finishReason" in candidate: - finish_reason = candidate["finishReason"] - - delta: dict[str, Any] = {"role": "assistant"} - if tool_calls: - delta["tool_calls"] = tool_calls - # Enforce OpenAI semantics: when tool_calls are present, do not include content - delta.pop("content", None) - # Force finish_reason to tool_calls to signal clients to execute tools - finish_reason = "tool_calls" - elif content: - delta["content"] = content - - return { - "id": response_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "delta": delta, - "finish_reason": finish_reason, - } - ], - } - - @staticmethod - def raw_text_to_domain_request(request: Any) -> CanonicalChatRequest: - """ - Translate a raw text request to a CanonicalChatRequest. - - Raw text format is typically used for simple text processing where - the input is just a plain text string. - """ - - if isinstance(request, str): - # Create a simple request with the text as user message - from src.core.domain.chat import ChatMessage - - return CanonicalChatRequest( - model="text-model", - messages=[ChatMessage(role="user", content=request)], - ) - elif isinstance(request, dict): - # If it's already a dict, treat it as OpenAI format - return Translation.openai_to_domain_request(request) - else: - # Handle object format - return Translation.openai_to_domain_request(request) - - @staticmethod - def raw_text_to_domain_response(response: Any) -> CanonicalChatResponse: - """ - Translate a raw text response to a CanonicalChatResponse. - - Raw text format is typically used for simple text responses. - """ - import time - - if isinstance(response, str): - return CanonicalChatResponse( - id=f"chatcmpl-raw-text-{int(time.time())}", - object="chat.completion", - created=int(time.time()), - model="text-model", - choices=[ - ChatCompletionChoice( - index=0, - message=ChatCompletionChoiceMessage( - role="assistant", content=response - ), - finish_reason="stop", - ) - ], - usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - ) - elif isinstance(response, dict): - # If it's already a dict, treat it as OpenAI format - return Translation.openai_to_domain_response(response) - else: - # Handle object format - return Translation.openai_to_domain_response(response) - - @staticmethod - def raw_text_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: - """ - Translate a raw text stream chunk to a canonical dictionary format. - - Raw text chunks are typically plain text strings. - """ - import time - import uuid - - if chunk is None: - # Handle end of stream - return { - "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": "text-model", - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "stop", - } - ], - } - - if isinstance(chunk, str): - # Raw text chunk - return { - "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": "text-model", - "choices": [ - { - "index": 0, - "delta": {"role": "assistant", "content": chunk}, - "finish_reason": None, - } - ], - } - elif isinstance(chunk, dict): - # Check if it's a wrapped text dict like {"text": "content"} - if "text" in chunk and isinstance(chunk["text"], str): - return { - "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": "text-model", - "choices": [ - { - "index": 0, - "delta": {"role": "assistant", "content": chunk["text"]}, - "finish_reason": None, - } - ], - } - else: - # If it's already a dict, treat it as OpenAI format - return Translation.openai_to_domain_stream_chunk(chunk) - else: - return {"error": "Invalid raw text chunk format"} - - @staticmethod - def responses_to_domain_request(request: Any) -> CanonicalChatRequest: - """ - Translate a Responses API request to a CanonicalChatRequest. - - The Responses API request includes structured output requirements via response_format. - This method converts the request to the internal domain format while preserving - the JSON schema information for later use by backends. - """ - from src.core.domain.responses_api import ResponsesRequest - - # Normalize incoming payload regardless of format (dict, model, or object) - def _prepare_payload(payload: dict[str, Any]) -> dict[str, Any]: - normalized_payload = dict(payload) - if "messages" not in normalized_payload and "input" in normalized_payload: - normalized_payload["messages"] = ( - Translation._normalize_responses_input_to_messages( - normalized_payload["input"] - ) - ) - return normalized_payload - - if isinstance(request, dict): - request_payload = _prepare_payload(request) - if not request_payload.get("model"): - raise ValueError("'model' is a required property") - other_params = { - k: v - for k, v in request_payload.items() - if k not in ["model", "messages"] - } - responses_request = ResponsesRequest( - model=request_payload.get("model") or "", - messages=request_payload.get("messages") or [], - **other_params, - ) - elif hasattr(request, "model_dump"): - request_payload = _prepare_payload(request.model_dump()) - responses_request = ( - request - if isinstance(request, ResponsesRequest) - else ResponsesRequest(**request_payload) - ) - else: - request_payload = { - "model": getattr(request, "model", None), - "messages": getattr(request, "messages", None), - "response_format": getattr(request, "response_format", None), - "max_tokens": getattr(request, "max_tokens", None), - "temperature": getattr(request, "temperature", None), - "top_p": getattr(request, "top_p", None), - "n": getattr(request, "n", None), - "stream": getattr(request, "stream", None), - "stop": getattr(request, "stop", None), - "presence_penalty": getattr(request, "presence_penalty", None), - "frequency_penalty": getattr(request, "frequency_penalty", None), - "logit_bias": getattr(request, "logit_bias", None), - "user": getattr(request, "user", None), - "seed": getattr(request, "seed", None), - "session_id": getattr(request, "session_id", None), - "agent": getattr(request, "agent", None), - "extra_body": getattr(request, "extra_body", None), - } - - input_value = getattr(request, "input", None) - if (not request_payload.get("messages")) and input_value is not None: - request_payload["messages"] = ( - Translation._normalize_responses_input_to_messages(input_value) - ) - - other_params = { - k: v - for k, v in request_payload.items() - if k not in ["model", "messages"] - } - responses_request = ResponsesRequest( - model=request_payload.get("model") or "", - messages=request_payload.get("messages") or [], - **other_params, - ) - - # Prepare extra_body with response format - extra_body = dict(responses_request.extra_body or {}) - if responses_request.response_format is not None: - extra_body["response_format"] = ( - responses_request.response_format.model_dump() - ) - - # Convert to CanonicalChatRequest - canonical_request = CanonicalChatRequest( - model=responses_request.model, - messages=responses_request.messages, - temperature=responses_request.temperature, - top_p=responses_request.top_p, - max_tokens=responses_request.max_tokens, - n=responses_request.n, - stream=responses_request.stream, - stop=responses_request.stop, - presence_penalty=responses_request.presence_penalty, - frequency_penalty=responses_request.frequency_penalty, - logit_bias=responses_request.logit_bias, - user=responses_request.user, - seed=responses_request.seed, - session_id=responses_request.session_id, - agent=responses_request.agent, - extra_body=extra_body, - ) - - return canonical_request - - @staticmethod - def from_domain_to_responses_response(response: ChatResponse) -> dict[str, Any]: - """ - Translate a domain ChatResponse to a Responses API response format. - - This method converts the internal domain response to the OpenAI Responses API format, - including parsing structured outputs and handling JSON schema validation results. - """ - import json - import time - - # Convert choices to Responses API format - choices = [] - output_items: list[dict[str, Any]] = [] - aggregated_output_text: list[str | None] = [] - - def _map_finish_reason_to_status(finish_reason: str | None) -> str: - if finish_reason in (None, "", "stop"): - return "completed" - if finish_reason == "length": - return "incomplete" - if finish_reason in {"tool_calls", "function_call"}: - return "requires_action" - if finish_reason == "content_filter": - return "blocked" - return "completed" - - for choice in response.choices: - if choice.message: - # Try to parse the content as JSON for structured output - parsed_content = None - raw_content = choice.message.content or "" - - # Clean up content for JSON parsing - cleaned_content = raw_content.strip() - - # Handle cases where the model might wrap JSON in markdown code blocks - if cleaned_content.startswith("```json") and cleaned_content.endswith( - "```" - ): - cleaned_content = cleaned_content[7:-3].strip() - elif cleaned_content.startswith("```") and cleaned_content.endswith( - "```" - ): - cleaned_content = cleaned_content[3:-3].strip() - - # Attempt to parse JSON content - if cleaned_content: - try: - parsed_content = json.loads(cleaned_content) - # If parsing succeeded, use the cleaned content as the actual content - raw_content = cleaned_content - except json.JSONDecodeError: - # Content is not valid JSON, leave parsed as None - # Try to extract JSON from the content if it contains other text - try: - # Look for JSON-like patterns in the content - import re - - json_pattern = r"\{.*\}" - json_match = re.search( - json_pattern, cleaned_content, re.DOTALL - ) - if json_match: - potential_json = json_match.group(0) - parsed_content = json.loads(potential_json) - raw_content = potential_json - except (json.JSONDecodeError, AttributeError): - # Still not valid JSON, leave parsed as None - pass - - message_payload: dict[str, Any] = { - "role": choice.message.role, - "content": raw_content or None, - "parsed": parsed_content, - } - - tool_calls_payload: list[dict[str, Any]] = [] - if choice.message.tool_calls: - for tool_call in choice.message.tool_calls: - if hasattr(tool_call, "model_dump"): - tool_data = tool_call.model_dump() - elif isinstance(tool_call, dict): - tool_data = dict(tool_call) - else: - function = getattr(tool_call, "function", None) - tool_data = { - "id": getattr(tool_call, "id", ""), - "type": getattr(tool_call, "type", "function"), - "function": { - "name": getattr(function, "name", ""), - "arguments": getattr(function, "arguments", "{}"), - }, - } - - function_payload = tool_data.get("function") - if isinstance(function_payload, dict): - arguments = function_payload.get("arguments") - if isinstance(arguments, dict | list): - function_payload["arguments"] = json.dumps(arguments) - elif arguments is None: - function_payload["arguments"] = "{}" - - tool_calls_payload.append(tool_data) - - if tool_calls_payload: - message_payload["tool_calls"] = tool_calls_payload - - response_choice = { - "index": choice.index, - "message": message_payload, - "finish_reason": choice.finish_reason or "stop", - } - choices.append(response_choice) - - # Build Responses API output structure (mirrors incoming payloads) - text_value = ( - message_payload.get("content") - if isinstance(message_payload.get("content"), str) - else None - ) - if text_value: - aggregated_output_text.append(text_value) - else: - aggregated_output_text.append(None) - - output_content_parts: list[dict[str, Any]] = [] - - if text_value: - output_content_parts.append( - {"type": "output_text", "text": text_value} - ) - - if tool_calls_payload: - for tool_call in tool_calls_payload: # type: ignore[assignment] - tool_call_dict: dict[str, Any] = ( - tool_call if isinstance(tool_call, dict) else {} - ) - output_content_parts.append( - { - "type": "tool_call", - "id": tool_call_dict.get("id", ""), - "function": tool_call_dict.get("function", {}), - } - ) - - output_item = { - "id": f"msg-{response.id}-{choice.index}", - "type": "message", - "role": choice.message.role, - "status": _map_finish_reason_to_status(choice.finish_reason), - "content": output_content_parts, - } - - if choice.finish_reason: - output_item["finish_reason"] = choice.finish_reason - - output_items.append(output_item) - - # Build the Responses API response - responses_response = { - "id": response.id, - "object": "response", - "created": response.created or int(time.time()), - "model": response.model, - "choices": choices, - } - - if output_items: - responses_response["output"] = output_items - - # Only include output_text if we have any textual content - text_values = [text for text in aggregated_output_text if text is not None] - if text_values: - responses_response["output_text"] = [ - text if text is not None else "" for text in aggregated_output_text - ] - - # Add usage information if available - if response.usage: - responses_response["usage"] = response.usage - - # Add system fingerprint if available - if hasattr(response, "system_fingerprint") and response.system_fingerprint: - responses_response["system_fingerprint"] = response.system_fingerprint - - return responses_response - - @staticmethod - def from_domain_to_responses_request( - request: CanonicalChatRequest, - ) -> dict[str, Any]: - """ - Translate a CanonicalChatRequest to an OpenAI Responses API request format. - - This method converts the internal domain request to the OpenAI Responses API format, - extracting the response_format from extra_body and structuring it properly. - """ - # Start with basic OpenAI request format - payload = Translation.from_domain_to_openai_request(request) - - if request.extra_body: - extra_body_copy = dict(request.extra_body) - - # Extract and restructure response_format from extra_body - response_format = extra_body_copy.pop("response_format", None) - if response_format is not None: - # Ensure the response_format is properly structured for Responses API - if isinstance(response_format, dict): - payload["response_format"] = response_format - elif hasattr(response_format, "model_dump"): - payload["response_format"] = response_format.model_dump() - else: - payload["response_format"] = response_format - - # Add any remaining extra_body parameters that are safe for Responses API - safe_extra_body = Translation._filter_responses_extra_body(extra_body_copy) - if safe_extra_body: - payload.update(safe_extra_body) - - return payload - - @staticmethod - def _filter_responses_extra_body(extra_body: dict[str, Any]) -> dict[str, Any]: - """Filter extra_body entries to include only Responses API specific parameters.""" - - if not extra_body: - return {} - - allowed_keys: set[str] = {"metadata"} - - return {key: value for key, value in extra_body.items() if key in allowed_keys} - - @staticmethod - def enhance_structured_output_response( - response: ChatResponse, - original_request_extra_body: dict[str, Any] | None = None, - ) -> ChatResponse: - """ - Enhance a ChatResponse with structured output validation and repair. - - This method validates the response against the original JSON schema - and attempts repair if validation fails. - - Args: - response: The original ChatResponse - original_request_extra_body: The extra_body from the original request containing schema info - - Returns: - Enhanced ChatResponse with validated/repaired structured output - """ - if not original_request_extra_body: - return response - - response_format = original_request_extra_body.get("response_format") - if not response_format or response_format.get("type") != "json_schema": - return response - - json_schema_info = response_format.get("json_schema", {}) - schema = json_schema_info.get("schema", {}) - - if not schema: - return response - - # Process each choice - enhanced_choices = [] - for choice in response.choices: - if not choice.message or not choice.message.content: - enhanced_choices.append(choice) - continue - - content = choice.message.content.strip() - - # Try to parse and validate the JSON - try: - parsed_json = json.loads(content) - - # Validate against schema - is_valid, error_msg = Translation.validate_json_against_schema( - parsed_json, schema - ) - - if is_valid: - # Content is valid, keep as is - enhanced_choices.append(choice) - else: - # Try to repair the JSON - repaired_json = Translation._attempt_json_repair( - parsed_json, schema, error_msg - ) - if repaired_json is not None: - # Use repaired JSON - repaired_content = json.dumps(repaired_json, indent=2) - enhanced_message = ChatCompletionChoiceMessage( - role=choice.message.role, - content=repaired_content, - tool_calls=choice.message.tool_calls, - ) - enhanced_choice = ChatCompletionChoice( - index=choice.index, - message=enhanced_message, - finish_reason=choice.finish_reason, - ) - enhanced_choices.append(enhanced_choice) - else: - # Repair failed, keep original - enhanced_choices.append(choice) - - except json.JSONDecodeError: - # Not valid JSON, try to extract and repair - extracted_and_repaired_content: str | None = ( - Translation._extract_and_repair_json(content, schema) - ) - if extracted_and_repaired_content is not None: - enhanced_message = ChatCompletionChoiceMessage( - role=choice.message.role, - content=extracted_and_repaired_content, - tool_calls=choice.message.tool_calls, - ) - enhanced_choice = ChatCompletionChoice( - index=choice.index, - message=enhanced_message, - finish_reason=choice.finish_reason, - ) - enhanced_choices.append(enhanced_choice) - else: - # Repair failed, keep original - enhanced_choices.append(choice) - - # Create enhanced response - enhanced_response = CanonicalChatResponse( - id=response.id, - object=response.object, - created=response.created, - model=response.model, - choices=enhanced_choices, - usage=response.usage, - system_fingerprint=getattr(response, "system_fingerprint", None), - ) - - return enhanced_response - - @staticmethod - def _attempt_json_repair( - json_data: dict[str, Any], schema: dict[str, Any], error_msg: str | None - ) -> dict[str, Any] | None: - """ - Attempt to repair JSON data to conform to schema. - - This is a basic repair mechanism that handles common issues. - """ - try: - repaired = dict(json_data) - - # Add missing required properties - if schema.get("type") == "object": - required = schema.get("required", []) - properties = schema.get("properties", {}) - - for prop in required: - if prop not in repaired: - # Add default value based on property type - prop_schema = properties.get(prop, {}) - prop_type = prop_schema.get("type", "string") - - if prop_type == "string": - repaired[prop] = "" - elif prop_type == "number": - repaired[prop] = 0.0 - elif prop_type == "integer": - repaired[prop] = 0 - elif prop_type == "boolean": - repaired[prop] = False - elif prop_type == "array": - repaired[prop] = [] - elif prop_type == "object": - repaired[prop] = {} - else: - repaired[prop] = None - - # Validate the repaired JSON - is_valid, _ = Translation.validate_json_against_schema(repaired, schema) - return repaired if is_valid else None - - except Exception: - return None - - @staticmethod - def _iter_json_candidates( - content: str, - *, - max_candidates: int = 20, - max_object_size: int = 512 * 1024, - ) -> list[str]: - """Find potential JSON object substrings using a linear-time scan.""" - - candidates: list[str] = [] - depth = 0 - start_index: int | None = None - escape_next = False - string_delimiter: str | None = None - - for index, char in enumerate(content): - if string_delimiter is not None: - if escape_next: - escape_next = False - continue - if char == "\\": - escape_next = True - continue - if char == string_delimiter: - string_delimiter = None - continue - - if char in ('"', "'"): - string_delimiter = char - continue - - if char == "{": - if depth == 0: - start_index = index - depth += 1 - elif char == "}": - if depth == 0: - continue - depth -= 1 - if depth == 0 and start_index is not None: - candidate = content[start_index : index + 1] - start_index = None - if len(candidate) > max_object_size: - logger.warning( - "Skipping oversized JSON candidate (%d bytes)", - len(candidate), - ) - continue - candidates.append(candidate) - if len(candidates) >= max_candidates: - break - - return candidates - - @staticmethod - def _extract_and_repair_json(content: str, schema: dict[str, Any]) -> str | None: - """Extract JSON from content and attempt repair.""" - - try: - for candidate in Translation._iter_json_candidates(content): - try: - parsed = json.loads(candidate) - except json.JSONDecodeError: - continue - - if not isinstance(parsed, dict): - continue - - repaired = Translation._attempt_json_repair(parsed, schema, None) - if repaired is not None: - return json.dumps(repaired, indent=2) - - return None - except Exception: - return None +from __future__ import annotations + +import json +import logging +import mimetypes +import os +from typing import Any, cast + +from src.core.app.constants.logging_constants import TRACE_LEVEL + +_MAX_SANITIZE_DEPTH = 100 + +from src.core.domain.base_translator import BaseTranslator +from src.core.domain.chat import ( + CanonicalChatRequest, + CanonicalChatResponse, + ChatCompletionChoice, + ChatCompletionChoiceMessage, + ChatMessage, + ChatResponse, + FunctionCall, + ToolCall, +) +from src.core.services.tool_text_renderer import render_tool_call + +logger = logging.getLogger(__name__) + + +class Translation(BaseTranslator): + """ + A class for translating requests and responses between different API formats. + """ + + _codex_tool_call_index_base: dict[str, int] = {} + _codex_tool_call_item_index: dict[str, dict[str, int]] = {} + + @classmethod + def _reset_tool_call_state(cls, response_id: str | None) -> None: + if not response_id: + return + cls._codex_tool_call_index_base.pop(response_id, None) + cls._codex_tool_call_item_index.pop(response_id, None) + + @classmethod + def _assign_tool_call_index( + cls, + response_id: str | None, + output_index: Any, + item_id: str | None, + ) -> int: + if not response_id: + return 0 + + if not isinstance(output_index, int): + if item_id: + return cls._codex_tool_call_item_index.get(response_id, {}).get( + item_id, 0 + ) + return 0 + + base = cls._codex_tool_call_index_base.get(response_id) + if base is None or output_index < base: + cls._codex_tool_call_index_base[response_id] = output_index + base = output_index + + index = output_index - base + if index < 0: + index = 0 + + if item_id: + cls._codex_tool_call_item_index.setdefault(response_id, {})[item_id] = index + + return index + + @staticmethod + def validate_json_against_schema( + json_data: dict[str, Any], schema: dict[str, Any] + ) -> tuple[bool, str | None]: + """ + Validate JSON data against a JSON schema. + + Args: + json_data: The JSON data to validate + schema: The JSON schema to validate against + + Returns: + A tuple of (is_valid, error_message) + """ + try: + import jsonschema + + jsonschema.validate(json_data, schema) + return True, None + except ImportError: + # jsonschema not available, perform basic validation + return Translation._basic_schema_validation(json_data, schema) + except Exception as e: + # Check if this is a jsonschema error, even if the import failed + if "jsonschema" in str(e) and "ValidationError" in str(e): + return False, str(e) + # Fallback for other validation errors + return False, f"Schema validation error: {e!s}" + + @staticmethod + def _basic_schema_validation( + json_data: dict[str, Any], schema: dict[str, Any] + ) -> tuple[bool, str | None]: + """ + Perform basic JSON schema validation without jsonschema library. + + This is a fallback validation that checks basic schema requirements. + """ + try: + # Check type + schema_type = schema.get("type") + if schema_type == "object" and not isinstance(json_data, dict): + return False, f"Expected object, got {type(json_data).__name__}" + elif schema_type == "array" and not isinstance(json_data, list): + return False, f"Expected array, got {type(json_data).__name__}" + elif schema_type == "string" and not isinstance(json_data, str): + return False, f"Expected string, got {type(json_data).__name__}" + elif schema_type == "number" and not isinstance(json_data, int | float): + return False, f"Expected number, got {type(json_data).__name__}" + elif schema_type == "integer" and not isinstance(json_data, int): + return False, f"Expected integer, got {type(json_data).__name__}" + elif schema_type == "boolean" and not isinstance(json_data, bool): + return False, f"Expected boolean, got {type(json_data).__name__}" + + # Check required properties for objects + if schema_type == "object" and isinstance(json_data, dict): + required = schema.get("required", []) + for prop in required: + if prop not in json_data: + return False, f"Missing required property: {prop}" + + return True, None + except Exception as e: + return False, f"Basic validation error: {e!s}" + + @staticmethod + def _detect_image_mime_type(url: str) -> str: + """Detect the MIME type for an image URL or data URI.""" + if url.startswith("data:"): + header = url.split(",", 1)[0] + header = header.split(";", 1)[0] + if ":" in header: + candidate = header.split(":", 1)[1] + if candidate: + return candidate + return "image/jpeg" + + clean_url = url.split("?", 1)[0].split("#", 1)[0] + if "." in clean_url: + extension = clean_url.rsplit(".", 1)[-1].lower() + if extension: + mime_type = mimetypes.types_map.get(f".{extension}") + if mime_type and mime_type.startswith("image/"): + return mime_type + if extension == "jpg": + return "image/jpeg" + return "image/jpeg" + + @staticmethod + def _process_gemini_image_part(part: Any) -> dict[str, Any] | None: + """Convert a multimodal image part to Gemini format.""" + from src.core.domain.chat import MessageContentPartImage + + if not isinstance(part, MessageContentPartImage) or not part.image_url: + return None + + url_str = str(part.image_url.url or "").strip() + if not url_str: + return None + + # Inline data URIs are allowed + if url_str.startswith("data:"): + mime_type = Translation._detect_image_mime_type(url_str) + try: + _, base64_data = url_str.split(",", 1) + except ValueError: + base64_data = "" + return { + "inline_data": { + "mime_type": mime_type, + "data": base64_data, + } + } + + # For non-inline URIs, only allow http/https schemes. Reject file/ftp and local paths. + try: + from urllib.parse import urlparse + + scheme = (urlparse(url_str).scheme or "").lower() + except Exception: + scheme = "" + + allowed_schemes = {"http", "https"} + + if scheme not in allowed_schemes: + # Also treat Windows/local file paths (no scheme or drive-letter scheme) as invalid + return None + + mime_type = Translation._detect_image_mime_type(url_str) + return { + "file_data": { + "mime_type": mime_type, + "file_uri": url_str, + } + } + + @staticmethod + def _normalize_usage_metadata( + usage: dict[str, Any], source_format: str + ) -> dict[str, Any]: + """Normalize usage metadata from different API formats to a standard structure.""" + if source_format == "gemini": + return { + "prompt_tokens": usage.get("promptTokenCount", 0), + "completion_tokens": usage.get("candidatesTokenCount", 0), + "total_tokens": usage.get("totalTokenCount", 0), + } + elif source_format == "anthropic": + return { + "prompt_tokens": usage.get("input_tokens", 0), + "completion_tokens": usage.get("output_tokens", 0), + "total_tokens": usage.get("input_tokens", 0) + + usage.get("output_tokens", 0), + } + elif source_format in {"openai", "openai-responses"}: + prompt_tokens = usage.get("prompt_tokens", usage.get("input_tokens", 0)) + completion_tokens = usage.get( + "completion_tokens", usage.get("output_tokens", 0) + ) + total_tokens = usage.get("total_tokens") + if total_tokens is None: + total_tokens = prompt_tokens + completion_tokens + + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + else: + # Default normalization + return { + "prompt_tokens": usage.get("prompt_tokens", 0), + "completion_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + } + + @staticmethod + def _normalize_responses_input_to_messages( + input_payload: Any, + ) -> list[dict[str, Any]]: + """Coerce OpenAI Responses API input payloads into chat messages.""" + + def _normalize_message_entry(entry: Any) -> dict[str, Any] | None: + if entry is None: + return None + + if isinstance(entry, str): + return {"role": "user", "content": entry} + + if isinstance(entry, dict): + raw_role = entry.get("role") + if raw_role is None: + raw_role = "user" + role = str(raw_role) + message: dict[str, Any] = {"role": role} + + content = Translation._normalize_responses_content(entry.get("content")) + if content is not None: + if isinstance(content, list): + message["content_parts"] = content + message["content"] = content + else: + parts = [{"type": "text", "text": content}] + message["content_parts"] = parts + message["content"] = parts + + if "name" in entry and entry.get("name") is not None: + message["name"] = entry["name"] + + if "tool_calls" in entry and entry.get("tool_calls") is not None: + message["tool_calls"] = entry["tool_calls"] + + if "tool_call_id" in entry and entry.get("tool_call_id") is not None: + message["tool_call_id"] = entry["tool_call_id"] + + return message + + # Fallback: convert to string representation + return {"role": "user", "content": str(entry)} + + if input_payload is None: + return [] + + if isinstance(input_payload, str | bytes): + text_value = ( + input_payload.decode("utf-8", "ignore") + if isinstance(input_payload, bytes | bytearray) + else input_payload + ) + return [{"role": "user", "content": text_value}] + + if isinstance(input_payload, dict): + normalized = _normalize_message_entry(input_payload) + return [normalized] if normalized else [] + + if isinstance(input_payload, list | tuple): + messages: list[dict[str, Any]] = [] + for item in input_payload: + normalized = _normalize_message_entry(item) + if normalized: + messages.append(normalized) + return messages + + # Unknown type - coerce to a single user message + return [{"role": "user", "content": str(input_payload)}] + + @staticmethod + def _normalize_responses_content(content: Any) -> Any: + """Normalize Responses API content blocks into chat-compatible structures.""" + + def _coerce_text_value(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, bytes | bytearray): + return value.decode("utf-8", "ignore") + if isinstance(value, list): + segments: list[str] = [] + for segment in value: + if isinstance(segment, dict): + segments.append(_coerce_text_value(segment.get("text"))) + else: + segments.append(str(segment)) + return "".join(segments) + if isinstance(value, dict) and "text" in value: + return _coerce_text_value(value.get("text")) + return str(value) if value is not None else "" + + if content is None: + return None + + if isinstance(content, str | bytes | bytearray): + return _coerce_text_value(content) + + if isinstance(content, dict): + normalized_parts = Translation._normalize_responses_content_part(content) + if not normalized_parts: + return None + if len(normalized_parts) == 1 and normalized_parts[0].get("type") == "text": + return normalized_parts[0]["text"] + return normalized_parts + + if isinstance(content, list | tuple): + collected_parts: list[dict[str, Any]] = [] + for part in content: + if isinstance(part, dict): + collected_parts.extend( + Translation._normalize_responses_content_part(part) + ) + elif isinstance(part, str | bytes | bytearray): + collected_parts.append( + {"type": "text", "text": _coerce_text_value(part)} + ) + if not collected_parts: + return None + if len(collected_parts) == 1 and collected_parts[0].get("type") == "text": + return collected_parts[0]["text"] + return collected_parts + + return str(content) + + @staticmethod + def _normalize_responses_content_part(part: dict[str, Any]) -> list[dict[str, Any]]: + """Normalize a single Responses API content part.""" + + part_type = str(part.get("type") or "").lower() + normalized_parts: list[dict[str, Any]] = [] + + if part_type in {"text", "input_text", "output_text"}: + text_value = part.get("text") + if text_value is None: + text_value = part.get("value") + normalized_parts.append( + {"type": "text", "text": Translation._safe_string(text_value)} + ) + elif "image" in part_type: + image_payload = ( + part.get("image_url") + or part.get("imageUrl") + or part.get("image") + or part.get("image_data") + ) + if isinstance(image_payload, str): + image_payload = {"url": image_payload} + if isinstance(image_payload, dict) and image_payload.get("url"): + normalized_parts.append( + {"type": "image_url", "image_url": image_payload} + ) + elif part_type == "tool_call": + # Tool call parts are handled elsewhere in the pipeline; ignore here. + return [] + else: + # Preserve already-normalized structures (e.g., function calls) as-is + normalized_parts.append(part) + + return [p for p in normalized_parts if p] + + @staticmethod + def _safe_string(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, bytes | bytearray): + return value.decode("utf-8", "ignore") + return str(value) + + @staticmethod + def _map_gemini_finish_reason(finish_reason: str | None) -> str | None: + """Map Gemini finish reasons to canonical values.""" + if finish_reason is None: + return None + + normalized = str(finish_reason).lower() + mapping = { + "stop": "stop", + "max_tokens": "length", + "safety": "content_filter", + "tool_calls": "tool_calls", + } + return mapping.get(normalized, "stop") + + @staticmethod + def _normalize_stop_sequences(stop: Any) -> list[str] | None: + """Normalize stop sequences to a consistent format.""" + if stop is None: + return None + + if isinstance(stop, str): + return [stop] + + if isinstance(stop, list): + # Ensure all elements are strings + return [str(s) for s in stop] + + # Convert other types to string + return [str(stop)] + + @staticmethod + def _normalize_tool_arguments(args: Any) -> str: + """Normalize tool call arguments to a JSON string.""" + if args is None: + return "{}" + + if isinstance(args, str): + stripped = args.strip() + if not stripped: + return "{}" + + # First, try to load it as-is. It might be a valid JSON string. + try: + json.loads(stripped) + return stripped + except json.JSONDecodeError: + # If it fails, it might be a string using single quotes. + # We will try to fix it, but only if it doesn't create an invalid JSON. + pass + + try: + # Attempt to replace single quotes with double quotes for JSON compatibility. + # This is a common issue with LLM-generated JSON in string format. + # However, we must be careful not to corrupt strings that contain single quotes. + fixed_string = stripped.replace("'", '"') + json.loads(fixed_string) + return fixed_string + except (json.JSONDecodeError, TypeError): + # If replacement fails, it's likely not a simple quote issue. + # This can happen if the string contains legitimate single quotes. + # Return empty object instead of _raw format to maintain tool calling contract. + return "{}" + + if isinstance(args, dict): + try: + return json.dumps(args) + except TypeError: + # Handle dicts with non-serializable values + sanitized_dict = Translation._sanitize_dict_for_json(args) + return json.dumps(sanitized_dict) + + if isinstance(args, list | tuple): + try: + # PERFORMANCE OPTIMIZATION: Avoid unnecessary list copying + # Use args directly if it's already a list, only convert tuples + return json.dumps(args if isinstance(args, list) else list(args)) + except TypeError: + # Handle lists with non-serializable items + # PERFORMANCE OPTIMIZATION: Avoid unnecessary list copying + sanitized_list = Translation._sanitize_list_for_json( + args if isinstance(args, list) else list(args) + ) + return json.dumps(sanitized_list) + + # For primitive types that should be JSON serializable + if isinstance(args, int | float | bool): + return json.dumps(args) + + # For non-serializable objects, return empty object instead of _raw format + # This maintains the tool calling contract while preventing failures + return "{}" + + @staticmethod + def _is_json_serializable( + value: Any, + *, + max_depth: int, + _depth: int = 0, + _seen: set[int] | None = None, + ) -> bool: + """Best-effort check to determine if a value can be JSON-serialized.""" + + if _depth > max_depth: + return False + + if value is None or isinstance(value, str | int | float | bool): + return True + + if isinstance(value, list | tuple): + if _seen is None: + _seen = set() + obj_id = id(value) + if obj_id in _seen: + return False + _seen.add(obj_id) + try: + return all( + Translation._is_json_serializable( + item, + max_depth=max_depth, + _depth=_depth + 1, + _seen=_seen, + ) + for item in value + ) + finally: + _seen.remove(obj_id) + + if isinstance(value, dict): + if _seen is None: + _seen = set() + obj_id = id(value) + if obj_id in _seen: + return False + _seen.add(obj_id) + try: + for key, item in value.items(): + if key is not None and not isinstance( + key, str | int | float | bool + ): + return False + if not Translation._is_json_serializable( + item, + max_depth=max_depth, + _depth=_depth + 1, + _seen=_seen, + ): + return False + finally: + _seen.remove(obj_id) + return True + + return False + + @staticmethod + def _sanitize_dict_for_json( + data: dict[str, Any], + *, + max_depth: int = _MAX_SANITIZE_DEPTH, + _depth: int = 0, + _seen: set[int] | None = None, + ) -> dict[str, Any]: + """Sanitize a dictionary by removing or converting non-JSON-serializable values.""" + + if _depth > max_depth: + return {} + + if _seen is None: + _seen = set() + + obj_id = id(data) + if obj_id in _seen: + return {} + + _seen.add(obj_id) + try: + sanitized: dict[str, Any] = {} + sanitized_value: Any = None + for key, value in data.items(): + if key is not None and not isinstance(key, str | int | float | bool): + continue + + if Translation._is_json_serializable( + value, + max_depth=max_depth, + _depth=_depth + 1, + _seen=_seen, + ): + sanitized[key] = value + continue + + if isinstance(value, dict): + sanitized_value = Translation._sanitize_dict_for_json( + value, + max_depth=max_depth, + _depth=_depth + 1, + _seen=_seen, + ) + elif isinstance(value, list | tuple): + sanitized_value = Translation._sanitize_list_for_json( + value if isinstance(value, list) else list(value), + max_depth=max_depth, + _depth=_depth + 1, + _seen=_seen, + ) + elif isinstance(value, str | int | float | bool) or value is None: + sanitized_value = value + else: + continue + + sanitized[key] = sanitized_value + + return sanitized + finally: + _seen.remove(obj_id) + + @staticmethod + def _sanitize_list_for_json( + data: list[Any], + *, + max_depth: int = _MAX_SANITIZE_DEPTH, + _depth: int = 0, + _seen: set[int] | None = None, + ) -> list[Any]: + """Sanitize a list by removing or converting non-JSON-serializable items.""" + + if _depth > max_depth: + return [] + + if _seen is None: + _seen = set() + + obj_id = id(data) + if obj_id in _seen: + return [] + + _seen.add(obj_id) + try: + sanitized: list[Any] = [] + for item in data: + if Translation._is_json_serializable( + item, + max_depth=max_depth, + _depth=_depth + 1, + _seen=_seen, + ): + sanitized.append(item) + continue + + if isinstance(item, dict): + sanitized.append( + Translation._sanitize_dict_for_json( + item, + max_depth=max_depth, + _depth=_depth + 1, + _seen=_seen, + ) + ) + elif isinstance(item, list | tuple): + sanitized.append( + Translation._sanitize_list_for_json( + item if isinstance(item, list) else list(item), + max_depth=max_depth, + _depth=_depth + 1, + _seen=_seen, + ) + ) + elif isinstance(item, str | int | float | bool) or item is None: + sanitized.append(item) + else: + continue + + return sanitized + finally: + _seen.remove(obj_id) + + @staticmethod + def _process_gemini_function_call(function_call: dict[str, Any]) -> ToolCall: + """Process a Gemini function call part into a ToolCall.""" + import uuid + + name = function_call.get("name", "") + raw_args = function_call.get("args", function_call.get("arguments")) + normalized_args = Translation._normalize_tool_arguments(raw_args) + + return ToolCall( + id=f"call_{uuid.uuid4().hex[:12]}", + type="function", + function=FunctionCall(name=name, arguments=normalized_args), + ) + + @staticmethod + def gemini_to_domain_request(request: Any) -> CanonicalChatRequest: + """ + Translate a Gemini request to a CanonicalChatRequest. + """ + from src.core.domain.gemini_translation import ( + gemini_request_to_canonical_request, + ) + + return gemini_request_to_canonical_request(request) + + @staticmethod + def anthropic_to_domain_request(request: Any) -> CanonicalChatRequest: + """ + Translate an Anthropic request to a CanonicalChatRequest. + """ + # Use the helper method to safely access request parameters + system_prompt = Translation._get_request_param(request, "system") + raw_messages = Translation._get_request_param(request, "messages", []) + normalized_messages: list[Any] = [] + + if system_prompt: + normalized_messages.append({"role": "system", "content": system_prompt}) + + if raw_messages: + for message in raw_messages: + normalized_messages.append(message) + + stop_param = Translation._get_request_param(request, "stop") + stop_sequences = Translation._get_request_param(request, "stop_sequences") + normalized_stop = stop_param + if ( + normalized_stop is None or normalized_stop == [] or normalized_stop == "" + ) and stop_sequences not in (None, [], ""): + normalized_stop = stop_sequences + + return CanonicalChatRequest( + model=Translation._get_request_param(request, "model"), + messages=normalized_messages, + temperature=Translation._get_request_param(request, "temperature"), + top_p=Translation._get_request_param(request, "top_p"), + top_k=Translation._get_request_param(request, "top_k"), + repetition_penalty=Translation._get_request_param( + request, "repetition_penalty" + ), + min_p=Translation._get_request_param(request, "min_p"), + n=Translation._get_request_param(request, "n"), + stream=Translation._get_request_param(request, "stream"), + stop=normalized_stop, + max_tokens=Translation._get_request_param(request, "max_tokens"), + presence_penalty=Translation._get_request_param( + request, "presence_penalty" + ), + frequency_penalty=Translation._get_request_param( + request, "frequency_penalty" + ), + logit_bias=Translation._get_request_param(request, "logit_bias"), + user=Translation._get_request_param(request, "user"), + reasoning_effort=Translation._get_request_param( + request, "reasoning_effort" + ), + seed=Translation._get_request_param(request, "seed"), + tools=Translation._get_request_param(request, "tools"), + tool_choice=Translation._get_request_param(request, "tool_choice"), + extra_body=Translation._get_request_param(request, "extra_body"), + ) + + @staticmethod + def anthropic_to_domain_response(response: Any) -> CanonicalChatResponse: + """ + Translate an Anthropic response to a CanonicalChatResponse. + """ + import time + + if not isinstance(response, dict): + # Handle non-dict responses + return CanonicalChatResponse( + id=f"chatcmpl-anthropic-{int(time.time())}", + object="chat.completion", + created=int(time.time()), + model="unknown", + choices=[ + ChatCompletionChoice( + index=0, + message=ChatCompletionChoiceMessage( + role="assistant", content=str(response) + ), + finish_reason="stop", + ) + ], + usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + ) + + # Extract choices + choices = [] + if "content" in response: + for idx, item in enumerate(response["content"]): + if item.get("type") == "text": + choice = ChatCompletionChoice( + index=idx, + message=ChatCompletionChoiceMessage( + role="assistant", content=item.get("text", "") + ), + finish_reason=response.get("stop_reason", "stop"), + ) + choices.append(choice) + + # Extract usage + usage = response.get("usage", {}) + normalized_usage = Translation._normalize_usage_metadata(usage, "anthropic") + + return CanonicalChatResponse( + id=response.get("id", f"chatcmpl-anthropic-{int(time.time())}"), + object="chat.completion", + created=int(time.time()), + model=response.get("model", "unknown"), + choices=choices, + usage=normalized_usage, + ) + + @staticmethod + def gemini_to_domain_response(response: Any) -> CanonicalChatResponse: + """ + Translate a Gemini response to a CanonicalChatResponse. + """ + import time + import uuid + + # Generate a unique ID for the response + response_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" + created = int(time.time()) + model = "gemini-pro" # Default model if not specified + + # Extract choices from candidates + choices = [] + if isinstance(response, dict) and "candidates" in response: + for idx, candidate in enumerate(response["candidates"]): + content = "" + tool_calls = None + + # Extract content from parts + if "content" in candidate and "parts" in candidate["content"]: + parts = candidate["content"]["parts"] + + # Extract text content + text_parts = [] + for part in parts: + if "text" in part: + text_parts.append(part["text"]) + elif "functionCall" in part: + # Handle function calls (tool calls) + if tool_calls is None: + tool_calls = [] + + function_call = part["functionCall"] + tool_calls.append( + Translation._process_gemini_function_call(function_call) + ) + + content = "".join(text_parts) + + # Map finish reason + finish_reason = Translation._map_gemini_finish_reason( + candidate.get("finishReason") + ) + + # Create choice + choice = ChatCompletionChoice( + index=idx, + message=ChatCompletionChoiceMessage( + role="assistant", + content=content, + tool_calls=tool_calls, + ), + finish_reason=finish_reason, + ) + choices.append(choice) + + # Extract usage metadata + usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + if isinstance(response, dict) and "usageMetadata" in response: + usage_metadata = response["usageMetadata"] + usage = Translation._normalize_usage_metadata(usage_metadata, "gemini") + + # If no choices were extracted, create a default one + if not choices: + choices = [ + ChatCompletionChoice( + index=0, + message=ChatCompletionChoiceMessage(role="assistant", content=""), + finish_reason="stop", + ) + ] + + return CanonicalChatResponse( + id=response_id, + object="chat.completion", + created=created, + model=model, + choices=choices, + usage=usage, + ) + + @staticmethod + def gemini_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: + """ + Translate a Gemini streaming chunk to a canonical dictionary format. + + Args: + chunk: The Gemini streaming chunk. + + Returns: + A dictionary representing the canonical chunk format. + """ + import time + import uuid + + if not isinstance(chunk, dict): + return {"error": "Invalid chunk format: expected a dictionary"} + + response_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" + created = int(time.time()) + model = "gemini-pro" # Default model + + content_pieces: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + finish_reason = None + + if "candidates" in chunk: + for candidate in chunk["candidates"]: + if "content" in candidate and "parts" in candidate["content"]: + for part in candidate["content"]["parts"]: + if "text" in part: + content_pieces.append(part["text"]) + elif "functionCall" in part: + try: + tool_calls.append( + Translation._process_gemini_function_call( + part["functionCall"] + ).model_dump() + ) + except Exception: + continue + if "finishReason" in candidate: + finish_reason = Translation._map_gemini_finish_reason( + candidate["finishReason"] + ) + + delta: dict[str, Any] = {"role": "assistant"} + if content_pieces: + delta["content"] = "".join(content_pieces) + if tool_calls: + delta["tool_calls"] = tool_calls + + return { + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "delta": delta, + "finish_reason": finish_reason, + } + ], + } + + @staticmethod + def openai_to_domain_request(request: Any) -> CanonicalChatRequest: + """ + Translate an OpenAI request to a CanonicalChatRequest. + """ + if isinstance(request, dict): + model = request.get("model") + messages = request.get("messages", []) + top_k = request.get("top_k") + top_p = request.get("top_p") + temperature = request.get("temperature") + max_tokens = request.get("max_tokens") + stop = request.get("stop") + stream = request.get("stream", False) + tools = request.get("tools") + tool_choice = request.get("tool_choice") + seed = request.get("seed") + reasoning_effort = request.get("reasoning_effort") + reasoning_payload = request.get("reasoning") + repetition_penalty = request.get("repetition_penalty") + min_p = request.get("min_p") + else: + model = getattr(request, "model", None) + messages = getattr(request, "messages", []) + top_k = getattr(request, "top_k", None) + top_p = getattr(request, "top_p", None) + temperature = getattr(request, "temperature", None) + max_tokens = getattr(request, "max_tokens", None) + stop = getattr(request, "stop", None) + stream = getattr(request, "stream", False) + tools = getattr(request, "tools", None) + tool_choice = getattr(request, "tool_choice", None) + seed = getattr(request, "seed", None) + reasoning_effort = getattr(request, "reasoning_effort", None) + reasoning_payload = getattr(request, "reasoning", None) + repetition_penalty = getattr(request, "repetition_penalty", None) + min_p = getattr(request, "min_p", None) + + if reasoning_effort in ("", None) and isinstance(reasoning_payload, dict): + raw_effort = reasoning_payload.get("effort") + if isinstance(raw_effort, str) and raw_effort.strip(): + reasoning_effort = raw_effort + + normalized_reasoning: dict[str, Any] | None = None + if reasoning_payload: + if isinstance(reasoning_payload, dict): + normalized_reasoning = dict(reasoning_payload) + elif hasattr(reasoning_payload, "model_dump"): + normalized_reasoning = reasoning_payload.model_dump() # type: ignore[attr-defined] + + if not model: + raise ValueError("Model not found in request") + + # Convert messages to ChatMessage objects if they are dicts + chat_messages = [] + for msg in messages: + if isinstance(msg, dict): + chat_messages.append(ChatMessage(**msg)) + else: + chat_messages.append(msg) + + return CanonicalChatRequest( + model=model, + messages=chat_messages, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + min_p=min_p, + max_tokens=max_tokens, + stop=stop, + stream=stream, + tools=tools, + tool_choice=tool_choice, + seed=seed, + reasoning_effort=reasoning_effort, + reasoning=normalized_reasoning, + ) + + @staticmethod + def openai_to_domain_response(response: Any) -> CanonicalChatResponse: + """ + Translate an OpenAI response to a CanonicalChatResponse. + """ + import time + + if not isinstance(response, dict): + return CanonicalChatResponse( + id=f"chatcmpl-openai-{int(time.time())}", + object="chat.completion", + created=int(time.time()), + model="unknown", + choices=[ + ChatCompletionChoice( + index=0, + message=ChatCompletionChoiceMessage( + role="assistant", content=str(response) + ), + finish_reason="stop", + ) + ], + usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + ) + + choices: list[ChatCompletionChoice] = [] + for idx, ch in enumerate(response.get("choices", [])): + msg = ch.get("message", {}) + role = msg.get("role", "assistant") + content = msg.get("content") + + # Preserve tool_calls if present + tool_calls = None + raw_tool_calls = msg.get("tool_calls") + if isinstance(raw_tool_calls, list): + # Validate each tool call in the list before including it + validated_tool_calls = [] + for tc in raw_tool_calls: + # Convert dict to ToolCall if necessary + if isinstance(tc, dict): + # Create a ToolCall object from the dict + # Assuming the dict has the necessary structure for ToolCall + # We'll need to import ToolCall if not already available + # For now, we'll use a simple approach + try: + # Create ToolCall from dict, assuming proper structure + tool_call_obj = ToolCall(**tc) + validated_tool_calls.append(tool_call_obj) + except (TypeError, ValueError): + # If conversion fails, skip this tool call + pass + elif isinstance(tc, ToolCall): + validated_tool_calls.append(tc) + else: + # Log or handle invalid tool call + # For now, we'll skip invalid ones + pass + tool_calls = validated_tool_calls if validated_tool_calls else None + + message_obj = ChatCompletionChoiceMessage( + role=role, content=content, tool_calls=tool_calls + ) + + choices.append( + ChatCompletionChoice( + index=idx, + message=message_obj, + finish_reason=ch.get("finish_reason"), + ) + ) + + usage = response.get("usage") or {} + normalized_usage = Translation._normalize_usage_metadata(usage, "openai") + + return CanonicalChatResponse( + id=response.get("id", "chatcmpl-openai-unk"), + object=response.get("object", "chat.completion"), + created=response.get("created", int(__import__("time").time())), + model=response.get("model", "unknown"), + choices=choices, + usage=normalized_usage, + ) + + @staticmethod + def responses_to_domain_response(response: Any) -> CanonicalChatResponse: + """Translate an OpenAI Responses API response to a canonical response.""" + import time + + if not isinstance(response, dict): + return Translation.openai_to_domain_response(response) + + # If the backend already returned OpenAI-style choices, reuse that logic. + if response.get("choices") and not response.get("output"): + return Translation.openai_to_domain_response(response) + + output_items = response.get("output") or [] + choices: list[ChatCompletionChoice] = [] + + for idx, item in enumerate(output_items): + if not isinstance(item, dict): + continue + + role = item.get("role", "assistant") + content_parts = item.get("content") + if not isinstance(content_parts, list): + content_parts = [] + + text_segments: list[str] = [] + tool_calls: list[ToolCall] = [] + + for part in content_parts: + if not isinstance(part, dict): + continue + + part_type = part.get("type") + if part_type in {"output_text", "text", "input_text"}: + text_value = part.get("text") or part.get("value") or "" + if text_value: + text_segments.append(str(text_value)) + elif part_type == "tool_call": + function_payload = ( + part.get("function") or part.get("function_call") or {} + ) + normalized_args = Translation._normalize_tool_arguments( + function_payload.get("arguments") + or function_payload.get("args") + or function_payload.get("arguments_json") + ) + tool_calls.append( + ToolCall( + id=part.get("id") or f"tool_call_{idx}_{len(tool_calls)}", + function=FunctionCall( + name=function_payload.get("name", ""), + arguments=normalized_args, + ), + ) + ) + + content_text = "\n".join( + segment for segment in text_segments if segment + ).strip() + + finish_reason = item.get("finish_reason") or item.get("status") + if finish_reason == "completed": + finish_reason = "stop" + elif finish_reason == "incomplete": + finish_reason = "length" + elif finish_reason in {"in_progress", "generating"}: + finish_reason = None + elif finish_reason is None and (content_text or tool_calls): + finish_reason = "stop" + + message = ChatCompletionChoiceMessage( + role=role, + content=content_text or None, + tool_calls=tool_calls or None, + ) + + choices.append( + ChatCompletionChoice( + index=idx, + message=message, + finish_reason=finish_reason, + ) + ) + + if not choices: + # Fallback to output_text aggregation used by the Responses API when + # the structured output array is empty. This happens when the + # backend only returns plain text without additional metadata. + output_text = response.get("output_text") + fallback_text_segments: list[str] = [] + if isinstance(output_text, list): + fallback_text_segments = [ + str(segment) for segment in output_text if segment + ] + elif isinstance(output_text, str) and output_text: + fallback_text_segments = [output_text] + + if fallback_text_segments: + aggregated_text = "".join(fallback_text_segments) + status = response.get("status") + fallback_finish_reason: str | None + if status == "completed": + fallback_finish_reason = "stop" + elif status == "incomplete": + fallback_finish_reason = "length" + elif status in {"in_progress", "generating"}: + fallback_finish_reason = None + else: + fallback_finish_reason = "stop" if aggregated_text else None + + message = ChatCompletionChoiceMessage( + role="assistant", + content=aggregated_text, + tool_calls=None, + ) + + choices.append( + ChatCompletionChoice( + index=0, + message=message, + finish_reason=fallback_finish_reason, + ) + ) + + if not choices: + # Fallback to OpenAI conversion to avoid returning an empty response + return Translation.openai_to_domain_response(response) + + usage = response.get("usage") or {} + normalized_usage = Translation._normalize_usage_metadata( + usage, "openai-responses" + ) + + return CanonicalChatResponse( + id=response.get("id", f"resp-{int(time.time())}"), + object=response.get("object", "response"), + created=response.get("created", int(time.time())), + model=response.get("model", "unknown"), + choices=choices, + usage=normalized_usage, + system_fingerprint=response.get("system_fingerprint"), + ) + + @staticmethod + def openai_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: + """ + Translate an OpenAI streaming chunk to a canonical dictionary format. + + Args: + chunk: The OpenAI streaming chunk. + + Returns: + A dictionary representing the canonical chunk format. + """ + import json + import time + import uuid + + if isinstance(chunk, bytes | bytearray): + try: + chunk = chunk.decode("utf-8") + except Exception: + return {"error": "Invalid chunk format: unable to decode bytes"} + + if isinstance(chunk, str): + stripped_chunk = chunk.strip() + + if not stripped_chunk: + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "unknown", + "choices": [ + {"index": 0, "delta": {}, "finish_reason": None}, + ], + } + + if stripped_chunk.startswith(":"): + # Comment/heartbeat lines (e.g., ": ping") should be ignored by emitting + # an empty delta so downstream processors keep the stream alive. + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "unknown", + "choices": [ + {"index": 0, "delta": {}, "finish_reason": None}, + ], + } + + if stripped_chunk.startswith("data:"): + stripped_chunk = stripped_chunk[5:].strip() + + if stripped_chunk == "[DONE]": + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "unknown", + "choices": [ + {"index": 0, "delta": {}, "finish_reason": "stop"}, + ], + } + + try: + chunk = json.loads(stripped_chunk) + except json.JSONDecodeError as exc: + logger.warning( + "Responses stream chunk JSON decode failed: %s", + stripped_chunk[:300], + ) + return { + "error": "Invalid chunk format: expected JSON after 'data:' prefix", + "details": {"message": str(exc)}, + } + + if not isinstance(chunk, dict): + return {"error": "Invalid chunk format: expected a dictionary"} + + # Basic validation for essential keys + if "id" not in chunk or "choices" not in chunk: + if logger.isEnabledFor(TRACE_LEVEL): + try: + logger.log( + TRACE_LEVEL, + "OpenAI stream chunk missing id/choices: %s", + json.dumps(chunk)[:500], + ) + except Exception: + logger.log( + TRACE_LEVEL, + "OpenAI stream chunk missing id/choices (non-serializable)", + ) + return {"error": "Invalid chunk: missing 'id' or 'choices'"} + + # For simplicity, we'll return the chunk as a dictionary. + # In a more complex scenario, you might map this to a Pydantic model. + return dict(chunk) + + @staticmethod + def responses_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: + """Translate an OpenAI Responses streaming chunk to canonical format.""" + import json + import time + import uuid + + def _heartbeat_chunk(finish_reason: str | None = None) -> dict[str, Any]: + """Return a minimal chunk used for comments/heartbeats.""" + return { + "id": f"resp-{uuid.uuid4().hex[:16]}", + "object": "response.chunk", + "created": int(time.time()), + "model": "unknown", + "choices": [ + {"index": 0, "delta": {}, "finish_reason": finish_reason}, + ], + } + + def _extract_text(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, dict): + if "text" in value: + return _extract_text(value["text"]) + if "content" in value: + return _extract_text(value["content"]) + if "value" in value: + return _extract_text(value["value"]) + if isinstance(value, list): + parts = [_extract_text(v) for v in value] + return "".join(part for part in parts if part) + if value is None: + return "" + return str(value) + + if isinstance(chunk, bytes | bytearray): + try: + chunk = chunk.decode("utf-8") + except UnicodeDecodeError: + return { + "error": "Invalid chunk format: unable to decode bytes", + } + + event_type_from_sse: str | None = None + if isinstance(chunk, str): + stripped_chunk = chunk.strip() + + if not stripped_chunk: + return {"error": "Invalid chunk format: empty string"} + + if stripped_chunk.startswith(":"): + # Comment/heartbeat line (e.g., ": ping") + return _heartbeat_chunk() + + data_parts: list[str] = [] + has_data_prefix = False + for raw_line in chunk.splitlines(): + line = raw_line.strip() + if not line: + continue + if line.startswith(":"): + return _heartbeat_chunk() + if line.startswith("event:"): + event_type_from_sse = line[6:].strip() + continue + if line.startswith("data:"): + has_data_prefix = True + payload = line[5:].strip() + if payload.startswith("event:") and not event_type_from_sse: + event_type_from_sse = payload[6:].strip() + continue + data_parts.append(payload) + continue + data_parts.append(line) + + stripped_chunk = "\n".join(part for part in data_parts if part).strip() + + if not has_data_prefix: + return _heartbeat_chunk() + + if not stripped_chunk: + return _heartbeat_chunk() + + if stripped_chunk == "[DONE]": + return _heartbeat_chunk(finish_reason="stop") + + try: + chunk = json.loads(stripped_chunk) + except json.JSONDecodeError as exc: + logger.warning( + "Responses stream chunk JSON decode failed: %s", + stripped_chunk[:300], + ) + return { + "error": "Invalid chunk format: expected JSON after 'data:' prefix", + "details": {"message": str(exc)}, + } + + if not isinstance(chunk, dict): + return {"error": "Invalid chunk format: expected a dictionary"} + + response_payload = chunk.get("response") + if isinstance(response_payload, dict): + chunk_id = response_payload.get("id") + created = response_payload.get("created") + model = response_payload.get("model") + else: + chunk_id = None + created = None + model = None + + chunk_id = chunk_id or chunk.get("id") or f"resp-{uuid.uuid4().hex[:16]}" + created = created or chunk.get("created") or int(time.time()) + model = model or chunk.get("model") or "unknown" + object_type = chunk.get("object") or "response.chunk" + index = chunk.get("index", 0) + event_type = ( + (chunk.get("type") or event_type_from_sse or "").strip() + if chunk.get("type") or event_type_from_sse + else "" + ) + + if logger.isEnabledFor(TRACE_LEVEL): + try: + logger.log( + TRACE_LEVEL, + "Responses event type=%s payload=%s", + event_type or "", + json.dumps(chunk)[:400], + ) + except Exception: + logger.log( + TRACE_LEVEL, + "Responses event type=%s payload=", + event_type or "", + ) + + def _build_chunk( + delta: dict[str, Any] | None = None, + finish_reason: str | None = None, + ) -> dict[str, Any]: + return { + "id": chunk_id, + "object": object_type, + "created": created, + "model": model, + "choices": [ + { + "index": index, + "delta": delta or {}, + "finish_reason": finish_reason, + } + ], + } + + if event_type == "response.output_text.delta": + delta_payload = chunk.get("delta") + text = _extract_text(delta_payload) + if not text: + return _build_chunk() + delta_map: dict[str, Any] = {"content": text} + if isinstance(delta_payload, dict): + role = delta_payload.get("role") + if role: + delta_map["role"] = role + delta_map.setdefault("role", "assistant") + return _build_chunk(delta_map) + + if event_type == "response.reasoning_summary_text.delta": + summary_text = _extract_text(chunk.get("delta")) + return _build_chunk({"reasoning_summary": summary_text}) + + if event_type == "response.reasoning_text.delta": + reasoning_text = _extract_text(chunk.get("delta")) + return _build_chunk({"reasoning_content": reasoning_text}) + + if event_type == "response.function_call_arguments.delta": + call_id = chunk.get("item_id") or chunk.get("call_id") + name = chunk.get("name") or "" + delta_payload = chunk.get("delta") or {} + if isinstance(delta_payload, str): + arguments_fragment = delta_payload + else: + arguments_fragment = _extract_text(delta_payload) + if not isinstance(arguments_fragment, str): + arguments_fragment = json.dumps(delta_payload) + if arguments_fragment is None: + arguments_fragment = "" + tool_index = Translation._assign_tool_call_index( + chunk_id, chunk.get("output_index"), call_id + ) + function_payload: dict[str, Any] = { + "arguments": arguments_fragment, + } + if name: + function_payload["name"] = name + delta = { + "tool_calls": [ + { + "index": tool_index, + "id": call_id or "", + "type": "function", + "function": function_payload, + } + ] + } + return _build_chunk(delta) + + if event_type == "response.function_call_arguments.done": + call_id = chunk.get("item_id") or chunk.get("call_id") + name = chunk.get("name") or "" + arguments = chunk.get("arguments") + if isinstance(arguments, dict | list): + arguments = json.dumps(arguments) + elif arguments is None: + arguments = "{}" + else: + arguments = str(arguments) + tool_index = Translation._assign_tool_call_index( + chunk_id, chunk.get("output_index"), call_id + ) + tool_call_obj = ToolCall( + id=call_id or "", + type="function", + function=FunctionCall(name=name, arguments=arguments), + ) + tool_text = render_tool_call(tool_call_obj) + delta = { + "tool_calls": [ + { + "index": tool_index, + "id": call_id or "", + "type": "function", + "function": { + "name": name, + "arguments": arguments, + }, + } + ] + } + if tool_text: + delta["_tool_call_text"] = tool_text # type: ignore[assignment] + return _build_chunk(delta, "tool_calls") + + if event_type == "response.output_item.done": + item = chunk.get("item") or {} + item_type = item.get("type") + + if logger.isEnabledFor(TRACE_LEVEL): + try: + logger.log( + TRACE_LEVEL, + "Responses output_item.done item=%s", + json.dumps(item)[:400], + ) + except Exception: + logger.log( + TRACE_LEVEL, + "Responses output_item.done item=", + ) + + if item_type == "message": + text = _extract_text(item.get("content", [])) + message_delta: dict[str, Any] = {"content": text} if text else {} + role = item.get("role") + if role: + message_delta["role"] = role + return _build_chunk(message_delta or None) + + if item_type == "function_call": + arguments = item.get("arguments", "{}") + if not isinstance(arguments, str): + arguments = json.dumps(arguments) + call_id = ( + item.get("call_id") + or item.get("id") + or f"call_{uuid.uuid4().hex[:8]}" + ) + tool_index = Translation._assign_tool_call_index( + chunk_id, chunk.get("output_index"), call_id + ) + tool_call_obj = ToolCall( + id=call_id, + type="function", + function=FunctionCall( + name=item.get("name", ""), arguments=arguments + ), + ) + tool_text = render_tool_call(tool_call_obj) + delta = { + "tool_calls": [ + { + "id": call_id, + "index": tool_index, + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": arguments, + }, + } + ] + } + if tool_text: + delta["_tool_call_text"] = tool_text # type: ignore[assignment] + return _build_chunk(delta, "tool_calls") + + if item_type == "custom_tool_call": + input_payload = item.get("input", "") + if not isinstance(input_payload, str): + input_payload = json.dumps(input_payload) + call_id = ( + item.get("call_id") + or item.get("id") + or f"custom_{uuid.uuid4().hex[:8]}" + ) + tool_index = Translation._assign_tool_call_index( + chunk_id, chunk.get("output_index"), call_id + ) + tool_call_obj = ToolCall( + id=call_id, + type="function", + function=FunctionCall( + name=item.get("name", ""), arguments=input_payload + ), + ) + tool_text = render_tool_call(tool_call_obj) + delta = { + "tool_calls": [ + { + "id": call_id, + "index": tool_index, + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": input_payload or "{}", + }, + } + ] + } + if tool_text: + delta["_tool_call_text"] = tool_text # type: ignore[assignment] + + return _build_chunk(delta) + + if item_type == "local_shell_call": + action = item.get("action") or {} + arguments = action if isinstance(action, str) else json.dumps(action) + call_id = ( + item.get("call_id") + or item.get("id") + or f"shell_{uuid.uuid4().hex[:8]}" + ) + tool_index = Translation._assign_tool_call_index( + chunk_id, chunk.get("output_index"), call_id + ) + tool_call_obj = ToolCall( + id=call_id, + type="function", + function=FunctionCall(name="shell", arguments=arguments), + ) + tool_text = render_tool_call(tool_call_obj) + delta = { + "tool_calls": [ + { + "id": call_id, + "index": tool_index, + "type": "function", + "function": { + "name": "shell", + "arguments": arguments, + }, + } + ] + } + if tool_text: + delta["_tool_call_text"] = tool_text # type: ignore[assignment] + + return _build_chunk(delta) + + return _build_chunk() + + if event_type == "response.completed": + response_info = chunk.get("response") or {} + result = _build_chunk({}, "stop") + usage = response_info.get("usage") + if usage: + result["usage"] = usage + response_id = response_info.get("id") or chunk_id + if response_id: + result["response_id"] = response_id + Translation._reset_tool_call_state(response_id) + return result + + if event_type == "response.created": + response_info = chunk.get("response") or {} + response_id = response_info.get("id") or chunk_id + if response_id: + Translation._reset_tool_call_state(response_id) + created_delta: dict[str, Any] = {} + if response_id: + created_delta["response_id"] = response_id + created_delta["role"] = "assistant" + return _build_chunk(created_delta or None) + + if event_type == "response.failed": + response_info = chunk.get("response") or {} + error_payload = response_info.get("error") or chunk.get("error") or {} + Translation._reset_tool_call_state(response_info.get("id") or chunk_id) + return { + "error": "Responses stream reported failure", + "details": error_payload, + } + + if event_type in { + "response.output_text.done", + "response.output_item.added", + "response.custom_tool_call_input.done", + "response.custom_tool_call_input.delta", + "response.function_call_arguments.delta", + "response.in_progress", + "response.content_part.done", + }: + return _build_chunk() + + if "choices" in chunk: + choices = chunk.get("choices") or [] + if not isinstance(choices, list) or not choices: + return _build_chunk() + + primary_choice = choices[0] or {} + finish_reason = primary_choice.get("finish_reason") + raw_delta = primary_choice.get("delta") or {} + if isinstance(raw_delta, dict): + delta = cast(dict[str, Any], dict(raw_delta)) + else: + delta = {"content": cast(Any, str(raw_delta))} + + content_value = delta.get("content") + if isinstance(content_value, list): + text_parts: list[str] = [] + for part in content_value: + if not isinstance(part, dict): + continue + part_type = part.get("type") + if part_type in {"output_text", "text", "input_text"}: + text_value = part.get("text") or part.get("value") or "" + if text_value: + text_parts.append(str(text_value)) + delta["content"] = cast(Any, "".join(text_parts)) + elif isinstance(content_value, dict): + delta["content"] = cast(Any, json.dumps(content_value)) + elif content_value is None: + delta.pop("content", None) + else: + delta["content"] = cast(Any, str(content_value)) + + tool_calls = delta.get("tool_calls") + if isinstance(tool_calls, list): + normalized_tool_calls: list[dict[str, Any]] = [] + for tool_call in tool_calls: + if isinstance(tool_call, dict): + call_data = dict(tool_call) + else: + function = getattr(tool_call, "function", None) + call_data = { + "id": getattr(tool_call, "id", ""), + "type": getattr(tool_call, "type", "function"), + "function": { + "name": getattr(function, "name", ""), + "arguments": getattr(function, "arguments", "{}"), + }, + } + + function_payload = call_data.get("function") or {} + if isinstance(function_payload, dict): + arguments = function_payload.get("arguments") + if isinstance(arguments, dict | list): + function_payload["arguments"] = json.dumps(arguments) + elif arguments is None: + function_payload["arguments"] = "{}" + else: + function_payload["arguments"] = str(arguments) + + normalized_tool_calls.append(call_data) + + if normalized_tool_calls: + delta["tool_calls"] = normalized_tool_calls + else: + delta.pop("tool_calls", None) + + return _build_chunk(delta, finish_reason) + + # Default: emit an empty chunk to keep the stream progressing. + return _build_chunk() + + @staticmethod + def openrouter_to_domain_request(request: Any) -> CanonicalChatRequest: + """ + Translate an OpenRouter request to a CanonicalChatRequest. + """ + if isinstance(request, dict): + model = request.get("model") + messages = request.get("messages", []) + top_k = request.get("top_k") + top_p = request.get("top_p") + temperature = request.get("temperature") + max_tokens = request.get("max_tokens") + stop = request.get("stop") + seed = request.get("seed") + reasoning_effort = request.get("reasoning_effort") + repetition_penalty = request.get("repetition_penalty") + min_p = request.get("min_p") + extra_params = request.get("extra_params") + else: + model = getattr(request, "model", None) + messages = getattr(request, "messages", []) + top_k = getattr(request, "top_k", None) + top_p = getattr(request, "top_p", None) + temperature = getattr(request, "temperature", None) + max_tokens = getattr(request, "max_tokens", None) + stop = getattr(request, "stop", None) + seed = getattr(request, "seed", None) + reasoning_effort = getattr(request, "reasoning_effort", None) + repetition_penalty = getattr(request, "repetition_penalty", None) + min_p = getattr(request, "min_p", None) + extra_params = getattr(request, "extra_params", None) + + if not model: + raise ValueError("Model not found in request") + + # Convert messages to ChatMessage objects if they are dicts + chat_messages = [] + for msg in messages: + if isinstance(msg, dict): + chat_messages.append(ChatMessage(**msg)) + else: + chat_messages.append(msg) + + return CanonicalChatRequest( + model=model, + messages=chat_messages, + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_tokens=max_tokens, + stop=stop, + seed=seed, + reasoning_effort=reasoning_effort, + repetition_penalty=repetition_penalty, + min_p=min_p, + stream=( + request.get("stream") + if isinstance(request, dict) + else getattr(request, "stream", None) + ), + extra_body=( + request.get("extra_body") + if isinstance(request, dict) + else getattr(request, "extra_body", None) + ) + or (extra_params if extra_params is not None else None), + tools=( + request.get("tools") + if isinstance(request, dict) + else getattr(request, "tools", None) + ), + tool_choice=( + request.get("tool_choice") + if isinstance(request, dict) + else getattr(request, "tool_choice", None) + ), + ) + + @staticmethod + def _validate_request_parameters(request: CanonicalChatRequest) -> None: + """Validate required parameters in a domain request.""" + if not request.model: + raise ValueError("Model is required") + + if not request.messages: + raise ValueError("Messages are required") + + # Validate message structure + for message in request.messages: + if not message.role: + raise ValueError("Message role is required") + + # Allow assistant messages that carry only tool_calls (no textual content) + if message.role != "system": + has_text = bool(message.content) + has_tool_calls = bool(getattr(message, "tool_calls", None)) + if not has_text and not ( + message.role == "assistant" and has_tool_calls + ): + raise ValueError(f"Content is required for {message.role} messages") + + # Validate tool parameters if present + if request.tools: + for tool in request.tools: + if isinstance(tool, dict): + if "function" not in tool: + raise ValueError("Tool must have a function") + if "name" not in tool.get("function", {}): + raise ValueError("Tool function must have a name") + + @staticmethod + def from_domain_to_gemini_request(request: CanonicalChatRequest) -> dict[str, Any]: + """ + Translate a CanonicalChatRequest to a Gemini request. + """ + + Translation._validate_request_parameters(request) + + config: dict[str, Any] = {} + if request.top_k is not None: + config["topK"] = request.top_k + if request.top_p is not None: + config["topP"] = request.top_p + if request.temperature is not None: + config["temperature"] = request.temperature + if request.max_tokens is not None: + config["maxOutputTokens"] = request.max_tokens + if request.stop: + config["stopSequences"] = Translation._normalize_stop_sequences( + request.stop + ) + + # Handle thinking budget overrides and reasoning effort mapping. + def _resolve_thinking_budget(reasoning_effort: str | None) -> int | None: + """Resolve thinking budget from CLI override or reasoning effort.""" + cli_value = os.environ.get("THINKING_BUDGET") + if cli_value is not None: + try: + return int(cli_value) + except ValueError: + return None + + if reasoning_effort is None: + return None + + effort_to_budget: dict[str, int] = { + "low": 512, + "medium": 2048, + "high": -1, + } + + return effort_to_budget.get(reasoning_effort.lower(), None) + + thinking_budget = _resolve_thinking_budget(request.reasoning_effort) + if thinking_budget is not None: + config["thinkingConfig"] = { + "thinkingBudget": thinking_budget, + "includeThoughts": True, + } + + # Process messages with proper handling of multimodal content and tool calls + contents: list[dict[str, Any]] = [] + # Track tool_call id -> function name to map tool responses + tool_name_by_id: dict[str, str] = {} + + # Group consecutive tool messages together to match Gemini's requirement + # that all functionResponse parts must be in a single user message + i = 0 + while i < len(request.messages): + message = request.messages[i] + + # Map assistant role to 'model' for Gemini compatibility; keep others as-is + if message.role == "assistant": + gemini_role = "model" + elif message.role == "tool": + # Gemini expects function responses from the "user" role + gemini_role = "user" + else: + gemini_role = message.role + msg_dict: dict[str, Any] = {"role": gemini_role} + parts: list[dict[str, Any]] = [] + + # Add assistant tool calls as functionCall parts + has_tool_calls = message.role == "assistant" and getattr( + message, "tool_calls", None + ) + if has_tool_calls: + try: + for tc in message.tool_calls or []: + tc_dict = tc if isinstance(tc, dict) else tc.model_dump() + fn = (tc_dict.get("function") or {}).get("name", "") + args_raw = (tc_dict.get("function") or {}).get("arguments", "") + # Remember mapping for subsequent tool responses + if "id" in tc_dict: + tool_name_by_id[tc_dict["id"]] = fn + # Parse arguments as JSON when possible + import json as _json + + try: + args_val = ( + _json.loads(args_raw) + if isinstance(args_raw, str) + else args_raw + ) + except Exception: + args_val = args_raw + parts.append({"functionCall": {"name": fn, "args": args_val}}) + except Exception: + # Best-effort; continue even if a tool call cannot be parsed + pass + + # Handle content which could be string, list of parts, or None + # IMPORTANT: Gemini API requires that if a message has functionCall parts, + # it should NOT have text content in the same message. This prevents + # "number of function response parts not equal to function call parts" errors. + if not has_tool_calls: + if isinstance(message.content, str): + # Simple text content + parts.append({"text": message.content}) + elif isinstance(message.content, list): + # Multimodal content (list of parts) + for part in message.content: + if hasattr(part, "type") and part.type == "image_url": + processed_image = Translation._process_gemini_image_part( + part + ) + if processed_image: + parts.append(processed_image) + elif hasattr(part, "type") and part.type == "text": + from src.core.domain.chat import MessageContentPartText + + # Handle text part + if isinstance(part, MessageContentPartText) and hasattr( + part, "text" + ): + parts.append({"text": part.text}) + else: + # Try best effort conversion + if hasattr(part, "model_dump"): + part_dict = part.model_dump() + if "text" in part_dict: + parts.append({"text": part_dict["text"]}) + + # Map tool role messages to functionResponse parts + # Group all consecutive tool messages into a single user message + if message.role == "tool": + # Collect all consecutive tool messages + tool_messages = [message] + j = i + 1 + while j < len(request.messages) and request.messages[j].role == "tool": + tool_messages.append(request.messages[j]) + j += 1 + + # Process all tool messages into functionResponse parts + for tool_msg in tool_messages: + # Try to map tool_call_id back to the function name + name = tool_name_by_id.get( + getattr(tool_msg, "tool_call_id", ""), "" + ) + resp_obj: dict[str, Any] + val = tool_msg.content + # Try to parse JSON result if provided + if isinstance(val, str): + import json as _json + + try: + resp_obj = _json.loads(val) + except Exception: + resp_obj = {"text": val} + elif isinstance(val, dict): + resp_obj = val + else: + resp_obj = {"text": str(val)} + + parts.append( + {"functionResponse": {"name": name, "response": resp_obj}} + ) + + # Skip the tool messages we just processed + i = j - 1 + + # Add parts to message + msg_dict["parts"] = parts # type: ignore + + # Only add non-empty messages + if parts: + contents.append(msg_dict) + + i += 1 + + result = {"contents": contents, "generationConfig": config} + + # Add tools if present + if request.tools: + # Gemini Code Assist only allows multiple tools when they are all + # search tools. For function calling, we must group ALL functions + # into a SINGLE tool entry with a combined function_declarations list. + function_declarations: list[dict[str, Any]] = [] + + for tool in request.tools: + # Accept dict-like or model-like entries + tool_dict: dict[str, Any] + if isinstance(tool, dict): + tool_dict = tool + else: + try: + tool_dict = tool.model_dump() # type: ignore[attr-defined] + except Exception: + tool_dict = {} + function = ( + tool_dict.get("function") if isinstance(tool_dict, dict) else None + ) + if not function: + # Skip non-function tools for now (unsupported/mixed types) + continue + + params = Translation._sanitize_gemini_parameters( + function.get("parameters", {}) + ) + function_declarations.append( + { + "name": function.get("name", ""), + "description": function.get("description", ""), + "parameters": params, + } + ) + + if function_declarations: + result["tools"] = [{"function_declarations": function_declarations}] + + # Handle tool_choice for Gemini + if request.tool_choice: + mode = "AUTO" # Default + allowed_functions = None + + if isinstance(request.tool_choice, str): + if request.tool_choice == "none": + mode = "NONE" + elif request.tool_choice == "auto": + mode = "AUTO" + elif request.tool_choice in ["any", "required"]: + mode = "ANY" + elif ( + isinstance(request.tool_choice, dict) + and request.tool_choice.get("type") == "function" + ): + function_spec = request.tool_choice.get("function", {}) + function_name = function_spec.get("name") + if function_name: + mode = "ANY" + allowed_functions = [function_name] + + fcc: dict[str, Any] = {"mode": mode} + if allowed_functions: + fcc["allowedFunctionNames"] = allowed_functions + result["toolConfig"] = {"functionCallingConfig": fcc} + + # Handle structured output for Responses API + if request.extra_body and "response_format" in request.extra_body: + response_format = request.extra_body["response_format"] + if response_format.get("type") == "json_schema": + json_schema = response_format.get("json_schema", {}) + schema = json_schema.get("schema", {}) + + # For Gemini, add JSON mode and schema constraint to generation config + generation_config = result["generationConfig"] + if isinstance(generation_config, dict): + generation_config["responseMimeType"] = "application/json" + generation_config["responseSchema"] = schema + + # Add schema name and description as additional context if available + schema_name = json_schema.get("name") + schema_description = json_schema.get("description") + if schema_name or schema_description: + # Add schema context to the last user message or create a system-like instruction + schema_context = "Generate a JSON response" + if schema_name: + schema_context += f" for '{schema_name}'" + if schema_description: + schema_context += f": {schema_description}" + schema_context += ( + ". The response must conform to the provided JSON schema." + ) + + # Add this as context to help the model understand the structured output requirement + if ( + contents + and isinstance(contents[-1], dict) + and contents[-1].get("role") == "user" + ): + # Append to the last user message + last_message = contents[-1] + if ( + isinstance(last_message, dict) + and last_message.get("parts") + and isinstance(last_message["parts"], list) + ): + last_message["parts"].append( + {"text": f"\n\n{schema_context}"} + ) + else: + # Add as a new user message + contents.append( + {"role": "user", "parts": [{"text": schema_context}]} + ) + + return result + + @staticmethod + def _sanitize_gemini_parameters(schema: dict[str, Any]) -> dict[str, Any]: + """Sanitize OpenAI tool JSON schema for Gemini Code Assist function_declarations. + + The Code Assist API rejects certain JSON Schema keywords (e.g., "$schema", + and sometimes draft-specific fields like "exclusiveMinimum"). This method + removes unsupported keywords while preserving the core shape (type, + properties, required, items, enum, etc.). + + Args: + schema: Original JSON schema dict from OpenAI tool definition + + Returns: + A sanitized schema dict suitable for Gemini Code Assist. + """ + if not isinstance(schema, dict): + return {} + + blacklist = { + "$schema", + "$id", + "$comment", + "exclusiveMinimum", + "exclusiveMaximum", + } + + def _clean(obj: Any) -> Any: + if isinstance(obj, dict): + cleaned: dict[str, Any] = {} + for k, v in obj.items(): + if k in blacklist: + continue + cleaned[k] = _clean(v) + return cleaned + if isinstance(obj, list): + return [_clean(x) for x in obj] + return obj + + cleaned = _clean(schema) + return cleaned if isinstance(cleaned, dict) else {} + + @staticmethod + def from_domain_to_openai_request(request: CanonicalChatRequest) -> dict[str, Any]: + """ + Translate a CanonicalChatRequest to an OpenAI request. + """ + messages_payload: list[dict[str, Any]] = [] + for message in request.messages: + if hasattr(message, "to_dict"): + message_dict = message.to_dict() + # Preserve explicit `content: None` semantics expected by the + # OpenAI API when tool_calls are present. + if "content" not in message_dict: + message_dict["content"] = None + else: + message_dict = { + "role": getattr(message, "role", "assistant"), + "content": getattr(message, "content", None), + } + tool_calls = getattr(message, "tool_calls", None) + if tool_calls is not None: + message_dict["tool_calls"] = tool_calls + tool_call_id = getattr(message, "tool_call_id", None) + if tool_call_id is not None: + message_dict["tool_call_id"] = tool_call_id + + messages_payload.append(message_dict) + + payload: dict[str, Any] = { + "model": request.model, + "messages": messages_payload, + } + + # Add all supported parameters + if request.top_p is not None: + payload["top_p"] = request.top_p + if request.temperature is not None: + payload["temperature"] = request.temperature + if request.max_tokens is not None: + payload["max_tokens"] = request.max_tokens + if request.stream is not None: + payload["stream"] = request.stream + if request.stop is not None: + payload["stop"] = Translation._normalize_stop_sequences(request.stop) + if request.seed is not None: + payload["seed"] = request.seed + if request.presence_penalty is not None: + payload["presence_penalty"] = request.presence_penalty + if request.frequency_penalty is not None: + payload["frequency_penalty"] = request.frequency_penalty + if request.repetition_penalty is not None: + payload["repetition_penalty"] = request.repetition_penalty + if request.min_p is not None: + payload["min_p"] = request.min_p + if request.user is not None: + payload["user"] = request.user + if request.tools is not None: + payload["tools"] = request.tools + if request.tool_choice is not None: + payload["tool_choice"] = request.tool_choice + + # Handle OpenAI reasoning configuration + reasoning_payload: dict[str, Any] | None = None + if request.reasoning is not None: + if isinstance(request.reasoning, dict): + reasoning_payload = dict(request.reasoning) + elif hasattr(request.reasoning, "model_dump"): + reasoning_payload = request.reasoning.model_dump() # type: ignore[attr-defined] + + effort_value = request.reasoning_effort + normalized_effort: str | None + if isinstance(effort_value, str): + normalized_effort = effort_value.strip() + else: + normalized_effort = str(effort_value) if effort_value is not None else None + + if normalized_effort: + if reasoning_payload is None: + reasoning_payload = {} + if "effort" not in reasoning_payload: + reasoning_payload["effort"] = effort_value + + if reasoning_payload: + payload["reasoning"] = reasoning_payload + + # Handle structured output for Responses API + if request.extra_body and "response_format" in request.extra_body: + response_format = request.extra_body["response_format"] + if response_format.get("type") == "json_schema": + # For OpenAI, we can pass the response_format directly + payload["response_format"] = response_format + + return payload + + @staticmethod + def anthropic_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: + """ + Translate an Anthropic streaming chunk to a canonical dictionary format. + + Args: + chunk: The Anthropic streaming chunk (can be SSE string or dict). + + Returns: + A dictionary representing the canonical chunk format. + """ + import json + import time + import uuid + + # Handle SSE-formatted strings + if isinstance(chunk, str): + # Parse SSE format - handle multi-line SSE events + chunk = chunk.strip() + + # Handle [DONE] marker + if "data: [DONE]" in chunk or chunk == "[DONE]": + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "claude-3-opus-20240229", + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": None, + } + ], + } + + # Extract data line from multi-line SSE chunk + data_line = None + for line in chunk.split("\n"): + line = line.strip() + if line.startswith("data:"): + data_line = line[5:].strip() + break + + # If no data line found, check if entire chunk is just event/id lines + if data_line is None: + if chunk.startswith(("event:", "id:")) or not chunk: + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "claude-3-opus-20240229", + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": None, + } + ], + } + # Try to use the whole chunk as JSON + data_line = chunk + + # Try to parse as JSON + try: + chunk = json.loads(data_line) + except json.JSONDecodeError: + return {"error": "Invalid chunk format: expected a dictionary"} + + if not isinstance(chunk, dict): + return {"error": "Invalid chunk format: expected a dictionary"} + + response_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" + created = int(time.time()) + model = "claude-3-opus-20240229" # Default model + + content = "" + finish_reason = None + role = None + + # Handle different Anthropic event types + event_type = chunk.get("type") + + if event_type == "message_start": + # Message start event - set role + role = "assistant" + elif event_type == "content_block_start": + # Content block start - no content yet + pass + elif event_type == "content_block_delta": + # Content delta - extract text + delta = chunk.get("delta", {}) + if delta.get("type") == "text_delta": + content = delta.get("text", "") + elif event_type == "content_block_stop": + # Content block stop - no action needed + pass + elif event_type == "message_delta": + # Message delta - check for finish reason + delta = chunk.get("delta", {}) + stop_reason = delta.get("stop_reason") + if stop_reason == "end_turn": + finish_reason = "stop" + elif stop_reason == "max_tokens": + finish_reason = "length" + elif stop_reason == "tool_use": + finish_reason = "tool_calls" + elif event_type == "message_stop": + # Message stop - mark as complete + finish_reason = "stop" + + # Build delta + output_delta: dict[str, Any] = {} + if role: + output_delta["role"] = role + if content: + output_delta["content"] = content + + return { + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "delta": output_delta, + "finish_reason": finish_reason, + } + ], + } + + @staticmethod + def from_domain_to_anthropic_request( + request: CanonicalChatRequest, + ) -> dict[str, Any]: + """ + Translate a CanonicalChatRequest to an Anthropic request. + """ + # Process messages with proper handling of system messages and multimodal content + processed_messages = [] + system_message = None + + for message in request.messages: + if message.role == "system": + # Extract system message + system_message = message.content + continue + + # Process regular messages + msg_dict = {"role": message.role} + + # Handle content which could be string, list of parts, or None + if message.content is None: + # Skip empty content + continue + elif isinstance(message.content, str): + # Simple text content + msg_dict["content"] = message.content + elif isinstance(message.content, list): + # Multimodal content (list of parts) + content_parts = [] + for part in message.content: + from src.core.domain.chat import ( + MessageContentPartImage, + MessageContentPartText, + ) + + if isinstance(part, MessageContentPartImage): + # Handle image part + if part.image_url: + url_str = str(part.image_url.url) + # Only include data URLs; skip http/https URLs + if url_str.startswith("data:"): + content_parts.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": url_str.split(",", 1)[-1], + }, + } + ) + elif isinstance(part, MessageContentPartText): + # Handle text part + content_parts.append({"type": "text", "text": part.text}) + else: + # Try best effort conversion + if hasattr(part, "model_dump"): + part_dict = part.model_dump() + if "text" in part_dict: + content_parts.append( + {"type": "text", "text": part_dict["text"]} + ) + + if content_parts: + # Use type annotation to help mypy + msg_dict["content"] = content_parts # type: ignore + + # Handle tool calls if present + if message.tool_calls: + tool_calls = [] + for tool_call in message.tool_calls: + if hasattr(tool_call, "model_dump"): + tool_call_dict = tool_call.model_dump() + tool_calls.append(tool_call_dict) + elif isinstance(tool_call, dict): + tool_calls.append(tool_call) + else: + # Convert to dict if possible + try: + tool_call_dict = dict(tool_call) + tool_calls.append(tool_call_dict) + except (TypeError, ValueError): + # Skip if can't convert + continue + + if tool_calls: + # Use type annotation to help mypy + msg_dict["tool_calls"] = tool_calls # type: ignore + + # Handle tool call ID if present + if message.tool_call_id: + msg_dict["tool_call_id"] = message.tool_call_id + + # Handle name if present + if message.name: + msg_dict["name"] = message.name + + processed_messages.append(msg_dict) + + payload: dict[str, Any] = { + "model": request.model, + "messages": processed_messages, + "max_tokens": request.max_tokens or 1024, + "stream": request.stream, + } + + if system_message: + payload["system"] = system_message + if request.temperature is not None: + payload["temperature"] = request.temperature + if request.top_p is not None: + payload["top_p"] = request.top_p + if request.top_k is not None: + payload["top_k"] = request.top_k + + # Handle tools if present + if request.tools: + # Convert tools to Anthropic format + anthropic_tools = [] + for tool in request.tools: + if isinstance(tool, dict) and "function" in tool: + anthropic_tool = {"type": "function", "function": tool["function"]} + anthropic_tools.append(anthropic_tool) + elif not isinstance(tool, dict): + tool_dict = tool.model_dump() + if "function" in tool_dict: + anthropic_tool = { + "type": "function", + "function": tool_dict["function"], + } + anthropic_tools.append(anthropic_tool) + + if anthropic_tools: + payload["tools"] = anthropic_tools + + # Handle tool_choice if present + if request.tool_choice: + if isinstance(request.tool_choice, dict): + if request.tool_choice.get("type") == "function": + # Already in Anthropic format + payload["tool_choice"] = request.tool_choice + elif "function" in request.tool_choice: + # Convert from OpenAI format to Anthropic format + payload["tool_choice"] = { + "type": "function", + "function": request.tool_choice["function"], + } + elif request.tool_choice == "auto" or request.tool_choice == "none": + payload["tool_choice"] = request.tool_choice + + # Add stop sequences if present + if request.stop: + payload["stop_sequences"] = Translation._normalize_stop_sequences( + request.stop + ) + + # Add metadata if present in extra_body + if request.extra_body and isinstance(request.extra_body, dict): + metadata = request.extra_body.get("metadata") + if metadata: + payload["metadata"] = metadata + + # Handle structured output for Responses API + response_format = request.extra_body.get("response_format") + if response_format and response_format.get("type") == "json_schema": + json_schema = response_format.get("json_schema", {}) + schema = json_schema.get("schema", {}) + schema_name = json_schema.get("name") + schema_description = json_schema.get("description") + strict = json_schema.get("strict", True) + + # For Anthropic, add comprehensive JSON schema instruction to system message + import json + + schema_instruction = ( + "\n\nYou must respond with valid JSON that conforms to this schema" + ) + if schema_name: + schema_instruction += f" for '{schema_name}'" + if schema_description: + schema_instruction += f" ({schema_description})" + schema_instruction += f":\n\n{json.dumps(schema, indent=2)}" + + if strict: + schema_instruction += "\n\nIMPORTANT: The response must strictly adhere to this schema. Do not include any additional fields or deviate from the specified structure." + else: + schema_instruction += "\n\nNote: The response should generally follow this schema, but minor variations may be acceptable." + + schema_instruction += "\n\nRespond only with the JSON object, no additional text or formatting." + + if payload.get("system"): + if isinstance(payload["system"], str): + payload["system"] += schema_instruction + else: + # If not a string, we cannot append. Replace it. + payload["system"] = schema_instruction + else: + payload["system"] = ( + f"You are a helpful assistant.{schema_instruction}" + ) + + return payload + + @staticmethod + def code_assist_to_domain_request(request: Any) -> CanonicalChatRequest: + """ + Translate a Code Assist API request to a CanonicalChatRequest. + + The Code Assist API uses the same format as OpenAI for the core request, + but with additional project field and different endpoint. + """ + # Code Assist API request format is essentially the same as OpenAI + # but may include a "project" field + if isinstance(request, dict): + # Remove Code Assist specific fields and treat as OpenAI format + cleaned_request = {k: v for k, v in request.items() if k != "project"} + return Translation.openai_to_domain_request(cleaned_request) + else: + # Handle object format by extracting fields + return Translation.openai_to_domain_request(request) + + @staticmethod + def code_assist_to_domain_response(response: Any) -> CanonicalChatResponse: + """ + Translate a Code Assist API response to a CanonicalChatResponse. + + The Code Assist API wraps the response in a "response" object and uses + different structure than standard Gemini API. + """ + import time + + if not isinstance(response, dict): + # Handle non-dict responses + return CanonicalChatResponse( + id=f"chatcmpl-code-assist-{int(time.time())}", + object="chat.completion", + created=int(time.time()), + model="unknown", + choices=[ + ChatCompletionChoice( + index=0, + message=ChatCompletionChoiceMessage( + role="assistant", content=str(response) + ), + finish_reason="stop", + ) + ], + usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + ) + + # Extract from Code Assist response wrapper + response_wrapper = response.get("response", {}) + candidates = response_wrapper.get("candidates", []) + generated_text = "" + + if candidates and len(candidates) > 0: + candidate = candidates[0] + content = candidate.get("content") or {} + parts = content.get("parts", []) + + if parts and len(parts) > 0: + generated_text = parts[0].get("text", "") + + # Create canonical response + return CanonicalChatResponse( + id=f"chatcmpl-code-assist-{int(time.time())}", + object="chat.completion", + created=int(time.time()), + model=response.get("model", "code-assist-model"), + choices=[ + ChatCompletionChoice( + index=0, + message=ChatCompletionChoiceMessage( + role="assistant", content=generated_text + ), + finish_reason="stop", + ) + ], + usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + ) + + @staticmethod + def code_assist_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: + """ + Translate a Code Assist API streaming chunk to a canonical dictionary format. + + Code Assist API uses Server-Sent Events (SSE) format with "data: " prefix. + """ + import time + import uuid + + if chunk is None: + # Handle end of stream + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "code-assist-model", + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + } + ], + } + + if not isinstance(chunk, dict): + return {"error": "Invalid chunk format: expected a dictionary"} + + response_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" + created = int(time.time()) + model = "code-assist-model" + + content = "" + finish_reason = None + tool_calls: list[dict[str, Any]] | None = None + + # Extract from Code Assist response wrapper + response_wrapper = chunk.get("response", {}) + candidates = response_wrapper.get("candidates", []) + + if candidates and len(candidates) > 0: + candidate = candidates[0] + content_obj = candidate.get("content") or {} + parts = content_obj.get("parts", []) + + if parts and len(parts) > 0: + # Collect text and function calls + text_parts: list[str] = [] + for part in parts: + if isinstance(part, dict) and "text" in part: + text_parts.append(part.get("text", "")) + elif isinstance(part, dict) and "functionCall" in part: + try: + if tool_calls is None: + tool_calls = [] + tool_calls.append( + Translation._process_gemini_function_call( + part["functionCall"] + ).model_dump() + ) + except Exception: + # Ignore malformed functionCall parts + continue + content = "".join(text_parts) + + if "finishReason" in candidate: + finish_reason = candidate["finishReason"] + + delta: dict[str, Any] = {"role": "assistant"} + if tool_calls: + delta["tool_calls"] = tool_calls + # Enforce OpenAI semantics: when tool_calls are present, do not include content + delta.pop("content", None) + # Force finish_reason to tool_calls to signal clients to execute tools + finish_reason = "tool_calls" + elif content: + delta["content"] = content + + return { + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "delta": delta, + "finish_reason": finish_reason, + } + ], + } + + @staticmethod + def raw_text_to_domain_request(request: Any) -> CanonicalChatRequest: + """ + Translate a raw text request to a CanonicalChatRequest. + + Raw text format is typically used for simple text processing where + the input is just a plain text string. + """ + + if isinstance(request, str): + # Create a simple request with the text as user message + from src.core.domain.chat import ChatMessage + + return CanonicalChatRequest( + model="text-model", + messages=[ChatMessage(role="user", content=request)], + ) + elif isinstance(request, dict): + # If it's already a dict, treat it as OpenAI format + return Translation.openai_to_domain_request(request) + else: + # Handle object format + return Translation.openai_to_domain_request(request) + + @staticmethod + def raw_text_to_domain_response(response: Any) -> CanonicalChatResponse: + """ + Translate a raw text response to a CanonicalChatResponse. + + Raw text format is typically used for simple text responses. + """ + import time + + if isinstance(response, str): + return CanonicalChatResponse( + id=f"chatcmpl-raw-text-{int(time.time())}", + object="chat.completion", + created=int(time.time()), + model="text-model", + choices=[ + ChatCompletionChoice( + index=0, + message=ChatCompletionChoiceMessage( + role="assistant", content=response + ), + finish_reason="stop", + ) + ], + usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + ) + elif isinstance(response, dict): + # If it's already a dict, treat it as OpenAI format + return Translation.openai_to_domain_response(response) + else: + # Handle object format + return Translation.openai_to_domain_response(response) + + @staticmethod + def raw_text_to_domain_stream_chunk(chunk: Any) -> dict[str, Any]: + """ + Translate a raw text stream chunk to a canonical dictionary format. + + Raw text chunks are typically plain text strings. + """ + import time + import uuid + + if chunk is None: + # Handle end of stream + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "text-model", + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + } + ], + } + + if isinstance(chunk, str): + # Raw text chunk + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "text-model", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": chunk}, + "finish_reason": None, + } + ], + } + elif isinstance(chunk, dict): + # Check if it's a wrapped text dict like {"text": "content"} + if "text" in chunk and isinstance(chunk["text"], str): + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:16]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "text-model", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": chunk["text"]}, + "finish_reason": None, + } + ], + } + else: + # If it's already a dict, treat it as OpenAI format + return Translation.openai_to_domain_stream_chunk(chunk) + else: + return {"error": "Invalid raw text chunk format"} + + @staticmethod + def responses_to_domain_request(request: Any) -> CanonicalChatRequest: + """ + Translate a Responses API request to a CanonicalChatRequest. + + The Responses API request includes structured output requirements via response_format. + This method converts the request to the internal domain format while preserving + the JSON schema information for later use by backends. + """ + from src.core.domain.responses_api import ResponsesRequest + + # Normalize incoming payload regardless of format (dict, model, or object) + def _prepare_payload(payload: dict[str, Any]) -> dict[str, Any]: + normalized_payload = dict(payload) + if "messages" not in normalized_payload and "input" in normalized_payload: + normalized_payload["messages"] = ( + Translation._normalize_responses_input_to_messages( + normalized_payload["input"] + ) + ) + return normalized_payload + + if isinstance(request, dict): + request_payload = _prepare_payload(request) + if not request_payload.get("model"): + raise ValueError("'model' is a required property") + other_params = { + k: v + for k, v in request_payload.items() + if k not in ["model", "messages"] + } + responses_request = ResponsesRequest( + model=request_payload.get("model") or "", + messages=request_payload.get("messages") or [], + **other_params, + ) + elif hasattr(request, "model_dump"): + request_payload = _prepare_payload(request.model_dump()) + responses_request = ( + request + if isinstance(request, ResponsesRequest) + else ResponsesRequest(**request_payload) + ) + else: + request_payload = { + "model": getattr(request, "model", None), + "messages": getattr(request, "messages", None), + "response_format": getattr(request, "response_format", None), + "max_tokens": getattr(request, "max_tokens", None), + "temperature": getattr(request, "temperature", None), + "top_p": getattr(request, "top_p", None), + "n": getattr(request, "n", None), + "stream": getattr(request, "stream", None), + "stop": getattr(request, "stop", None), + "presence_penalty": getattr(request, "presence_penalty", None), + "frequency_penalty": getattr(request, "frequency_penalty", None), + "logit_bias": getattr(request, "logit_bias", None), + "user": getattr(request, "user", None), + "seed": getattr(request, "seed", None), + "session_id": getattr(request, "session_id", None), + "agent": getattr(request, "agent", None), + "extra_body": getattr(request, "extra_body", None), + } + + input_value = getattr(request, "input", None) + if (not request_payload.get("messages")) and input_value is not None: + request_payload["messages"] = ( + Translation._normalize_responses_input_to_messages(input_value) + ) + + other_params = { + k: v + for k, v in request_payload.items() + if k not in ["model", "messages"] + } + responses_request = ResponsesRequest( + model=request_payload.get("model") or "", + messages=request_payload.get("messages") or [], + **other_params, + ) + + # Prepare extra_body with response format + extra_body = dict(responses_request.extra_body or {}) + if responses_request.response_format is not None: + extra_body["response_format"] = ( + responses_request.response_format.model_dump() + ) + + # Convert to CanonicalChatRequest + canonical_request = CanonicalChatRequest( + model=responses_request.model, + messages=responses_request.messages, + temperature=responses_request.temperature, + top_p=responses_request.top_p, + max_tokens=responses_request.max_tokens, + n=responses_request.n, + stream=responses_request.stream, + stop=responses_request.stop, + presence_penalty=responses_request.presence_penalty, + frequency_penalty=responses_request.frequency_penalty, + logit_bias=responses_request.logit_bias, + user=responses_request.user, + seed=responses_request.seed, + session_id=responses_request.session_id, + agent=responses_request.agent, + extra_body=extra_body, + ) + + return canonical_request + + @staticmethod + def from_domain_to_responses_response(response: ChatResponse) -> dict[str, Any]: + """ + Translate a domain ChatResponse to a Responses API response format. + + This method converts the internal domain response to the OpenAI Responses API format, + including parsing structured outputs and handling JSON schema validation results. + """ + import json + import time + + # Convert choices to Responses API format + choices = [] + output_items: list[dict[str, Any]] = [] + aggregated_output_text: list[str | None] = [] + + def _map_finish_reason_to_status(finish_reason: str | None) -> str: + if finish_reason in (None, "", "stop"): + return "completed" + if finish_reason == "length": + return "incomplete" + if finish_reason in {"tool_calls", "function_call"}: + return "requires_action" + if finish_reason == "content_filter": + return "blocked" + return "completed" + + for choice in response.choices: + if choice.message: + # Try to parse the content as JSON for structured output + parsed_content = None + raw_content = choice.message.content or "" + + # Clean up content for JSON parsing + cleaned_content = raw_content.strip() + + # Handle cases where the model might wrap JSON in markdown code blocks + if cleaned_content.startswith("```json") and cleaned_content.endswith( + "```" + ): + cleaned_content = cleaned_content[7:-3].strip() + elif cleaned_content.startswith("```") and cleaned_content.endswith( + "```" + ): + cleaned_content = cleaned_content[3:-3].strip() + + # Attempt to parse JSON content + if cleaned_content: + try: + parsed_content = json.loads(cleaned_content) + # If parsing succeeded, use the cleaned content as the actual content + raw_content = cleaned_content + except json.JSONDecodeError: + # Content is not valid JSON, leave parsed as None + # Try to extract JSON from the content if it contains other text + try: + # Look for JSON-like patterns in the content + import re + + json_pattern = r"\{.*\}" + json_match = re.search( + json_pattern, cleaned_content, re.DOTALL + ) + if json_match: + potential_json = json_match.group(0) + parsed_content = json.loads(potential_json) + raw_content = potential_json + except (json.JSONDecodeError, AttributeError): + # Still not valid JSON, leave parsed as None + pass + + message_payload: dict[str, Any] = { + "role": choice.message.role, + "content": raw_content or None, + "parsed": parsed_content, + } + + tool_calls_payload: list[dict[str, Any]] = [] + if choice.message.tool_calls: + for tool_call in choice.message.tool_calls: + if hasattr(tool_call, "model_dump"): + tool_data = tool_call.model_dump() + elif isinstance(tool_call, dict): + tool_data = dict(tool_call) + else: + function = getattr(tool_call, "function", None) + tool_data = { + "id": getattr(tool_call, "id", ""), + "type": getattr(tool_call, "type", "function"), + "function": { + "name": getattr(function, "name", ""), + "arguments": getattr(function, "arguments", "{}"), + }, + } + + function_payload = tool_data.get("function") + if isinstance(function_payload, dict): + arguments = function_payload.get("arguments") + if isinstance(arguments, dict | list): + function_payload["arguments"] = json.dumps(arguments) + elif arguments is None: + function_payload["arguments"] = "{}" + + tool_calls_payload.append(tool_data) + + if tool_calls_payload: + message_payload["tool_calls"] = tool_calls_payload + + response_choice = { + "index": choice.index, + "message": message_payload, + "finish_reason": choice.finish_reason or "stop", + } + choices.append(response_choice) + + # Build Responses API output structure (mirrors incoming payloads) + text_value = ( + message_payload.get("content") + if isinstance(message_payload.get("content"), str) + else None + ) + if text_value: + aggregated_output_text.append(text_value) + else: + aggregated_output_text.append(None) + + output_content_parts: list[dict[str, Any]] = [] + + if text_value: + output_content_parts.append( + {"type": "output_text", "text": text_value} + ) + + if tool_calls_payload: + for tool_call in tool_calls_payload: # type: ignore[assignment] + tool_call_dict: dict[str, Any] = ( + tool_call if isinstance(tool_call, dict) else {} + ) + output_content_parts.append( + { + "type": "tool_call", + "id": tool_call_dict.get("id", ""), + "function": tool_call_dict.get("function", {}), + } + ) + + output_item = { + "id": f"msg-{response.id}-{choice.index}", + "type": "message", + "role": choice.message.role, + "status": _map_finish_reason_to_status(choice.finish_reason), + "content": output_content_parts, + } + + if choice.finish_reason: + output_item["finish_reason"] = choice.finish_reason + + output_items.append(output_item) + + # Build the Responses API response + responses_response = { + "id": response.id, + "object": "response", + "created": response.created or int(time.time()), + "model": response.model, + "choices": choices, + } + + if output_items: + responses_response["output"] = output_items + + # Only include output_text if we have any textual content + text_values = [text for text in aggregated_output_text if text is not None] + if text_values: + responses_response["output_text"] = [ + text if text is not None else "" for text in aggregated_output_text + ] + + # Add usage information if available + if response.usage: + responses_response["usage"] = response.usage + + # Add system fingerprint if available + if hasattr(response, "system_fingerprint") and response.system_fingerprint: + responses_response["system_fingerprint"] = response.system_fingerprint + + return responses_response + + @staticmethod + def from_domain_to_responses_request( + request: CanonicalChatRequest, + ) -> dict[str, Any]: + """ + Translate a CanonicalChatRequest to an OpenAI Responses API request format. + + This method converts the internal domain request to the OpenAI Responses API format, + extracting the response_format from extra_body and structuring it properly. + """ + # Start with basic OpenAI request format + payload = Translation.from_domain_to_openai_request(request) + + if request.extra_body: + extra_body_copy = dict(request.extra_body) + + # Extract and restructure response_format from extra_body + response_format = extra_body_copy.pop("response_format", None) + if response_format is not None: + # Ensure the response_format is properly structured for Responses API + if isinstance(response_format, dict): + payload["response_format"] = response_format + elif hasattr(response_format, "model_dump"): + payload["response_format"] = response_format.model_dump() + else: + payload["response_format"] = response_format + + # Add any remaining extra_body parameters that are safe for Responses API + safe_extra_body = Translation._filter_responses_extra_body(extra_body_copy) + if safe_extra_body: + payload.update(safe_extra_body) + + return payload + + @staticmethod + def _filter_responses_extra_body(extra_body: dict[str, Any]) -> dict[str, Any]: + """Filter extra_body entries to include only Responses API specific parameters.""" + + if not extra_body: + return {} + + allowed_keys: set[str] = {"metadata"} + + return {key: value for key, value in extra_body.items() if key in allowed_keys} + + @staticmethod + def enhance_structured_output_response( + response: ChatResponse, + original_request_extra_body: dict[str, Any] | None = None, + ) -> ChatResponse: + """ + Enhance a ChatResponse with structured output validation and repair. + + This method validates the response against the original JSON schema + and attempts repair if validation fails. + + Args: + response: The original ChatResponse + original_request_extra_body: The extra_body from the original request containing schema info + + Returns: + Enhanced ChatResponse with validated/repaired structured output + """ + if not original_request_extra_body: + return response + + response_format = original_request_extra_body.get("response_format") + if not response_format or response_format.get("type") != "json_schema": + return response + + json_schema_info = response_format.get("json_schema", {}) + schema = json_schema_info.get("schema", {}) + + if not schema: + return response + + # Process each choice + enhanced_choices = [] + for choice in response.choices: + if not choice.message or not choice.message.content: + enhanced_choices.append(choice) + continue + + content = choice.message.content.strip() + + # Try to parse and validate the JSON + try: + parsed_json = json.loads(content) + + # Validate against schema + is_valid, error_msg = Translation.validate_json_against_schema( + parsed_json, schema + ) + + if is_valid: + # Content is valid, keep as is + enhanced_choices.append(choice) + else: + # Try to repair the JSON + repaired_json = Translation._attempt_json_repair( + parsed_json, schema, error_msg + ) + if repaired_json is not None: + # Use repaired JSON + repaired_content = json.dumps(repaired_json, indent=2) + enhanced_message = ChatCompletionChoiceMessage( + role=choice.message.role, + content=repaired_content, + tool_calls=choice.message.tool_calls, + ) + enhanced_choice = ChatCompletionChoice( + index=choice.index, + message=enhanced_message, + finish_reason=choice.finish_reason, + ) + enhanced_choices.append(enhanced_choice) + else: + # Repair failed, keep original + enhanced_choices.append(choice) + + except json.JSONDecodeError: + # Not valid JSON, try to extract and repair + extracted_and_repaired_content: str | None = ( + Translation._extract_and_repair_json(content, schema) + ) + if extracted_and_repaired_content is not None: + enhanced_message = ChatCompletionChoiceMessage( + role=choice.message.role, + content=extracted_and_repaired_content, + tool_calls=choice.message.tool_calls, + ) + enhanced_choice = ChatCompletionChoice( + index=choice.index, + message=enhanced_message, + finish_reason=choice.finish_reason, + ) + enhanced_choices.append(enhanced_choice) + else: + # Repair failed, keep original + enhanced_choices.append(choice) + + # Create enhanced response + enhanced_response = CanonicalChatResponse( + id=response.id, + object=response.object, + created=response.created, + model=response.model, + choices=enhanced_choices, + usage=response.usage, + system_fingerprint=getattr(response, "system_fingerprint", None), + ) + + return enhanced_response + + @staticmethod + def _attempt_json_repair( + json_data: dict[str, Any], schema: dict[str, Any], error_msg: str | None + ) -> dict[str, Any] | None: + """ + Attempt to repair JSON data to conform to schema. + + This is a basic repair mechanism that handles common issues. + """ + try: + repaired = dict(json_data) + + # Add missing required properties + if schema.get("type") == "object": + required = schema.get("required", []) + properties = schema.get("properties", {}) + + for prop in required: + if prop not in repaired: + # Add default value based on property type + prop_schema = properties.get(prop, {}) + prop_type = prop_schema.get("type", "string") + + if prop_type == "string": + repaired[prop] = "" + elif prop_type == "number": + repaired[prop] = 0.0 + elif prop_type == "integer": + repaired[prop] = 0 + elif prop_type == "boolean": + repaired[prop] = False + elif prop_type == "array": + repaired[prop] = [] + elif prop_type == "object": + repaired[prop] = {} + else: + repaired[prop] = None + + # Validate the repaired JSON + is_valid, _ = Translation.validate_json_against_schema(repaired, schema) + return repaired if is_valid else None + + except Exception: + return None + + @staticmethod + def _iter_json_candidates( + content: str, + *, + max_candidates: int = 20, + max_object_size: int = 512 * 1024, + ) -> list[str]: + """Find potential JSON object substrings using a linear-time scan.""" + + candidates: list[str] = [] + depth = 0 + start_index: int | None = None + escape_next = False + string_delimiter: str | None = None + + for index, char in enumerate(content): + if string_delimiter is not None: + if escape_next: + escape_next = False + continue + if char == "\\": + escape_next = True + continue + if char == string_delimiter: + string_delimiter = None + continue + + if char in ('"', "'"): + string_delimiter = char + continue + + if char == "{": + if depth == 0: + start_index = index + depth += 1 + elif char == "}": + if depth == 0: + continue + depth -= 1 + if depth == 0 and start_index is not None: + candidate = content[start_index : index + 1] + start_index = None + if len(candidate) > max_object_size: + logger.warning( + "Skipping oversized JSON candidate (%d bytes)", + len(candidate), + ) + continue + candidates.append(candidate) + if len(candidates) >= max_candidates: + break + + return candidates + + @staticmethod + def _extract_and_repair_json(content: str, schema: dict[str, Any]) -> str | None: + """Extract JSON from content and attempt repair.""" + + try: + for candidate in Translation._iter_json_candidates(content): + try: + parsed = json.loads(candidate) + except json.JSONDecodeError: + continue + + if not isinstance(parsed, dict): + continue + + repaired = Translation._attempt_json_repair(parsed, schema, None) + if repaired is not None: + return json.dumps(repaired, indent=2) + + return None + except Exception: + return None diff --git a/src/core/services/parameter_resolution_service.py b/src/core/services/parameter_resolution_service.py index 842b2a95..68deb0a4 100644 --- a/src/core/services/parameter_resolution_service.py +++ b/src/core/services/parameter_resolution_service.py @@ -1,293 +1,315 @@ -""" -Parameter Resolution Service - -This module provides parameter resolution from multiple sources with precedence handling. -Tracks parameter sources for debugging and applies precedence rules. -""" - -from __future__ import annotations - -import logging -from dataclasses import dataclass -from typing import Any - -logger = logging.getLogger(__name__) - - -@dataclass -class ParameterSource: - """Tracks the source and value of a parameter.""" - - value: Any - source: str # "uri", "session", "header", "config", "default" - - def __repr__(self) -> str: - return f"ParameterSource(value={self.value!r}, source={self.source!r})" - - -@dataclass -class ResolvedParameters: - """Container for resolved parameters with source tracking.""" - - temperature: ParameterSource | None = None - reasoning_effort: ParameterSource | None = None - top_p: ParameterSource | None = None - top_k: ParameterSource | None = None - - def to_dict(self) -> dict[str, Any]: - """ - Extract just the parameter values for backend application. - - Returns: - Dictionary with parameter names and their effective values, - excluding None values. - - Examples: - >>> params = ResolvedParameters( - ... temperature=ParameterSource(0.5, "uri"), - ... reasoning_effort=ParameterSource("high", "session") - ... ) - >>> params.to_dict() - {"temperature": 0.5, "reasoning_effort": "high"} - """ - result: dict[str, Any] = {} - - if self.temperature is not None: - result["temperature"] = self.temperature.value - - if self.top_p is not None: - result["top_p"] = self.top_p.value - - if self.top_k is not None: - result["top_k"] = self.top_k.value - - if self.reasoning_effort is not None: - result["reasoning_effort"] = self.reasoning_effort.value - - return result - - def get_debug_info(self) -> dict[str, dict[str, Any]]: - """ - Get parameter sources and values for debugging. - - Returns: - Dictionary with parameter names mapped to their debug information - including effective value and source. - - Examples: - >>> params = ResolvedParameters( - ... temperature=ParameterSource(0.5, "uri") - ... ) - >>> params.get_debug_info() - { - "temperature": { - "effective_value": 0.5, - "source": "uri" - } - } - """ - result: dict[str, dict[str, Any]] = {} - - if self.temperature is not None: - result["temperature"] = { - "effective_value": self.temperature.value, - "source": self.temperature.source, - } - - if self.top_p is not None: - result["top_p"] = { - "effective_value": self.top_p.value, - "source": self.top_p.source, - } - - if self.top_k is not None: - result["top_k"] = { - "effective_value": self.top_k.value, - "source": self.top_k.source, - } - - if self.reasoning_effort is not None: - result["reasoning_effort"] = { - "effective_value": self.reasoning_effort.value, - "source": self.reasoning_effort.source, - } - - return result - - -class ParameterResolutionService: - """ - Resolves model parameters from multiple sources with precedence. - - Precedence (highest to lowest): - 1. Interactive session commands (highest priority) - 2. URI parameters from model string - 3. Request headers - 4. Configuration file defaults (lowest priority) - - The service tracks the source of each parameter value for debugging - and transparency. - """ - - # Supported parameter names - SUPPORTED_PARAMETERS = [ - "temperature", - "top_p", - "top_k", - "reasoning_effort", - ] - - def resolve_parameters( - self, - uri_params: dict[str, Any] | None = None, - header_params: dict[str, Any] | None = None, - config_params: dict[str, Any] | None = None, - session_params: dict[str, Any] | None = None, - backend: str = "", - ) -> ResolvedParameters: - """ - Resolve parameters from all sources with precedence. - - Args: - uri_params: Parameters from URI query string - header_params: Parameters from request headers - config_params: Parameters from configuration file - session_params: Parameters from interactive session commands - backend: Backend name for logging context - - Returns: - ResolvedParameters with effective values and source tracking - - Examples: - >>> service = ParameterResolutionService() - >>> result = service.resolve_parameters( - ... uri_params={"temperature": 0.5}, - ... config_params={"temperature": 0.8} - ... ) - >>> result.temperature.value - 0.5 - >>> result.temperature.source - 'uri' - """ - # Initialize with None values - uri_params = uri_params or {} - header_params = header_params or {} - config_params = config_params or {} - session_params = session_params or {} - - # Track overridden sources for debugging - overridden_sources: dict[str, list[tuple[str, Any]]] = { - param: [] for param in self.SUPPORTED_PARAMETERS - } - - # Resolve each parameter with precedence - resolved = ResolvedParameters() - - for param_name in self.SUPPORTED_PARAMETERS: - resolved_value = self._resolve_single_parameter( - param_name, - uri_params, - header_params, - config_params, - session_params, - overridden_sources, - ) - setattr(resolved, param_name, resolved_value) - - # Emit debug logs - self._log_resolution_debug(backend, resolved, overridden_sources) - - return resolved - - def _resolve_single_parameter( - self, - param_name: str, - uri_params: dict[str, Any], - header_params: dict[str, Any], - config_params: dict[str, Any], - session_params: dict[str, Any], - overridden_sources: dict[str, list[tuple[str, Any]]], - ) -> ParameterSource | None: - """ - Resolve a single parameter from all sources with precedence. - - Precedence order (highest to lowest): - 1. session_params - 2. uri_params - 3. header_params - 4. config_params - - Args: - param_name: Name of the parameter to resolve - uri_params: URI parameters - header_params: Header parameters - config_params: Config parameters - session_params: Session parameters - overridden_sources: Dict to track overridden sources for debugging - - Returns: - ParameterSource with the effective value and source, or None if not found - """ - # Collect all sources in precedence order (lowest to highest) - sources = [ - ("config", config_params.get(param_name)), - ("header", header_params.get(param_name)), - ("uri", uri_params.get(param_name)), - ("session", session_params.get(param_name)), - ] - - # Find the highest priority source with a value - effective_source: ParameterSource | None = None - - for source_name, value in sources: - if value is not None: - # Track overridden sources - if effective_source is not None: - overridden_sources[param_name].append( - (effective_source.source, effective_source.value) - ) - - # Update effective source (higher priority) - effective_source = ParameterSource(value=value, source=source_name) - - return effective_source - - def _log_resolution_debug( - self, - backend: str, - resolved: ResolvedParameters, - overridden_sources: dict[str, list[tuple[str, Any]]], - ) -> None: - """ - Emit debug logs showing parameter resolution details. - - Args: - backend: Backend name for context - resolved: Resolved parameters - overridden_sources: Dict of overridden sources for each parameter - """ - # Only log if there are resolved parameters - debug_info = resolved.get_debug_info() - if not debug_info: - return - - # Build log message - log_lines = [f"Parameter resolution for {backend}:"] - - for param_name, info in debug_info.items(): - effective_value = info["effective_value"] - source = info["source"] - - # Build override information - overrides = overridden_sources.get(param_name, []) - if overrides: - override_str = ", ".join([f"{src}={val}" for src, val in overrides]) - log_lines.append( - f" {param_name}: {effective_value} (source: {source}, overrode: {override_str})" - ) - else: - log_lines.append( - f" {param_name}: {effective_value} (source: {source})" - ) - - logger.debug("\n".join(log_lines)) +""" +Parameter Resolution Service + +This module provides parameter resolution from multiple sources with precedence handling. +Tracks parameter sources for debugging and applies precedence rules. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class ParameterSource: + """Tracks the source and value of a parameter.""" + + value: Any + source: str # "uri", "session", "header", "config", "default" + + def __repr__(self) -> str: + return f"ParameterSource(value={self.value!r}, source={self.source!r})" + + +@dataclass +class ResolvedParameters: + """Container for resolved parameters with source tracking.""" + + temperature: ParameterSource | None = None + reasoning_effort: ParameterSource | None = None + top_p: ParameterSource | None = None + top_k: ParameterSource | None = None + repetition_penalty: ParameterSource | None = None + min_p: ParameterSource | None = None + + def to_dict(self) -> dict[str, Any]: + """ + Extract just the parameter values for backend application. + + Returns: + Dictionary with parameter names and their effective values, + excluding None values. + + Examples: + >>> params = ResolvedParameters( + ... temperature=ParameterSource(0.5, "uri"), + ... reasoning_effort=ParameterSource("high", "session") + ... ) + >>> params.to_dict() + {"temperature": 0.5, "reasoning_effort": "high"} + """ + result: dict[str, Any] = {} + + if self.temperature is not None: + result["temperature"] = self.temperature.value + + if self.top_p is not None: + result["top_p"] = self.top_p.value + + if self.top_k is not None: + result["top_k"] = self.top_k.value + + if self.reasoning_effort is not None: + result["reasoning_effort"] = self.reasoning_effort.value + + if self.repetition_penalty is not None: + result["repetition_penalty"] = self.repetition_penalty.value + + if self.min_p is not None: + result["min_p"] = self.min_p.value + + return result + + def get_debug_info(self) -> dict[str, dict[str, Any]]: + """ + Get parameter sources and values for debugging. + + Returns: + Dictionary with parameter names mapped to their debug information + including effective value and source. + + Examples: + >>> params = ResolvedParameters( + ... temperature=ParameterSource(0.5, "uri") + ... ) + >>> params.get_debug_info() + { + "temperature": { + "effective_value": 0.5, + "source": "uri" + } + } + """ + result: dict[str, dict[str, Any]] = {} + + if self.temperature is not None: + result["temperature"] = { + "effective_value": self.temperature.value, + "source": self.temperature.source, + } + + if self.top_p is not None: + result["top_p"] = { + "effective_value": self.top_p.value, + "source": self.top_p.source, + } + + if self.top_k is not None: + result["top_k"] = { + "effective_value": self.top_k.value, + "source": self.top_k.source, + } + + if self.reasoning_effort is not None: + result["reasoning_effort"] = { + "effective_value": self.reasoning_effort.value, + "source": self.reasoning_effort.source, + } + + if self.repetition_penalty is not None: + result["repetition_penalty"] = { + "effective_value": self.repetition_penalty.value, + "source": self.repetition_penalty.source, + } + + if self.min_p is not None: + result["min_p"] = { + "effective_value": self.min_p.value, + "source": self.min_p.source, + } + + return result + + +class ParameterResolutionService: + """ + Resolves model parameters from multiple sources with precedence. + + Precedence (highest to lowest): + 1. Interactive session commands (highest priority) + 2. URI parameters from model string + 3. Request headers + 4. Configuration file defaults (lowest priority) + + The service tracks the source of each parameter value for debugging + and transparency. + """ + + # Supported parameter names + SUPPORTED_PARAMETERS = [ + "temperature", + "top_p", + "top_k", + "reasoning_effort", + "repetition_penalty", + "min_p", + ] + + def resolve_parameters( + self, + uri_params: dict[str, Any] | None = None, + header_params: dict[str, Any] | None = None, + config_params: dict[str, Any] | None = None, + session_params: dict[str, Any] | None = None, + backend: str = "", + ) -> ResolvedParameters: + """ + Resolve parameters from all sources with precedence. + + Args: + uri_params: Parameters from URI query string + header_params: Parameters from request headers + config_params: Parameters from configuration file + session_params: Parameters from interactive session commands + backend: Backend name for logging context + + Returns: + ResolvedParameters with effective values and source tracking + + Examples: + >>> service = ParameterResolutionService() + >>> result = service.resolve_parameters( + ... uri_params={"temperature": 0.5}, + ... config_params={"temperature": 0.8} + ... ) + >>> result.temperature.value + 0.5 + >>> result.temperature.source + 'uri' + """ + # Initialize with None values + uri_params = uri_params or {} + header_params = header_params or {} + config_params = config_params or {} + session_params = session_params or {} + + # Track overridden sources for debugging + overridden_sources: dict[str, list[tuple[str, Any]]] = { + param: [] for param in self.SUPPORTED_PARAMETERS + } + + # Resolve each parameter with precedence + resolved = ResolvedParameters() + + for param_name in self.SUPPORTED_PARAMETERS: + resolved_value = self._resolve_single_parameter( + param_name, + uri_params, + header_params, + config_params, + session_params, + overridden_sources, + ) + setattr(resolved, param_name, resolved_value) + + # Emit debug logs + self._log_resolution_debug(backend, resolved, overridden_sources) + + return resolved + + def _resolve_single_parameter( + self, + param_name: str, + uri_params: dict[str, Any], + header_params: dict[str, Any], + config_params: dict[str, Any], + session_params: dict[str, Any], + overridden_sources: dict[str, list[tuple[str, Any]]], + ) -> ParameterSource | None: + """ + Resolve a single parameter from all sources with precedence. + + Precedence order (highest to lowest): + 1. session_params + 2. uri_params + 3. header_params + 4. config_params + + Args: + param_name: Name of the parameter to resolve + uri_params: URI parameters + header_params: Header parameters + config_params: Config parameters + session_params: Session parameters + overridden_sources: Dict to track overridden sources for debugging + + Returns: + ParameterSource with the effective value and source, or None if not found + """ + # Collect all sources in precedence order (lowest to highest) + sources = [ + ("config", config_params.get(param_name)), + ("header", header_params.get(param_name)), + ("uri", uri_params.get(param_name)), + ("session", session_params.get(param_name)), + ] + + # Find the highest priority source with a value + effective_source: ParameterSource | None = None + + for source_name, value in sources: + if value is not None: + # Track overridden sources + if effective_source is not None: + overridden_sources[param_name].append( + (effective_source.source, effective_source.value) + ) + + # Update effective source (higher priority) + effective_source = ParameterSource(value=value, source=source_name) + + return effective_source + + def _log_resolution_debug( + self, + backend: str, + resolved: ResolvedParameters, + overridden_sources: dict[str, list[tuple[str, Any]]], + ) -> None: + """ + Emit debug logs showing parameter resolution details. + + Args: + backend: Backend name for context + resolved: Resolved parameters + overridden_sources: Dict of overridden sources for each parameter + """ + # Only log if there are resolved parameters + debug_info = resolved.get_debug_info() + if not debug_info: + return + + # Build log message + log_lines = [f"Parameter resolution for {backend}:"] + + for param_name, info in debug_info.items(): + effective_value = info["effective_value"] + source = info["source"] + + # Build override information + overrides = overridden_sources.get(param_name, []) + if overrides: + override_str = ", ".join([f"{src}={val}" for src, val in overrides]) + log_lines.append( + f" {param_name}: {effective_value} (source: {source}, overrode: {override_str})" + ) + else: + log_lines.append( + f" {param_name}: {effective_value} (source: {source})" + ) + + logger.debug("\n".join(log_lines)) diff --git a/tests/unit/connectors/test_precision_payload_mapping.py b/tests/unit/connectors/test_precision_payload_mapping.py index b1c443ca..656badec 100644 --- a/tests/unit/connectors/test_precision_payload_mapping.py +++ b/tests/unit/connectors/test_precision_payload_mapping.py @@ -45,6 +45,41 @@ async def fake_post(url: str, json: dict, headers: dict) -> httpx.Response: assert captured_payload.get("top_p") == 0.34 +@pytest.mark.asyncio +async def test_openai_warns_for_unsupported_penalties( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + cfg = AppConfig() + client = httpx.AsyncClient() + connector = OpenAIConnector(client, cfg, translation_service=TranslationService()) + connector.api_key = "test-api-key" + req = ChatRequest( + model="gpt-4", + messages=_messages(), + repetition_penalty=1.1, + min_p=0.05, + ) + + captured_payload: dict[str, Any] = {} + + async def fake_post(url: str, json: dict, headers: dict) -> httpx.Response: + captured_payload.update(json) + return httpx.Response( + 200, json={"id": "1", "choices": [{"message": {"content": "ok"}}]} + ) + + monkeypatch.setattr(client, "post", fake_post) + caplog.set_level("WARNING") + + await connector.chat_completions(req, req.messages, req.model) + + assert "does not support the 'repetition_penalty'" in caplog.text + assert "does not support the 'min_p'" in caplog.text + assert "repetition_penalty" not in captured_payload + assert "min_p" not in captured_payload + + @pytest.mark.asyncio async def test_openai_payload_uses_processed_messages_with_list_content( monkeypatch: pytest.MonkeyPatch, @@ -89,7 +124,12 @@ async def test_openrouter_payload_contains_temperature_and_top_p( connector = OpenRouterBackend(client, cfg, translation_service=TranslationService()) connector.api_key = "test-api-key" # Add API key to avoid authentication error req = ChatRequest( - model="openrouter:gpt-4", messages=_messages(), temperature=0.2, top_p=0.5 + model="openrouter:gpt-4", + messages=_messages(), + temperature=0.2, + top_p=0.5, + repetition_penalty=1.01, + min_p=0.15, ) captured_payload = {} @@ -106,11 +146,14 @@ async def fake_post(url: str, json: dict, headers: dict) -> httpx.Response: assert captured_payload.get("temperature") == 0.2 assert captured_payload.get("top_p") == 0.5 + assert captured_payload.get("repetition_penalty") == 1.01 + assert captured_payload.get("min_p") == 0.15 @pytest.mark.asyncio async def test_anthropic_payload_contains_temperature_and_top_p( monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, ) -> None: cfg = AppConfig() client = httpx.AsyncClient() @@ -143,12 +186,22 @@ async def fake_post(url: str, json: dict, headers: dict) -> Any: # type: ignore monkeypatch.setattr(client, "post", fake_post) req = ChatRequest( - model="claude-3", messages=_messages(), temperature=0.25, top_p=0.6 + model="claude-3", + messages=_messages(), + temperature=0.25, + top_p=0.6, + repetition_penalty=1.05, + min_p=0.2, ) + caplog.set_level("WARNING") await backend.chat_completions(req, req.messages, req.model, api_key="test-key") payload = captured.get("payload", {}) assert payload.get("temperature") == 0.25 assert payload.get("top_p") == 0.6 + assert "repetition_penalty" not in payload + assert "min_p" not in payload + assert "repetition_penalty" in caplog.text + assert "min_p" in caplog.text def test_gemini_public_generation_config_clamping_and_topk() -> None: @@ -167,7 +220,32 @@ def test_gemini_public_generation_config_clamping_and_topk() -> None: assert gc.get("topK") == 50 -def test_gemini_oauth_personal_builds_topk() -> None: +def test_gemini_generation_config_warns_for_penalties( + caplog: pytest.LogCaptureFixture, +) -> None: + cfg = AppConfig() + backend = GeminiBackend( + httpx.AsyncClient(), cfg, translation_service=TranslationService() + ) + payload: dict[str, Any] = {} + req = ChatRequest( + model="gemini-pro", + messages=_messages(), + repetition_penalty=1.05, + min_p=0.22, + ) + caplog.set_level("WARNING") + backend._apply_generation_config(payload, req) + gc = payload.get("generationConfig", {}) + assert "repetitionPenalty" not in gc + assert "minP" not in gc + assert "repetition_penalty" in caplog.text + assert "min_p" in caplog.text + + +def test_gemini_oauth_personal_builds_topk( + caplog: pytest.LogCaptureFixture, +) -> None: cfg = AppConfig() backend = GeminiOAuthPlanConnector( httpx.AsyncClient(), cfg, translation_service=TranslationService() @@ -178,14 +256,21 @@ class _Req: top_p = 0.55 top_k = 33 max_tokens = 777 + repetition_penalty = 1.02 + min_p = 0.18 + caplog.set_level("WARNING") gc = backend._build_generation_config(_Req()) assert gc["temperature"] == pytest.approx(0.22) assert gc["topP"] == pytest.approx(0.55) assert gc["topK"] == 33 + assert "repetition_penalty" in caplog.text + assert "min_p" in caplog.text -def test_gemini_cloud_project_builds_topk() -> None: +def test_gemini_cloud_project_builds_topk( + caplog: pytest.LogCaptureFixture, +) -> None: cfg = AppConfig() # Minimal init (project id may be None for this isolated helper test) backend = GeminiCloudProjectConnector( @@ -200,8 +285,13 @@ class _Req: top_p = 0.77 top_k = 21 max_tokens = 512 + repetition_penalty = 1.04 + min_p = 0.12 + caplog.set_level("WARNING") gc = backend._build_generation_config(_Req()) assert gc["temperature"] == pytest.approx(0.3) assert gc["topP"] == pytest.approx(0.77) assert gc["topK"] == 21 + assert "repetition_penalty" in caplog.text + assert "min_p" in caplog.text diff --git a/tests/unit/core/services/test_parameter_resolution_service.py b/tests/unit/core/services/test_parameter_resolution_service.py index b8138589..97247d1d 100644 --- a/tests/unit/core/services/test_parameter_resolution_service.py +++ b/tests/unit/core/services/test_parameter_resolution_service.py @@ -1,679 +1,737 @@ -"""Unit tests for parameter resolution service.""" - -import logging - -import pytest -from src.core.services.parameter_resolution_service import ( - ParameterResolutionService, - ParameterSource, - ResolvedParameters, -) - - -class TestParameterSource: - """Test cases for ParameterSource dataclass.""" - - def test_parameter_source_creation(self): - """Test creating a ParameterSource instance.""" - source = ParameterSource(value=0.5, source="uri") - - assert source.value == 0.5 - assert source.source == "uri" - - def test_parameter_source_repr(self): - """Test ParameterSource string representation.""" - source = ParameterSource(value=0.7, source="header") - repr_str = repr(source) - - assert "ParameterSource" in repr_str - assert "0.7" in repr_str - assert "header" in repr_str - - -class TestResolvedParameters: - """Test cases for ResolvedParameters dataclass.""" - - def test_resolved_parameters_creation_empty(self): - """Test creating an empty ResolvedParameters instance.""" - params = ResolvedParameters() - - assert params.temperature is None - assert params.reasoning_effort is None - assert params.top_p is None - assert params.top_k is None - - def test_resolved_parameters_creation_with_values(self): - """Test creating ResolvedParameters with values.""" - temp_source = ParameterSource(value=0.5, source="uri") - effort_source = ParameterSource(value="high", source="session") - top_p_source = ParameterSource(value=0.9, source="config") - top_k_source = ParameterSource(value=42, source="header") - - params = ResolvedParameters( - temperature=temp_source, - reasoning_effort=effort_source, - top_p=top_p_source, - top_k=top_k_source, - ) - - assert params.temperature == temp_source - assert params.reasoning_effort == effort_source - assert params.top_p == top_p_source - assert params.top_k == top_k_source - - def test_to_dict_empty(self): - """Test to_dict with no parameters.""" - params = ResolvedParameters() - result = params.to_dict() - - assert result == {} - - def test_to_dict_with_temperature_only(self): - """Test to_dict with only temperature.""" - params = ResolvedParameters(temperature=ParameterSource(0.5, "uri")) - result = params.to_dict() - - assert result == {"temperature": 0.5} - - def test_to_dict_with_reasoning_effort_only(self): - """Test to_dict with only reasoning_effort.""" - params = ResolvedParameters(reasoning_effort=ParameterSource("high", "session")) - result = params.to_dict() - - assert result == {"reasoning_effort": "high"} - - def test_to_dict_with_both_parameters(self): - """Test to_dict with both parameters.""" - params = ResolvedParameters( - temperature=ParameterSource(0.7, "header"), - reasoning_effort=ParameterSource("medium", "config"), - ) - result = params.to_dict() - - assert result == {"temperature": 0.7, "reasoning_effort": "medium"} - - def test_to_dict_with_top_parameters(self): - """Test to_dict with top_p and top_k parameters.""" - params = ResolvedParameters( - top_p=ParameterSource(0.92, "uri"), - top_k=ParameterSource(32, "session"), - ) - result = params.to_dict() - - assert result == {"top_p": 0.92, "top_k": 32} - - def test_get_debug_info_empty(self): - """Test get_debug_info with no parameters.""" - params = ResolvedParameters() - debug_info = params.get_debug_info() - - assert debug_info == {} - - def test_get_debug_info_with_temperature(self): - """Test get_debug_info with temperature.""" - params = ResolvedParameters(temperature=ParameterSource(0.5, "uri")) - debug_info = params.get_debug_info() - - assert "temperature" in debug_info - assert debug_info["temperature"]["effective_value"] == 0.5 - assert debug_info["temperature"]["source"] == "uri" - - def test_get_debug_info_with_reasoning_effort(self): - """Test get_debug_info with reasoning_effort.""" - params = ResolvedParameters(reasoning_effort=ParameterSource("high", "session")) - debug_info = params.get_debug_info() - - assert "reasoning_effort" in debug_info - assert debug_info["reasoning_effort"]["effective_value"] == "high" - assert debug_info["reasoning_effort"]["source"] == "session" - - def test_get_debug_info_with_both_parameters(self): - """Test get_debug_info with both parameters.""" - params = ResolvedParameters( - temperature=ParameterSource(0.8, "config"), - reasoning_effort=ParameterSource("low", "header"), - ) - debug_info = params.get_debug_info() - - assert len(debug_info) == 2 - assert debug_info["temperature"]["effective_value"] == 0.8 - assert debug_info["temperature"]["source"] == "config" - assert debug_info["reasoning_effort"]["effective_value"] == "low" - assert debug_info["reasoning_effort"]["source"] == "header" - - def test_get_debug_info_with_top_parameters(self): - """Test get_debug_info includes top_p and top_k entries.""" - params = ResolvedParameters( - top_p=ParameterSource(0.85, "uri"), - top_k=ParameterSource(16, "session"), - ) - debug_info = params.get_debug_info() - - assert "top_p" in debug_info - assert debug_info["top_p"]["effective_value"] == 0.85 - assert debug_info["top_p"]["source"] == "uri" - assert "top_k" in debug_info - assert debug_info["top_k"]["effective_value"] == 16 - assert debug_info["top_k"]["source"] == "session" - - -class TestParameterResolutionService: - """Test cases for ParameterResolutionService.""" - - @pytest.fixture - def service(self): - """Create a service instance for testing.""" - return ParameterResolutionService() - - # ======================================================================== - # Precedence Order Tests (session > uri > header > config) - # ======================================================================== - - def test_precedence_config_only(self, service): - """Test resolution with only config parameters.""" - result = service.resolve_parameters(config_params={"temperature": 0.8}) - - assert result.temperature is not None - assert result.temperature.value == 0.8 - assert result.temperature.source == "config" - - def test_precedence_header_overrides_config(self, service): - """Test that header parameters override config parameters.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, header_params={"temperature": 0.6} - ) - - assert result.temperature is not None - assert result.temperature.value == 0.6 - assert result.temperature.source == "header" - - def test_precedence_uri_overrides_header_and_config(self, service): - """Test that URI parameters override header and config parameters.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, - header_params={"temperature": 0.6}, - uri_params={"temperature": 0.4}, - ) - - assert result.temperature is not None - assert result.temperature.value == 0.4 - assert result.temperature.source == "uri" - - def test_precedence_session_overrides_all(self, service): - """Test that session parameters override all other sources.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, - header_params={"temperature": 0.6}, - uri_params={"temperature": 0.4}, - session_params={"temperature": 0.2}, - ) - - assert result.temperature is not None - assert result.temperature.value == 0.2 - assert result.temperature.source == "session" - - def test_precedence_reasoning_effort_config_only(self, service): - """Test reasoning_effort resolution with only config.""" - result = service.resolve_parameters(config_params={"reasoning_effort": "low"}) - - assert result.reasoning_effort is not None - assert result.reasoning_effort.value == "low" - assert result.reasoning_effort.source == "config" - - def test_precedence_reasoning_effort_header_overrides_config(self, service): - """Test that header reasoning_effort overrides config.""" - result = service.resolve_parameters( - config_params={"reasoning_effort": "low"}, - header_params={"reasoning_effort": "medium"}, - ) - - assert result.reasoning_effort is not None - assert result.reasoning_effort.value == "medium" - assert result.reasoning_effort.source == "header" - - def test_precedence_reasoning_effort_uri_overrides_header(self, service): - """Test that URI reasoning_effort overrides header and config.""" - result = service.resolve_parameters( - config_params={"reasoning_effort": "low"}, - header_params={"reasoning_effort": "medium"}, - uri_params={"reasoning_effort": "high"}, - ) - - assert result.reasoning_effort is not None - assert result.reasoning_effort.value == "high" - assert result.reasoning_effort.source == "uri" - - def test_precedence_reasoning_effort_session_overrides_all(self, service): - """Test that session reasoning_effort overrides all sources.""" - result = service.resolve_parameters( - config_params={"reasoning_effort": "low"}, - header_params={"reasoning_effort": "medium"}, - uri_params={"reasoning_effort": "high"}, - session_params={"reasoning_effort": "low"}, - ) - - assert result.reasoning_effort is not None - assert result.reasoning_effort.value == "low" - assert result.reasoning_effort.source == "session" - - def test_precedence_mixed_parameters_different_sources(self, service): - """Test precedence with different parameters from different sources.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8, "reasoning_effort": "low"}, - uri_params={"temperature": 0.5}, - session_params={"reasoning_effort": "high"}, - ) - - # Temperature from URI (overrides config) - assert result.temperature is not None - assert result.temperature.value == 0.5 - assert result.temperature.source == "uri" - - # Reasoning effort from session (overrides config) - assert result.reasoning_effort is not None - assert result.reasoning_effort.value == "high" - assert result.reasoning_effort.source == "session" - - def test_precedence_top_p_all_sources(self, service): - """Test precedence handling for top_p across all sources.""" - result = service.resolve_parameters( - config_params={"top_p": 0.2}, - header_params={"top_p": 0.4}, - uri_params={"top_p": 0.6}, - session_params={"top_p": 0.8}, - ) - - assert result.top_p is not None - assert result.top_p.value == 0.8 - assert result.top_p.source == "session" - - def test_precedence_top_k_uri_overrides(self, service): - """Test precedence for top_k where URI overrides config/header.""" - result = service.resolve_parameters( - config_params={"top_k": 16}, - header_params={"top_k": 24}, - uri_params={"top_k": 32}, - ) - - assert result.top_k is not None - assert result.top_k.value == 32 - assert result.top_k.source == "uri" - - # ======================================================================== - # Source Tracking Tests - # ======================================================================== - - def test_source_tracking_single_source(self, service): - """Test source tracking with a single source.""" - result = service.resolve_parameters(uri_params={"temperature": 0.5}) - - assert result.temperature is not None - assert result.temperature.source == "uri" - - def test_source_tracking_multiple_sources_temperature(self, service): - """Test source tracking for temperature from multiple sources.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, - header_params={"temperature": 0.6}, - uri_params={"temperature": 0.4}, - ) - - # Should track that URI is the effective source - assert result.temperature.source == "uri" - - def test_source_tracking_multiple_sources_reasoning_effort(self, service): - """Test source tracking for reasoning_effort from multiple sources.""" - result = service.resolve_parameters( - config_params={"reasoning_effort": "low"}, - session_params={"reasoning_effort": "high"}, - ) - - # Should track that session is the effective source - assert result.reasoning_effort.source == "session" - - def test_source_tracking_independent_parameters(self, service): - """Test that source tracking is independent for each parameter.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, - uri_params={"reasoning_effort": "medium"}, - ) - - assert result.temperature.source == "config" - assert result.reasoning_effort.source == "uri" - - def test_source_tracking_top_parameters(self, service): - """Test source tracking for top_p and top_k parameters.""" - result = service.resolve_parameters( - config_params={"top_p": 0.2, "top_k": 16}, - session_params={"top_k": 64}, - uri_params={"top_p": 0.9}, - ) - - assert result.top_p is not None - assert result.top_p.source == "uri" - assert result.top_p.value == 0.9 - assert result.top_k is not None - assert result.top_k.source == "session" - assert result.top_k.value == 64 - - # ======================================================================== - # Debug Output Tests - # ======================================================================== - - def test_debug_output_format_single_parameter(self, service, caplog): - """Test debug output format with a single parameter.""" - with caplog.at_level(logging.DEBUG): - _result = service.resolve_parameters( - uri_params={"temperature": 0.5}, backend="openai:gpt-4" - ) - - assert "Parameter resolution for openai:gpt-4" in caplog.text - assert "temperature: 0.5" in caplog.text - assert "source: uri" in caplog.text - - def test_debug_output_format_multiple_parameters(self, service, caplog): - """Test debug output format with multiple parameters.""" - with caplog.at_level(logging.DEBUG): - _result = service.resolve_parameters( - uri_params={"temperature": 0.5, "reasoning_effort": "high"}, - backend="anthropic:claude", - ) - - assert "Parameter resolution for anthropic:claude" in caplog.text - assert "temperature: 0.5" in caplog.text - assert "reasoning_effort: high" in caplog.text - - def test_debug_output_includes_top_parameters(self, service, caplog): - """Test debug logging includes top_p and top_k values.""" - with caplog.at_level(logging.DEBUG): - _result = service.resolve_parameters( - uri_params={"top_p": 0.9}, - header_params={"top_k": 24}, - backend="test:debug", - ) - - assert "Parameter resolution for test:debug" in caplog.text - assert "top_p: 0.9" in caplog.text - assert "top_k: 24" in caplog.text - - def test_debug_output_shows_overridden_sources(self, service, caplog): - """Test that debug output shows overridden sources.""" - with caplog.at_level(logging.DEBUG): - _result = service.resolve_parameters( - config_params={"temperature": 0.8}, - header_params={"temperature": 0.6}, - uri_params={"temperature": 0.4}, - backend="test:model", - ) - - assert "temperature: 0.4" in caplog.text - assert "source: uri" in caplog.text - assert "overrode:" in caplog.text - assert "config=0.8" in caplog.text - assert "header=0.6" in caplog.text - - def test_debug_output_no_overrides(self, service, caplog): - """Test debug output when there are no overrides.""" - with caplog.at_level(logging.DEBUG): - _result = service.resolve_parameters( - uri_params={"temperature": 0.5}, backend="test:model" - ) - - assert "temperature: 0.5" in caplog.text - assert "source: uri" in caplog.text - # Should not contain "overrode:" when there are no overrides - log_lines = [line for line in caplog.text.split("\n") if "temperature" in line] - assert any( - "source: uri" in line and "overrode:" not in line for line in log_lines - ) - - def test_debug_output_empty_parameters(self, service, caplog): - """Test that no debug output is generated for empty parameters.""" - with caplog.at_level(logging.DEBUG): - _result = service.resolve_parameters(backend="test:model") - - # Should not log anything when no parameters are resolved - assert "Parameter resolution" not in caplog.text - - def test_debug_info_structure(self, service): - """Test the structure of debug info returned by get_debug_info.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, - uri_params={"temperature": 0.5, "reasoning_effort": "high"}, - ) - - debug_info = result.get_debug_info() - - assert "temperature" in debug_info - assert "reasoning_effort" in debug_info - assert debug_info["temperature"]["effective_value"] == 0.5 - assert debug_info["temperature"]["source"] == "uri" - assert debug_info["reasoning_effort"]["effective_value"] == "high" - assert debug_info["reasoning_effort"]["source"] == "uri" - - # ======================================================================== - # Missing Sources Tests - # ======================================================================== - - def test_missing_all_sources(self, service): - """Test resolution when all sources are missing.""" - result = service.resolve_parameters() - - assert result.temperature is None - assert result.reasoning_effort is None - assert result.top_p is None - assert result.top_k is None - - def test_missing_session_params(self, service): - """Test resolution when session params are missing.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, - header_params={"temperature": 0.6}, - uri_params={"temperature": 0.4}, - ) - - # Should still resolve correctly without session params - assert result.temperature.value == 0.4 - assert result.temperature.source == "uri" - - def test_missing_uri_params(self, service): - """Test resolution when URI params are missing.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, - header_params={"temperature": 0.6}, - session_params={"temperature": 0.2}, - ) - - # Should still resolve correctly without URI params - assert result.temperature.value == 0.2 - assert result.temperature.source == "session" - - def test_missing_header_params(self, service): - """Test resolution when header params are missing.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, - uri_params={"temperature": 0.4}, - session_params={"temperature": 0.2}, - ) - - # Should still resolve correctly without header params - assert result.temperature.value == 0.2 - assert result.temperature.source == "session" - - def test_missing_config_params(self, service): - """Test resolution when config params are missing.""" - result = service.resolve_parameters( - header_params={"temperature": 0.6}, - uri_params={"temperature": 0.4}, - session_params={"temperature": 0.2}, - ) - - # Should still resolve correctly without config params - assert result.temperature.value == 0.2 - assert result.temperature.source == "session" - - def test_missing_multiple_sources(self, service): - """Test resolution when multiple sources are missing.""" - result = service.resolve_parameters(uri_params={"temperature": 0.5}) - - # Should resolve with only URI params - assert result.temperature.value == 0.5 - assert result.temperature.source == "uri" - assert result.reasoning_effort is None - - def test_partial_parameters_across_sources(self, service): - """Test resolution with partial parameters from different sources.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, - uri_params={"reasoning_effort": "high"}, - ) - - assert result.temperature.value == 0.8 - assert result.temperature.source == "config" - assert result.reasoning_effort.value == "high" - assert result.reasoning_effort.source == "uri" - - def test_none_values_treated_as_missing(self, service): - """Test that None values in parameter dicts are treated as missing.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, - uri_params={"temperature": None}, - ) - - # None in URI params should not override config - assert result.temperature.value == 0.8 - assert result.temperature.source == "config" - - # ======================================================================== - # Edge Cases and Special Scenarios - # ======================================================================== - - def test_empty_dict_sources(self, service): - """Test resolution with empty dict sources.""" - result = service.resolve_parameters( - config_params={}, header_params={}, uri_params={}, session_params={} - ) - - assert result.temperature is None - assert result.reasoning_effort is None - - def test_backend_parameter_in_logging(self, service, caplog): - """Test that backend parameter is used in logging.""" - with caplog.at_level(logging.DEBUG): - _result = service.resolve_parameters( - uri_params={"temperature": 0.5}, backend="custom:backend:model" - ) - - assert "custom:backend:model" in caplog.text - - def test_empty_backend_string(self, service, caplog): - """Test resolution with empty backend string.""" - with caplog.at_level(logging.DEBUG): - result = service.resolve_parameters( - uri_params={"temperature": 0.5}, backend="" - ) - - # Should still work, just with empty backend in logs - assert result.temperature.value == 0.5 - - def test_parameter_value_types_preserved(self, service): - """Test that parameter value types are preserved through resolution.""" - result = service.resolve_parameters( - uri_params={"temperature": 0.5, "reasoning_effort": "high"} - ) - - assert isinstance(result.temperature.value, float) - assert isinstance(result.reasoning_effort.value, str) - - def test_resolution_with_all_sources_different_params(self, service): - """Test resolution when each source provides different parameters.""" - result = service.resolve_parameters( - config_params={"temperature": 0.8}, - header_params={"reasoning_effort": "low"}, - uri_params={}, - session_params={}, - ) - - assert result.temperature.value == 0.8 - assert result.temperature.source == "config" - assert result.reasoning_effort.value == "low" - assert result.reasoning_effort.source == "header" - - def test_override_tracking_all_sources(self, service, caplog): - """Test that all overridden sources are tracked in debug output.""" - with caplog.at_level(logging.DEBUG): - _result = service.resolve_parameters( - config_params={"temperature": 0.1}, - header_params={"temperature": 0.3}, - uri_params={"temperature": 0.5}, - session_params={"temperature": 0.7}, - backend="test:model", - ) - - # Session should be effective, and should show all overridden sources - assert "temperature: 0.7" in caplog.text - assert "source: session" in caplog.text - assert "config=0.1" in caplog.text - assert "header=0.3" in caplog.text - assert "uri=0.5" in caplog.text - - def test_to_dict_excludes_none_values(self, service): - """Test that to_dict excludes None values.""" - result = service.resolve_parameters(uri_params={"temperature": 0.5}) - - result_dict = result.to_dict() - - assert "temperature" in result_dict - assert "reasoning_effort" not in result_dict - - def test_supported_parameters_constant(self, service): - """Test that SUPPORTED_PARAMETERS constant is defined correctly.""" - assert hasattr(service, "SUPPORTED_PARAMETERS") - assert "temperature" in service.SUPPORTED_PARAMETERS - assert "reasoning_effort" in service.SUPPORTED_PARAMETERS - assert "top_p" in service.SUPPORTED_PARAMETERS - assert "top_k" in service.SUPPORTED_PARAMETERS - assert len(service.SUPPORTED_PARAMETERS) == 4 - - # ======================================================================== - # Integration-like Tests - # ======================================================================== - - def test_realistic_scenario_uri_overrides(self, service): - """Test realistic scenario where URI params override config.""" - result = service.resolve_parameters( - config_params={"temperature": 0.7, "reasoning_effort": "medium"}, - uri_params={"temperature": 0.9}, - backend="openai:gpt-4", - ) - - assert result.temperature.value == 0.9 - assert result.temperature.source == "uri" - assert result.reasoning_effort.value == "medium" - assert result.reasoning_effort.source == "config" - - def test_realistic_scenario_session_commands(self, service): - """Test realistic scenario with session commands overriding everything.""" - result = service.resolve_parameters( - config_params={"temperature": 0.7}, - header_params={"temperature": 0.8}, - uri_params={"temperature": 0.9}, - session_params={"temperature": 0.5}, - backend="anthropic:claude-3", - ) - - assert result.temperature.value == 0.5 - assert result.temperature.source == "session" - - def test_realistic_scenario_no_overrides(self, service): - """Test realistic scenario where each source provides unique parameters.""" - result = service.resolve_parameters( - config_params={"reasoning_effort": "low"}, - uri_params={"temperature": 0.6}, - backend="gemini:pro", - ) - - assert result.temperature.value == 0.6 - assert result.temperature.source == "uri" - assert result.reasoning_effort.value == "low" - assert result.reasoning_effort.source == "config" +"""Unit tests for parameter resolution service.""" + +import logging + +import pytest +from src.core.services.parameter_resolution_service import ( + ParameterResolutionService, + ParameterSource, + ResolvedParameters, +) + + +class TestParameterSource: + """Test cases for ParameterSource dataclass.""" + + def test_parameter_source_creation(self): + """Test creating a ParameterSource instance.""" + source = ParameterSource(value=0.5, source="uri") + + assert source.value == 0.5 + assert source.source == "uri" + + def test_parameter_source_repr(self): + """Test ParameterSource string representation.""" + source = ParameterSource(value=0.7, source="header") + repr_str = repr(source) + + assert "ParameterSource" in repr_str + assert "0.7" in repr_str + assert "header" in repr_str + + +class TestResolvedParameters: + """Test cases for ResolvedParameters dataclass.""" + + def test_resolved_parameters_creation_empty(self): + """Test creating an empty ResolvedParameters instance.""" + params = ResolvedParameters() + + assert params.temperature is None + assert params.reasoning_effort is None + assert params.top_p is None + assert params.top_k is None + assert params.repetition_penalty is None + assert params.min_p is None + + def test_resolved_parameters_creation_with_values(self): + """Test creating ResolvedParameters with values.""" + temp_source = ParameterSource(value=0.5, source="uri") + effort_source = ParameterSource(value="high", source="session") + top_p_source = ParameterSource(value=0.9, source="config") + top_k_source = ParameterSource(value=42, source="header") + + repetition_source = ParameterSource(value=1.1, source="uri") + min_p_source = ParameterSource(value=0.05, source="session") + + params = ResolvedParameters( + temperature=temp_source, + reasoning_effort=effort_source, + top_p=top_p_source, + top_k=top_k_source, + repetition_penalty=repetition_source, + min_p=min_p_source, + ) + + assert params.temperature == temp_source + assert params.reasoning_effort == effort_source + assert params.top_p == top_p_source + assert params.top_k == top_k_source + assert params.repetition_penalty == repetition_source + assert params.min_p == min_p_source + + def test_to_dict_empty(self): + """Test to_dict with no parameters.""" + params = ResolvedParameters() + result = params.to_dict() + + assert result == {} + + def test_to_dict_with_temperature_only(self): + """Test to_dict with only temperature.""" + params = ResolvedParameters(temperature=ParameterSource(0.5, "uri")) + result = params.to_dict() + + assert result == {"temperature": 0.5} + + def test_to_dict_with_reasoning_effort_only(self): + """Test to_dict with only reasoning_effort.""" + params = ResolvedParameters(reasoning_effort=ParameterSource("high", "session")) + result = params.to_dict() + + assert result == {"reasoning_effort": "high"} + + def test_to_dict_with_both_parameters(self): + """Test to_dict with both parameters.""" + params = ResolvedParameters( + temperature=ParameterSource(0.7, "header"), + reasoning_effort=ParameterSource("medium", "config"), + ) + result = params.to_dict() + + assert result == {"temperature": 0.7, "reasoning_effort": "medium"} + + def test_to_dict_with_top_parameters(self): + """Test to_dict with top_p and top_k parameters.""" + params = ResolvedParameters( + top_p=ParameterSource(0.92, "uri"), + top_k=ParameterSource(32, "session"), + ) + result = params.to_dict() + + assert result == {"top_p": 0.92, "top_k": 32} + + def test_to_dict_with_penalty_parameters(self): + """Test to_dict with repetition_penalty and min_p.""" + params = ResolvedParameters( + repetition_penalty=ParameterSource(1.2, "uri"), + min_p=ParameterSource(0.07, "session"), + ) + result = params.to_dict() + + assert result == {"repetition_penalty": 1.2, "min_p": 0.07} + + def test_get_debug_info_empty(self): + """Test get_debug_info with no parameters.""" + params = ResolvedParameters() + debug_info = params.get_debug_info() + + assert debug_info == {} + + def test_get_debug_info_with_temperature(self): + """Test get_debug_info with temperature.""" + params = ResolvedParameters(temperature=ParameterSource(0.5, "uri")) + debug_info = params.get_debug_info() + + assert "temperature" in debug_info + assert debug_info["temperature"]["effective_value"] == 0.5 + assert debug_info["temperature"]["source"] == "uri" + + def test_get_debug_info_with_reasoning_effort(self): + """Test get_debug_info with reasoning_effort.""" + params = ResolvedParameters(reasoning_effort=ParameterSource("high", "session")) + debug_info = params.get_debug_info() + + assert "reasoning_effort" in debug_info + assert debug_info["reasoning_effort"]["effective_value"] == "high" + assert debug_info["reasoning_effort"]["source"] == "session" + + def test_get_debug_info_with_both_parameters(self): + """Test get_debug_info with both parameters.""" + params = ResolvedParameters( + temperature=ParameterSource(0.8, "config"), + reasoning_effort=ParameterSource("low", "header"), + ) + debug_info = params.get_debug_info() + + assert len(debug_info) == 2 + assert debug_info["temperature"]["effective_value"] == 0.8 + assert debug_info["temperature"]["source"] == "config" + assert debug_info["reasoning_effort"]["effective_value"] == "low" + assert debug_info["reasoning_effort"]["source"] == "header" + + def test_get_debug_info_with_top_parameters(self): + """Test get_debug_info includes top_p and top_k entries.""" + params = ResolvedParameters( + top_p=ParameterSource(0.85, "uri"), + top_k=ParameterSource(16, "session"), + ) + debug_info = params.get_debug_info() + + assert "top_p" in debug_info + assert debug_info["top_p"]["effective_value"] == 0.85 + assert debug_info["top_p"]["source"] == "uri" + assert "top_k" in debug_info + assert debug_info["top_k"]["effective_value"] == 16 + assert debug_info["top_k"]["source"] == "session" + + def test_get_debug_info_with_penalty_parameters(self): + """Test get_debug_info includes penalty parameters.""" + params = ResolvedParameters( + repetition_penalty=ParameterSource(1.25, "config"), + min_p=ParameterSource(0.09, "header"), + ) + debug_info = params.get_debug_info() + + assert "repetition_penalty" in debug_info + assert debug_info["repetition_penalty"]["effective_value"] == 1.25 + assert debug_info["repetition_penalty"]["source"] == "config" + assert "min_p" in debug_info + assert debug_info["min_p"]["effective_value"] == 0.09 + assert debug_info["min_p"]["source"] == "header" + + +class TestParameterResolutionService: + """Test cases for ParameterResolutionService.""" + + @pytest.fixture + def service(self): + """Create a service instance for testing.""" + return ParameterResolutionService() + + # ======================================================================== + # Precedence Order Tests (session > uri > header > config) + # ======================================================================== + + def test_precedence_config_only(self, service): + """Test resolution with only config parameters.""" + result = service.resolve_parameters(config_params={"temperature": 0.8}) + + assert result.temperature is not None + assert result.temperature.value == 0.8 + assert result.temperature.source == "config" + + def test_precedence_header_overrides_config(self, service): + """Test that header parameters override config parameters.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, header_params={"temperature": 0.6} + ) + + assert result.temperature is not None + assert result.temperature.value == 0.6 + assert result.temperature.source == "header" + + def test_precedence_uri_overrides_header_and_config(self, service): + """Test that URI parameters override header and config parameters.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, + header_params={"temperature": 0.6}, + uri_params={"temperature": 0.4}, + ) + + assert result.temperature is not None + assert result.temperature.value == 0.4 + assert result.temperature.source == "uri" + + def test_precedence_session_overrides_all(self, service): + """Test that session parameters override all other sources.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, + header_params={"temperature": 0.6}, + uri_params={"temperature": 0.4}, + session_params={"temperature": 0.2}, + ) + + assert result.temperature is not None + assert result.temperature.value == 0.2 + assert result.temperature.source == "session" + + def test_precedence_reasoning_effort_config_only(self, service): + """Test reasoning_effort resolution with only config.""" + result = service.resolve_parameters(config_params={"reasoning_effort": "low"}) + + assert result.reasoning_effort is not None + assert result.reasoning_effort.value == "low" + assert result.reasoning_effort.source == "config" + + def test_precedence_reasoning_effort_header_overrides_config(self, service): + """Test that header reasoning_effort overrides config.""" + result = service.resolve_parameters( + config_params={"reasoning_effort": "low"}, + header_params={"reasoning_effort": "medium"}, + ) + + assert result.reasoning_effort is not None + assert result.reasoning_effort.value == "medium" + assert result.reasoning_effort.source == "header" + + def test_precedence_reasoning_effort_uri_overrides_header(self, service): + """Test that URI reasoning_effort overrides header and config.""" + result = service.resolve_parameters( + config_params={"reasoning_effort": "low"}, + header_params={"reasoning_effort": "medium"}, + uri_params={"reasoning_effort": "high"}, + ) + + assert result.reasoning_effort is not None + assert result.reasoning_effort.value == "high" + assert result.reasoning_effort.source == "uri" + + def test_precedence_reasoning_effort_session_overrides_all(self, service): + """Test that session reasoning_effort overrides all sources.""" + result = service.resolve_parameters( + config_params={"reasoning_effort": "low"}, + header_params={"reasoning_effort": "medium"}, + uri_params={"reasoning_effort": "high"}, + session_params={"reasoning_effort": "low"}, + ) + + assert result.reasoning_effort is not None + assert result.reasoning_effort.value == "low" + assert result.reasoning_effort.source == "session" + + def test_precedence_repetition_penalty_uri(self, service): + """Test repetition_penalty resolution from URI.""" + result = service.resolve_parameters( + config_params={"repetition_penalty": 1.05}, + uri_params={"repetition_penalty": 1.2}, + ) + + assert result.repetition_penalty is not None + assert result.repetition_penalty.value == 1.2 + assert result.repetition_penalty.source == "uri" + + def test_precedence_min_p_session(self, service): + """Test min_p resolution with session override.""" + result = service.resolve_parameters( + config_params={"min_p": 0.1}, + session_params={"min_p": 0.2}, + ) + + assert result.min_p is not None + assert result.min_p.value == 0.2 + assert result.min_p.source == "session" + + def test_precedence_mixed_parameters_different_sources(self, service): + """Test precedence with different parameters from different sources.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8, "reasoning_effort": "low"}, + uri_params={"temperature": 0.5}, + session_params={"reasoning_effort": "high"}, + ) + + # Temperature from URI (overrides config) + assert result.temperature is not None + assert result.temperature.value == 0.5 + assert result.temperature.source == "uri" + + # Reasoning effort from session (overrides config) + assert result.reasoning_effort is not None + assert result.reasoning_effort.value == "high" + assert result.reasoning_effort.source == "session" + + def test_precedence_top_p_all_sources(self, service): + """Test precedence handling for top_p across all sources.""" + result = service.resolve_parameters( + config_params={"top_p": 0.2}, + header_params={"top_p": 0.4}, + uri_params={"top_p": 0.6}, + session_params={"top_p": 0.8}, + ) + + assert result.top_p is not None + assert result.top_p.value == 0.8 + assert result.top_p.source == "session" + + def test_precedence_top_k_uri_overrides(self, service): + """Test precedence for top_k where URI overrides config/header.""" + result = service.resolve_parameters( + config_params={"top_k": 16}, + header_params={"top_k": 24}, + uri_params={"top_k": 32}, + ) + + assert result.top_k is not None + assert result.top_k.value == 32 + assert result.top_k.source == "uri" + + # ======================================================================== + # Source Tracking Tests + # ======================================================================== + + def test_source_tracking_single_source(self, service): + """Test source tracking with a single source.""" + result = service.resolve_parameters(uri_params={"temperature": 0.5}) + + assert result.temperature is not None + assert result.temperature.source == "uri" + + def test_source_tracking_multiple_sources_temperature(self, service): + """Test source tracking for temperature from multiple sources.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, + header_params={"temperature": 0.6}, + uri_params={"temperature": 0.4}, + ) + + # Should track that URI is the effective source + assert result.temperature.source == "uri" + + def test_source_tracking_multiple_sources_reasoning_effort(self, service): + """Test source tracking for reasoning_effort from multiple sources.""" + result = service.resolve_parameters( + config_params={"reasoning_effort": "low"}, + session_params={"reasoning_effort": "high"}, + ) + + # Should track that session is the effective source + assert result.reasoning_effort.source == "session" + + def test_source_tracking_independent_parameters(self, service): + """Test that source tracking is independent for each parameter.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, + uri_params={"reasoning_effort": "medium"}, + ) + + assert result.temperature.source == "config" + assert result.reasoning_effort.source == "uri" + + def test_source_tracking_top_parameters(self, service): + """Test source tracking for top_p and top_k parameters.""" + result = service.resolve_parameters( + config_params={"top_p": 0.2, "top_k": 16}, + session_params={"top_k": 64}, + uri_params={"top_p": 0.9}, + ) + + assert result.top_p is not None + assert result.top_p.source == "uri" + assert result.top_p.value == 0.9 + assert result.top_k is not None + assert result.top_k.source == "session" + assert result.top_k.value == 64 + + # ======================================================================== + # Debug Output Tests + # ======================================================================== + + def test_debug_output_format_single_parameter(self, service, caplog): + """Test debug output format with a single parameter.""" + with caplog.at_level(logging.DEBUG): + _result = service.resolve_parameters( + uri_params={"temperature": 0.5}, backend="openai:gpt-4" + ) + + assert "Parameter resolution for openai:gpt-4" in caplog.text + assert "temperature: 0.5" in caplog.text + assert "source: uri" in caplog.text + + def test_debug_output_format_multiple_parameters(self, service, caplog): + """Test debug output format with multiple parameters.""" + with caplog.at_level(logging.DEBUG): + _result = service.resolve_parameters( + uri_params={"temperature": 0.5, "reasoning_effort": "high"}, + backend="anthropic:claude", + ) + + assert "Parameter resolution for anthropic:claude" in caplog.text + assert "temperature: 0.5" in caplog.text + assert "reasoning_effort: high" in caplog.text + + def test_debug_output_includes_top_parameters(self, service, caplog): + """Test debug logging includes top_p and top_k values.""" + with caplog.at_level(logging.DEBUG): + _result = service.resolve_parameters( + uri_params={"top_p": 0.9}, + header_params={"top_k": 24}, + backend="test:debug", + ) + + assert "Parameter resolution for test:debug" in caplog.text + assert "top_p: 0.9" in caplog.text + assert "top_k: 24" in caplog.text + + def test_debug_output_shows_overridden_sources(self, service, caplog): + """Test that debug output shows overridden sources.""" + with caplog.at_level(logging.DEBUG): + _result = service.resolve_parameters( + config_params={"temperature": 0.8}, + header_params={"temperature": 0.6}, + uri_params={"temperature": 0.4}, + backend="test:model", + ) + + assert "temperature: 0.4" in caplog.text + assert "source: uri" in caplog.text + assert "overrode:" in caplog.text + assert "config=0.8" in caplog.text + assert "header=0.6" in caplog.text + + def test_debug_output_no_overrides(self, service, caplog): + """Test debug output when there are no overrides.""" + with caplog.at_level(logging.DEBUG): + _result = service.resolve_parameters( + uri_params={"temperature": 0.5}, backend="test:model" + ) + + assert "temperature: 0.5" in caplog.text + assert "source: uri" in caplog.text + # Should not contain "overrode:" when there are no overrides + log_lines = [line for line in caplog.text.split("\n") if "temperature" in line] + assert any( + "source: uri" in line and "overrode:" not in line for line in log_lines + ) + + def test_debug_output_empty_parameters(self, service, caplog): + """Test that no debug output is generated for empty parameters.""" + with caplog.at_level(logging.DEBUG): + _result = service.resolve_parameters(backend="test:model") + + # Should not log anything when no parameters are resolved + assert "Parameter resolution" not in caplog.text + + def test_debug_info_structure(self, service): + """Test the structure of debug info returned by get_debug_info.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, + uri_params={"temperature": 0.5, "reasoning_effort": "high"}, + ) + + debug_info = result.get_debug_info() + + assert "temperature" in debug_info + assert "reasoning_effort" in debug_info + assert debug_info["temperature"]["effective_value"] == 0.5 + assert debug_info["temperature"]["source"] == "uri" + assert debug_info["reasoning_effort"]["effective_value"] == "high" + assert debug_info["reasoning_effort"]["source"] == "uri" + + # ======================================================================== + # Missing Sources Tests + # ======================================================================== + + def test_missing_all_sources(self, service): + """Test resolution when all sources are missing.""" + result = service.resolve_parameters() + + assert result.temperature is None + assert result.reasoning_effort is None + assert result.top_p is None + assert result.top_k is None + + def test_missing_session_params(self, service): + """Test resolution when session params are missing.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, + header_params={"temperature": 0.6}, + uri_params={"temperature": 0.4}, + ) + + # Should still resolve correctly without session params + assert result.temperature.value == 0.4 + assert result.temperature.source == "uri" + + def test_missing_uri_params(self, service): + """Test resolution when URI params are missing.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, + header_params={"temperature": 0.6}, + session_params={"temperature": 0.2}, + ) + + # Should still resolve correctly without URI params + assert result.temperature.value == 0.2 + assert result.temperature.source == "session" + + def test_missing_header_params(self, service): + """Test resolution when header params are missing.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, + uri_params={"temperature": 0.4}, + session_params={"temperature": 0.2}, + ) + + # Should still resolve correctly without header params + assert result.temperature.value == 0.2 + assert result.temperature.source == "session" + + def test_missing_config_params(self, service): + """Test resolution when config params are missing.""" + result = service.resolve_parameters( + header_params={"temperature": 0.6}, + uri_params={"temperature": 0.4}, + session_params={"temperature": 0.2}, + ) + + # Should still resolve correctly without config params + assert result.temperature.value == 0.2 + assert result.temperature.source == "session" + + def test_missing_multiple_sources(self, service): + """Test resolution when multiple sources are missing.""" + result = service.resolve_parameters(uri_params={"temperature": 0.5}) + + # Should resolve with only URI params + assert result.temperature.value == 0.5 + assert result.temperature.source == "uri" + assert result.reasoning_effort is None + + def test_partial_parameters_across_sources(self, service): + """Test resolution with partial parameters from different sources.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, + uri_params={"reasoning_effort": "high"}, + ) + + assert result.temperature.value == 0.8 + assert result.temperature.source == "config" + assert result.reasoning_effort.value == "high" + assert result.reasoning_effort.source == "uri" + + def test_none_values_treated_as_missing(self, service): + """Test that None values in parameter dicts are treated as missing.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, + uri_params={"temperature": None}, + ) + + # None in URI params should not override config + assert result.temperature.value == 0.8 + assert result.temperature.source == "config" + + # ======================================================================== + # Edge Cases and Special Scenarios + # ======================================================================== + + def test_empty_dict_sources(self, service): + """Test resolution with empty dict sources.""" + result = service.resolve_parameters( + config_params={}, header_params={}, uri_params={}, session_params={} + ) + + assert result.temperature is None + assert result.reasoning_effort is None + + def test_backend_parameter_in_logging(self, service, caplog): + """Test that backend parameter is used in logging.""" + with caplog.at_level(logging.DEBUG): + _result = service.resolve_parameters( + uri_params={"temperature": 0.5}, backend="custom:backend:model" + ) + + assert "custom:backend:model" in caplog.text + + def test_empty_backend_string(self, service, caplog): + """Test resolution with empty backend string.""" + with caplog.at_level(logging.DEBUG): + result = service.resolve_parameters( + uri_params={"temperature": 0.5}, backend="" + ) + + # Should still work, just with empty backend in logs + assert result.temperature.value == 0.5 + + def test_parameter_value_types_preserved(self, service): + """Test that parameter value types are preserved through resolution.""" + result = service.resolve_parameters( + uri_params={"temperature": 0.5, "reasoning_effort": "high"} + ) + + assert isinstance(result.temperature.value, float) + assert isinstance(result.reasoning_effort.value, str) + + def test_resolution_with_all_sources_different_params(self, service): + """Test resolution when each source provides different parameters.""" + result = service.resolve_parameters( + config_params={"temperature": 0.8}, + header_params={"reasoning_effort": "low"}, + uri_params={}, + session_params={}, + ) + + assert result.temperature.value == 0.8 + assert result.temperature.source == "config" + assert result.reasoning_effort.value == "low" + assert result.reasoning_effort.source == "header" + + def test_override_tracking_all_sources(self, service, caplog): + """Test that all overridden sources are tracked in debug output.""" + with caplog.at_level(logging.DEBUG): + _result = service.resolve_parameters( + config_params={"temperature": 0.1}, + header_params={"temperature": 0.3}, + uri_params={"temperature": 0.5}, + session_params={"temperature": 0.7}, + backend="test:model", + ) + + # Session should be effective, and should show all overridden sources + assert "temperature: 0.7" in caplog.text + assert "source: session" in caplog.text + assert "config=0.1" in caplog.text + assert "header=0.3" in caplog.text + assert "uri=0.5" in caplog.text + + def test_to_dict_excludes_none_values(self, service): + """Test that to_dict excludes None values.""" + result = service.resolve_parameters(uri_params={"temperature": 0.5}) + + result_dict = result.to_dict() + + assert "temperature" in result_dict + assert "reasoning_effort" not in result_dict + + def test_supported_parameters_constant(self, service): + """Test that SUPPORTED_PARAMETERS constant is defined correctly.""" + assert hasattr(service, "SUPPORTED_PARAMETERS") + assert "temperature" in service.SUPPORTED_PARAMETERS + assert "reasoning_effort" in service.SUPPORTED_PARAMETERS + assert "top_p" in service.SUPPORTED_PARAMETERS + assert "top_k" in service.SUPPORTED_PARAMETERS + assert "repetition_penalty" in service.SUPPORTED_PARAMETERS + assert "min_p" in service.SUPPORTED_PARAMETERS + assert len(service.SUPPORTED_PARAMETERS) == 6 + + # ======================================================================== + # Integration-like Tests + # ======================================================================== + + def test_realistic_scenario_uri_overrides(self, service): + """Test realistic scenario where URI params override config.""" + result = service.resolve_parameters( + config_params={"temperature": 0.7, "reasoning_effort": "medium"}, + uri_params={"temperature": 0.9}, + backend="openai:gpt-4", + ) + + assert result.temperature.value == 0.9 + assert result.temperature.source == "uri" + assert result.reasoning_effort.value == "medium" + assert result.reasoning_effort.source == "config" + + def test_realistic_scenario_session_commands(self, service): + """Test realistic scenario with session commands overriding everything.""" + result = service.resolve_parameters( + config_params={"temperature": 0.7}, + header_params={"temperature": 0.8}, + uri_params={"temperature": 0.9}, + session_params={"temperature": 0.5}, + backend="anthropic:claude-3", + ) + + assert result.temperature.value == 0.5 + assert result.temperature.source == "session" + + def test_realistic_scenario_no_overrides(self, service): + """Test realistic scenario where each source provides unique parameters.""" + result = service.resolve_parameters( + config_params={"reasoning_effort": "low"}, + uri_params={"temperature": 0.6}, + backend="gemini:pro", + ) + + assert result.temperature.value == 0.6 + assert result.temperature.source == "uri" + assert result.reasoning_effort.value == "low" + assert result.reasoning_effort.source == "config"