diff --git a/src/core/config/app_config.py b/src/core/config/app_config.py index 4131d6400..6094967eb 100644 --- a/src/core/config/app_config.py +++ b/src/core/config/app_config.py @@ -1,2007 +1,2009 @@ -from __future__ import annotations - -import json -import logging -import os -from collections.abc import Callable, Mapping -from enum import Enum -from pathlib import Path -from typing import Any, cast - -from pydantic import ConfigDict, Field, field_validator, model_validator - -from src.core.config.parameter_resolution import ParameterResolution, ParameterSource - - -def get_openrouter_headers(cfg: dict[str, str], api_key: str) -> dict[str, str]: - """Construct headers for OpenRouter requests. - - Be tolerant of minimal cfg dicts provided by tests by falling back to - sensible defaults when optional keys are absent. - """ - referer: str = cfg.get("app_site_url", "http://localhost:8000") - x_title: str = cfg.get("app_x_title", "InterceptorProxy") - return { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "HTTP-Referer": referer, - "X-Title": x_title, - } - - -def _collect_api_keys_from_env( - base_name: str, - env: Mapping[str, str], - resolution: ParameterResolution | None = None, -) -> dict[str, str]: - """Collect API keys from environment variables with parameter resolution tracking.""" - single_key = env.get(base_name) - numbered_keys: dict[str, str] = {} - numbered_key_names = [] - for i in range(1, 21): - key_name = f"{base_name}_{i}" - key = env.get(key_name) - if key: - numbered_keys[key_name] = key - numbered_key_names.append(key_name) - - if single_key and numbered_keys: - logger.warning( - "Both %s and %s_ environment variables are set. Prioritizing %s_ and ignoring %s.", - base_name, - base_name, - base_name, - base_name, - ) - if resolution is not None: - resolution.record( - f"backends.{base_name.lower().replace('_', '')}.api_key", - list(numbered_keys.values()), - ParameterSource.ENVIRONMENT, - origin=",".join(numbered_key_names), - ) - return numbered_keys - - if single_key: - result = {base_name: single_key} - if resolution is not None: - resolution.record( - f"backends.{base_name.lower().replace('_', '')}.api_key", - list(result.values()), - ParameterSource.ENVIRONMENT, - origin=base_name, - ) - return result - - if resolution is not None and numbered_keys: - resolution.record( - f"backends.{base_name.lower().replace('_', '')}.api_key", - list(numbered_keys.values()), - ParameterSource.ENVIRONMENT, - origin=",".join(numbered_key_names), - ) - return numbered_keys - - -from src.core.domain.configuration.app_identity_config import AppIdentityConfig -from src.core.domain.configuration.assessment_config import AssessmentConfig -from src.core.domain.configuration.header_config import ( - HeaderConfig, - HeaderOverrideMode, -) -from src.core.domain.configuration.reasoning_aliases_config import ( - ReasoningAliasesConfig, -) -from src.core.domain.configuration.sandboxing_config import SandboxingConfiguration -from src.core.interfaces.configuration_interface import IConfig -from src.core.interfaces.model_bases import DomainModel - -# Note: Avoid self-imports to prevent circular dependencies. Classes are defined below. - -logger = logging.getLogger(__name__) - - -def _process_api_keys(keys_string: str) -> list[str]: - """Process a comma-separated string of API keys.""" - keys = keys_string.split(",") - result: list[str] = [] - for key in keys: - stripped_key = key.strip() - if stripped_key: - result.append(stripped_key) - return result - - -def _get_api_keys_from_env( - env: Mapping[str, str], resolution: ParameterResolution | None = None -) -> list[str]: - """Get API keys from environment variables.""" - result: list[str] = [] - - # Get API keys from API_KEYS environment variable - api_keys_raw: str | None = env.get("API_KEYS") - if api_keys_raw and isinstance(api_keys_raw, str): - result.extend(_process_api_keys(api_keys_raw)) - - if result and resolution is not None: - resolution.record( - "auth.api_keys", - result, - ParameterSource.ENVIRONMENT, - origin="API_KEYS", - ) - - return result - - -def _env_to_bool( - name: str, - default: bool, - env: Mapping[str, str], - *, - path: str | None = None, - resolution: ParameterResolution | None = None, -) -> bool: - """Return an environment variable parsed as a boolean flag.""" - value = env.get(name) - if value is None: - return default - result = value.strip().lower() in {"1", "true", "yes", "on"} - if resolution is not None and path is not None: - resolution.record(path, result, ParameterSource.ENVIRONMENT, origin=name) - return result - - -def _env_to_int( - name: str, - default: int, - env: Mapping[str, str], - *, - path: str | None = None, - resolution: ParameterResolution | None = None, -) -> int: - """Return an environment variable parsed as an integer.""" - value = env.get(name) - if value is None: - return default - try: - result = int(value) - except (TypeError, ValueError): - result = default - if resolution is not None and path is not None and value is not None: - resolution.record(path, result, ParameterSource.ENVIRONMENT, origin=name) - return result - - -def _env_to_float( - name: str, - default: float, - env: Mapping[str, str], - *, - path: str | None = None, - resolution: ParameterResolution | None = None, -) -> float: - """Return an environment variable parsed as a float.""" - value = env.get(name) - if value is None: - return default - try: - result = float(value) - except (TypeError, ValueError): - result = default - if resolution is not None and path is not None and value is not None: - resolution.record(path, result, ParameterSource.ENVIRONMENT, origin=name) - return result - - -def _get_env_value( - env: Mapping[str, str], - name: str, - default: Any, - *, - path: str, - resolution: ParameterResolution | None = None, - transform: Callable[[str], Any] | None = None, -) -> Any: - """Return an environment variable value and optionally record its source.""" - - if name in env: - raw_value = env[name] - value = transform(raw_value) if transform is not None else raw_value - if resolution is not None: - resolution.record(path, value, ParameterSource.ENVIRONMENT, origin=name) - return value - return default - - -def _to_int(value: str, fallback: int) -> int: - try: - return int(value) - except (TypeError, ValueError): - return fallback - - -def _to_float(value: str, fallback: float | None) -> float | None: - try: - return float(value) - except (TypeError, ValueError): - return fallback - - -class LogLevel(str, Enum): - """Log levels for configuration.""" - - DEBUG = "DEBUG" - INFO = "INFO" - WARNING = "WARNING" - ERROR = "ERROR" - CRITICAL = "CRITICAL" - - -class BackendConfig(DomainModel): - """Configuration for a backend service.""" - - model_config = ConfigDict(frozen=True) - - api_key: list[str] = Field(default_factory=list) - api_url: str | None = None - models: list[str] = Field(default_factory=list) - timeout: int = 120 # seconds - identity: AppIdentityConfig | None = None - extra: dict[str, Any] = Field(default_factory=dict) - - @field_validator("api_key", mode="before") - @classmethod - def validate_api_key(cls, v: Any) -> list[str]: - """Ensure api_key is always a list.""" - if isinstance(v, str): - return [v] - return v if isinstance(v, list) else [] - - @field_validator("api_url") - @classmethod - def validate_api_url(cls, v: str | None) -> str | None: - """Validate the API URL if provided.""" - if v is not None and not v.startswith(("http://", "https://")): - raise ValueError("API URL must start with http:// or https://") - return v - - -class AuthConfig(DomainModel): - """Authentication configuration.""" - - model_config = ConfigDict(frozen=True) - - disable_auth: bool = False - api_keys: list[str] = Field(default_factory=list) - auth_token: str | None = None - redact_api_keys_in_prompts: bool = True - trusted_ips: list[str] = Field(default_factory=list) - brute_force_protection: BruteForceProtectionConfig = Field( - default_factory=lambda: BruteForceProtectionConfig() - ) - - -class BruteForceProtectionConfig(DomainModel): - """Configuration for brute-force protection on API authentication.""" - - model_config = ConfigDict(frozen=True) - - enabled: bool = True - max_failed_attempts: int = 5 - ttl_seconds: int = 900 - initial_block_seconds: int = 30 - block_multiplier: float = 2.0 - max_block_seconds: int = 3600 - - -class LoggingConfig(DomainModel): - """Logging configuration.""" - - model_config = ConfigDict(frozen=True) - - level: LogLevel = LogLevel.INFO - request_logging: bool = False - response_logging: bool = False - log_file: str | None = None - # Optional separate wire-capture log file; when set, all outbound requests - # and inbound replies/SSE payloads are captured verbatim to this file. - capture_file: str | None = None - # Optional max size in bytes; when exceeded, rotate current capture to - # `.1` and start a new file (overwrite existing .1). - capture_max_bytes: int | None = None - # Optional per-chunk truncation size in bytes for streaming capture. When - # set, stream chunks written to capture are truncated to this size with a - # short marker appended; streaming to client remains unmodified. - capture_truncate_bytes: int | None = None - # Optional number of rotated files to keep (e.g., file.1..file.N). If not - # set or <= 0, keeps a single rotation (file.1). Used only when - # capture_max_bytes is set. - capture_max_files: int | None = None - # Time-based rotation period in seconds (default 1 day). If set <= 0, time - # rotation is disabled. - capture_rotate_interval_seconds: int = 86400 - # Total disk cap across current capture file and rotated files. If set <= 0, - # disabled. Default is 100 MiB. - capture_total_max_bytes: int = 104857600 - # Buffer size for wire capture writes (bytes). Default 64KB. - capture_buffer_size: int = 65536 - # How often to flush buffer to disk (seconds). Default 1.0 second. - capture_flush_interval: float = 1.0 - # Maximum entries to buffer before forcing flush. Default 100. - capture_max_entries_per_flush: int = 100 - - -class ToolCallReactorConfig(DomainModel): - """Configuration for the Tool Call Reactor system. - - The Tool Call Reactor provides event-driven reactions to tool calls - from LLMs, allowing custom handlers to monitor, modify, or replace responses. - """ - - model_config = ConfigDict(frozen=True) - - enabled: bool = True - """Whether the Tool Call Reactor is enabled.""" - - apply_diff_steering_enabled: bool = True - """Whether the legacy apply_diff steering handler is enabled.""" - - apply_diff_steering_rate_limit_seconds: int = 60 - """Legacy rate limit window for apply_diff steering in seconds. - - Controls how often steering messages are shown for apply_diff tool calls - within the same session. Default: 60 seconds (1 message per minute). - """ - - apply_diff_steering_message: str | None = None - """Legacy custom steering message for apply_diff tool calls. - - If None, uses the default message. Can be customized to fit your workflow. - """ - - pytest_full_suite_steering_enabled: bool = False - """Whether steering for full pytest suite commands is enabled.""" - - pytest_full_suite_steering_message: str | None = None - """Optional custom steering message when detecting full pytest suite runs.""" - - pytest_context_saving_enabled: bool = False - """Whether pytest context-saving command rewrites are enabled.""" - - fix_think_tags_enabled: bool = False - """Whether correction of improperly formatted tags is enabled.""" - - # New: fully configurable steering rules - steering_rules: list[dict[str, Any]] = Field(default_factory=list) - """Configurable steering rules. - - Each rule is a dict describing when to trigger steering and what message to - return. See README for details. Minimal fields: - - name: Unique rule name - - enabled: bool - - triggers: { tool_names: [..], phrases: [..] } - - message: Replacement content when swallowed - - rate_limit: { calls_per_window: int, window_seconds: int } - - priority: int (optional; higher runs first) - """ - - # Tool access control policies - access_policies: list[dict[str, Any]] = Field(default_factory=list) - """Tool access control policies. - - Each policy defines which tools are allowed or blocked for specific models/agents. - Minimal fields: - - name: Unique policy identifier - - model_pattern: Regex pattern for matching model names - - default_policy: "allow" or "deny" - Optional fields: - - agent_pattern: Regex pattern for matching agents - - allowed_patterns: List of regex patterns for allowed tools - - blocked_patterns: List of regex patterns for blocked tools - - block_message: Message to return when blocking a tool call - - priority: int (higher values take precedence) - """ - - -class PlanningPhaseConfig(DomainModel): - """Configuration for planning phase model routing.""" - - model_config = ConfigDict(frozen=True) - - enabled: bool = False - strong_model: str | None = None - max_turns: int = 10 - max_file_writes: int = 1 - # Optional parameter overrides for the strong model - overrides: dict[str, Any] | None = None - - -class SessionContinuityConfig(DomainModel): - """Configuration for intelligent session continuity detection.""" - - model_config = ConfigDict(frozen=True) - - enabled: bool = True - fuzzy_matching: bool = True - max_session_age_seconds: int = 604800 # 7 days - fingerprint_message_count: int = 5 - client_key_includes_ip: bool = True - - -class SessionConfig(DomainModel): - """Session management configuration.""" - - model_config = ConfigDict(frozen=True) - - cleanup_enabled: bool = True - cleanup_interval: int = 3600 # 1 hour - max_age: int = 86400 # 1 day - default_interactive_mode: bool = True - force_set_project: bool = False - disable_interactive_commands: bool = False - project_dir_resolution_model: str | None = None - project_dir_resolution_mode: str = "hybrid" - tool_call_repair_enabled: bool = True - # Max per-session buffer for tool-call repair streaming (bytes) - tool_call_repair_buffer_cap_bytes: int = 64 * 1024 - json_repair_enabled: bool = True - # Max per-session buffer for JSON repair streaming (bytes) - json_repair_buffer_cap_bytes: int = 64 * 1024 - json_repair_strict_mode: bool = False - json_repair_schema: dict[str, Any] | None = None # Added - tool_call_reactor: ToolCallReactorConfig = Field( - default_factory=ToolCallReactorConfig - ) - dangerous_command_prevention_enabled: bool = True - dangerous_command_steering_message: str | None = None - pytest_compression_enabled: bool = True - pytest_compression_min_lines: int = 30 - pytest_full_suite_steering_enabled: bool | None = None - pytest_full_suite_steering_message: str | None = None - fix_think_tags_enabled: bool = False - fix_think_tags_streaming_buffer_size: int = 4096 - planning_phase: PlanningPhaseConfig = Field(default_factory=PlanningPhaseConfig) - max_per_session_backends: int = 32 - session_continuity: SessionContinuityConfig = Field( - default_factory=SessionContinuityConfig - ) - tool_access_global_overrides: dict[str, Any] | None = None - # Tool call processing behavior configuration - force_reprocess_tool_calls: bool = False - log_skipped_tool_calls: bool = False - - @model_validator(mode="before") - @classmethod - def _sync_pytest_full_suite_settings(cls, values: dict[str, Any]) -> dict[str, Any]: - """Keep pytest full-suite steering settings mirrored with reactor config.""" - reactor_config = values.get("tool_call_reactor") - - # Convert to dict if it's already a ToolCallReactorConfig instance - if isinstance(reactor_config, ToolCallReactorConfig): - reactor_config_dict = reactor_config.model_dump() - elif isinstance(reactor_config, dict): - reactor_config_dict = dict(reactor_config) - else: - reactor_config_dict = {} - - enabled = values.get("pytest_full_suite_steering_enabled") - message = values.get("pytest_full_suite_steering_message") - - # Update the dict instead of mutating frozen model - if enabled is not None: - reactor_config_dict["pytest_full_suite_steering_enabled"] = enabled - else: - values["pytest_full_suite_steering_enabled"] = reactor_config_dict.get( - "pytest_full_suite_steering_enabled", False - ) - - if message is not None: - reactor_config_dict["pytest_full_suite_steering_message"] = message - else: - values["pytest_full_suite_steering_message"] = reactor_config_dict.get( - "pytest_full_suite_steering_message" - ) - - # Store the dict - Pydantic will convert it to ToolCallReactorConfig - values["tool_call_reactor"] = reactor_config_dict - return values - - -class EmptyResponseConfig(DomainModel): - """Configuration for empty response handling.""" - - model_config = ConfigDict(frozen=True) - - enabled: bool = True - """Whether the empty response recovery is enabled.""" - - max_retries: int = 1 - """Maximum number of retries for empty responses.""" - - -class ModelAliasRule(DomainModel): - """A rule for rewriting a model name.""" - - model_config = ConfigDict(frozen=True) - - pattern: str - replacement: str - - -class RewritingConfig(DomainModel): - """Configuration for content rewriting.""" - - model_config = ConfigDict(frozen=True) - - enabled: bool = False - config_path: str = "config/replacements" - - -class EditPrecisionConfig(DomainModel): - """Configuration for automated edit-precision tuning. - - When enabled, detects agent edit-failure prompts and lowers sampling - parameters for the next single call to improve precision. - """ - - model_config = ConfigDict(frozen=True) - - enabled: bool = True - temperature: float = 0.1 - # Only applied if override_top_p is True; otherwise top_p remains unchanged - min_top_p: float | None = 0.3 - # Control whether top_p/top_k are overridden by this feature - override_top_p: bool = False - override_top_k: bool = False - # Target top_k to apply when override_top_k is True (for providers that support it, e.g., Gemini) - target_top_k: int | None = None - # Optional regex pattern; when set, agents with names matching this pattern - # will be excluded (feature disabled) even if enabled=True. - exclude_agents_regex: str | None = None - - -from src.core.services.backend_registry import ( - backend_registry, # Updated import path -) - - -class BackendSettings(DomainModel): - """Settings for all backends. - - Note: This class is intentionally not frozen because it needs to support - dynamic backend configurations that are added at runtime. Backend configs - are stored in __dict__ to allow attribute-style access (e.g., config.backends.openai) - without pre-defining all possible backends as fields. - """ - - model_config = ConfigDict(frozen=False, extra="allow") - - default_backend: str = "openai" - static_route: str | None = ( - None # Force all requests to backend:model (e.g., "gemini-oauth-plan:gemini-2.5-pro") - ) - disable_gemini_oauth_fallback: bool = False - disable_hybrid_backend: bool = False - hybrid_backend_repeat_messages: bool = False - reasoning_injection_probability: float = Field( - default=1.0, - ge=0.0, - le=1.0, - description="Probability of using the reasoning model for a request in the hybrid backend.", - ) - - def __init__(self, **data: Any) -> None: - # Separate standard fields from backend-specific configs - known_fields = set(self.model_fields.keys()) - - init_data = {k: v for k, v in data.items() if k in known_fields} - backend_data = {k: v for k, v in data.items() if k not in known_fields} - - # Initialize the model with standard fields - super().__init__(**init_data) - - # Manually set the backend configurations - for backend_name, config_data in backend_data.items(): - if isinstance(config_data, dict): - self.__dict__[backend_name] = BackendConfig(**config_data) - elif isinstance(config_data, BackendConfig): - self.__dict__[backend_name] = config_data - - # Ensure all registered backends have a config - for backend_name in backend_registry.get_registered_backends(): - if backend_name not in self.__dict__: - self.__dict__[backend_name] = BackendConfig() - - self._initialization_complete = True - - def __getitem__(self, key: str) -> BackendConfig: - """Allow dictionary-style access to backend configs.""" - if key in self.__dict__: - return cast(BackendConfig, self.__dict__[key]) - raise KeyError(f"Backend '{key}' not found") - - def __setitem__(self, key: str, value: BackendConfig) -> None: - """Allow dictionary-style setting of backend configs.""" - self.__dict__[key] = value - - def __setattr__(self, name: str, value: Any) -> None: - """Allow attribute-style assignment for backend configs.""" - if ( - name in {"default_backend"} - or name.startswith("_") - or name in self.model_fields - ): - super().__setattr__(name, value) - return - if isinstance(value, BackendConfig): - config = value - elif isinstance(value, dict): - config = BackendConfig(**value) - else: - config = BackendConfig() - self.__dict__[name] = config - - def get(self, key: str, default: Any = None) -> Any: - """Dictionary-style get with default.""" - return cast(BackendConfig | None, self.__dict__.get(key, default)) - - @property - def functional_backends(self) -> set[str]: - """Get the set of functional backends (those with API keys).""" - functional: set[str] = set() - registered = backend_registry.get_registered_backends() - for backend_name in registered: - if backend_name in self.__dict__: - config: Any = self.__dict__[backend_name] - if isinstance(config, BackendConfig) and config.api_key: - functional.add(backend_name) - - # Consider OAuth-style backends functional even without an api_key in config, - # since they source credentials from local auth stores (e.g., CLI-managed files). - oauth_like: set[str] = set() - for name in registered: - if name.endswith("-oauth") or name.startswith("gemini-oauth"): - oauth_like.add(name) - if name == "gemini-cli-cloud-project": - oauth_like.add(name) - - functional.update(oauth_like.intersection(set(registered))) - - # Include any dynamically added backends present in __dict__ that have api_key - # (used in tests and when users add custom backends not in the registry). - for name, cfg in getattr(self, "__dict__", {}).items(): - if ( - name == "default_backend" - or name.startswith("_") - or not isinstance(cfg, BackendConfig) - ): - continue - if cfg.api_key: - functional.add(name) - return functional - - def __getattr__(self, name: str) -> Any: - """Allow accessing backend configs as attributes. - - If an attribute for a backend is missing, create a default - BackendConfig instance lazily. This ensures tests and runtime - code can access `config.backends.openai` / `config.backends.gemini` - even if the registry hasn't been populated yet. - """ - if name == "default_backend": # Handle default_backend separately - # Ensure we use the explicitly set default_backend if available - if "default_backend" in self.__dict__: - return self.__dict__["default_backend"] - # Otherwise fall back to openai - return "openai" - - # Check if the attribute exists in __dict__ - if name in self.__dict__: - return cast(BackendConfig, self.__dict__[name]) - - # Avoid creating configs for private/internal attributes to maintain security - if name.startswith(("_", "__")): - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) - - # Check if we're still initializing (indicated by presence of __dict__ keys - # that suggest initialization hasn't completed). Don't create empty configs - # during initialization - let the __init__ method handle it. - # Only create empty configs after initialization is complete. - if not hasattr(self, "_initialization_complete"): - # During initialization, raise AttributeError to let __init__ handle it - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) - - # Lazily create a default backend configuration for unknown backends. - # This allows accessing backend configs without pre-registration while - # maintaining backward compatibility. Created configs are cached for - # subsequent access to avoid creating multiple instances. - config = BackendConfig() - self.__dict__[name] = config - return config - - def model_dump(self, **kwargs: Any) -> dict[str, Any]: - """Override model_dump to include default_backend and dynamic backends.""" - dumped: dict[str, Any] = super().model_dump(**kwargs) - # Add dynamic backends to the dumped dictionary - for backend_name in backend_registry.get_registered_backends(): - if backend_name in self.__dict__: - config: Any = self.__dict__[backend_name] - if isinstance(config, BackendConfig): - dumped[backend_name] = config.model_dump() - return dumped - - def model_is_functional(self, model_id: str) -> bool: - """Check if a model is available in any functional backend.""" - if ":" not in model_id: - return False # Invalid format - - backend_name, _ = model_id.split(":", 1) - return backend_name in self.functional_backends - - -class AppConfig(DomainModel, IConfig): - """Complete application configuration.""" - - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) - - host: str = "127.0.0.1" # Default to localhost for security - port: int = 8000 - anthropic_port: int | None = None # Will be set to port + 1 if not provided - proxy_timeout: int = 120 - command_prefix: str = "!/" - strict_command_detection: bool = False - context_window_override: int | None = None # Override context window for all models - gcp_project_id: str | None = None - gemini_credentials_path: str | None = None - disable_health_checks: bool = False - - # Rate limit settings - default_rate_limit: int = 60 - default_rate_window: int = 60 - - # Backend settings - backends: BackendSettings = Field(default_factory=BackendSettings) - model_defaults: dict[str, dict[str, Any]] = Field(default_factory=dict) - failover_routes: dict[str, dict[str, Any]] = Field(default_factory=dict) - - # No nested class references - use direct imports instead - - # Identity settings - identity: AppIdentityConfig = Field(default_factory=AppIdentityConfig) - - # Auth settings - auth: AuthConfig = Field(default_factory=AuthConfig) - - # Session settings - session: SessionConfig = Field(default_factory=SessionConfig) - - # Logging settings - logging: LoggingConfig = Field(default_factory=LoggingConfig) - - # Empty response handling settings - empty_response: EmptyResponseConfig = Field(default_factory=EmptyResponseConfig) - - # Edit-precision tuning settings - edit_precision: EditPrecisionConfig = Field(default_factory=EditPrecisionConfig) - - # Rewriting settings - rewriting: RewritingConfig = Field(default_factory=RewritingConfig) - assessment: AssessmentConfig = Field(default_factory=AssessmentConfig) - - # Reasoning aliases settings - reasoning_aliases: ReasoningAliasesConfig = Field( - default_factory=lambda: ReasoningAliasesConfig(reasoning_alias_settings=[]) - ) - - # Model name rewrite rules - model_aliases: list[ModelAliasRule] = Field(default_factory=list) - - # Sandboxing settings - sandboxing: SandboxingConfiguration = Field(default_factory=SandboxingConfiguration) - - # FastAPI app instance - app: Any = None - - def model_is_functional(self, model_id: str) -> bool: - """Check if a model is available in any functional backend.""" - return self.backends.model_is_functional(model_id) - - def save(self, path: str | Path) -> None: - """Save the current configuration to a file.""" - p = Path(path) - data = self.model_dump(mode="json", exclude_none=True) - # Normalize structure to match schema expectations - # - default_backend must be at top-level (already present) - # - Remove runtime-only fields that are not part of schema or can cause validation errors - for runtime_key in ["app"]: - if runtime_key in data: - data[runtime_key] = None - # Filter out unsupported top-level keys (schema has additionalProperties: false) - allowed_top_keys = { - "host", - "port", - "anthropic_port", - "proxy_timeout", - "command_prefix", - "strict_command_detection", - "context_window_override", - "default_rate_limit", - "default_rate_window", - "model_defaults", - "failover_routes", - "identity", - "empty_response", - "edit_precision", - "rewriting", - "app", - "logging", - "auth", - "session", - "backends", - "default_backend", - "reasoning_aliases", - "model_aliases", - "sandboxing", - } - data = {k: v for k, v in data.items() if k in allowed_top_keys} - # Ensure nested sections only include serializable primitives - # (model_dump already handles pydantic models) - if p.suffix.lower() in {".yaml", ".yml"}: - import yaml - - logger.debug(f"Saving configuration to {p}: {data}") - with p.open("w", encoding="utf-8") as f: - yaml.safe_dump(data, f, sort_keys=False) - else: - # Legacy: still allow JSON save if requested by extension - with p.open("w", encoding="utf-8") as f: - f.write(self.model_dump_json(indent=4)) - - @classmethod - def from_env( - cls, - *, - environ: Mapping[str, str] | None = None, - resolution: ParameterResolution | None = None, - ) -> AppConfig: - """Create AppConfig from environment variables. - - Returns: - AppConfig instance - """ - env: Mapping[str, str] = os.environ if environ is None else environ - - # Build configuration from environment - config: dict[str, Any] = { - # Server settings - "gcp_project_id": _get_env_value( - env, - "GOOGLE_CLOUD_PROJECT", - _get_env_value( - env, - "GCP_PROJECT_ID", - None, - path="gcp_project_id", - resolution=resolution, - ), - path="gcp_project_id", - resolution=resolution, - ), - "gemini_credentials_path": _get_env_value( - env, - "GEMINI_CREDENTIALS_PATH", - None, - path="gemini_credentials_path", - resolution=resolution, - ), - "disable_health_checks": _env_to_bool( - "DISABLE_HEALTH_CHECKS", - False, - env, - path="disable_health_checks", - resolution=resolution, - ), - "host": _get_env_value( - env, - "APP_HOST", - "127.0.0.1", # Default to localhost for security - path="host", - resolution=resolution, - ), - "port": _get_env_value( - env, - "APP_PORT", - 8000, - path="port", - resolution=resolution, - transform=lambda value: _to_int(value, 8000), - ), - "anthropic_port": _get_env_value( - env, - "ANTHROPIC_PORT", - None, - path="anthropic_port", - resolution=resolution, - transform=lambda value: _to_int(value, 0) if value else None, - ), - "proxy_timeout": _get_env_value( - env, - "PROXY_TIMEOUT", - 120, - path="proxy_timeout", - resolution=resolution, - transform=lambda value: _to_int(value, 120), - ), - "command_prefix": _get_env_value( - env, - "COMMAND_PREFIX", - "!/", - path="command_prefix", - resolution=resolution, - ), - "auth": { - "disable_auth": _env_to_bool( - "DISABLE_AUTH", - False, - env, - path="auth.disable_auth", - resolution=resolution, - ), - "api_keys": _get_api_keys_from_env(env, resolution), - "auth_token": _get_env_value( - env, - "AUTH_TOKEN", - None, - path="auth.auth_token", - resolution=resolution, - ), - "brute_force_protection": { - "enabled": _env_to_bool( - "BRUTE_FORCE_PROTECTION_ENABLED", - True, - env, - path="auth.brute_force_protection.enabled", - resolution=resolution, - ), - "max_failed_attempts": _env_to_int( - "BRUTE_FORCE_MAX_FAILED_ATTEMPTS", - 5, - env, - path="auth.brute_force_protection.max_failed_attempts", - resolution=resolution, - ), - "ttl_seconds": _env_to_int( - "BRUTE_FORCE_TTL_SECONDS", - 900, - env, - path="auth.brute_force_protection.ttl_seconds", - resolution=resolution, - ), - "initial_block_seconds": _env_to_int( - "BRUTE_FORCE_INITIAL_BLOCK_SECONDS", - 30, - env, - path="auth.brute_force_protection.initial_block_seconds", - resolution=resolution, - ), - "block_multiplier": _env_to_float( - "BRUTE_FORCE_BLOCK_MULTIPLIER", - 2.0, - env, - path="auth.brute_force_protection.block_multiplier", - resolution=resolution, - ), - "max_block_seconds": _env_to_int( - "BRUTE_FORCE_MAX_BLOCK_SECONDS", - 3600, - env, - path="auth.brute_force_protection.max_block_seconds", - resolution=resolution, - ), - }, - }, - } - - if not config.get("anthropic_port"): - config["anthropic_port"] = int(config["port"]) + 1 - if resolution is not None: - resolution.record( - "anthropic_port", - config["anthropic_port"], - ParameterSource.DERIVED, - origin="port+1", - ) - - # After populating auth config, if disable_auth is true, clear api_keys - auth_config: dict[str, Any] = config["auth"] - if isinstance(auth_config, dict) and auth_config.get("disable_auth"): - auth_config["api_keys"] = [] - - # Add session, logging, and backend config - planning_overrides: dict[str, Any] = {} - planning_temperature = _get_env_value( - env, - "PLANNING_PHASE_TEMPERATURE", - None, - path="session.planning_phase.overrides.temperature", - resolution=resolution, - transform=lambda value: _to_float(value, None), - ) - if planning_temperature is not None: - planning_overrides["temperature"] = planning_temperature - - planning_top_p = _get_env_value( - env, - "PLANNING_PHASE_TOP_P", - None, - path="session.planning_phase.overrides.top_p", - resolution=resolution, - transform=lambda value: _to_float(value, None), - ) - if planning_top_p is not None: - planning_overrides["top_p"] = planning_top_p - - planning_reasoning = _get_env_value( - env, - "PLANNING_PHASE_REASONING_EFFORT", - None, - path="session.planning_phase.overrides.reasoning_effort", - resolution=resolution, - ) - if planning_reasoning is not None: - planning_overrides["reasoning_effort"] = planning_reasoning - - planning_budget = _get_env_value( - env, - "PLANNING_PHASE_THINKING_BUDGET", - None, - path="session.planning_phase.overrides.thinking_budget", - resolution=resolution, - transform=lambda value: _to_int(value, 0), - ) - if planning_budget is not None: - planning_overrides["thinking_budget"] = planning_budget - - config["session"] = { - "cleanup_enabled": _env_to_bool( - "SESSION_CLEANUP_ENABLED", - True, - env, - path="session.cleanup_enabled", - resolution=resolution, - ), - "cleanup_interval": _env_to_int( - "SESSION_CLEANUP_INTERVAL", - 3600, - env, - path="session.cleanup_interval", - resolution=resolution, - ), - "max_age": _env_to_int( - "SESSION_MAX_AGE", - 86400, - env, - path="session.max_age", - resolution=resolution, - ), - "default_interactive_mode": _env_to_bool( - "DEFAULT_INTERACTIVE_MODE", - True, - env, - path="session.default_interactive_mode", - resolution=resolution, - ), - "force_set_project": _env_to_bool( - "FORCE_SET_PROJECT", - False, - env, - path="session.force_set_project", - resolution=resolution, - ), - "project_dir_resolution_model": _get_env_value( - env, - "PROJECT_DIR_RESOLUTION_MODEL", - None, - path="session.project_dir_resolution_model", - resolution=resolution, - ), - "project_dir_resolution_mode": _get_env_value( - env, - "PROJECT_DIR_RESOLUTION_MODE", - "hybrid", - path="session.project_dir_resolution_mode", - resolution=resolution, - ), - "tool_call_repair_enabled": _env_to_bool( - "TOOL_CALL_REPAIR_ENABLED", - True, - env, - path="session.tool_call_repair_enabled", - resolution=resolution, - ), - "tool_call_repair_buffer_cap_bytes": _get_env_value( - env, - "TOOL_CALL_REPAIR_BUFFER_CAP_BYTES", - 65536, - path="session.tool_call_repair_buffer_cap_bytes", - resolution=resolution, - transform=lambda value: _to_int(value, 65536), - ), - "json_repair_enabled": _env_to_bool( - "JSON_REPAIR_ENABLED", - True, - env, - path="session.json_repair_enabled", - resolution=resolution, - ), - "json_repair_buffer_cap_bytes": _get_env_value( - env, - "JSON_REPAIR_BUFFER_CAP_BYTES", - 65536, - path="session.json_repair_buffer_cap_bytes", - resolution=resolution, - transform=lambda value: _to_int(value, 65536), - ), - "json_repair_schema": _get_env_value( - env, - "JSON_REPAIR_SCHEMA", - None, - path="session.json_repair_schema", - resolution=resolution, - transform=lambda value: json.loads(value), - ), - "dangerous_command_prevention_enabled": _env_to_bool( - "DANGEROUS_COMMAND_PREVENTION_ENABLED", - True, - env, - path="session.dangerous_command_prevention_enabled", - resolution=resolution, - ), - "dangerous_command_steering_message": _get_env_value( - env, - "DANGEROUS_COMMAND_STEERING_MESSAGE", - None, - path="session.dangerous_command_steering_message", - resolution=resolution, - ), - "pytest_compression_enabled": _env_to_bool( - "PYTEST_COMPRESSION_ENABLED", - True, - env, - path="session.pytest_compression_enabled", - resolution=resolution, - ), - "pytest_compression_min_lines": _env_to_int( - "PYTEST_COMPRESSION_MIN_LINES", - 30, - env, - path="session.pytest_compression_min_lines", - resolution=resolution, - ), - "pytest_full_suite_steering_enabled": _env_to_bool( - "PYTEST_FULL_SUITE_STEERING_ENABLED", - False, - env, - path="session.pytest_full_suite_steering_enabled", - resolution=resolution, - ), - "pytest_full_suite_steering_message": _get_env_value( - env, - "PYTEST_FULL_SUITE_STEERING_MESSAGE", - None, - path="session.pytest_full_suite_steering_message", - resolution=resolution, - ), - "fix_think_tags_enabled": _env_to_bool( - "FIX_THINK_TAGS_ENABLED", - False, - env, - path="session.fix_think_tags_enabled", - resolution=resolution, - ), - "fix_think_tags_streaming_buffer_size": _env_to_int( - "FIX_THINK_TAGS_STREAMING_BUFFER_SIZE", - 4096, - env, - path="session.fix_think_tags_streaming_buffer_size", - resolution=resolution, - ), - "planning_phase": { - "enabled": _env_to_bool( - "PLANNING_PHASE_ENABLED", - False, - env, - path="session.planning_phase.enabled", - resolution=resolution, - ), - "strong_model": _get_env_value( - env, - "PLANNING_PHASE_STRONG_MODEL", - None, - path="session.planning_phase.strong_model", - resolution=resolution, - ), - "max_turns": _env_to_int( - "PLANNING_PHASE_MAX_TURNS", - 10, - env, - path="session.planning_phase.max_turns", - resolution=resolution, - ), - "max_file_writes": _env_to_int( - "PLANNING_PHASE_MAX_FILE_WRITES", - 1, - env, - path="session.planning_phase.max_file_writes", - resolution=resolution, - ), - "overrides": planning_overrides, - }, - "force_reprocess_tool_calls": _env_to_bool( - "FORCE_REPROCESS_TOOL_CALLS", - False, - env, - path="session.force_reprocess_tool_calls", - resolution=resolution, - ), - "log_skipped_tool_calls": _env_to_bool( - "LOG_SKIPPED_TOOL_CALLS", - False, - env, - path="session.log_skipped_tool_calls", - resolution=resolution, - ), - } - - config["logging"] = { - "level": _get_env_value( - env, - "LOG_LEVEL", - "INFO", - path="logging.level", - resolution=resolution, - ), - "request_logging": _env_to_bool( - "REQUEST_LOGGING", - False, - env, - path="logging.request_logging", - resolution=resolution, - ), - "response_logging": _env_to_bool( - "RESPONSE_LOGGING", - False, - env, - path="logging.response_logging", - resolution=resolution, - ), - "log_file": _get_env_value( - env, - "LOG_FILE", - None, - path="logging.log_file", - resolution=resolution, - ), - "capture_file": _get_env_value( - env, - "CAPTURE_FILE", - None, - path="logging.capture_file", - resolution=resolution, - ), - "capture_max_bytes": _get_env_value( - env, - "CAPTURE_MAX_BYTES", - None, - path="logging.capture_max_bytes", - resolution=resolution, - transform=lambda value: _to_int(value, 0), - ), - "capture_truncate_bytes": _get_env_value( - env, - "CAPTURE_TRUNCATE_BYTES", - None, - path="logging.capture_truncate_bytes", - resolution=resolution, - transform=lambda value: _to_int(value, 0), - ), - "capture_max_files": _get_env_value( - env, - "CAPTURE_MAX_FILES", - None, - path="logging.capture_max_files", - resolution=resolution, - transform=lambda value: _to_int(value, 0), - ), - "capture_rotate_interval_seconds": _get_env_value( - env, - "CAPTURE_ROTATE_INTERVAL_SECONDS", - 86400, - path="logging.capture_rotate_interval_seconds", - resolution=resolution, - transform=lambda value: _to_int(value, 86400), - ), - "capture_total_max_bytes": _get_env_value( - env, - "CAPTURE_TOTAL_MAX_BYTES", - 104857600, - path="logging.capture_total_max_bytes", - resolution=resolution, - transform=lambda value: _to_int(value, 104857600), - ), - "capture_buffer_size": _get_env_value( - env, - "CAPTURE_BUFFER_SIZE", - 65536, - path="logging.capture_buffer_size", - resolution=resolution, - transform=lambda value: _to_int(value, 65536), - ), - "capture_flush_interval": _get_env_value( - env, - "CAPTURE_FLUSH_INTERVAL", - 1.0, - path="logging.capture_flush_interval", - resolution=resolution, - transform=lambda value: _to_float(value, 1.0), - ), - "capture_max_entries_per_flush": _get_env_value( - env, - "CAPTURE_MAX_ENTRIES_PER_FLUSH", - 100, - path="logging.capture_max_entries_per_flush", - resolution=resolution, - transform=lambda value: _to_int(value, 100), - ), - } - - config["empty_response"] = { - "enabled": _env_to_bool( - "EMPTY_RESPONSE_HANDLING_ENABLED", - True, - env, - path="empty_response.enabled", - resolution=resolution, - ), - "max_retries": _env_to_int( - "EMPTY_RESPONSE_MAX_RETRIES", - 1, - env, - path="empty_response.max_retries", - resolution=resolution, - ), - } - - # Edit precision settings - config["edit_precision"] = { - "enabled": _env_to_bool( - "EDIT_PRECISION_ENABLED", - True, - env, - path="edit_precision.enabled", - resolution=resolution, - ), - "temperature": _env_to_float( - "EDIT_PRECISION_TEMPERATURE", - 0.1, - env, - path="edit_precision.temperature", - resolution=resolution, - ), - "min_top_p": _env_to_float( - "EDIT_PRECISION_MIN_TOP_P", - 0.3, - env, - path="edit_precision.min_top_p", - resolution=resolution, - ), - "override_top_p": _env_to_bool( - "EDIT_PRECISION_OVERRIDE_TOP_P", - False, - env, - path="edit_precision.override_top_p", - resolution=resolution, - ), - "override_top_k": _env_to_bool( - "EDIT_PRECISION_OVERRIDE_TOP_K", - False, - env, - path="edit_precision.override_top_k", - resolution=resolution, - ), - "target_top_k": _get_env_value( - env, - "EDIT_PRECISION_TARGET_TOP_K", - None, - path="edit_precision.target_top_k", - resolution=resolution, - transform=lambda value: _to_int(value, 0) or None, - ), - "exclude_agents_regex": _get_env_value( - env, - "EDIT_PRECISION_EXCLUDE_AGENTS_REGEX", - None, - path="edit_precision.exclude_agents_regex", - resolution=resolution, - ), - } - - config["rewriting"] = { - "enabled": _env_to_bool( - "REWRITING_ENABLED", - False, - env, - path="rewriting.enabled", - resolution=resolution, - ), - "config_path": _get_env_value( - env, - "REWRITING_CONFIG_PATH", - "config/replacements", - path="rewriting.config_path", - resolution=resolution, - ), - } - - # Assessment configuration from environment - config["assessment"] = { - "enabled": _env_to_bool( - "LLM_ASSESSMENT_ENABLED", - False, - env, - path="assessment.enabled", - resolution=resolution, - ), - "turn_threshold": _env_to_int( - "LLM_ASSESSMENT_TURN_THRESHOLD", - 30, - env, - path="assessment.turn_threshold", - resolution=resolution, - ), - "confidence_threshold": _env_to_float( - "LLM_ASSESSMENT_CONFIDENCE_THRESHOLD", - 0.9, - env, - path="assessment.confidence_threshold", - resolution=resolution, - ), - "backend": _get_env_value( - env, - "LLM_ASSESSMENT_BACKEND", - "openai", # Default backend - path="assessment.backend", - resolution=resolution, - ), - "model": _get_env_value( - env, - "LLM_ASSESSMENT_MODEL", - "gpt-4o-mini", # Default model - path="assessment.model", - resolution=resolution, - ), - "history_window": _env_to_int( - "LLM_ASSESSMENT_HISTORY_WINDOW", - 20, - env, - path="assessment.history_window", - resolution=resolution, - ), - } - - # Sandboxing configuration from environment - config["sandboxing"] = { - "enabled": _env_to_bool( - "ENABLE_SANDBOXING", - False, - env, - path="sandboxing.enabled", - resolution=resolution, - ), - "strict_mode": _env_to_bool( - "SANDBOXING_STRICT_MODE", - False, - env, - path="sandboxing.strict_mode", - resolution=resolution, - ), - "allow_parent_access": _env_to_bool( - "SANDBOXING_ALLOW_PARENT_ACCESS", - False, - env, - path="sandboxing.allow_parent_access", - resolution=resolution, - ), - } - - # Model aliases configuration from environment - model_aliases_env = env.get("MODEL_ALIASES") - if model_aliases_env: - try: - alias_data = json.loads(model_aliases_env) - if isinstance(alias_data, list): - config["model_aliases"] = [ - {"pattern": item["pattern"], "replacement": item["replacement"]} - for item in alias_data - if isinstance(item, dict) - and "pattern" in item - and "replacement" in item - ] - if resolution is not None: - resolution.record( - "model_aliases", - config["model_aliases"], - ParameterSource.ENVIRONMENT, - origin="MODEL_ALIASES", - ) - except (json.JSONDecodeError, KeyError, TypeError) as e: - logger.warning( - f"Invalid MODEL_ALIASES environment variable format: {e}" - ) - config["model_aliases"] = [] - else: - config["model_aliases"] = [] - - config["backends"] = { - "default_backend": _get_env_value( - env, - "LLM_BACKEND", - "openai", - path="backends.default_backend", - resolution=resolution, - ), - "disable_gemini_oauth_fallback": _env_to_bool( - "DISABLE_GEMINI_OAUTH_FALLBACK", - False, - env, - path="backends.disable_gemini_oauth_fallback", - resolution=resolution, - ), - "disable_hybrid_backend": _env_to_bool( - "DISABLE_HYBRID_BACKEND", - False, - env, - path="backends.disable_hybrid_backend", - resolution=resolution, - ), - "hybrid_backend_repeat_messages": _env_to_bool( - "HYBRID_BACKEND_REPEAT_MESSAGES", - False, - env, - path="backends.hybrid_backend_repeat_messages", - resolution=resolution, - ), - "reasoning_injection_probability": _env_to_float( - "REASONING_INJECTION_PROBABILITY", - 1.0, - env, - path="backends.reasoning_injection_probability", - resolution=resolution, - ), - } - - config["identity"] = AppIdentityConfig( - title=HeaderConfig( - override_value=_get_env_value( - env, - "APP_TITLE", - None, - path="identity.title.override_value", - resolution=resolution, - ), - mode=HeaderOverrideMode( - _get_env_value( - env, - "APP_TITLE_MODE", - "passthrough", - path="identity.title.mode", - resolution=resolution, - ) - ), - default_value="llm-interactive-proxy", - passthrough_name="x-title", - ), - url=HeaderConfig( - override_value=_get_env_value( - env, - "APP_URL", - None, - path="identity.url.override_value", - resolution=resolution, - ), - mode=HeaderOverrideMode( - _get_env_value( - env, - "APP_URL_MODE", - "passthrough", - path="identity.url.mode", - resolution=resolution, - ) - ), - default_value="https://github.com/matdev83/llm-interactive-proxy", - passthrough_name="http-referer", - ), - user_agent=HeaderConfig( - override_value=_get_env_value( - env, - "APP_USER_AGENT", - None, - path="identity.user_agent.override_value", - resolution=resolution, - ), - mode=HeaderOverrideMode( - _get_env_value( - env, - "APP_USER_AGENT_MODE", - "passthrough", - path="identity.user_agent.mode", - resolution=resolution, - ) - ), - default_value="llm-interactive-proxy", - passthrough_name="user-agent", - ), - ) - - # Log the determined default_backend - logger.info( - f"AppConfig.from_env - Determined default_backend: {config['backends']['default_backend']}" - ) - - # Extract backend configurations from environment - config_backends: dict[str, Any] = config["backends"] - assert isinstance(config_backends, dict) - - # Collect and assign API keys for specific backends - openrouter_keys = _collect_api_keys_from_env( - "OPENROUTER_API_KEY", env, resolution - ) - if openrouter_keys: - config_backends["openrouter"] = config_backends.get("openrouter", {}) - config_backends["openrouter"]["api_key"] = list(openrouter_keys.values()) - config_backends["openrouter"]["api_url"] = _get_env_value( - env, - "OPENROUTER_API_BASE_URL", - "https://openrouter.ai/api/v1", - path="backends.openrouter.api_url", - resolution=resolution, - ) - timeout_value = _get_env_value( - env, - "OPENROUTER_TIMEOUT", - None, - path="backends.openrouter.timeout", - resolution=resolution, - transform=lambda value: _to_int(value, 0), - ) - if timeout_value: - config_backends["openrouter"]["timeout"] = timeout_value - if resolution is not None: - resolution.record( - "backends.openrouter.api_key", - config_backends["openrouter"]["api_key"], - ParameterSource.ENVIRONMENT, - origin="OPENROUTER_API_KEY*", - ) - - gemini_keys: dict[str, str] = _collect_api_keys_from_env( - "GEMINI_API_KEY", env, resolution - ) - if gemini_keys: - config_backends["gemini"] = config_backends.get("gemini", {}) - config_backends["gemini"]["api_key"] = list(gemini_keys.values()) - config_backends["gemini"]["api_url"] = _get_env_value( - env, - "GEMINI_API_BASE_URL", - "https://generativelanguage.googleapis.com", - path="backends.gemini.api_url", - resolution=resolution, - ) - gemini_timeout = _get_env_value( - env, - "GEMINI_TIMEOUT", - None, - path="backends.gemini.timeout", - resolution=resolution, - transform=lambda value: _to_int(value, 0), - ) - if gemini_timeout: - config_backends["gemini"]["timeout"] = gemini_timeout - if resolution is not None: - resolution.record( - "backends.gemini.api_key", - config_backends["gemini"]["api_key"], - ParameterSource.ENVIRONMENT, - origin="GEMINI_API_KEY*", - ) - - anthropic_keys: dict[str, str] = _collect_api_keys_from_env( - "ANTHROPIC_API_KEY", env, resolution - ) - if anthropic_keys: - config_backends["anthropic"] = config_backends.get("anthropic", {}) - config_backends["anthropic"]["api_key"] = list(anthropic_keys.values()) - config_backends["anthropic"]["api_url"] = _get_env_value( - env, - "ANTHROPIC_API_BASE_URL", - "https://api.anthropic.com/v1", - path="backends.anthropic.api_url", - resolution=resolution, - ) - anthropic_timeout = _get_env_value( - env, - "ANTHROPIC_TIMEOUT", - None, - path="backends.anthropic.timeout", - resolution=resolution, - transform=lambda value: _to_int(value, 0), - ) - if anthropic_timeout: - config_backends["anthropic"]["timeout"] = anthropic_timeout - if resolution is not None: - resolution.record( - "backends.anthropic.api_key", - config_backends["anthropic"]["api_key"], - ParameterSource.ENVIRONMENT, - origin="ANTHROPIC_API_KEY*", - ) - - zai_keys: dict[str, str] = _collect_api_keys_from_env( - "ZAI_API_KEY", env, resolution - ) - if zai_keys: - config_backends["zai"] = config_backends.get("zai", {}) - config_backends["zai"]["api_key"] = list(zai_keys.values()) - config_backends["zai"]["api_url"] = _get_env_value( - env, - "ZAI_API_BASE_URL", - None, - path="backends.zai.api_url", - resolution=resolution, - ) - zai_timeout = _get_env_value( - env, - "ZAI_TIMEOUT", - None, - path="backends.zai.timeout", - resolution=resolution, - transform=lambda value: _to_int(value, 0), - ) - if zai_timeout: - config_backends["zai"]["timeout"] = zai_timeout - if resolution is not None: - resolution.record( - "backends.zai.api_key", - config_backends["zai"]["api_key"], - ParameterSource.ENVIRONMENT, - origin="ZAI_API_KEY*", - ) - - openai_keys: dict[str, str] = _collect_api_keys_from_env( - "OPENAI_API_KEY", env, resolution - ) - if openai_keys: - config_backends["openai"] = config_backends.get("openai", {}) - config_backends["openai"]["api_key"] = list(openai_keys.values()) - config_backends["openai"]["api_url"] = _get_env_value( - env, - "OPENAI_API_BASE_URL", - "https://api.openai.com/v1", - path="backends.openai.api_url", - resolution=resolution, - ) - openai_timeout = _get_env_value( - env, - "OPENAI_TIMEOUT", - None, - path="backends.openai.timeout", - resolution=resolution, - transform=lambda value: _to_int(value, 0), - ) - if openai_timeout: - config_backends["openai"]["timeout"] = openai_timeout - if resolution is not None: - resolution.record( - "backends.openai.api_key", - config_backends["openai"]["api_key"], - ParameterSource.ENVIRONMENT, - origin="OPENAI_API_KEY*", - ) - - minimax_keys: dict[str, str] = _collect_api_keys_from_env( - "MINIMAX_API_KEY", env, resolution - ) - if minimax_keys: - config_backends["minimax"] = config_backends.get("minimax", {}) - config_backends["minimax"]["api_key"] = list(minimax_keys.values()) - config_backends["minimax"]["api_url"] = _get_env_value( - env, - "MINIMAX_API_BASE_URL", - "https://api.minimax.io/v1", - path="backends.minimax.api_url", - resolution=resolution, - ) - minimax_timeout = _get_env_value( - env, - "MINIMAX_TIMEOUT", - None, - path="backends.minimax.timeout", - resolution=resolution, - transform=lambda value: _to_int(value, 0), - ) - if minimax_timeout: - config_backends["minimax"]["timeout"] = minimax_timeout - if resolution is not None: - resolution.record( - "backends.minimax.api_key", - config_backends["minimax"]["api_key"], - ParameterSource.ENVIRONMENT, - origin="MINIMAX_API_KEY*", - ) - - # Handle default backend if it's not explicitly configured above - default_backend_type: str = str( - config["backends"].get("default_backend", "openai") - ) - if default_backend_type not in config_backends: - # If the default backend is not explicitly configured, ensure it has a basic config - config_backends[default_backend_type] = config_backends.get( - default_backend_type, {} - ) - # Add a dummy API key if running in test environment and no API key is present - if env.get("PYTEST_CURRENT_TEST") and ( - not config_backends[default_backend_type] - or not config_backends[default_backend_type].get("api_key") - ): - config_backends[default_backend_type]["api_key"] = [ - f"test-key-{default_backend_type}" - ] - logger.info( - f"Added test API key for default backend {default_backend_type}" - ) - - return cls(**config) # type: ignore - - def get(self, key: str, default: Any = None) -> Any: - """Get a configuration value by key.""" - # Split the key by dots to handle nested attributes - keys = key.split(".") - value: Any = self - - try: - for k in keys: - if isinstance(value, dict): - value = value.get(k, default) - else: - value = getattr(value, k, default) - return value - except Exception: - return default - - def set(self, key: str, value: Any) -> None: - """Set a configuration value.""" - # For simplicity, we'll only handle top-level attributes - # In a more complex implementation, we might want to handle nested attributes - setattr(self, key, value) - - def get_gcp_project_id(self) -> str | None: - """Return the GCP Project ID.""" - return self.gcp_project_id - - -def _merge_dicts(d1: dict[str, Any], d2: dict[str, Any]) -> dict[str, Any]: - for k, v in d2.items(): - if k in d1 and isinstance(d1[k], dict) and isinstance(v, dict): - _merge_dicts(d1[k], v) - else: - d1[k] = v - return d1 - - -def _set_by_path(target: dict[str, Any], path: str, value: Any) -> None: - parts = path.split(".") - current: dict[str, Any] = target - for key in parts[:-1]: - current = current.setdefault(key, {}) # type: ignore[assignment] - current[parts[-1]] = value - - -def _get_by_path(source: dict[str, Any], path: str) -> Any: - parts = path.split(".") - current: Any = source - for key in parts: - if not isinstance(current, dict): - return None - current = current.get(key) - return current - - -def _flatten_dict(data: dict[str, Any]) -> dict[str, Any]: - flattened: dict[str, Any] = {} - - def _walk(value: Any, prefix: str) -> None: - if isinstance(value, dict): - for key, child in value.items(): - new_prefix = f"{prefix}.{key}" if prefix else key - _walk(child, new_prefix) - else: - flattened[prefix] = value - - _walk(data, "") - return flattened - - -def load_config( - config_path: str | Path | None = None, - *, - resolution: ParameterResolution | None = None, - environ: Mapping[str, str] | None = None, -) -> AppConfig: - """ - Load configuration from file and environment. - - Args: - config_path: Optional path to configuration file - - Returns: - AppConfig instance - """ - env = os.environ if environ is None else environ - res = resolution or ParameterResolution() - - config_data: dict[str, Any] = AppConfig().model_dump() - - if config_path: - try: - import yaml - - path: Path = Path(config_path) - if not path.exists(): - logger.warning(f"Configuration file not found: {config_path}") - else: - if path.suffix.lower() not in [".yaml", ".yml"]: - raise ValueError( - f"Unsupported configuration file format: {path.suffix}. Use YAML (.yaml/.yml)." - ) - - with open(path, encoding="utf-8") as f: - file_config: dict[str, Any] = yaml.safe_load(f) or {} - - from pathlib import Path as _Path - - from src.core.config.semantic_validation import ( - validate_config_semantics, - ) - from src.core.config.yaml_validation import validate_yaml_against_schema - - schema_path = ( - _Path.cwd() / "config" / "schemas" / "app_config.schema.yaml" - ) - validate_yaml_against_schema(_Path(path), schema_path) - validate_config_semantics(file_config, path) - - _merge_dicts(config_data, file_config) - origin = str(path) - for name, value in _flatten_dict(file_config).items(): - res.record( - name, - value, - ParameterSource.CONFIG_FILE, - origin=origin, - ) - except Exception as exc: # type: ignore[misc] - logger.critical(f"Error loading configuration file: {exc!s}") - raise - - env_config = AppConfig.from_env(environ=env, resolution=res) - env_dump = env_config.model_dump() - for name in res.latest_by_source(ParameterSource.ENVIRONMENT): - value = _get_by_path(env_dump, name) - _set_by_path(config_data, name, value) - - return AppConfig.model_validate(config_data) +from __future__ import annotations + +import json +import logging +import os +from collections.abc import Callable, Mapping +from enum import Enum +from pathlib import Path +from typing import Any, cast + +from pydantic import ConfigDict, Field, field_validator, model_validator + +from src.core.config.parameter_resolution import ParameterResolution, ParameterSource + + +def get_openrouter_headers(cfg: dict[str, str], api_key: str) -> dict[str, str]: + """Construct headers for OpenRouter requests. + + Be tolerant of minimal cfg dicts provided by tests by falling back to + sensible defaults when optional keys are absent. + """ + referer: str = cfg.get("app_site_url", "http://localhost:8000") + x_title: str = cfg.get("app_x_title", "InterceptorProxy") + return { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "HTTP-Referer": referer, + "X-Title": x_title, + } + + +def _collect_api_keys_from_env( + base_name: str, + env: Mapping[str, str], + resolution: ParameterResolution | None = None, +) -> dict[str, str]: + """Collect API keys from environment variables with parameter resolution tracking.""" + single_key = env.get(base_name) + numbered_keys: dict[str, str] = {} + numbered_key_names = [] + for i in range(1, 21): + key_name = f"{base_name}_{i}" + key = env.get(key_name) + if key: + numbered_keys[key_name] = key + numbered_key_names.append(key_name) + + if single_key and numbered_keys: + logger.warning( + "Both %s and %s_ environment variables are set. Prioritizing %s_ and ignoring %s.", + base_name, + base_name, + base_name, + base_name, + ) + if resolution is not None: + resolution.record( + f"backends.{base_name.lower().replace('_', '')}.api_key", + list(numbered_keys.values()), + ParameterSource.ENVIRONMENT, + origin=",".join(numbered_key_names), + ) + return numbered_keys + + if single_key: + result = {base_name: single_key} + if resolution is not None: + resolution.record( + f"backends.{base_name.lower().replace('_', '')}.api_key", + list(result.values()), + ParameterSource.ENVIRONMENT, + origin=base_name, + ) + return result + + if resolution is not None and numbered_keys: + resolution.record( + f"backends.{base_name.lower().replace('_', '')}.api_key", + list(numbered_keys.values()), + ParameterSource.ENVIRONMENT, + origin=",".join(numbered_key_names), + ) + return numbered_keys + + +from src.core.domain.configuration.app_identity_config import AppIdentityConfig +from src.core.domain.configuration.assessment_config import AssessmentConfig +from src.core.domain.configuration.header_config import ( + HeaderConfig, + HeaderOverrideMode, +) +from src.core.domain.configuration.reasoning_aliases_config import ( + ReasoningAliasesConfig, +) +from src.core.domain.configuration.sandboxing_config import SandboxingConfiguration +from src.core.interfaces.configuration_interface import IConfig +from src.core.interfaces.model_bases import DomainModel + +# Note: Avoid self-imports to prevent circular dependencies. Classes are defined below. + +logger = logging.getLogger(__name__) + + +def _process_api_keys(keys_string: str) -> list[str]: + """Process a comma-separated string of API keys.""" + keys = keys_string.split(",") + result: list[str] = [] + for key in keys: + stripped_key = key.strip() + if stripped_key: + result.append(stripped_key) + return result + + +def _get_api_keys_from_env( + env: Mapping[str, str], resolution: ParameterResolution | None = None +) -> list[str]: + """Get API keys from environment variables.""" + result: list[str] = [] + + # Get API keys from API_KEYS environment variable + api_keys_raw: str | None = env.get("API_KEYS") + if api_keys_raw and isinstance(api_keys_raw, str): + result.extend(_process_api_keys(api_keys_raw)) + + if result and resolution is not None: + resolution.record( + "auth.api_keys", + result, + ParameterSource.ENVIRONMENT, + origin="API_KEYS", + ) + + return result + + +def _env_to_bool( + name: str, + default: bool, + env: Mapping[str, str], + *, + path: str | None = None, + resolution: ParameterResolution | None = None, +) -> bool: + """Return an environment variable parsed as a boolean flag.""" + value = env.get(name) + if value is None: + return default + result = value.strip().lower() in {"1", "true", "yes", "on"} + if resolution is not None and path is not None: + resolution.record(path, result, ParameterSource.ENVIRONMENT, origin=name) + return result + + +def _env_to_int( + name: str, + default: int, + env: Mapping[str, str], + *, + path: str | None = None, + resolution: ParameterResolution | None = None, +) -> int: + """Return an environment variable parsed as an integer.""" + value = env.get(name) + if value is None: + return default + try: + result = int(value) + except (TypeError, ValueError): + result = default + if resolution is not None and path is not None and value is not None: + resolution.record(path, result, ParameterSource.ENVIRONMENT, origin=name) + return result + + +def _env_to_float( + name: str, + default: float, + env: Mapping[str, str], + *, + path: str | None = None, + resolution: ParameterResolution | None = None, +) -> float: + """Return an environment variable parsed as a float.""" + value = env.get(name) + if value is None: + return default + try: + result = float(value) + except (TypeError, ValueError): + result = default + if resolution is not None and path is not None and value is not None: + resolution.record(path, result, ParameterSource.ENVIRONMENT, origin=name) + return result + + +def _get_env_value( + env: Mapping[str, str], + name: str, + default: Any, + *, + path: str, + resolution: ParameterResolution | None = None, + transform: Callable[[str], Any] | None = None, +) -> Any: + """Return an environment variable value and optionally record its source.""" + + if name in env: + raw_value = env[name] + value = transform(raw_value) if transform is not None else raw_value + if resolution is not None: + resolution.record(path, value, ParameterSource.ENVIRONMENT, origin=name) + return value + return default + + +def _to_int(value: str, fallback: int) -> int: + try: + return int(value) + except (TypeError, ValueError): + return fallback + + +def _to_float(value: str, fallback: float | None) -> float | None: + try: + return float(value) + except (TypeError, ValueError): + return fallback + + +class LogLevel(str, Enum): + """Log levels for configuration.""" + + DEBUG = "DEBUG" + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + CRITICAL = "CRITICAL" + + +class BackendConfig(DomainModel): + """Configuration for a backend service.""" + + model_config = ConfigDict(frozen=True) + + api_key: list[str] = Field(default_factory=list) + api_url: str | None = None + models: list[str] = Field(default_factory=list) + timeout: int = 120 # seconds + identity: AppIdentityConfig | None = None + extra: dict[str, Any] = Field(default_factory=dict) + + @field_validator("api_key", mode="before") + @classmethod + def validate_api_key(cls, v: Any) -> list[str]: + """Ensure api_key is always a list.""" + if isinstance(v, str): + return [v] + return v if isinstance(v, list) else [] + + @field_validator("api_url") + @classmethod + def validate_api_url(cls, v: str | None) -> str | None: + """Validate the API URL if provided.""" + if v is not None and not v.startswith(("http://", "https://")): + raise ValueError("API URL must start with http:// or https://") + return v + + +class AuthConfig(DomainModel): + """Authentication configuration.""" + + model_config = ConfigDict(frozen=True) + + disable_auth: bool = False + api_keys: list[str] = Field(default_factory=list) + auth_token: str | None = None + redact_api_keys_in_prompts: bool = True + trusted_ips: list[str] = Field(default_factory=list) + brute_force_protection: BruteForceProtectionConfig = Field( + default_factory=lambda: BruteForceProtectionConfig() + ) + + +class BruteForceProtectionConfig(DomainModel): + """Configuration for brute-force protection on API authentication.""" + + model_config = ConfigDict(frozen=True) + + enabled: bool = True + max_failed_attempts: int = 5 + ttl_seconds: int = 900 + initial_block_seconds: int = 30 + block_multiplier: float = 2.0 + max_block_seconds: int = 3600 + + +class LoggingConfig(DomainModel): + """Logging configuration.""" + + model_config = ConfigDict(frozen=True) + + level: LogLevel = LogLevel.INFO + request_logging: bool = False + response_logging: bool = False + log_file: str | None = None + # Optional separate wire-capture log file; when set, all outbound requests + # and inbound replies/SSE payloads are captured verbatim to this file. + capture_file: str | None = None + # Optional max size in bytes; when exceeded, rotate current capture to + # `.1` and start a new file (overwrite existing .1). + capture_max_bytes: int | None = None + # Optional per-chunk truncation size in bytes for streaming capture. When + # set, stream chunks written to capture are truncated to this size with a + # short marker appended; streaming to client remains unmodified. + capture_truncate_bytes: int | None = None + # Optional number of rotated files to keep (e.g., file.1..file.N). If not + # set or <= 0, keeps a single rotation (file.1). Used only when + # capture_max_bytes is set. + capture_max_files: int | None = None + # Time-based rotation period in seconds (default 1 day). If set <= 0, time + # rotation is disabled. + capture_rotate_interval_seconds: int = 86400 + # Total disk cap across current capture file and rotated files. If set <= 0, + # disabled. Default is 100 MiB. + capture_total_max_bytes: int = 104857600 + # Buffer size for wire capture writes (bytes). Default 64KB. + capture_buffer_size: int = 65536 + # How often to flush buffer to disk (seconds). Default 1.0 second. + capture_flush_interval: float = 1.0 + # Maximum entries to buffer before forcing flush. Default 100. + capture_max_entries_per_flush: int = 100 + + +class ToolCallReactorConfig(DomainModel): + """Configuration for the Tool Call Reactor system. + + The Tool Call Reactor provides event-driven reactions to tool calls + from LLMs, allowing custom handlers to monitor, modify, or replace responses. + """ + + model_config = ConfigDict(frozen=True) + + enabled: bool = True + """Whether the Tool Call Reactor is enabled.""" + + apply_diff_steering_enabled: bool = True + """Whether the legacy apply_diff steering handler is enabled.""" + + apply_diff_steering_rate_limit_seconds: int = 60 + """Legacy rate limit window for apply_diff steering in seconds. + + Controls how often steering messages are shown for apply_diff tool calls + within the same session. Default: 60 seconds (1 message per minute). + """ + + apply_diff_steering_message: str | None = None + """Legacy custom steering message for apply_diff tool calls. + + If None, uses the default message. Can be customized to fit your workflow. + """ + + pytest_full_suite_steering_enabled: bool = False + """Whether steering for full pytest suite commands is enabled.""" + + pytest_full_suite_steering_message: str | None = None + """Optional custom steering message when detecting full pytest suite runs.""" + + pytest_context_saving_enabled: bool = False + """Whether pytest context-saving command rewrites are enabled.""" + + fix_think_tags_enabled: bool = False + """Whether correction of improperly formatted tags is enabled.""" + + # New: fully configurable steering rules + steering_rules: list[dict[str, Any]] = Field(default_factory=list) + """Configurable steering rules. + + Each rule is a dict describing when to trigger steering and what message to + return. See README for details. Minimal fields: + - name: Unique rule name + - enabled: bool + - triggers: { tool_names: [..], phrases: [..] } + - message: Replacement content when swallowed + - rate_limit: { calls_per_window: int, window_seconds: int } + - priority: int (optional; higher runs first) + """ + + # Tool access control policies + access_policies: list[dict[str, Any]] = Field(default_factory=list) + """Tool access control policies. + + Each policy defines which tools are allowed or blocked for specific models/agents. + Minimal fields: + - name: Unique policy identifier + - model_pattern: Regex pattern for matching model names + - default_policy: "allow" or "deny" + Optional fields: + - agent_pattern: Regex pattern for matching agents + - allowed_patterns: List of regex patterns for allowed tools + - blocked_patterns: List of regex patterns for blocked tools + - block_message: Message to return when blocking a tool call + - priority: int (higher values take precedence) + """ + + +class PlanningPhaseConfig(DomainModel): + """Configuration for planning phase model routing.""" + + model_config = ConfigDict(frozen=True) + + enabled: bool = False + strong_model: str | None = None + max_turns: int = 10 + max_file_writes: int = 1 + # Optional parameter overrides for the strong model + overrides: dict[str, Any] | None = None + + +class SessionContinuityConfig(DomainModel): + """Configuration for intelligent session continuity detection.""" + + model_config = ConfigDict(frozen=True) + + enabled: bool = True + fuzzy_matching: bool = True + max_session_age_seconds: int = 604800 # 7 days + fingerprint_message_count: int = 5 + client_key_includes_ip: bool = True + + +class SessionConfig(DomainModel): + """Session management configuration.""" + + model_config = ConfigDict(frozen=True) + + cleanup_enabled: bool = True + cleanup_interval: int = 3600 # 1 hour + max_age: int = 86400 # 1 day + default_interactive_mode: bool = True + force_set_project: bool = False + disable_interactive_commands: bool = False + project_dir_resolution_model: str | None = None + project_dir_resolution_mode: str = "hybrid" + tool_call_repair_enabled: bool = True + # Max per-session buffer for tool-call repair streaming (bytes) + tool_call_repair_buffer_cap_bytes: int = 64 * 1024 + json_repair_enabled: bool = True + # Max per-session buffer for JSON repair streaming (bytes) + json_repair_buffer_cap_bytes: int = 64 * 1024 + json_repair_strict_mode: bool = False + json_repair_schema: dict[str, Any] | None = None # Added + # TTL for cleaning up idle loop detection sessions to avoid memory leaks (seconds) + loop_detection_session_ttl_seconds: int = 600 + tool_call_reactor: ToolCallReactorConfig = Field( + default_factory=ToolCallReactorConfig + ) + dangerous_command_prevention_enabled: bool = True + dangerous_command_steering_message: str | None = None + pytest_compression_enabled: bool = True + pytest_compression_min_lines: int = 30 + pytest_full_suite_steering_enabled: bool | None = None + pytest_full_suite_steering_message: str | None = None + fix_think_tags_enabled: bool = False + fix_think_tags_streaming_buffer_size: int = 4096 + planning_phase: PlanningPhaseConfig = Field(default_factory=PlanningPhaseConfig) + max_per_session_backends: int = 32 + session_continuity: SessionContinuityConfig = Field( + default_factory=SessionContinuityConfig + ) + tool_access_global_overrides: dict[str, Any] | None = None + # Tool call processing behavior configuration + force_reprocess_tool_calls: bool = False + log_skipped_tool_calls: bool = False + + @model_validator(mode="before") + @classmethod + def _sync_pytest_full_suite_settings(cls, values: dict[str, Any]) -> dict[str, Any]: + """Keep pytest full-suite steering settings mirrored with reactor config.""" + reactor_config = values.get("tool_call_reactor") + + # Convert to dict if it's already a ToolCallReactorConfig instance + if isinstance(reactor_config, ToolCallReactorConfig): + reactor_config_dict = reactor_config.model_dump() + elif isinstance(reactor_config, dict): + reactor_config_dict = dict(reactor_config) + else: + reactor_config_dict = {} + + enabled = values.get("pytest_full_suite_steering_enabled") + message = values.get("pytest_full_suite_steering_message") + + # Update the dict instead of mutating frozen model + if enabled is not None: + reactor_config_dict["pytest_full_suite_steering_enabled"] = enabled + else: + values["pytest_full_suite_steering_enabled"] = reactor_config_dict.get( + "pytest_full_suite_steering_enabled", False + ) + + if message is not None: + reactor_config_dict["pytest_full_suite_steering_message"] = message + else: + values["pytest_full_suite_steering_message"] = reactor_config_dict.get( + "pytest_full_suite_steering_message" + ) + + # Store the dict - Pydantic will convert it to ToolCallReactorConfig + values["tool_call_reactor"] = reactor_config_dict + return values + + +class EmptyResponseConfig(DomainModel): + """Configuration for empty response handling.""" + + model_config = ConfigDict(frozen=True) + + enabled: bool = True + """Whether the empty response recovery is enabled.""" + + max_retries: int = 1 + """Maximum number of retries for empty responses.""" + + +class ModelAliasRule(DomainModel): + """A rule for rewriting a model name.""" + + model_config = ConfigDict(frozen=True) + + pattern: str + replacement: str + + +class RewritingConfig(DomainModel): + """Configuration for content rewriting.""" + + model_config = ConfigDict(frozen=True) + + enabled: bool = False + config_path: str = "config/replacements" + + +class EditPrecisionConfig(DomainModel): + """Configuration for automated edit-precision tuning. + + When enabled, detects agent edit-failure prompts and lowers sampling + parameters for the next single call to improve precision. + """ + + model_config = ConfigDict(frozen=True) + + enabled: bool = True + temperature: float = 0.1 + # Only applied if override_top_p is True; otherwise top_p remains unchanged + min_top_p: float | None = 0.3 + # Control whether top_p/top_k are overridden by this feature + override_top_p: bool = False + override_top_k: bool = False + # Target top_k to apply when override_top_k is True (for providers that support it, e.g., Gemini) + target_top_k: int | None = None + # Optional regex pattern; when set, agents with names matching this pattern + # will be excluded (feature disabled) even if enabled=True. + exclude_agents_regex: str | None = None + + +from src.core.services.backend_registry import ( + backend_registry, # Updated import path +) + + +class BackendSettings(DomainModel): + """Settings for all backends. + + Note: This class is intentionally not frozen because it needs to support + dynamic backend configurations that are added at runtime. Backend configs + are stored in __dict__ to allow attribute-style access (e.g., config.backends.openai) + without pre-defining all possible backends as fields. + """ + + model_config = ConfigDict(frozen=False, extra="allow") + + default_backend: str = "openai" + static_route: str | None = ( + None # Force all requests to backend:model (e.g., "gemini-oauth-plan:gemini-2.5-pro") + ) + disable_gemini_oauth_fallback: bool = False + disable_hybrid_backend: bool = False + hybrid_backend_repeat_messages: bool = False + reasoning_injection_probability: float = Field( + default=1.0, + ge=0.0, + le=1.0, + description="Probability of using the reasoning model for a request in the hybrid backend.", + ) + + def __init__(self, **data: Any) -> None: + # Separate standard fields from backend-specific configs + known_fields = set(self.model_fields.keys()) + + init_data = {k: v for k, v in data.items() if k in known_fields} + backend_data = {k: v for k, v in data.items() if k not in known_fields} + + # Initialize the model with standard fields + super().__init__(**init_data) + + # Manually set the backend configurations + for backend_name, config_data in backend_data.items(): + if isinstance(config_data, dict): + self.__dict__[backend_name] = BackendConfig(**config_data) + elif isinstance(config_data, BackendConfig): + self.__dict__[backend_name] = config_data + + # Ensure all registered backends have a config + for backend_name in backend_registry.get_registered_backends(): + if backend_name not in self.__dict__: + self.__dict__[backend_name] = BackendConfig() + + self._initialization_complete = True + + def __getitem__(self, key: str) -> BackendConfig: + """Allow dictionary-style access to backend configs.""" + if key in self.__dict__: + return cast(BackendConfig, self.__dict__[key]) + raise KeyError(f"Backend '{key}' not found") + + def __setitem__(self, key: str, value: BackendConfig) -> None: + """Allow dictionary-style setting of backend configs.""" + self.__dict__[key] = value + + def __setattr__(self, name: str, value: Any) -> None: + """Allow attribute-style assignment for backend configs.""" + if ( + name in {"default_backend"} + or name.startswith("_") + or name in self.model_fields + ): + super().__setattr__(name, value) + return + if isinstance(value, BackendConfig): + config = value + elif isinstance(value, dict): + config = BackendConfig(**value) + else: + config = BackendConfig() + self.__dict__[name] = config + + def get(self, key: str, default: Any = None) -> Any: + """Dictionary-style get with default.""" + return cast(BackendConfig | None, self.__dict__.get(key, default)) + + @property + def functional_backends(self) -> set[str]: + """Get the set of functional backends (those with API keys).""" + functional: set[str] = set() + registered = backend_registry.get_registered_backends() + for backend_name in registered: + if backend_name in self.__dict__: + config: Any = self.__dict__[backend_name] + if isinstance(config, BackendConfig) and config.api_key: + functional.add(backend_name) + + # Consider OAuth-style backends functional even without an api_key in config, + # since they source credentials from local auth stores (e.g., CLI-managed files). + oauth_like: set[str] = set() + for name in registered: + if name.endswith("-oauth") or name.startswith("gemini-oauth"): + oauth_like.add(name) + if name == "gemini-cli-cloud-project": + oauth_like.add(name) + + functional.update(oauth_like.intersection(set(registered))) + + # Include any dynamically added backends present in __dict__ that have api_key + # (used in tests and when users add custom backends not in the registry). + for name, cfg in getattr(self, "__dict__", {}).items(): + if ( + name == "default_backend" + or name.startswith("_") + or not isinstance(cfg, BackendConfig) + ): + continue + if cfg.api_key: + functional.add(name) + return functional + + def __getattr__(self, name: str) -> Any: + """Allow accessing backend configs as attributes. + + If an attribute for a backend is missing, create a default + BackendConfig instance lazily. This ensures tests and runtime + code can access `config.backends.openai` / `config.backends.gemini` + even if the registry hasn't been populated yet. + """ + if name == "default_backend": # Handle default_backend separately + # Ensure we use the explicitly set default_backend if available + if "default_backend" in self.__dict__: + return self.__dict__["default_backend"] + # Otherwise fall back to openai + return "openai" + + # Check if the attribute exists in __dict__ + if name in self.__dict__: + return cast(BackendConfig, self.__dict__[name]) + + # Avoid creating configs for private/internal attributes to maintain security + if name.startswith(("_", "__")): + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + # Check if we're still initializing (indicated by presence of __dict__ keys + # that suggest initialization hasn't completed). Don't create empty configs + # during initialization - let the __init__ method handle it. + # Only create empty configs after initialization is complete. + if not hasattr(self, "_initialization_complete"): + # During initialization, raise AttributeError to let __init__ handle it + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + # Lazily create a default backend configuration for unknown backends. + # This allows accessing backend configs without pre-registration while + # maintaining backward compatibility. Created configs are cached for + # subsequent access to avoid creating multiple instances. + config = BackendConfig() + self.__dict__[name] = config + return config + + def model_dump(self, **kwargs: Any) -> dict[str, Any]: + """Override model_dump to include default_backend and dynamic backends.""" + dumped: dict[str, Any] = super().model_dump(**kwargs) + # Add dynamic backends to the dumped dictionary + for backend_name in backend_registry.get_registered_backends(): + if backend_name in self.__dict__: + config: Any = self.__dict__[backend_name] + if isinstance(config, BackendConfig): + dumped[backend_name] = config.model_dump() + return dumped + + def model_is_functional(self, model_id: str) -> bool: + """Check if a model is available in any functional backend.""" + if ":" not in model_id: + return False # Invalid format + + backend_name, _ = model_id.split(":", 1) + return backend_name in self.functional_backends + + +class AppConfig(DomainModel, IConfig): + """Complete application configuration.""" + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + + host: str = "127.0.0.1" # Default to localhost for security + port: int = 8000 + anthropic_port: int | None = None # Will be set to port + 1 if not provided + proxy_timeout: int = 120 + command_prefix: str = "!/" + strict_command_detection: bool = False + context_window_override: int | None = None # Override context window for all models + gcp_project_id: str | None = None + gemini_credentials_path: str | None = None + disable_health_checks: bool = False + + # Rate limit settings + default_rate_limit: int = 60 + default_rate_window: int = 60 + + # Backend settings + backends: BackendSettings = Field(default_factory=BackendSettings) + model_defaults: dict[str, dict[str, Any]] = Field(default_factory=dict) + failover_routes: dict[str, dict[str, Any]] = Field(default_factory=dict) + + # No nested class references - use direct imports instead + + # Identity settings + identity: AppIdentityConfig = Field(default_factory=AppIdentityConfig) + + # Auth settings + auth: AuthConfig = Field(default_factory=AuthConfig) + + # Session settings + session: SessionConfig = Field(default_factory=SessionConfig) + + # Logging settings + logging: LoggingConfig = Field(default_factory=LoggingConfig) + + # Empty response handling settings + empty_response: EmptyResponseConfig = Field(default_factory=EmptyResponseConfig) + + # Edit-precision tuning settings + edit_precision: EditPrecisionConfig = Field(default_factory=EditPrecisionConfig) + + # Rewriting settings + rewriting: RewritingConfig = Field(default_factory=RewritingConfig) + assessment: AssessmentConfig = Field(default_factory=AssessmentConfig) + + # Reasoning aliases settings + reasoning_aliases: ReasoningAliasesConfig = Field( + default_factory=lambda: ReasoningAliasesConfig(reasoning_alias_settings=[]) + ) + + # Model name rewrite rules + model_aliases: list[ModelAliasRule] = Field(default_factory=list) + + # Sandboxing settings + sandboxing: SandboxingConfiguration = Field(default_factory=SandboxingConfiguration) + + # FastAPI app instance + app: Any = None + + def model_is_functional(self, model_id: str) -> bool: + """Check if a model is available in any functional backend.""" + return self.backends.model_is_functional(model_id) + + def save(self, path: str | Path) -> None: + """Save the current configuration to a file.""" + p = Path(path) + data = self.model_dump(mode="json", exclude_none=True) + # Normalize structure to match schema expectations + # - default_backend must be at top-level (already present) + # - Remove runtime-only fields that are not part of schema or can cause validation errors + for runtime_key in ["app"]: + if runtime_key in data: + data[runtime_key] = None + # Filter out unsupported top-level keys (schema has additionalProperties: false) + allowed_top_keys = { + "host", + "port", + "anthropic_port", + "proxy_timeout", + "command_prefix", + "strict_command_detection", + "context_window_override", + "default_rate_limit", + "default_rate_window", + "model_defaults", + "failover_routes", + "identity", + "empty_response", + "edit_precision", + "rewriting", + "app", + "logging", + "auth", + "session", + "backends", + "default_backend", + "reasoning_aliases", + "model_aliases", + "sandboxing", + } + data = {k: v for k, v in data.items() if k in allowed_top_keys} + # Ensure nested sections only include serializable primitives + # (model_dump already handles pydantic models) + if p.suffix.lower() in {".yaml", ".yml"}: + import yaml + + logger.debug(f"Saving configuration to {p}: {data}") + with p.open("w", encoding="utf-8") as f: + yaml.safe_dump(data, f, sort_keys=False) + else: + # Legacy: still allow JSON save if requested by extension + with p.open("w", encoding="utf-8") as f: + f.write(self.model_dump_json(indent=4)) + + @classmethod + def from_env( + cls, + *, + environ: Mapping[str, str] | None = None, + resolution: ParameterResolution | None = None, + ) -> AppConfig: + """Create AppConfig from environment variables. + + Returns: + AppConfig instance + """ + env: Mapping[str, str] = os.environ if environ is None else environ + + # Build configuration from environment + config: dict[str, Any] = { + # Server settings + "gcp_project_id": _get_env_value( + env, + "GOOGLE_CLOUD_PROJECT", + _get_env_value( + env, + "GCP_PROJECT_ID", + None, + path="gcp_project_id", + resolution=resolution, + ), + path="gcp_project_id", + resolution=resolution, + ), + "gemini_credentials_path": _get_env_value( + env, + "GEMINI_CREDENTIALS_PATH", + None, + path="gemini_credentials_path", + resolution=resolution, + ), + "disable_health_checks": _env_to_bool( + "DISABLE_HEALTH_CHECKS", + False, + env, + path="disable_health_checks", + resolution=resolution, + ), + "host": _get_env_value( + env, + "APP_HOST", + "127.0.0.1", # Default to localhost for security + path="host", + resolution=resolution, + ), + "port": _get_env_value( + env, + "APP_PORT", + 8000, + path="port", + resolution=resolution, + transform=lambda value: _to_int(value, 8000), + ), + "anthropic_port": _get_env_value( + env, + "ANTHROPIC_PORT", + None, + path="anthropic_port", + resolution=resolution, + transform=lambda value: _to_int(value, 0) if value else None, + ), + "proxy_timeout": _get_env_value( + env, + "PROXY_TIMEOUT", + 120, + path="proxy_timeout", + resolution=resolution, + transform=lambda value: _to_int(value, 120), + ), + "command_prefix": _get_env_value( + env, + "COMMAND_PREFIX", + "!/", + path="command_prefix", + resolution=resolution, + ), + "auth": { + "disable_auth": _env_to_bool( + "DISABLE_AUTH", + False, + env, + path="auth.disable_auth", + resolution=resolution, + ), + "api_keys": _get_api_keys_from_env(env, resolution), + "auth_token": _get_env_value( + env, + "AUTH_TOKEN", + None, + path="auth.auth_token", + resolution=resolution, + ), + "brute_force_protection": { + "enabled": _env_to_bool( + "BRUTE_FORCE_PROTECTION_ENABLED", + True, + env, + path="auth.brute_force_protection.enabled", + resolution=resolution, + ), + "max_failed_attempts": _env_to_int( + "BRUTE_FORCE_MAX_FAILED_ATTEMPTS", + 5, + env, + path="auth.brute_force_protection.max_failed_attempts", + resolution=resolution, + ), + "ttl_seconds": _env_to_int( + "BRUTE_FORCE_TTL_SECONDS", + 900, + env, + path="auth.brute_force_protection.ttl_seconds", + resolution=resolution, + ), + "initial_block_seconds": _env_to_int( + "BRUTE_FORCE_INITIAL_BLOCK_SECONDS", + 30, + env, + path="auth.brute_force_protection.initial_block_seconds", + resolution=resolution, + ), + "block_multiplier": _env_to_float( + "BRUTE_FORCE_BLOCK_MULTIPLIER", + 2.0, + env, + path="auth.brute_force_protection.block_multiplier", + resolution=resolution, + ), + "max_block_seconds": _env_to_int( + "BRUTE_FORCE_MAX_BLOCK_SECONDS", + 3600, + env, + path="auth.brute_force_protection.max_block_seconds", + resolution=resolution, + ), + }, + }, + } + + if not config.get("anthropic_port"): + config["anthropic_port"] = int(config["port"]) + 1 + if resolution is not None: + resolution.record( + "anthropic_port", + config["anthropic_port"], + ParameterSource.DERIVED, + origin="port+1", + ) + + # After populating auth config, if disable_auth is true, clear api_keys + auth_config: dict[str, Any] = config["auth"] + if isinstance(auth_config, dict) and auth_config.get("disable_auth"): + auth_config["api_keys"] = [] + + # Add session, logging, and backend config + planning_overrides: dict[str, Any] = {} + planning_temperature = _get_env_value( + env, + "PLANNING_PHASE_TEMPERATURE", + None, + path="session.planning_phase.overrides.temperature", + resolution=resolution, + transform=lambda value: _to_float(value, None), + ) + if planning_temperature is not None: + planning_overrides["temperature"] = planning_temperature + + planning_top_p = _get_env_value( + env, + "PLANNING_PHASE_TOP_P", + None, + path="session.planning_phase.overrides.top_p", + resolution=resolution, + transform=lambda value: _to_float(value, None), + ) + if planning_top_p is not None: + planning_overrides["top_p"] = planning_top_p + + planning_reasoning = _get_env_value( + env, + "PLANNING_PHASE_REASONING_EFFORT", + None, + path="session.planning_phase.overrides.reasoning_effort", + resolution=resolution, + ) + if planning_reasoning is not None: + planning_overrides["reasoning_effort"] = planning_reasoning + + planning_budget = _get_env_value( + env, + "PLANNING_PHASE_THINKING_BUDGET", + None, + path="session.planning_phase.overrides.thinking_budget", + resolution=resolution, + transform=lambda value: _to_int(value, 0), + ) + if planning_budget is not None: + planning_overrides["thinking_budget"] = planning_budget + + config["session"] = { + "cleanup_enabled": _env_to_bool( + "SESSION_CLEANUP_ENABLED", + True, + env, + path="session.cleanup_enabled", + resolution=resolution, + ), + "cleanup_interval": _env_to_int( + "SESSION_CLEANUP_INTERVAL", + 3600, + env, + path="session.cleanup_interval", + resolution=resolution, + ), + "max_age": _env_to_int( + "SESSION_MAX_AGE", + 86400, + env, + path="session.max_age", + resolution=resolution, + ), + "default_interactive_mode": _env_to_bool( + "DEFAULT_INTERACTIVE_MODE", + True, + env, + path="session.default_interactive_mode", + resolution=resolution, + ), + "force_set_project": _env_to_bool( + "FORCE_SET_PROJECT", + False, + env, + path="session.force_set_project", + resolution=resolution, + ), + "project_dir_resolution_model": _get_env_value( + env, + "PROJECT_DIR_RESOLUTION_MODEL", + None, + path="session.project_dir_resolution_model", + resolution=resolution, + ), + "project_dir_resolution_mode": _get_env_value( + env, + "PROJECT_DIR_RESOLUTION_MODE", + "hybrid", + path="session.project_dir_resolution_mode", + resolution=resolution, + ), + "tool_call_repair_enabled": _env_to_bool( + "TOOL_CALL_REPAIR_ENABLED", + True, + env, + path="session.tool_call_repair_enabled", + resolution=resolution, + ), + "tool_call_repair_buffer_cap_bytes": _get_env_value( + env, + "TOOL_CALL_REPAIR_BUFFER_CAP_BYTES", + 65536, + path="session.tool_call_repair_buffer_cap_bytes", + resolution=resolution, + transform=lambda value: _to_int(value, 65536), + ), + "json_repair_enabled": _env_to_bool( + "JSON_REPAIR_ENABLED", + True, + env, + path="session.json_repair_enabled", + resolution=resolution, + ), + "json_repair_buffer_cap_bytes": _get_env_value( + env, + "JSON_REPAIR_BUFFER_CAP_BYTES", + 65536, + path="session.json_repair_buffer_cap_bytes", + resolution=resolution, + transform=lambda value: _to_int(value, 65536), + ), + "json_repair_schema": _get_env_value( + env, + "JSON_REPAIR_SCHEMA", + None, + path="session.json_repair_schema", + resolution=resolution, + transform=lambda value: json.loads(value), + ), + "dangerous_command_prevention_enabled": _env_to_bool( + "DANGEROUS_COMMAND_PREVENTION_ENABLED", + True, + env, + path="session.dangerous_command_prevention_enabled", + resolution=resolution, + ), + "dangerous_command_steering_message": _get_env_value( + env, + "DANGEROUS_COMMAND_STEERING_MESSAGE", + None, + path="session.dangerous_command_steering_message", + resolution=resolution, + ), + "pytest_compression_enabled": _env_to_bool( + "PYTEST_COMPRESSION_ENABLED", + True, + env, + path="session.pytest_compression_enabled", + resolution=resolution, + ), + "pytest_compression_min_lines": _env_to_int( + "PYTEST_COMPRESSION_MIN_LINES", + 30, + env, + path="session.pytest_compression_min_lines", + resolution=resolution, + ), + "pytest_full_suite_steering_enabled": _env_to_bool( + "PYTEST_FULL_SUITE_STEERING_ENABLED", + False, + env, + path="session.pytest_full_suite_steering_enabled", + resolution=resolution, + ), + "pytest_full_suite_steering_message": _get_env_value( + env, + "PYTEST_FULL_SUITE_STEERING_MESSAGE", + None, + path="session.pytest_full_suite_steering_message", + resolution=resolution, + ), + "fix_think_tags_enabled": _env_to_bool( + "FIX_THINK_TAGS_ENABLED", + False, + env, + path="session.fix_think_tags_enabled", + resolution=resolution, + ), + "fix_think_tags_streaming_buffer_size": _env_to_int( + "FIX_THINK_TAGS_STREAMING_BUFFER_SIZE", + 4096, + env, + path="session.fix_think_tags_streaming_buffer_size", + resolution=resolution, + ), + "planning_phase": { + "enabled": _env_to_bool( + "PLANNING_PHASE_ENABLED", + False, + env, + path="session.planning_phase.enabled", + resolution=resolution, + ), + "strong_model": _get_env_value( + env, + "PLANNING_PHASE_STRONG_MODEL", + None, + path="session.planning_phase.strong_model", + resolution=resolution, + ), + "max_turns": _env_to_int( + "PLANNING_PHASE_MAX_TURNS", + 10, + env, + path="session.planning_phase.max_turns", + resolution=resolution, + ), + "max_file_writes": _env_to_int( + "PLANNING_PHASE_MAX_FILE_WRITES", + 1, + env, + path="session.planning_phase.max_file_writes", + resolution=resolution, + ), + "overrides": planning_overrides, + }, + "force_reprocess_tool_calls": _env_to_bool( + "FORCE_REPROCESS_TOOL_CALLS", + False, + env, + path="session.force_reprocess_tool_calls", + resolution=resolution, + ), + "log_skipped_tool_calls": _env_to_bool( + "LOG_SKIPPED_TOOL_CALLS", + False, + env, + path="session.log_skipped_tool_calls", + resolution=resolution, + ), + } + + config["logging"] = { + "level": _get_env_value( + env, + "LOG_LEVEL", + "INFO", + path="logging.level", + resolution=resolution, + ), + "request_logging": _env_to_bool( + "REQUEST_LOGGING", + False, + env, + path="logging.request_logging", + resolution=resolution, + ), + "response_logging": _env_to_bool( + "RESPONSE_LOGGING", + False, + env, + path="logging.response_logging", + resolution=resolution, + ), + "log_file": _get_env_value( + env, + "LOG_FILE", + None, + path="logging.log_file", + resolution=resolution, + ), + "capture_file": _get_env_value( + env, + "CAPTURE_FILE", + None, + path="logging.capture_file", + resolution=resolution, + ), + "capture_max_bytes": _get_env_value( + env, + "CAPTURE_MAX_BYTES", + None, + path="logging.capture_max_bytes", + resolution=resolution, + transform=lambda value: _to_int(value, 0), + ), + "capture_truncate_bytes": _get_env_value( + env, + "CAPTURE_TRUNCATE_BYTES", + None, + path="logging.capture_truncate_bytes", + resolution=resolution, + transform=lambda value: _to_int(value, 0), + ), + "capture_max_files": _get_env_value( + env, + "CAPTURE_MAX_FILES", + None, + path="logging.capture_max_files", + resolution=resolution, + transform=lambda value: _to_int(value, 0), + ), + "capture_rotate_interval_seconds": _get_env_value( + env, + "CAPTURE_ROTATE_INTERVAL_SECONDS", + 86400, + path="logging.capture_rotate_interval_seconds", + resolution=resolution, + transform=lambda value: _to_int(value, 86400), + ), + "capture_total_max_bytes": _get_env_value( + env, + "CAPTURE_TOTAL_MAX_BYTES", + 104857600, + path="logging.capture_total_max_bytes", + resolution=resolution, + transform=lambda value: _to_int(value, 104857600), + ), + "capture_buffer_size": _get_env_value( + env, + "CAPTURE_BUFFER_SIZE", + 65536, + path="logging.capture_buffer_size", + resolution=resolution, + transform=lambda value: _to_int(value, 65536), + ), + "capture_flush_interval": _get_env_value( + env, + "CAPTURE_FLUSH_INTERVAL", + 1.0, + path="logging.capture_flush_interval", + resolution=resolution, + transform=lambda value: _to_float(value, 1.0), + ), + "capture_max_entries_per_flush": _get_env_value( + env, + "CAPTURE_MAX_ENTRIES_PER_FLUSH", + 100, + path="logging.capture_max_entries_per_flush", + resolution=resolution, + transform=lambda value: _to_int(value, 100), + ), + } + + config["empty_response"] = { + "enabled": _env_to_bool( + "EMPTY_RESPONSE_HANDLING_ENABLED", + True, + env, + path="empty_response.enabled", + resolution=resolution, + ), + "max_retries": _env_to_int( + "EMPTY_RESPONSE_MAX_RETRIES", + 1, + env, + path="empty_response.max_retries", + resolution=resolution, + ), + } + + # Edit precision settings + config["edit_precision"] = { + "enabled": _env_to_bool( + "EDIT_PRECISION_ENABLED", + True, + env, + path="edit_precision.enabled", + resolution=resolution, + ), + "temperature": _env_to_float( + "EDIT_PRECISION_TEMPERATURE", + 0.1, + env, + path="edit_precision.temperature", + resolution=resolution, + ), + "min_top_p": _env_to_float( + "EDIT_PRECISION_MIN_TOP_P", + 0.3, + env, + path="edit_precision.min_top_p", + resolution=resolution, + ), + "override_top_p": _env_to_bool( + "EDIT_PRECISION_OVERRIDE_TOP_P", + False, + env, + path="edit_precision.override_top_p", + resolution=resolution, + ), + "override_top_k": _env_to_bool( + "EDIT_PRECISION_OVERRIDE_TOP_K", + False, + env, + path="edit_precision.override_top_k", + resolution=resolution, + ), + "target_top_k": _get_env_value( + env, + "EDIT_PRECISION_TARGET_TOP_K", + None, + path="edit_precision.target_top_k", + resolution=resolution, + transform=lambda value: _to_int(value, 0) or None, + ), + "exclude_agents_regex": _get_env_value( + env, + "EDIT_PRECISION_EXCLUDE_AGENTS_REGEX", + None, + path="edit_precision.exclude_agents_regex", + resolution=resolution, + ), + } + + config["rewriting"] = { + "enabled": _env_to_bool( + "REWRITING_ENABLED", + False, + env, + path="rewriting.enabled", + resolution=resolution, + ), + "config_path": _get_env_value( + env, + "REWRITING_CONFIG_PATH", + "config/replacements", + path="rewriting.config_path", + resolution=resolution, + ), + } + + # Assessment configuration from environment + config["assessment"] = { + "enabled": _env_to_bool( + "LLM_ASSESSMENT_ENABLED", + False, + env, + path="assessment.enabled", + resolution=resolution, + ), + "turn_threshold": _env_to_int( + "LLM_ASSESSMENT_TURN_THRESHOLD", + 30, + env, + path="assessment.turn_threshold", + resolution=resolution, + ), + "confidence_threshold": _env_to_float( + "LLM_ASSESSMENT_CONFIDENCE_THRESHOLD", + 0.9, + env, + path="assessment.confidence_threshold", + resolution=resolution, + ), + "backend": _get_env_value( + env, + "LLM_ASSESSMENT_BACKEND", + "openai", # Default backend + path="assessment.backend", + resolution=resolution, + ), + "model": _get_env_value( + env, + "LLM_ASSESSMENT_MODEL", + "gpt-4o-mini", # Default model + path="assessment.model", + resolution=resolution, + ), + "history_window": _env_to_int( + "LLM_ASSESSMENT_HISTORY_WINDOW", + 20, + env, + path="assessment.history_window", + resolution=resolution, + ), + } + + # Sandboxing configuration from environment + config["sandboxing"] = { + "enabled": _env_to_bool( + "ENABLE_SANDBOXING", + False, + env, + path="sandboxing.enabled", + resolution=resolution, + ), + "strict_mode": _env_to_bool( + "SANDBOXING_STRICT_MODE", + False, + env, + path="sandboxing.strict_mode", + resolution=resolution, + ), + "allow_parent_access": _env_to_bool( + "SANDBOXING_ALLOW_PARENT_ACCESS", + False, + env, + path="sandboxing.allow_parent_access", + resolution=resolution, + ), + } + + # Model aliases configuration from environment + model_aliases_env = env.get("MODEL_ALIASES") + if model_aliases_env: + try: + alias_data = json.loads(model_aliases_env) + if isinstance(alias_data, list): + config["model_aliases"] = [ + {"pattern": item["pattern"], "replacement": item["replacement"]} + for item in alias_data + if isinstance(item, dict) + and "pattern" in item + and "replacement" in item + ] + if resolution is not None: + resolution.record( + "model_aliases", + config["model_aliases"], + ParameterSource.ENVIRONMENT, + origin="MODEL_ALIASES", + ) + except (json.JSONDecodeError, KeyError, TypeError) as e: + logger.warning( + f"Invalid MODEL_ALIASES environment variable format: {e}" + ) + config["model_aliases"] = [] + else: + config["model_aliases"] = [] + + config["backends"] = { + "default_backend": _get_env_value( + env, + "LLM_BACKEND", + "openai", + path="backends.default_backend", + resolution=resolution, + ), + "disable_gemini_oauth_fallback": _env_to_bool( + "DISABLE_GEMINI_OAUTH_FALLBACK", + False, + env, + path="backends.disable_gemini_oauth_fallback", + resolution=resolution, + ), + "disable_hybrid_backend": _env_to_bool( + "DISABLE_HYBRID_BACKEND", + False, + env, + path="backends.disable_hybrid_backend", + resolution=resolution, + ), + "hybrid_backend_repeat_messages": _env_to_bool( + "HYBRID_BACKEND_REPEAT_MESSAGES", + False, + env, + path="backends.hybrid_backend_repeat_messages", + resolution=resolution, + ), + "reasoning_injection_probability": _env_to_float( + "REASONING_INJECTION_PROBABILITY", + 1.0, + env, + path="backends.reasoning_injection_probability", + resolution=resolution, + ), + } + + config["identity"] = AppIdentityConfig( + title=HeaderConfig( + override_value=_get_env_value( + env, + "APP_TITLE", + None, + path="identity.title.override_value", + resolution=resolution, + ), + mode=HeaderOverrideMode( + _get_env_value( + env, + "APP_TITLE_MODE", + "passthrough", + path="identity.title.mode", + resolution=resolution, + ) + ), + default_value="llm-interactive-proxy", + passthrough_name="x-title", + ), + url=HeaderConfig( + override_value=_get_env_value( + env, + "APP_URL", + None, + path="identity.url.override_value", + resolution=resolution, + ), + mode=HeaderOverrideMode( + _get_env_value( + env, + "APP_URL_MODE", + "passthrough", + path="identity.url.mode", + resolution=resolution, + ) + ), + default_value="https://github.com/matdev83/llm-interactive-proxy", + passthrough_name="http-referer", + ), + user_agent=HeaderConfig( + override_value=_get_env_value( + env, + "APP_USER_AGENT", + None, + path="identity.user_agent.override_value", + resolution=resolution, + ), + mode=HeaderOverrideMode( + _get_env_value( + env, + "APP_USER_AGENT_MODE", + "passthrough", + path="identity.user_agent.mode", + resolution=resolution, + ) + ), + default_value="llm-interactive-proxy", + passthrough_name="user-agent", + ), + ) + + # Log the determined default_backend + logger.info( + f"AppConfig.from_env - Determined default_backend: {config['backends']['default_backend']}" + ) + + # Extract backend configurations from environment + config_backends: dict[str, Any] = config["backends"] + assert isinstance(config_backends, dict) + + # Collect and assign API keys for specific backends + openrouter_keys = _collect_api_keys_from_env( + "OPENROUTER_API_KEY", env, resolution + ) + if openrouter_keys: + config_backends["openrouter"] = config_backends.get("openrouter", {}) + config_backends["openrouter"]["api_key"] = list(openrouter_keys.values()) + config_backends["openrouter"]["api_url"] = _get_env_value( + env, + "OPENROUTER_API_BASE_URL", + "https://openrouter.ai/api/v1", + path="backends.openrouter.api_url", + resolution=resolution, + ) + timeout_value = _get_env_value( + env, + "OPENROUTER_TIMEOUT", + None, + path="backends.openrouter.timeout", + resolution=resolution, + transform=lambda value: _to_int(value, 0), + ) + if timeout_value: + config_backends["openrouter"]["timeout"] = timeout_value + if resolution is not None: + resolution.record( + "backends.openrouter.api_key", + config_backends["openrouter"]["api_key"], + ParameterSource.ENVIRONMENT, + origin="OPENROUTER_API_KEY*", + ) + + gemini_keys: dict[str, str] = _collect_api_keys_from_env( + "GEMINI_API_KEY", env, resolution + ) + if gemini_keys: + config_backends["gemini"] = config_backends.get("gemini", {}) + config_backends["gemini"]["api_key"] = list(gemini_keys.values()) + config_backends["gemini"]["api_url"] = _get_env_value( + env, + "GEMINI_API_BASE_URL", + "https://generativelanguage.googleapis.com", + path="backends.gemini.api_url", + resolution=resolution, + ) + gemini_timeout = _get_env_value( + env, + "GEMINI_TIMEOUT", + None, + path="backends.gemini.timeout", + resolution=resolution, + transform=lambda value: _to_int(value, 0), + ) + if gemini_timeout: + config_backends["gemini"]["timeout"] = gemini_timeout + if resolution is not None: + resolution.record( + "backends.gemini.api_key", + config_backends["gemini"]["api_key"], + ParameterSource.ENVIRONMENT, + origin="GEMINI_API_KEY*", + ) + + anthropic_keys: dict[str, str] = _collect_api_keys_from_env( + "ANTHROPIC_API_KEY", env, resolution + ) + if anthropic_keys: + config_backends["anthropic"] = config_backends.get("anthropic", {}) + config_backends["anthropic"]["api_key"] = list(anthropic_keys.values()) + config_backends["anthropic"]["api_url"] = _get_env_value( + env, + "ANTHROPIC_API_BASE_URL", + "https://api.anthropic.com/v1", + path="backends.anthropic.api_url", + resolution=resolution, + ) + anthropic_timeout = _get_env_value( + env, + "ANTHROPIC_TIMEOUT", + None, + path="backends.anthropic.timeout", + resolution=resolution, + transform=lambda value: _to_int(value, 0), + ) + if anthropic_timeout: + config_backends["anthropic"]["timeout"] = anthropic_timeout + if resolution is not None: + resolution.record( + "backends.anthropic.api_key", + config_backends["anthropic"]["api_key"], + ParameterSource.ENVIRONMENT, + origin="ANTHROPIC_API_KEY*", + ) + + zai_keys: dict[str, str] = _collect_api_keys_from_env( + "ZAI_API_KEY", env, resolution + ) + if zai_keys: + config_backends["zai"] = config_backends.get("zai", {}) + config_backends["zai"]["api_key"] = list(zai_keys.values()) + config_backends["zai"]["api_url"] = _get_env_value( + env, + "ZAI_API_BASE_URL", + None, + path="backends.zai.api_url", + resolution=resolution, + ) + zai_timeout = _get_env_value( + env, + "ZAI_TIMEOUT", + None, + path="backends.zai.timeout", + resolution=resolution, + transform=lambda value: _to_int(value, 0), + ) + if zai_timeout: + config_backends["zai"]["timeout"] = zai_timeout + if resolution is not None: + resolution.record( + "backends.zai.api_key", + config_backends["zai"]["api_key"], + ParameterSource.ENVIRONMENT, + origin="ZAI_API_KEY*", + ) + + openai_keys: dict[str, str] = _collect_api_keys_from_env( + "OPENAI_API_KEY", env, resolution + ) + if openai_keys: + config_backends["openai"] = config_backends.get("openai", {}) + config_backends["openai"]["api_key"] = list(openai_keys.values()) + config_backends["openai"]["api_url"] = _get_env_value( + env, + "OPENAI_API_BASE_URL", + "https://api.openai.com/v1", + path="backends.openai.api_url", + resolution=resolution, + ) + openai_timeout = _get_env_value( + env, + "OPENAI_TIMEOUT", + None, + path="backends.openai.timeout", + resolution=resolution, + transform=lambda value: _to_int(value, 0), + ) + if openai_timeout: + config_backends["openai"]["timeout"] = openai_timeout + if resolution is not None: + resolution.record( + "backends.openai.api_key", + config_backends["openai"]["api_key"], + ParameterSource.ENVIRONMENT, + origin="OPENAI_API_KEY*", + ) + + minimax_keys: dict[str, str] = _collect_api_keys_from_env( + "MINIMAX_API_KEY", env, resolution + ) + if minimax_keys: + config_backends["minimax"] = config_backends.get("minimax", {}) + config_backends["minimax"]["api_key"] = list(minimax_keys.values()) + config_backends["minimax"]["api_url"] = _get_env_value( + env, + "MINIMAX_API_BASE_URL", + "https://api.minimax.io/v1", + path="backends.minimax.api_url", + resolution=resolution, + ) + minimax_timeout = _get_env_value( + env, + "MINIMAX_TIMEOUT", + None, + path="backends.minimax.timeout", + resolution=resolution, + transform=lambda value: _to_int(value, 0), + ) + if minimax_timeout: + config_backends["minimax"]["timeout"] = minimax_timeout + if resolution is not None: + resolution.record( + "backends.minimax.api_key", + config_backends["minimax"]["api_key"], + ParameterSource.ENVIRONMENT, + origin="MINIMAX_API_KEY*", + ) + + # Handle default backend if it's not explicitly configured above + default_backend_type: str = str( + config["backends"].get("default_backend", "openai") + ) + if default_backend_type not in config_backends: + # If the default backend is not explicitly configured, ensure it has a basic config + config_backends[default_backend_type] = config_backends.get( + default_backend_type, {} + ) + # Add a dummy API key if running in test environment and no API key is present + if env.get("PYTEST_CURRENT_TEST") and ( + not config_backends[default_backend_type] + or not config_backends[default_backend_type].get("api_key") + ): + config_backends[default_backend_type]["api_key"] = [ + f"test-key-{default_backend_type}" + ] + logger.info( + f"Added test API key for default backend {default_backend_type}" + ) + + return cls(**config) # type: ignore + + def get(self, key: str, default: Any = None) -> Any: + """Get a configuration value by key.""" + # Split the key by dots to handle nested attributes + keys = key.split(".") + value: Any = self + + try: + for k in keys: + if isinstance(value, dict): + value = value.get(k, default) + else: + value = getattr(value, k, default) + return value + except Exception: + return default + + def set(self, key: str, value: Any) -> None: + """Set a configuration value.""" + # For simplicity, we'll only handle top-level attributes + # In a more complex implementation, we might want to handle nested attributes + setattr(self, key, value) + + def get_gcp_project_id(self) -> str | None: + """Return the GCP Project ID.""" + return self.gcp_project_id + + +def _merge_dicts(d1: dict[str, Any], d2: dict[str, Any]) -> dict[str, Any]: + for k, v in d2.items(): + if k in d1 and isinstance(d1[k], dict) and isinstance(v, dict): + _merge_dicts(d1[k], v) + else: + d1[k] = v + return d1 + + +def _set_by_path(target: dict[str, Any], path: str, value: Any) -> None: + parts = path.split(".") + current: dict[str, Any] = target + for key in parts[:-1]: + current = current.setdefault(key, {}) # type: ignore[assignment] + current[parts[-1]] = value + + +def _get_by_path(source: dict[str, Any], path: str) -> Any: + parts = path.split(".") + current: Any = source + for key in parts: + if not isinstance(current, dict): + return None + current = current.get(key) + return current + + +def _flatten_dict(data: dict[str, Any]) -> dict[str, Any]: + flattened: dict[str, Any] = {} + + def _walk(value: Any, prefix: str) -> None: + if isinstance(value, dict): + for key, child in value.items(): + new_prefix = f"{prefix}.{key}" if prefix else key + _walk(child, new_prefix) + else: + flattened[prefix] = value + + _walk(data, "") + return flattened + + +def load_config( + config_path: str | Path | None = None, + *, + resolution: ParameterResolution | None = None, + environ: Mapping[str, str] | None = None, +) -> AppConfig: + """ + Load configuration from file and environment. + + Args: + config_path: Optional path to configuration file + + Returns: + AppConfig instance + """ + env = os.environ if environ is None else environ + res = resolution or ParameterResolution() + + config_data: dict[str, Any] = AppConfig().model_dump() + + if config_path: + try: + import yaml + + path: Path = Path(config_path) + if not path.exists(): + logger.warning(f"Configuration file not found: {config_path}") + else: + if path.suffix.lower() not in [".yaml", ".yml"]: + raise ValueError( + f"Unsupported configuration file format: {path.suffix}. Use YAML (.yaml/.yml)." + ) + + with open(path, encoding="utf-8") as f: + file_config: dict[str, Any] = yaml.safe_load(f) or {} + + from pathlib import Path as _Path + + from src.core.config.semantic_validation import ( + validate_config_semantics, + ) + from src.core.config.yaml_validation import validate_yaml_against_schema + + schema_path = ( + _Path.cwd() / "config" / "schemas" / "app_config.schema.yaml" + ) + validate_yaml_against_schema(_Path(path), schema_path) + validate_config_semantics(file_config, path) + + _merge_dicts(config_data, file_config) + origin = str(path) + for name, value in _flatten_dict(file_config).items(): + res.record( + name, + value, + ParameterSource.CONFIG_FILE, + origin=origin, + ) + except Exception as exc: # type: ignore[misc] + logger.critical(f"Error loading configuration file: {exc!s}") + raise + + env_config = AppConfig.from_env(environ=env, resolution=res) + env_dump = env_config.model_dump() + for name in res.latest_by_source(ParameterSource.ENVIRONMENT): + value = _get_by_path(env_dump, name) + _set_by_path(config_data, name, value) + + return AppConfig.model_validate(config_data) diff --git a/src/core/di/services.py b/src/core/di/services.py index 7d8d9c937..5c306ac9c 100644 --- a/src/core/di/services.py +++ b/src/core/di/services.py @@ -1,1780 +1,1788 @@ -""" -Services and DI container configuration. - -This module provides functions for configuring the DI container with services -and resolving services from the container. -""" - -from __future__ import annotations - -import logging -import os -from collections.abc import Callable -from typing import Any, TypeVar, cast - -from src.core.common.exceptions import ServiceResolutionError -from src.core.config.app_config import AppConfig -from src.core.di.container import ServiceCollection -from src.core.domain.streaming_response_processor import ( - IStreamProcessor, - LoopDetectionProcessor, -) -from src.core.interfaces.agent_response_formatter_interface import ( - IAgentResponseFormatter, -) -from src.core.interfaces.app_settings_interface import IAppSettings -from src.core.interfaces.application_state_interface import IApplicationState -from src.core.interfaces.backend_config_provider_interface import ( - IBackendConfigProvider, -) -from src.core.interfaces.backend_processor_interface import IBackendProcessor -from src.core.interfaces.backend_request_manager_interface import ( - IBackendRequestManager, -) -from src.core.interfaces.backend_service_interface import IBackendService -from src.core.interfaces.command_processor_interface import ICommandProcessor -from src.core.interfaces.command_service_interface import ICommandService -from src.core.interfaces.configuration_interface import IConfig -from src.core.interfaces.di_interface import IServiceProvider -from src.core.interfaces.middleware_application_manager_interface import ( - IMiddlewareApplicationManager, -) -from src.core.interfaces.path_validator_interface import IPathValidator -from src.core.interfaces.rate_limiter_interface import IRateLimiter -from src.core.interfaces.repositories_interface import ISessionRepository -from src.core.interfaces.request_processor_interface import IRequestProcessor -from src.core.interfaces.response_handler_interface import ( - INonStreamingResponseHandler, - IStreamingResponseHandler, -) -from src.core.interfaces.response_manager_interface import IResponseManager -from src.core.interfaces.response_parser_interface import IResponseParser -from src.core.interfaces.response_processor_interface import ( - IResponseMiddleware, - IResponseProcessor, -) -from src.core.interfaces.session_manager_interface import ISessionManager -from src.core.interfaces.session_resolver_interface import ISessionResolver -from src.core.interfaces.session_service_interface import ISessionService -from src.core.interfaces.state_provider_interface import ( - ISecureStateAccess, - ISecureStateModification, -) -from src.core.interfaces.streaming_response_processor_interface import IStreamNormalizer -from src.core.interfaces.tool_call_repair_service_interface import ( - IToolCallRepairService, -) -from src.core.interfaces.translation_service_interface import ITranslationService -from src.core.interfaces.wire_capture_interface import IWireCapture -from src.core.services.app_settings_service import AppSettings -from src.core.services.application_state_service import ApplicationStateService -from src.core.services.backend_processor import BackendProcessor -from src.core.services.backend_request_manager_service import BackendRequestManager -from src.core.services.backend_service import BackendService -from src.core.services.command_processor import CommandProcessor -from src.core.services.dangerous_command_service import DangerousCommandService -from src.core.services.failover_service import FailoverService -from src.core.services.file_sandboxing_handler import FileSandboxingHandler -from src.core.services.json_repair_service import JsonRepairService -from src.core.services.middleware_application_manager import ( - MiddlewareApplicationManager, -) -from src.core.services.path_validation_service import PathValidationService -from src.core.services.pytest_compression_service import PytestCompressionService -from src.core.services.request_processor_service import RequestProcessor -from src.core.services.response_handlers import ( - DefaultNonStreamingResponseHandler, - DefaultStreamingResponseHandler, -) -from src.core.services.response_manager_service import ( - AgentResponseFormatter, - ResponseManager, -) -from src.core.services.response_parser_service import ResponseParser -from src.core.services.response_processor_service import ResponseProcessor -from src.core.services.secure_command_factory import SecureCommandFactory -from src.core.services.secure_state_service import SecureStateService -from src.core.services.session_manager_service import SessionManager -from src.core.services.session_resolver_service import DefaultSessionResolver -from src.core.services.session_service_impl import SessionService -from src.core.services.streaming.content_accumulation_processor import ( - ContentAccumulationProcessor, -) -from src.core.services.streaming.json_repair_processor import JsonRepairProcessor -from src.core.services.streaming.middleware_application_processor import ( - MiddlewareApplicationProcessor, -) -from src.core.services.streaming.stream_normalizer import StreamNormalizer -from src.core.services.streaming.tool_call_repair_processor import ( - ToolCallRepairProcessor, -) -from src.core.services.structured_output_middleware import StructuredOutputMiddleware -from src.core.services.tool_call_reactor_middleware import ToolCallReactorMiddleware -from src.core.services.tool_call_reactor_service import ( - InMemoryToolCallHistoryTracker, - ToolCallReactorService, -) -from src.core.services.tool_call_repair_service import ToolCallRepairService -from src.core.services.translation_service import TranslationService - -T = TypeVar("T") - -# Global service collection -_service_collection: ServiceCollection | None = None -_service_provider: IServiceProvider | None = None - - -def _get_di_diagnostics() -> bool: - """Get DI diagnostics setting from environment.""" - return os.getenv("DI_STRICT_DIAGNOSTICS", "false").lower() in ( - "true", - "1", - "yes", - ) - - -def get_service_collection() -> ServiceCollection: - """Get the global service collection. - - Returns: - The global service collection - """ - global _service_collection - if _service_collection is None: - _service_collection = ServiceCollection() - # Ensure core services are registered into the global collection early. - # This makes DI shape consistent across processes/tests and avoids many - # order-dependent failures. register_core_services is idempotent. - try: - register_core_services(_service_collection, None) - except Exception as exc: - logging.getLogger(__name__).exception( - "Failed to register core services into global service collection" - ) - _service_collection = None - raise ServiceResolutionError( - "Failed to register core services", - details={ - "error_type": type(exc).__name__, - "error_message": str(exc), - }, - ) from exc - return _service_collection - - -def get_or_build_service_provider() -> IServiceProvider: - """Get the global service provider or build one if it doesn't exist. - - Returns: - The global service provider - """ - global _service_provider - if _service_provider is None: - if _get_di_diagnostics(): - logging.getLogger("llm.di").info( - "Building service provider; descriptors=%d", - len(get_service_collection()._descriptors), - ) - _service_provider = get_service_collection().build_service_provider() - return _service_provider - - -def set_service_provider(provider: IServiceProvider) -> None: - """Set the global service provider (used for tests/late init). - - Args: - provider: The ServiceProvider instance to set as the global provider - """ - global _service_provider - _service_provider = provider - - -def get_service_provider() -> IServiceProvider: - """Return the global service provider, building it if necessary. - - This is a compatibility wrapper used by callers that expect a - `get_service_provider` symbol. - """ - provider = get_or_build_service_provider() - return _ensure_tool_call_reactor_services(provider) - - -def _ensure_tool_call_reactor_services( - provider: IServiceProvider, -) -> IServiceProvider: - """Ensure the provider can resolve ToolCallReactor components. - - Args: - provider: The current service provider instance. - - Returns: - A provider that can resolve the ToolCallReactor service and middleware. - - Raises: - ServiceResolutionError: If re-registration fails to provide the required services. - """ - - from src.core.services.tool_call_reactor_middleware import ToolCallReactorMiddleware - from src.core.services.tool_call_reactor_service import ToolCallReactorService - - missing_components: list[str] = [] - - if provider.get_service(ToolCallReactorService) is None: - missing_components.append("ToolCallReactorService") - if provider.get_service(ToolCallReactorMiddleware) is None: - missing_components.append("ToolCallReactorMiddleware") - - if not missing_components: - return provider - - logger = logging.getLogger(__name__) - logger.warning( - "DI provider missing tool call reactor components: %s. Re-registering core services.", - ", ".join(missing_components), - ) - - services = get_service_collection() - descriptors = getattr(services, "_descriptors", {}) - - preserved_descriptors: dict[type, Any] = {} - for key in (AppConfig, cast(type, IConfig)): - descriptor = descriptors.get(key) - if descriptor is not None: - preserved_descriptors[key] = descriptor - - register_core_services(services) - - descriptors.update(preserved_descriptors) - - new_provider = services.build_service_provider() - set_service_provider(new_provider) - - still_missing: list[str] = [] - if new_provider.get_service(ToolCallReactorService) is None: - still_missing.append("ToolCallReactorService") - if new_provider.get_service(ToolCallReactorMiddleware) is None: - still_missing.append("ToolCallReactorMiddleware") - - if still_missing: - raise ServiceResolutionError( - "Failed to register required Tool Call Reactor services.", - details={"missing_components": still_missing}, - ) - - return new_provider - - -def register_core_services( - services: ServiceCollection, app_config: AppConfig | None = None -) -> None: - """Register core services with the service collection. - - Args: - services: The service collection to register services with - app_config: Optional application configuration - """ - logger: logging.Logger = logging.getLogger(__name__) - # Register AppConfig and IConfig - if app_config is not None: - services.add_instance(AppConfig, app_config) - # Also register it as IConfig for interface resolution - try: - services.add_instance( - cast(type, IConfig), - app_config, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IConfig interface: {e}") - # Continue without interface registration if it fails - else: - # Register default AppConfig as IConfig for testing and basic functionality - default_config = AppConfig() - services.add_instance(AppConfig, default_config) - try: - services.add_instance( - cast(type, IConfig), - default_config, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register default IConfig interface: {e}") - # Continue without interface registration if it fails - - # Helper wrappers to make registration idempotent and provide debug logging - - def _registered(service_type: type) -> bool: - desc = getattr(services, "_descriptors", None) - return desc is not None and service_type in desc - - def _add_singleton( - service_type: type, - implementation_type: type | None = None, - implementation_factory: Callable[[IServiceProvider], Any] | None = None, - ) -> None: - if _registered(service_type): - logger.debug( - "Skipping registration of %s; already present", - getattr(service_type, "__name__", str(service_type)), - ) - return - services.add_singleton( - service_type, implementation_type, implementation_factory - ) - - def _add_instance(service_type: type, instance: Any) -> None: - if _registered(service_type): - logger.debug( - "Skipping instance registration of %s; already present", - getattr(service_type, "__name__", str(service_type)), - ) - return - services.add_instance(service_type, instance) - - # Register session resolver - _add_singleton(DefaultSessionResolver) - # Register both the concrete type and the interface - _add_singleton(ISessionResolver, DefaultSessionResolver) # type: ignore[type-abstract] - - # Register CommandService with factory - def _command_service_factory(provider: IServiceProvider) -> ICommandService: - from src.core.commands.parser import CommandParser - from src.core.commands.service import NewCommandService - from src.core.services.command_policy_service import CommandPolicyService - from src.core.services.command_state_service import CommandStateService - from src.core.services.session_service_impl import SessionService - - session_service = provider.get_required_service(SessionService) - command_parser = provider.get_required_service(CommandParser) - config = provider.get_required_service(AppConfig) - app_state = provider.get_service(cast(type, IApplicationState)) - state_service = provider.get_required_service(CommandStateService) - policy_service = provider.get_required_service(CommandPolicyService) - return NewCommandService( - session_service, - command_parser, - strict_command_detection=config.strict_command_detection, - app_state=app_state, - command_state_service=state_service, - command_policy_service=policy_service, - config=config, - ) - - # Register CommandService and bind to interface - _add_singleton(ICommandService, implementation_factory=_command_service_factory) # type: ignore[type-abstract] - - # Register CommandParser - from src.core.commands.parser import CommandParser - from src.core.interfaces.command_parser_interface import ICommandParser - - _add_singleton(ICommandParser, CommandParser) # type: ignore[type-abstract] - _add_singleton(CommandParser, CommandParser) # Also register concrete type - - # Ensure command handlers are imported so their @command decorators register them - try: - import importlib - import pkgutil - - package_name = "src.core.commands.handlers" - package = importlib.import_module(package_name) - for m in pkgutil.iter_modules(package.__path__): # type: ignore[attr-defined] - importlib.import_module(f"{package_name}.{m.name}") - except Exception: - logging.getLogger(__name__).warning( - "Failed to import command handlers for registration", exc_info=True - ) - - # Register session service factory - def _session_service_factory(provider: IServiceProvider) -> SessionService: - # Import here to avoid circular imports - from src.core.repositories.in_memory_session_repository import ( - InMemorySessionRepository, - ) - - # Create repository - repository: InMemorySessionRepository = InMemorySessionRepository() - - # Return session service - return SessionService(repository) - - # Register session service and bind to interface - _add_singleton(SessionService, implementation_factory=_session_service_factory) - - try: - services.add_singleton( - cast(type, ISessionService), implementation_factory=_session_service_factory - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register ISessionService interface: {e}") - # Continue if concrete SessionService is registered - - # Register command state service - from src.core.interfaces.command_state_service_interface import ( - ICommandStateService, - ) - from src.core.services.command_state_service import CommandStateService - - def _command_state_service_factory( - provider: IServiceProvider, - ) -> CommandStateService: - session = provider.get_required_service(SessionService) - return CommandStateService(session) - - _add_singleton( - CommandStateService, implementation_factory=_command_state_service_factory - ) - - try: - services.add_singleton( - cast(type, ICommandStateService), - implementation_factory=lambda provider: provider.get_required_service( - CommandStateService - ), - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register ICommandStateService interface: {e}") - # Continue if concrete CommandStateService is registered - - # Register command policy service - from src.core.interfaces.command_policy_service_interface import ( - ICommandPolicyService, - ) - from src.core.services.command_policy_service import CommandPolicyService - - def _command_policy_service_factory( - provider: IServiceProvider, - ) -> CommandPolicyService: - cfg = provider.get_required_service(AppConfig) - app_state = provider.get_service(cast(type, IApplicationState)) - return CommandPolicyService(cfg, app_state) - - _add_singleton( - CommandPolicyService, implementation_factory=_command_policy_service_factory - ) - - try: - services.add_singleton( - cast(type, ICommandPolicyService), - implementation_factory=lambda provider: provider.get_required_service( - CommandPolicyService - ), - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register ICommandPolicyService interface: {e}") - # Continue if concrete CommandPolicyService is registered - - # Register command processor - def _command_processor_factory(provider: IServiceProvider) -> ICommandProcessor: - # Get command service - from typing import cast - - from src.core.commands.tool_call_command_processor import ( - ToolCallCommandProcessor, - ) - from src.core.services.delegating_command_processor import ( - DelegatingCommandProcessor, - ) - - command_service: ICommandService = provider.get_required_service( - cast(type, ICommandService) - ) - - # Create the processors - text_command_processor = CommandProcessor(command_service) - tool_call_command_processor = ToolCallCommandProcessor(command_service) - - # Return the delegating processor - return DelegatingCommandProcessor( - tool_call_command_processor=tool_call_command_processor, - text_command_processor=text_command_processor, - ) - - # Register command processor and bind to interface - try: - services.add_singleton( - cast(type, ICommandProcessor), - implementation_factory=_command_processor_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register ICommandProcessor interface: {e}") - # Continue without interface registration if it fails - - # Register backend processor - def _backend_processor_factory(provider: IServiceProvider) -> BackendProcessor: - # Get backend service and session service - from typing import cast - - backend_service: IBackendService = provider.get_required_service( - cast(type, IBackendService) - ) - session_service: ISessionService = provider.get_required_service( - cast(type, ISessionService) - ) - app_state: IApplicationState = provider.get_required_service( - cast(type, IApplicationState) - ) - - # Return backend processor - return BackendProcessor(backend_service, session_service, app_state) - - # Register backend processor and bind to interface - _add_singleton(BackendProcessor, implementation_factory=_backend_processor_factory) - - try: - services.add_singleton( - cast(type, IBackendProcessor), - implementation_factory=_backend_processor_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IBackendProcessor interface: {e}") - # Continue if concrete BackendProcessor is registered - - # Register response handlers - _add_singleton(DefaultNonStreamingResponseHandler) - _add_singleton(DefaultStreamingResponseHandler) - - try: - services.add_singleton( - cast(type, INonStreamingResponseHandler), DefaultNonStreamingResponseHandler - ) - services.add_singleton( - cast(type, IStreamingResponseHandler), DefaultStreamingResponseHandler - ) - except Exception as e: - logger.warning(f"Failed to register response handler interfaces: {e}") - # Continue if concrete handlers are registered - - # Register MiddlewareApplicationManager and IMiddlewareApplicationManager with configured middleware list - def _middleware_application_manager_factory( - provider: IServiceProvider, - ) -> MiddlewareApplicationManager: - from src.core.app.middleware.json_repair_middleware import JsonRepairMiddleware - from src.core.app.middleware.tool_call_repair_middleware import ( - ToolCallRepairMiddleware, - ) - from src.core.config.app_config import AppConfig - from src.core.services.empty_response_middleware import ( - EmptyResponseMiddleware, - ) - from src.core.services.middleware_application_manager import ( - MiddlewareApplicationManager, - ) - from src.core.services.tool_call_loop_middleware import ( - ToolCallLoopDetectionMiddleware, - ) - - cfg: AppConfig = provider.get_required_service(AppConfig) - middlewares: list[IResponseMiddleware] = [] - - try: - if getattr(cfg.empty_response, "enabled", True): - middlewares.append( - EmptyResponseMiddleware( - enabled=True, - max_retries=getattr(cfg.empty_response, "max_retries", 1), - ) - ) - except Exception as e: - logging.getLogger(__name__).warning( - f"Error configuring EmptyResponseMiddleware: {e}", exc_info=True - ) - - # Edit-precision response-side detection (optional) - try: - from src.core.services.edit_precision_response_middleware import ( - EditPrecisionResponseMiddleware, - ) - - app_state = provider.get_required_service(ApplicationStateService) - middlewares.append(EditPrecisionResponseMiddleware(app_state)) - except Exception as e: - logging.getLogger(__name__).warning( - f"Error configuring EditPrecisionResponseMiddleware: {e}", - exc_info=True, - ) - - # Think tags fix middleware (optional) - try: - if getattr(cfg.session, "fix_think_tags_enabled", False): - from src.core.services.think_tags_fix_middleware import ( - ThinkTagsFixMiddleware, - ) - - # Configure streaming buffer size from config - buffer_size = getattr( - cfg.session, "fix_think_tags_streaming_buffer_size", 4096 - ) - middlewares.append( - ThinkTagsFixMiddleware( - enabled=True, streaming_buffer_size=buffer_size - ) - ) - except Exception as e: - logging.getLogger(__name__).warning( - f"Error configuring ThinkTagsFixMiddleware: {e}", - exc_info=True, - ) - - if getattr(cfg.session, "json_repair_enabled", False): - json_service: JsonRepairService = provider.get_required_service( - JsonRepairService - ) - middlewares.append(JsonRepairMiddleware(cfg, json_service)) - - if getattr(cfg.session, "tool_call_repair_enabled", True): - tcr_service: ToolCallRepairService = provider.get_required_service( - ToolCallRepairService - ) - middlewares.append(ToolCallRepairMiddleware(cfg, tcr_service)) - - try: - middlewares.append(ToolCallLoopDetectionMiddleware()) - except Exception as e: - logging.getLogger(__name__).warning( - f"Error configuring ToolCallLoopDetectionMiddleware: {e}", exc_info=True - ) - - # Add tool call reactor middleware - try: - tool_call_reactor_middleware = provider.get_required_service( - ToolCallReactorMiddleware - ) - middlewares.append(tool_call_reactor_middleware) - except Exception as e: - logging.getLogger(__name__).warning( - f"Error configuring ToolCallReactorMiddleware: {e}", exc_info=True - ) - - # Dangerous command prevention will be handled by Tool Call Reactor handler. - # Keeping old middleware disabled to avoid duplicate processing. - - return MiddlewareApplicationManager(middlewares) - - _add_singleton( - MiddlewareApplicationManager, - implementation_factory=_middleware_application_manager_factory, - ) - try: - services.add_singleton( - cast(type, IMiddlewareApplicationManager), - implementation_factory=_middleware_application_manager_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning( - f"Failed to register IMiddlewareApplicationManager interface: {e}" - ) - # Continue if concrete MiddlewareApplicationManager is registered - - # Register MiddlewareApplicationProcessor used inside the streaming pipeline - def _middleware_application_processor_factory( - provider: IServiceProvider, - ) -> MiddlewareApplicationProcessor: - manager: MiddlewareApplicationManager = provider.get_required_service( - MiddlewareApplicationManager - ) - app_state: IApplicationState = provider.get_required_service( - IApplicationState # type: ignore[type-abstract] - ) - - import os - - from src.core.domain.configuration.loop_detection_config import ( - LoopDetectionConfiguration, - ) - from src.tool_call_loop.config import ToolCallLoopConfig - - env_config = ToolCallLoopConfig.from_env_vars(dict(os.environ)) - loop_config = ( - LoopDetectionConfiguration() - .with_tool_loop_detection_enabled(env_config.enabled) - .with_tool_loop_max_repeats(env_config.max_repeats) - .with_tool_loop_ttl_seconds(env_config.ttl_seconds) - .with_tool_loop_mode(env_config.mode) - ) - - return MiddlewareApplicationProcessor( - manager._middleware, - default_loop_config=loop_config, - app_state=app_state, - ) - - _add_singleton( - MiddlewareApplicationProcessor, - implementation_factory=_middleware_application_processor_factory, - ) - - # Register response processor - def _response_processor_factory(provider: IServiceProvider) -> ResponseProcessor: - from typing import cast - - app_state: IApplicationState = provider.get_required_service( - cast(type, IApplicationState) - ) - stream_normalizer: IStreamNormalizer = provider.get_required_service( - cast(type, IStreamNormalizer) - ) - response_parser: IResponseParser = provider.get_required_service( - cast(type, IResponseParser) - ) - middleware_application_manager: IMiddlewareApplicationManager = ( - provider.get_required_service(cast(type, IMiddlewareApplicationManager)) - ) - - # Get the middleware manager to access the middleware list - middleware_manager: MiddlewareApplicationManager = ( - provider.get_required_service(MiddlewareApplicationManager) - ) - - # Get loop detector for non-streaming responses - from src.core.interfaces.loop_detector_interface import ILoopDetector - - loop_detector = provider.get_service(cast(type, ILoopDetector)) - - return ResponseProcessor( - response_parser=response_parser, - middleware_application_manager=middleware_application_manager, - app_state=app_state, - loop_detector=loop_detector, - stream_normalizer=stream_normalizer, - middleware_list=middleware_manager._middleware, - ) - - # Register response processor and bind to interface - _add_singleton( - ResponseProcessor, implementation_factory=_response_processor_factory - ) - - try: - services.add_singleton( - cast(type, IResponseProcessor), - implementation_factory=_response_processor_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IResponseProcessor interface: {e}") - # Continue if concrete ResponseProcessor is registered - - def _application_state_factory( - provider: IServiceProvider, - ) -> ApplicationStateService: - # Create application state service - return ApplicationStateService() - - # Register app settings - def _app_settings_factory(provider: IServiceProvider) -> AppSettings: - # Get app_state from IApplicationState if available - app_state: Any | None = None - try: - app_state_service: IApplicationState | None = provider.get_service( - ApplicationStateService - ) - if app_state_service: - app_state = app_state_service.get_setting("service_provider") - except Exception as e: - logger.debug(f"Could not get app_state from ApplicationStateService: {e}") - app_state = None - - # Create app settings - return AppSettings(app_state) - - # Register app settings and bind to interface - _add_singleton(AppSettings, implementation_factory=_app_settings_factory) - - try: - services.add_singleton( - cast(type, IAppSettings), implementation_factory=_app_settings_factory - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IAppSettings interface: {e}") - # Continue if concrete AppSettings is registered - - # Register application state service - _add_singleton(ApplicationStateService) - - try: - services.add_singleton( - cast(type, IApplicationState), - implementation_factory=_application_state_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IApplicationState interface: {e}") - # Continue if concrete ApplicationStateService is registered - - # Register secure state service - def _secure_state_factory(provider: IServiceProvider) -> SecureStateService: - app_state = provider.get_required_service(ApplicationStateService) - return SecureStateService(app_state) - - _add_singleton(SecureStateService, implementation_factory=_secure_state_factory) - - try: - services.add_singleton( - cast(type, ISecureStateAccess), implementation_factory=_secure_state_factory - ) # type: ignore[type-abstract] - services.add_singleton( - cast(type, ISecureStateModification), - implementation_factory=_secure_state_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register secure state interfaces: {e}") - # Continue if concrete SecureStateService is registered - - # Register secure command factory - def _secure_command_factory(provider: IServiceProvider) -> SecureCommandFactory: - secure_state = provider.get_required_service(SecureStateService) - return SecureCommandFactory( - state_reader=secure_state, state_modifier=secure_state - ) - - _add_singleton(SecureCommandFactory, implementation_factory=_secure_command_factory) - - # Register conversation fingerprint service - from src.core.services.conversation_fingerprint_service import ( - ConversationFingerprintService, - ) - - _add_singleton(ConversationFingerprintService) - - # Register session manager - def _session_manager_factory(provider: IServiceProvider) -> SessionManager: - session_service = provider.get_required_service(ISessionService) # type: ignore[type-abstract] - session_resolver = provider.get_required_service(ISessionResolver) # type: ignore[type-abstract] - # Get session repository for fingerprint tracking - session_repository = provider.get_service(cast(type, ISessionRepository)) # type: ignore[type-abstract] - fingerprint_service = provider.get_required_service( - ConversationFingerprintService - ) - return SessionManager( - session_service, - session_resolver, - session_repository=session_repository, - fingerprint_service=fingerprint_service, - ) - - _add_singleton(SessionManager, implementation_factory=_session_manager_factory) - - try: - services.add_singleton( - cast(type, ISessionManager), implementation_factory=_session_manager_factory - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register ISessionManager interface: {e}") - # Continue if concrete SessionManager is registered - - # Register agent response formatter - def _agent_response_formatter_factory( - provider: IServiceProvider, - ) -> AgentResponseFormatter: - session_service = provider.get_service(SessionService) - return AgentResponseFormatter(session_service=session_service) - - _add_singleton( - AgentResponseFormatter, implementation_factory=_agent_response_formatter_factory - ) - - try: - services.add_singleton( - cast(type, IAgentResponseFormatter), - implementation_factory=_agent_response_formatter_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IAgentResponseFormatter interface: {e}") - # Continue if concrete AgentResponseFormatter is registered - - # Register response manager - def _response_manager_factory(provider: IServiceProvider) -> ResponseManager: - agent_response_formatter = provider.get_required_service(IAgentResponseFormatter) # type: ignore[type-abstract] - session_service = provider.get_required_service(ISessionService) # type: ignore[type-abstract] - return ResponseManager(agent_response_formatter, session_service) - - _add_singleton(ResponseManager, implementation_factory=_response_manager_factory) - - try: - services.add_singleton( - cast(type, IResponseManager), - implementation_factory=_response_manager_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IResponseManager interface: {e}") - # Continue if concrete ResponseManager is registered - - # Register backend request manager - def _backend_request_manager_factory( - provider: IServiceProvider, - ) -> BackendRequestManager: - backend_processor = provider.get_required_service(IBackendProcessor) # type: ignore[type-abstract] - response_processor = provider.get_required_service(IResponseProcessor) # type: ignore[type-abstract] - wire_capture = provider.get_required_service(IWireCapture) # type: ignore[type-abstract] - return BackendRequestManager( - backend_processor, response_processor, wire_capture - ) - - _add_singleton( - BackendRequestManager, implementation_factory=_backend_request_manager_factory - ) - - try: - services.add_singleton( - cast(type, IBackendRequestManager), - implementation_factory=_backend_request_manager_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IBackendRequestManager interface: {e}") - # Continue if concrete BackendRequestManager is registered - - # Register stream normalizer - def _stream_normalizer_factory(provider: IServiceProvider) -> StreamNormalizer: - # Retrieve all stream processors in the correct order - try: - from src.core.config.app_config import AppConfig - - app_config: AppConfig = provider.get_required_service(AppConfig) - - # Optional JSON repair processor (enabled via config) - json_repair_processor = None - if getattr(app_config.session, "json_repair_enabled", False): - json_repair_processor = provider.get_required_service( - JsonRepairProcessor - ) - tool_call_repair_processor = None - if getattr(app_config.session, "tool_call_repair_enabled", True): - tool_call_repair_processor = provider.get_required_service( - ToolCallRepairProcessor - ) - loop_detection_processor = None - try: - loop_detection_processor = provider.get_required_service( - LoopDetectionProcessor - ) - logger.debug( - "LoopDetectionProcessor successfully registered for streaming" - ) - except Exception as e: - logger.warning( - f"Failed to register LoopDetectionProcessor for streaming: {e}" - ) - loop_detection_processor = None - middleware_application_processor = provider.get_required_service( - MiddlewareApplicationProcessor - ) - content_accumulation_processor = provider.get_required_service( - ContentAccumulationProcessor - ) - - processors: list[IStreamProcessor] = [] - # Prefer JSON repair first so JSON blocks are valid - if json_repair_processor is not None: - processors.append(json_repair_processor) - # Then text loop detection - if loop_detection_processor is not None: - processors.append(loop_detection_processor) - # Then tool-call repair - if tool_call_repair_processor is not None: - processors.append(tool_call_repair_processor) - # Middleware and accumulation - processors.append(middleware_application_processor) - processors.append(content_accumulation_processor) - except Exception as e: - logger.warning( - f"Error creating stream processors: {e}. Using default configuration." - ) - # Create minimal configuration with just content accumulation - # Use default 10MB buffer limit for fallback - content_accumulation_processor = ContentAccumulationProcessor( - max_buffer_bytes=10 * 1024 * 1024 - ) - processors = [content_accumulation_processor] - - return StreamNormalizer(processors) - - _add_singleton(StreamNormalizer, implementation_factory=_stream_normalizer_factory) - - try: - services.add_singleton( - cast(type, IStreamNormalizer), - implementation_factory=_stream_normalizer_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IStreamNormalizer interface: {e}") - # Continue if concrete StreamNormalizer is registered - - # Register ResponseParser - def _response_parser_factory(provider: IServiceProvider) -> ResponseParser: - - return ResponseParser() - - _add_singleton(ResponseParser, implementation_factory=_response_parser_factory) - try: - services.add_singleton( - cast(type, IResponseParser), implementation_factory=_response_parser_factory - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IResponseParser interface: {e}") - # Continue if concrete ResponseParser is registered - - # Register individual stream processors - def _loop_detection_processor_factory( - provider: IServiceProvider, - ) -> LoopDetectionProcessor: - from src.core.interfaces.loop_detector_interface import ILoopDetector - - # Create a factory function that creates new detector instances - # This ensures each session gets its own isolated detector - def create_detector() -> ILoopDetector: - return provider.get_required_service(cast(type, ILoopDetector)) - - return LoopDetectionProcessor(loop_detector_factory=create_detector) - - _add_singleton( - LoopDetectionProcessor, implementation_factory=_loop_detection_processor_factory - ) - - # Register ContentAccumulationProcessor with configured buffer limit - def _content_accumulation_processor_factory( - provider: IServiceProvider, - ) -> ContentAccumulationProcessor: - from src.core.config.app_config import AppConfig - - config: AppConfig = provider.get_required_service(AppConfig) - buffer_cap = getattr( - config.session, "content_accumulation_buffer_cap_bytes", 10 * 1024 * 1024 - ) - return ContentAccumulationProcessor(max_buffer_bytes=buffer_cap) - - _add_singleton( - ContentAccumulationProcessor, - implementation_factory=_content_accumulation_processor_factory, - ) - - # Register JSON repair service and processor - def _json_repair_service_factory(provider: IServiceProvider) -> JsonRepairService: - return JsonRepairService() - - _add_singleton( - JsonRepairService, implementation_factory=_json_repair_service_factory - ) - - # Register StructuredOutputMiddleware - def _structured_output_middleware_factory( - provider: IServiceProvider, - ) -> StructuredOutputMiddleware: - json_repair_service: JsonRepairService = provider.get_required_service( - JsonRepairService - ) - return StructuredOutputMiddleware(json_repair_service) - - _add_singleton( - StructuredOutputMiddleware, - implementation_factory=_structured_output_middleware_factory, - ) - - def _json_repair_processor_factory( - provider: IServiceProvider, - ) -> JsonRepairProcessor: - from src.core.config.app_config import AppConfig - - config: AppConfig = provider.get_required_service(AppConfig) - service: JsonRepairService = provider.get_required_service(JsonRepairService) - return JsonRepairProcessor( - repair_service=service, - buffer_cap_bytes=getattr( - config.session, "json_repair_buffer_cap_bytes", 64 * 1024 - ), - strict_mode=getattr(config.session, "json_repair_strict_mode", False), - schema=getattr(config.session, "json_repair_schema", None), - enabled=getattr(config.session, "json_repair_enabled", False), - ) - - _add_singleton( - JsonRepairProcessor, implementation_factory=_json_repair_processor_factory - ) - - # Wire capture service is registered in CoreServicesStage using BufferedWireCapture. - # Intentionally avoid legacy StructuredWireCapture registration here to keep - # the active format consistent across the app. - - # Register tool call repair service (if not already registered elsewhere as a concrete type) - def _tool_call_repair_service_factory( - provider: IServiceProvider, - ) -> ToolCallRepairService: - return ToolCallRepairService() - - _add_singleton( - ToolCallRepairService, implementation_factory=_tool_call_repair_service_factory - ) - - # Register TranslationService (dependency of BackendService) - def _translation_service_factory(provider: IServiceProvider) -> TranslationService: - return TranslationService() - - _add_singleton( - TranslationService, implementation_factory=_translation_service_factory - ) - - # Register ITranslationService interface to resolve to the same singleton instance - def _translation_service_interface_factory( - provider: IServiceProvider, - ) -> TranslationService: - return provider.get_required_service(TranslationService) - - _add_singleton( - cast(type, ITranslationService), - implementation_factory=_translation_service_interface_factory, - ) - - # Register assessment services if enabled - if app_config and app_config.assessment.enabled: - logger.info( - "LLM Assessment System ACTIVATED - Monitoring conversations for unproductive patterns" - ) - - # Initialize assessment prompts first - from src.core.services.assessment_prompts import initialize_prompts - - try: - initialize_prompts() - logger.info("Assessment prompts loaded successfully") - except Exception as e: - logger.error(f"Failed to load assessment prompts: {e}") - raise - - # Import assessment services only when needed to avoid circular imports - from src.core.interfaces.assessment_service_interface import ( - IAssessmentBackendService, - IAssessmentRepository, - IAssessmentService, - ITurnCounterService, - ) - from src.core.repositories.assessment_repository import ( - InMemoryAssessmentRepository, - ) - from src.core.services.assessment_backend_service import ( - AssessmentBackendService, - ) - from src.core.services.assessment_service import AssessmentService - from src.core.services.turn_counter_service import TurnCounterService - - # Assessment repository - def _assessment_repository_factory( - provider: IServiceProvider, - ) -> InMemoryAssessmentRepository: - return InMemoryAssessmentRepository() - - _add_singleton( - IAssessmentRepository, implementation_factory=_assessment_repository_factory # type: ignore[type-abstract] - ) - - # Turn counter service - def _turn_counter_service_factory( - provider: IServiceProvider, - ) -> TurnCounterService: - repository = provider.get_required_service(IAssessmentRepository) # type: ignore[type-abstract] - config = provider.get_required_service(AppConfig).assessment - return TurnCounterService(repository, config) - - _add_singleton( - ITurnCounterService, implementation_factory=_turn_counter_service_factory # type: ignore[type-abstract] - ) - - # Assessment backend service - - def _assessment_backend_service_factory( - provider: IServiceProvider, - ) -> AssessmentBackendService: - backend_service = provider.get_required_service(IBackendService) # type: ignore[type-abstract] - config = provider.get_required_service(AppConfig).assessment - return AssessmentBackendService(backend_service, config) - - _add_singleton( - IAssessmentBackendService, - implementation_factory=_assessment_backend_service_factory, # type: ignore[type-abstract] - ) - - # Core assessment service - - def _assessment_service_factory( - provider: IServiceProvider, - ) -> AssessmentService: - backend_service = provider.get_required_service(IAssessmentBackendService) # type: ignore[type-abstract] - config = provider.get_required_service(AppConfig).assessment - return AssessmentService(backend_service, config) - - _add_singleton( - IAssessmentService, implementation_factory=_assessment_service_factory # type: ignore[type-abstract] - ) - - logger.info( - f"Assessment services registered: backend={app_config.assessment.backend}, " - f"model={app_config.assessment.model}, threshold={app_config.assessment.turn_threshold}" - ) - - try: - services.add_singleton( - cast(type, IToolCallRepairService), - implementation_factory=_tool_call_repair_service_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IToolCallRepairService interface: {e}") - # Continue if concrete ToolCallRepairService is registered - - # Register tool call repair processor - def _tool_call_repair_processor_factory( - provider: IServiceProvider, - ) -> ToolCallRepairProcessor: - tool_call_repair_service = provider.get_required_service(IToolCallRepairService) # type: ignore[type-abstract] - return ToolCallRepairProcessor(tool_call_repair_service) - - _add_singleton( - ToolCallRepairProcessor, - implementation_factory=_tool_call_repair_processor_factory, - ) - - # Register dangerous command service - def _dangerous_command_service_factory( - provider: IServiceProvider, - ) -> DangerousCommandService: - from src.core.config.app_config import AppConfig - from src.core.domain.configuration.dangerous_command_config import ( - DEFAULT_DANGEROUS_COMMAND_CONFIG, - ) - from src.core.services.dangerous_command_service import ( - DangerousCommandService, - ) - - provider.get_required_service(AppConfig) - return DangerousCommandService(DEFAULT_DANGEROUS_COMMAND_CONFIG) - - _add_singleton( - DangerousCommandService, - implementation_factory=_dangerous_command_service_factory, - ) - - # Register pytest compression service - def _pytest_compression_service_factory( - provider: IServiceProvider, - ) -> PytestCompressionService: - from src.core.services.pytest_compression_service import ( - PytestCompressionService, - ) - - provider.get_required_service(AppConfig) - return PytestCompressionService() - - _add_singleton( - PytestCompressionService, - implementation_factory=_pytest_compression_service_factory, - ) - - # Register tool access policy service - from src.core.services.tool_access_policy_service import ToolAccessPolicyService - - def _tool_access_policy_service_factory( - provider: IServiceProvider, - ) -> ToolAccessPolicyService: - from src.core.config.app_config import AppConfig - - app_config: AppConfig = provider.get_required_service(AppConfig) - reactor_config = app_config.session.tool_call_reactor - - # Get global overrides from session config (set by CLI parameters) - global_overrides = getattr( - app_config.session, "tool_access_global_overrides", None - ) - return ToolAccessPolicyService( - reactor_config, global_overrides=global_overrides - ) - - _add_singleton( - ToolAccessPolicyService, - implementation_factory=_tool_access_policy_service_factory, - ) - - # Register tool call reactor services - def _tool_call_history_tracker_factory( - provider: IServiceProvider, - ) -> InMemoryToolCallHistoryTracker: - return InMemoryToolCallHistoryTracker() - - _add_singleton( - InMemoryToolCallHistoryTracker, - implementation_factory=_tool_call_history_tracker_factory, - ) - - def _tool_call_reactor_factory( - provider: IServiceProvider, - ) -> ToolCallReactorService: - from src.core.config.app_config import AppConfig - - history_tracker = provider.get_required_service(InMemoryToolCallHistoryTracker) - reactor = ToolCallReactorService(history_tracker) - - # Get configuration - app_config: AppConfig = provider.get_required_service(AppConfig) - reactor_config = app_config.session.tool_call_reactor - - # Register default handlers if enabled - if reactor_config.enabled: - from src.core.services.tool_call_handlers.config_steering_handler import ( - ConfigSteeringHandler, - ) - from src.core.services.tool_call_handlers.dangerous_command_handler import ( - DangerousCommandHandler, - ) - from src.core.services.tool_call_handlers.pytest_full_suite_handler import ( - PytestFullSuiteHandler, - ) - - # Register config-driven steering handler (includes synthesized legacy apply_diff rule when enabled) - try: - # Build effective rules from config while avoiding expensive deep copy. - # Since steering_rules are configuration data (immutable during runtime), - # we can safely use a shallow copy for better performance. - effective_rules = ( - (reactor_config.steering_rules or []).copy() - if reactor_config.steering_rules - else [] - ) - - # Synthesize legacy apply_diff rule if enabled and missing - if getattr(reactor_config, "apply_diff_steering_enabled", True): - has_apply_rule = False - for r in effective_rules: - triggers = (r or {}).get("triggers") or {} - tnames = triggers.get("tool_names") or [] - phrases = triggers.get("phrases") or [] - if "apply_diff" in tnames or any( - isinstance(p, str) and "apply_diff" in p for p in phrases - ): - has_apply_rule = True - break - if not has_apply_rule: - effective_rules.append( - { - "name": "apply_diff_to_patch_file", - "enabled": True, - "priority": 100, - "triggers": { - "tool_names": ["apply_diff"], - "phrases": [], - }, - "message": ( - reactor_config.apply_diff_steering_message - or ( - "You tried to use apply_diff tool. Please prefer to use patch_file tool instead, " - "as it is superior to apply_diff and provides automated Python QA checks." - ) - ), - "rate_limit": { - "calls_per_window": 1, - "window_seconds": reactor_config.apply_diff_steering_rate_limit_seconds, - }, - } - ) - - if effective_rules: - config_handler = ConfigSteeringHandler(rules=effective_rules) - try: - reactor.register_handler_sync(config_handler) - except Exception as e: - logger.warning( - f"Failed to register config steering handler: {e}", - exc_info=True, - ) - except Exception as e: - logger.warning( - "Failed to register steering handlers: %s", e, exc_info=True - ) - - # Register DangerousCommandHandler if enabled in session config - try: - if getattr( - app_config.session, "dangerous_command_prevention_enabled", True - ): - dangerous_service = provider.get_required_service( - DangerousCommandService - ) - dangerous_handler = DangerousCommandHandler( - dangerous_service, - steering_message=getattr( - app_config.session, - "dangerous_command_steering_message", - None, - ), - enabled=True, - ) - try: - reactor.register_handler_sync(dangerous_handler) - except Exception as e: - logger.warning( - f"Failed to register dangerous command handler: {e}", - exc_info=True, - ) - except Exception as e: - logger.warning( - f"Failed to register DangerousCommandHandler: {e}", exc_info=True - ) - - # Register PytestFullSuiteHandler if enabled - try: - if getattr(reactor_config, "pytest_full_suite_steering_enabled", False): - steering_message = getattr( - reactor_config, "pytest_full_suite_steering_message", None - ) - pytest_full_suite_handler = PytestFullSuiteHandler( - message=steering_message, - enabled=True, - ) - try: - reactor.register_handler_sync(pytest_full_suite_handler) - except Exception as e: - logger.warning( - f"Failed to register pytest full-suite handler: {e}", - exc_info=True, - ) - except Exception as e: - logger.warning( - f"Failed to register PytestFullSuiteHandler: {e}", exc_info=True - ) - - # Register PytestContextSavingHandler if enabled - try: - if getattr(reactor_config, "pytest_context_saving_enabled", False): - from src.core.services.tool_call_handlers.pytest_context_saving_handler import ( - PytestContextSavingHandler, - ) - - context_saving_handler = PytestContextSavingHandler(enabled=True) - try: - reactor.register_handler_sync(context_saving_handler) - except Exception as e: - logger.warning( - f"Failed to register pytest context saving handler: {e}", - exc_info=True, - ) - except Exception as e: - logger.warning( - f"Failed to register PytestContextSavingHandler: {e}", exc_info=True - ) - - # Register PytestCompressionHandler if enabled in session config - try: - if getattr(app_config.session, "pytest_compression_enabled", True): - from src.core.services.tool_call_handlers.pytest_compression_handler import ( - PytestCompressionHandler, - ) - - pytest_compression_service = provider.get_required_service( - PytestCompressionService - ) - session_service = provider.get_required_service(SessionService) - pytest_handler = PytestCompressionHandler( - pytest_compression_service, - session_service, - enabled=True, - ) - try: - reactor.register_handler_sync(pytest_handler) - except Exception as e: - logger.warning( - f"Failed to register pytest compression handler: {e}", - exc_info=True, - ) - except Exception as e: - logger.warning( - f"Failed to register PytestCompressionHandler: {e}", exc_info=True - ) - - # Register ToolAccessControlHandler if access policies are configured - try: - from src.core.services.tool_access_policy_service import ( - ToolAccessPolicyService, - ) - from src.core.services.tool_call_handlers.tool_access_control_handler import ( - ToolAccessControlHandler, - ) - - # Get the policy service - policy_service = provider.get_required_service(ToolAccessPolicyService) - - # Only register if there are policies configured - if policy_service._policies: - tool_access_handler = ToolAccessControlHandler( - policy_service=policy_service, - priority=90, # After dangerous-command handler (100) - reactor_service=reactor, # Pass reactor for telemetry - ) - try: - reactor.register_handler_sync(tool_access_handler) - logger.info( - f"Registered ToolAccessControlHandler with priority 90 " - f"({len(policy_service._policies)} policies loaded)" - ) - except Exception as e: - logger.warning( - f"Failed to register tool access control handler: {e}", - exc_info=True, - ) - except Exception as e: - logger.warning( - f"Failed to register ToolAccessControlHandler: {e}", exc_info=True - ) - - return reactor - - _add_singleton( - ToolCallReactorService, - implementation_factory=_tool_call_reactor_factory, - ) - - def _tool_call_reactor_middleware_factory( - provider: IServiceProvider, - ) -> ToolCallReactorMiddleware: - from src.core.config.app_config import AppConfig - - reactor = provider.get_required_service(ToolCallReactorService) - - # Get configuration to determine if middleware should be enabled - app_config: AppConfig = provider.get_required_service(AppConfig) - enabled = app_config.session.tool_call_reactor.enabled - - return ToolCallReactorMiddleware(reactor, enabled=enabled, priority=-10) - - _add_singleton( - ToolCallReactorMiddleware, - implementation_factory=_tool_call_reactor_middleware_factory, - ) - - # Register PathValidationService - def _path_validation_service_factory( - provider: IServiceProvider, - ) -> PathValidationService: - return PathValidationService() - - _add_singleton( - PathValidationService, implementation_factory=_path_validation_service_factory - ) - _add_singleton( - IPathValidator, # type: ignore[type-abstract] - implementation_factory=lambda p: p.get_required_service(PathValidationService), - ) - - # Register FileSandboxingHandler - def _file_sandboxing_handler_factory( - provider: IServiceProvider, - ) -> FileSandboxingHandler: - config = provider.get_required_service(AppConfig) - path_validator = provider.get_required_service(IPathValidator) # type: ignore[type-abstract] - session_service = provider.get_required_service(ISessionService) # type: ignore[type-abstract] - - return FileSandboxingHandler( - config=config.sandboxing, - path_validator=path_validator, - session_service=session_service, - ) - - _add_singleton( - FileSandboxingHandler, implementation_factory=_file_sandboxing_handler_factory - ) - - # Register backend service - def _backend_service_factory(provider: IServiceProvider) -> BackendService: - # Import required modules - import httpx - - from src.core.services.backend_factory import BackendFactory - from src.core.services.rate_limiter import RateLimiter - - # Get or create dependencies - httpx_client: httpx.AsyncClient | None = provider.get_service(httpx.AsyncClient) - if httpx_client is None: - try: - httpx_client = httpx.AsyncClient( - http2=True, - timeout=httpx.Timeout( - connect=10.0, read=60.0, write=60.0, pool=60.0 - ), - limits=httpx.Limits( - max_connections=100, max_keepalive_connections=20 - ), - trust_env=False, - ) - except ImportError: - httpx_client = httpx.AsyncClient( - http2=False, - timeout=httpx.Timeout( - connect=10.0, read=60.0, write=60.0, pool=60.0 - ), - limits=httpx.Limits( - max_connections=100, max_keepalive_connections=20 - ), - trust_env=False, - ) - - # Get app config - app_config: AppConfig = provider.get_required_service(AppConfig) - - backend_factory: BackendFactory = provider.get_required_service(BackendFactory) - - # Resolve the rate limiter from the DI container when available - rate_limiter: IRateLimiter | None = provider.get_service(RateLimiter) - if rate_limiter is None: - rate_limiter = provider.get_service(cast(type, IRateLimiter)) - if rate_limiter is None: - logging.getLogger(__name__).warning( - "RateLimiter service not registered; creating transient instance" - ) - rate_limiter = RateLimiter() - - # Get application state service - app_state: IApplicationState = provider.get_required_service(IApplicationState) # type: ignore[type-abstract] - - # Get failover coordinator (optional for test environments) - failover_coordinator = None - try: - failover_coordinator = provider.get_service(IFailoverCoordinator) # type: ignore[type-abstract] - except Exception as e: - logger.debug(f"FailoverCoordinator not available: {e}") - - # Get backend config provider or create one - backend_config_provider = None - try: - backend_config_provider = provider.get_service(IBackendConfigProvider) # type: ignore[type-abstract] - except Exception as e: - logger.debug( - f"BackendConfigProvider not available, will create default: {e}" - ) - - # If not available, create one with the app config - if backend_config_provider is None: - from src.core.services.backend_config_provider import BackendConfigProvider - - backend_config_provider = BackendConfigProvider(app_config) - - # Optionally build a failover strategy based on feature flag - failover_strategy = None - try: - if ( - app_state.get_use_failover_strategy() - and failover_coordinator is not None - ): - from src.core.services.failover_strategy import DefaultFailoverStrategy - - failover_strategy = DefaultFailoverStrategy(failover_coordinator) - except (AttributeError, ImportError, TypeError) as e: - logging.getLogger(__name__).debug( - "Failed to enable failover strategy: %s", e, exc_info=True - ) - - # Return backend service - return BackendService( - backend_factory, - rate_limiter, - app_config, - session_service=provider.get_required_service(SessionService), - app_state=app_state, - backend_config_provider=backend_config_provider, - failover_coordinator=failover_coordinator, - failover_strategy=failover_strategy, - wire_capture=provider.get_required_service(IWireCapture), # type: ignore[type-abstract] - ) - - # Register backend service and bind to interface - _add_singleton(BackendService, implementation_factory=_backend_service_factory) - - try: - services.add_singleton( - cast(type, IBackendService), implementation_factory=_backend_service_factory - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IBackendService interface: {e}") - # Continue if concrete BackendService is registered - - # Register FailoverService first (dependency of FailoverCoordinator) - def _failover_service_factory(provider: IServiceProvider) -> FailoverService: - # FailoverService constructor takes failover_routes dict, defaulting to empty - return FailoverService(failover_routes={}) - - _add_singleton(FailoverService, implementation_factory=_failover_service_factory) - - # Register failover coordinator (if not already registered elsewhere as a concrete type) - def _failover_coordinator_factory( - provider: IServiceProvider, - ) -> FailoverCoordinator: - from src.core.services.failover_coordinator import FailoverCoordinator - from src.core.services.failover_service import FailoverService - - failover_service = provider.get_required_service(FailoverService) - return FailoverCoordinator(failover_service) - - from src.core.services.failover_coordinator import FailoverCoordinator - - _add_singleton( - FailoverCoordinator, implementation_factory=_failover_coordinator_factory - ) - - try: - from src.core.interfaces.failover_interface import IFailoverCoordinator - - services.add_singleton( - cast(type, IFailoverCoordinator), - implementation_factory=_failover_coordinator_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IFailoverCoordinator interface: {e}") - # Continue if concrete FailoverCoordinator is registered - - # Register request processor - def _request_processor_factory(provider: IServiceProvider) -> RequestProcessor: - # Get required services - command_processor = provider.get_required_service(ICommandProcessor) # type: ignore[type-abstract] - session_manager = provider.get_required_service(ISessionManager) # type: ignore[type-abstract] - backend_request_manager = provider.get_required_service(IBackendRequestManager) # type: ignore[type-abstract] - response_manager = provider.get_required_service(IResponseManager) # type: ignore[type-abstract] - app_state = provider.get_service(IApplicationState) # type: ignore[type-abstract] - - # Return request processor with decomposed services - return RequestProcessor( - command_processor, - session_manager, - backend_request_manager, - response_manager, - app_state=app_state, - ) - - # Register request processor and bind to interface - _add_singleton(RequestProcessor, implementation_factory=_request_processor_factory) - - try: - _add_singleton( - cast(type, IRequestProcessor), - implementation_factory=_request_processor_factory, - ) # type: ignore[type-abstract] - except Exception as e: - logger.warning(f"Failed to register IRequestProcessor interface: {e}") - # Continue if concrete RequestProcessor is registered - - -def get_service(service_type: type[T]) -> T | None: - """Get a service from the global service provider. - - Args: - service_type: The type of service to get - - Returns: - The service instance, or None if the service is not registered - """ - provider = get_or_build_service_provider() - return provider.get_service(service_type) # type: ignore - - -def get_required_service(service_type: type[T]) -> T: - """Get a required service from the global service provider. - - Args: - service_type: The type of service to get - - Returns: - The service instance - - Raises: - Exception: If the service is not registered - """ - provider = get_or_build_service_provider() - return provider.get_required_service(service_type) # type: ignore +""" +Services and DI container configuration. + +This module provides functions for configuring the DI container with services +and resolving services from the container. +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Callable +from typing import Any, TypeVar, cast + +from src.core.common.exceptions import ServiceResolutionError +from src.core.config.app_config import AppConfig +from src.core.di.container import ServiceCollection +from src.core.domain.streaming_response_processor import ( + IStreamProcessor, + LoopDetectionProcessor, +) +from src.core.interfaces.agent_response_formatter_interface import ( + IAgentResponseFormatter, +) +from src.core.interfaces.app_settings_interface import IAppSettings +from src.core.interfaces.application_state_interface import IApplicationState +from src.core.interfaces.backend_config_provider_interface import ( + IBackendConfigProvider, +) +from src.core.interfaces.backend_processor_interface import IBackendProcessor +from src.core.interfaces.backend_request_manager_interface import ( + IBackendRequestManager, +) +from src.core.interfaces.backend_service_interface import IBackendService +from src.core.interfaces.command_processor_interface import ICommandProcessor +from src.core.interfaces.command_service_interface import ICommandService +from src.core.interfaces.configuration_interface import IConfig +from src.core.interfaces.di_interface import IServiceProvider +from src.core.interfaces.middleware_application_manager_interface import ( + IMiddlewareApplicationManager, +) +from src.core.interfaces.path_validator_interface import IPathValidator +from src.core.interfaces.rate_limiter_interface import IRateLimiter +from src.core.interfaces.repositories_interface import ISessionRepository +from src.core.interfaces.request_processor_interface import IRequestProcessor +from src.core.interfaces.response_handler_interface import ( + INonStreamingResponseHandler, + IStreamingResponseHandler, +) +from src.core.interfaces.response_manager_interface import IResponseManager +from src.core.interfaces.response_parser_interface import IResponseParser +from src.core.interfaces.response_processor_interface import ( + IResponseMiddleware, + IResponseProcessor, +) +from src.core.interfaces.session_manager_interface import ISessionManager +from src.core.interfaces.session_resolver_interface import ISessionResolver +from src.core.interfaces.session_service_interface import ISessionService +from src.core.interfaces.state_provider_interface import ( + ISecureStateAccess, + ISecureStateModification, +) +from src.core.interfaces.streaming_response_processor_interface import IStreamNormalizer +from src.core.interfaces.tool_call_repair_service_interface import ( + IToolCallRepairService, +) +from src.core.interfaces.translation_service_interface import ITranslationService +from src.core.interfaces.wire_capture_interface import IWireCapture +from src.core.services.app_settings_service import AppSettings +from src.core.services.application_state_service import ApplicationStateService +from src.core.services.backend_processor import BackendProcessor +from src.core.services.backend_request_manager_service import BackendRequestManager +from src.core.services.backend_service import BackendService +from src.core.services.command_processor import CommandProcessor +from src.core.services.dangerous_command_service import DangerousCommandService +from src.core.services.failover_service import FailoverService +from src.core.services.file_sandboxing_handler import FileSandboxingHandler +from src.core.services.json_repair_service import JsonRepairService +from src.core.services.middleware_application_manager import ( + MiddlewareApplicationManager, +) +from src.core.services.path_validation_service import PathValidationService +from src.core.services.pytest_compression_service import PytestCompressionService +from src.core.services.request_processor_service import RequestProcessor +from src.core.services.response_handlers import ( + DefaultNonStreamingResponseHandler, + DefaultStreamingResponseHandler, +) +from src.core.services.response_manager_service import ( + AgentResponseFormatter, + ResponseManager, +) +from src.core.services.response_parser_service import ResponseParser +from src.core.services.response_processor_service import ResponseProcessor +from src.core.services.secure_command_factory import SecureCommandFactory +from src.core.services.secure_state_service import SecureStateService +from src.core.services.session_manager_service import SessionManager +from src.core.services.session_resolver_service import DefaultSessionResolver +from src.core.services.session_service_impl import SessionService +from src.core.services.streaming.content_accumulation_processor import ( + ContentAccumulationProcessor, +) +from src.core.services.streaming.json_repair_processor import JsonRepairProcessor +from src.core.services.streaming.middleware_application_processor import ( + MiddlewareApplicationProcessor, +) +from src.core.services.streaming.stream_normalizer import StreamNormalizer +from src.core.services.streaming.tool_call_repair_processor import ( + ToolCallRepairProcessor, +) +from src.core.services.structured_output_middleware import StructuredOutputMiddleware +from src.core.services.tool_call_reactor_middleware import ToolCallReactorMiddleware +from src.core.services.tool_call_reactor_service import ( + InMemoryToolCallHistoryTracker, + ToolCallReactorService, +) +from src.core.services.tool_call_repair_service import ToolCallRepairService +from src.core.services.translation_service import TranslationService + +T = TypeVar("T") + +# Global service collection +_service_collection: ServiceCollection | None = None +_service_provider: IServiceProvider | None = None + + +def _get_di_diagnostics() -> bool: + """Get DI diagnostics setting from environment.""" + return os.getenv("DI_STRICT_DIAGNOSTICS", "false").lower() in ( + "true", + "1", + "yes", + ) + + +def get_service_collection() -> ServiceCollection: + """Get the global service collection. + + Returns: + The global service collection + """ + global _service_collection + if _service_collection is None: + _service_collection = ServiceCollection() + # Ensure core services are registered into the global collection early. + # This makes DI shape consistent across processes/tests and avoids many + # order-dependent failures. register_core_services is idempotent. + try: + register_core_services(_service_collection, None) + except Exception as exc: + logging.getLogger(__name__).exception( + "Failed to register core services into global service collection" + ) + _service_collection = None + raise ServiceResolutionError( + "Failed to register core services", + details={ + "error_type": type(exc).__name__, + "error_message": str(exc), + }, + ) from exc + return _service_collection + + +def get_or_build_service_provider() -> IServiceProvider: + """Get the global service provider or build one if it doesn't exist. + + Returns: + The global service provider + """ + global _service_provider + if _service_provider is None: + if _get_di_diagnostics(): + logging.getLogger("llm.di").info( + "Building service provider; descriptors=%d", + len(get_service_collection()._descriptors), + ) + _service_provider = get_service_collection().build_service_provider() + return _service_provider + + +def set_service_provider(provider: IServiceProvider) -> None: + """Set the global service provider (used for tests/late init). + + Args: + provider: The ServiceProvider instance to set as the global provider + """ + global _service_provider + _service_provider = provider + + +def get_service_provider() -> IServiceProvider: + """Return the global service provider, building it if necessary. + + This is a compatibility wrapper used by callers that expect a + `get_service_provider` symbol. + """ + provider = get_or_build_service_provider() + return _ensure_tool_call_reactor_services(provider) + + +def _ensure_tool_call_reactor_services( + provider: IServiceProvider, +) -> IServiceProvider: + """Ensure the provider can resolve ToolCallReactor components. + + Args: + provider: The current service provider instance. + + Returns: + A provider that can resolve the ToolCallReactor service and middleware. + + Raises: + ServiceResolutionError: If re-registration fails to provide the required services. + """ + + from src.core.services.tool_call_reactor_middleware import ToolCallReactorMiddleware + from src.core.services.tool_call_reactor_service import ToolCallReactorService + + missing_components: list[str] = [] + + if provider.get_service(ToolCallReactorService) is None: + missing_components.append("ToolCallReactorService") + if provider.get_service(ToolCallReactorMiddleware) is None: + missing_components.append("ToolCallReactorMiddleware") + + if not missing_components: + return provider + + logger = logging.getLogger(__name__) + logger.warning( + "DI provider missing tool call reactor components: %s. Re-registering core services.", + ", ".join(missing_components), + ) + + services = get_service_collection() + descriptors = getattr(services, "_descriptors", {}) + + preserved_descriptors: dict[type, Any] = {} + for key in (AppConfig, cast(type, IConfig)): + descriptor = descriptors.get(key) + if descriptor is not None: + preserved_descriptors[key] = descriptor + + register_core_services(services) + + descriptors.update(preserved_descriptors) + + new_provider = services.build_service_provider() + set_service_provider(new_provider) + + still_missing: list[str] = [] + if new_provider.get_service(ToolCallReactorService) is None: + still_missing.append("ToolCallReactorService") + if new_provider.get_service(ToolCallReactorMiddleware) is None: + still_missing.append("ToolCallReactorMiddleware") + + if still_missing: + raise ServiceResolutionError( + "Failed to register required Tool Call Reactor services.", + details={"missing_components": still_missing}, + ) + + return new_provider + + +def register_core_services( + services: ServiceCollection, app_config: AppConfig | None = None +) -> None: + """Register core services with the service collection. + + Args: + services: The service collection to register services with + app_config: Optional application configuration + """ + logger: logging.Logger = logging.getLogger(__name__) + # Register AppConfig and IConfig + if app_config is not None: + services.add_instance(AppConfig, app_config) + # Also register it as IConfig for interface resolution + try: + services.add_instance( + cast(type, IConfig), + app_config, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IConfig interface: {e}") + # Continue without interface registration if it fails + else: + # Register default AppConfig as IConfig for testing and basic functionality + default_config = AppConfig() + services.add_instance(AppConfig, default_config) + try: + services.add_instance( + cast(type, IConfig), + default_config, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register default IConfig interface: {e}") + # Continue without interface registration if it fails + + # Helper wrappers to make registration idempotent and provide debug logging + + def _registered(service_type: type) -> bool: + desc = getattr(services, "_descriptors", None) + return desc is not None and service_type in desc + + def _add_singleton( + service_type: type, + implementation_type: type | None = None, + implementation_factory: Callable[[IServiceProvider], Any] | None = None, + ) -> None: + if _registered(service_type): + logger.debug( + "Skipping registration of %s; already present", + getattr(service_type, "__name__", str(service_type)), + ) + return + services.add_singleton( + service_type, implementation_type, implementation_factory + ) + + def _add_instance(service_type: type, instance: Any) -> None: + if _registered(service_type): + logger.debug( + "Skipping instance registration of %s; already present", + getattr(service_type, "__name__", str(service_type)), + ) + return + services.add_instance(service_type, instance) + + # Register session resolver + _add_singleton(DefaultSessionResolver) + # Register both the concrete type and the interface + _add_singleton(ISessionResolver, DefaultSessionResolver) # type: ignore[type-abstract] + + # Register CommandService with factory + def _command_service_factory(provider: IServiceProvider) -> ICommandService: + from src.core.commands.parser import CommandParser + from src.core.commands.service import NewCommandService + from src.core.services.command_policy_service import CommandPolicyService + from src.core.services.command_state_service import CommandStateService + from src.core.services.session_service_impl import SessionService + + session_service = provider.get_required_service(SessionService) + command_parser = provider.get_required_service(CommandParser) + config = provider.get_required_service(AppConfig) + app_state = provider.get_service(cast(type, IApplicationState)) + state_service = provider.get_required_service(CommandStateService) + policy_service = provider.get_required_service(CommandPolicyService) + return NewCommandService( + session_service, + command_parser, + strict_command_detection=config.strict_command_detection, + app_state=app_state, + command_state_service=state_service, + command_policy_service=policy_service, + config=config, + ) + + # Register CommandService and bind to interface + _add_singleton(ICommandService, implementation_factory=_command_service_factory) # type: ignore[type-abstract] + + # Register CommandParser + from src.core.commands.parser import CommandParser + from src.core.interfaces.command_parser_interface import ICommandParser + + _add_singleton(ICommandParser, CommandParser) # type: ignore[type-abstract] + _add_singleton(CommandParser, CommandParser) # Also register concrete type + + # Ensure command handlers are imported so their @command decorators register them + try: + import importlib + import pkgutil + + package_name = "src.core.commands.handlers" + package = importlib.import_module(package_name) + for m in pkgutil.iter_modules(package.__path__): # type: ignore[attr-defined] + importlib.import_module(f"{package_name}.{m.name}") + except Exception: + logging.getLogger(__name__).warning( + "Failed to import command handlers for registration", exc_info=True + ) + + # Register session service factory + def _session_service_factory(provider: IServiceProvider) -> SessionService: + # Import here to avoid circular imports + from src.core.repositories.in_memory_session_repository import ( + InMemorySessionRepository, + ) + + # Create repository + repository: InMemorySessionRepository = InMemorySessionRepository() + + # Return session service + return SessionService(repository) + + # Register session service and bind to interface + _add_singleton(SessionService, implementation_factory=_session_service_factory) + + try: + services.add_singleton( + cast(type, ISessionService), implementation_factory=_session_service_factory + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register ISessionService interface: {e}") + # Continue if concrete SessionService is registered + + # Register command state service + from src.core.interfaces.command_state_service_interface import ( + ICommandStateService, + ) + from src.core.services.command_state_service import CommandStateService + + def _command_state_service_factory( + provider: IServiceProvider, + ) -> CommandStateService: + session = provider.get_required_service(SessionService) + return CommandStateService(session) + + _add_singleton( + CommandStateService, implementation_factory=_command_state_service_factory + ) + + try: + services.add_singleton( + cast(type, ICommandStateService), + implementation_factory=lambda provider: provider.get_required_service( + CommandStateService + ), + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register ICommandStateService interface: {e}") + # Continue if concrete CommandStateService is registered + + # Register command policy service + from src.core.interfaces.command_policy_service_interface import ( + ICommandPolicyService, + ) + from src.core.services.command_policy_service import CommandPolicyService + + def _command_policy_service_factory( + provider: IServiceProvider, + ) -> CommandPolicyService: + cfg = provider.get_required_service(AppConfig) + app_state = provider.get_service(cast(type, IApplicationState)) + return CommandPolicyService(cfg, app_state) + + _add_singleton( + CommandPolicyService, implementation_factory=_command_policy_service_factory + ) + + try: + services.add_singleton( + cast(type, ICommandPolicyService), + implementation_factory=lambda provider: provider.get_required_service( + CommandPolicyService + ), + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register ICommandPolicyService interface: {e}") + # Continue if concrete CommandPolicyService is registered + + # Register command processor + def _command_processor_factory(provider: IServiceProvider) -> ICommandProcessor: + # Get command service + from typing import cast + + from src.core.commands.tool_call_command_processor import ( + ToolCallCommandProcessor, + ) + from src.core.services.delegating_command_processor import ( + DelegatingCommandProcessor, + ) + + command_service: ICommandService = provider.get_required_service( + cast(type, ICommandService) + ) + + # Create the processors + text_command_processor = CommandProcessor(command_service) + tool_call_command_processor = ToolCallCommandProcessor(command_service) + + # Return the delegating processor + return DelegatingCommandProcessor( + tool_call_command_processor=tool_call_command_processor, + text_command_processor=text_command_processor, + ) + + # Register command processor and bind to interface + try: + services.add_singleton( + cast(type, ICommandProcessor), + implementation_factory=_command_processor_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register ICommandProcessor interface: {e}") + # Continue without interface registration if it fails + + # Register backend processor + def _backend_processor_factory(provider: IServiceProvider) -> BackendProcessor: + # Get backend service and session service + from typing import cast + + backend_service: IBackendService = provider.get_required_service( + cast(type, IBackendService) + ) + session_service: ISessionService = provider.get_required_service( + cast(type, ISessionService) + ) + app_state: IApplicationState = provider.get_required_service( + cast(type, IApplicationState) + ) + + # Return backend processor + return BackendProcessor(backend_service, session_service, app_state) + + # Register backend processor and bind to interface + _add_singleton(BackendProcessor, implementation_factory=_backend_processor_factory) + + try: + services.add_singleton( + cast(type, IBackendProcessor), + implementation_factory=_backend_processor_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IBackendProcessor interface: {e}") + # Continue if concrete BackendProcessor is registered + + # Register response handlers + _add_singleton(DefaultNonStreamingResponseHandler) + _add_singleton(DefaultStreamingResponseHandler) + + try: + services.add_singleton( + cast(type, INonStreamingResponseHandler), DefaultNonStreamingResponseHandler + ) + services.add_singleton( + cast(type, IStreamingResponseHandler), DefaultStreamingResponseHandler + ) + except Exception as e: + logger.warning(f"Failed to register response handler interfaces: {e}") + # Continue if concrete handlers are registered + + # Register MiddlewareApplicationManager and IMiddlewareApplicationManager with configured middleware list + def _middleware_application_manager_factory( + provider: IServiceProvider, + ) -> MiddlewareApplicationManager: + from src.core.app.middleware.json_repair_middleware import JsonRepairMiddleware + from src.core.app.middleware.tool_call_repair_middleware import ( + ToolCallRepairMiddleware, + ) + from src.core.config.app_config import AppConfig + from src.core.services.empty_response_middleware import ( + EmptyResponseMiddleware, + ) + from src.core.services.middleware_application_manager import ( + MiddlewareApplicationManager, + ) + from src.core.services.tool_call_loop_middleware import ( + ToolCallLoopDetectionMiddleware, + ) + + cfg: AppConfig = provider.get_required_service(AppConfig) + middlewares: list[IResponseMiddleware] = [] + + try: + if getattr(cfg.empty_response, "enabled", True): + middlewares.append( + EmptyResponseMiddleware( + enabled=True, + max_retries=getattr(cfg.empty_response, "max_retries", 1), + ) + ) + except Exception as e: + logging.getLogger(__name__).warning( + f"Error configuring EmptyResponseMiddleware: {e}", exc_info=True + ) + + # Edit-precision response-side detection (optional) + try: + from src.core.services.edit_precision_response_middleware import ( + EditPrecisionResponseMiddleware, + ) + + app_state = provider.get_required_service(ApplicationStateService) + middlewares.append(EditPrecisionResponseMiddleware(app_state)) + except Exception as e: + logging.getLogger(__name__).warning( + f"Error configuring EditPrecisionResponseMiddleware: {e}", + exc_info=True, + ) + + # Think tags fix middleware (optional) + try: + if getattr(cfg.session, "fix_think_tags_enabled", False): + from src.core.services.think_tags_fix_middleware import ( + ThinkTagsFixMiddleware, + ) + + # Configure streaming buffer size from config + buffer_size = getattr( + cfg.session, "fix_think_tags_streaming_buffer_size", 4096 + ) + middlewares.append( + ThinkTagsFixMiddleware( + enabled=True, streaming_buffer_size=buffer_size + ) + ) + except Exception as e: + logging.getLogger(__name__).warning( + f"Error configuring ThinkTagsFixMiddleware: {e}", + exc_info=True, + ) + + if getattr(cfg.session, "json_repair_enabled", False): + json_service: JsonRepairService = provider.get_required_service( + JsonRepairService + ) + middlewares.append(JsonRepairMiddleware(cfg, json_service)) + + if getattr(cfg.session, "tool_call_repair_enabled", True): + tcr_service: ToolCallRepairService = provider.get_required_service( + ToolCallRepairService + ) + middlewares.append(ToolCallRepairMiddleware(cfg, tcr_service)) + + try: + middlewares.append(ToolCallLoopDetectionMiddleware()) + except Exception as e: + logging.getLogger(__name__).warning( + f"Error configuring ToolCallLoopDetectionMiddleware: {e}", exc_info=True + ) + + # Add tool call reactor middleware + try: + tool_call_reactor_middleware = provider.get_required_service( + ToolCallReactorMiddleware + ) + middlewares.append(tool_call_reactor_middleware) + except Exception as e: + logging.getLogger(__name__).warning( + f"Error configuring ToolCallReactorMiddleware: {e}", exc_info=True + ) + + # Dangerous command prevention will be handled by Tool Call Reactor handler. + # Keeping old middleware disabled to avoid duplicate processing. + + return MiddlewareApplicationManager(middlewares) + + _add_singleton( + MiddlewareApplicationManager, + implementation_factory=_middleware_application_manager_factory, + ) + try: + services.add_singleton( + cast(type, IMiddlewareApplicationManager), + implementation_factory=_middleware_application_manager_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning( + f"Failed to register IMiddlewareApplicationManager interface: {e}" + ) + # Continue if concrete MiddlewareApplicationManager is registered + + # Register MiddlewareApplicationProcessor used inside the streaming pipeline + def _middleware_application_processor_factory( + provider: IServiceProvider, + ) -> MiddlewareApplicationProcessor: + manager: MiddlewareApplicationManager = provider.get_required_service( + MiddlewareApplicationManager + ) + app_state: IApplicationState = provider.get_required_service( + IApplicationState # type: ignore[type-abstract] + ) + + import os + + from src.core.domain.configuration.loop_detection_config import ( + LoopDetectionConfiguration, + ) + from src.tool_call_loop.config import ToolCallLoopConfig + + env_config = ToolCallLoopConfig.from_env_vars(dict(os.environ)) + loop_config = ( + LoopDetectionConfiguration() + .with_tool_loop_detection_enabled(env_config.enabled) + .with_tool_loop_max_repeats(env_config.max_repeats) + .with_tool_loop_ttl_seconds(env_config.ttl_seconds) + .with_tool_loop_mode(env_config.mode) + ) + + return MiddlewareApplicationProcessor( + manager._middleware, + default_loop_config=loop_config, + app_state=app_state, + ) + + _add_singleton( + MiddlewareApplicationProcessor, + implementation_factory=_middleware_application_processor_factory, + ) + + # Register response processor + def _response_processor_factory(provider: IServiceProvider) -> ResponseProcessor: + from typing import cast + + app_state: IApplicationState = provider.get_required_service( + cast(type, IApplicationState) + ) + stream_normalizer: IStreamNormalizer = provider.get_required_service( + cast(type, IStreamNormalizer) + ) + response_parser: IResponseParser = provider.get_required_service( + cast(type, IResponseParser) + ) + middleware_application_manager: IMiddlewareApplicationManager = ( + provider.get_required_service(cast(type, IMiddlewareApplicationManager)) + ) + + # Get the middleware manager to access the middleware list + middleware_manager: MiddlewareApplicationManager = ( + provider.get_required_service(MiddlewareApplicationManager) + ) + + # Get loop detector for non-streaming responses + from src.core.interfaces.loop_detector_interface import ILoopDetector + + loop_detector = provider.get_service(cast(type, ILoopDetector)) + + return ResponseProcessor( + response_parser=response_parser, + middleware_application_manager=middleware_application_manager, + app_state=app_state, + loop_detector=loop_detector, + stream_normalizer=stream_normalizer, + middleware_list=middleware_manager._middleware, + ) + + # Register response processor and bind to interface + _add_singleton( + ResponseProcessor, implementation_factory=_response_processor_factory + ) + + try: + services.add_singleton( + cast(type, IResponseProcessor), + implementation_factory=_response_processor_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IResponseProcessor interface: {e}") + # Continue if concrete ResponseProcessor is registered + + def _application_state_factory( + provider: IServiceProvider, + ) -> ApplicationStateService: + # Create application state service + return ApplicationStateService() + + # Register app settings + def _app_settings_factory(provider: IServiceProvider) -> AppSettings: + # Get app_state from IApplicationState if available + app_state: Any | None = None + try: + app_state_service: IApplicationState | None = provider.get_service( + ApplicationStateService + ) + if app_state_service: + app_state = app_state_service.get_setting("service_provider") + except Exception as e: + logger.debug(f"Could not get app_state from ApplicationStateService: {e}") + app_state = None + + # Create app settings + return AppSettings(app_state) + + # Register app settings and bind to interface + _add_singleton(AppSettings, implementation_factory=_app_settings_factory) + + try: + services.add_singleton( + cast(type, IAppSettings), implementation_factory=_app_settings_factory + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IAppSettings interface: {e}") + # Continue if concrete AppSettings is registered + + # Register application state service + _add_singleton(ApplicationStateService) + + try: + services.add_singleton( + cast(type, IApplicationState), + implementation_factory=_application_state_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IApplicationState interface: {e}") + # Continue if concrete ApplicationStateService is registered + + # Register secure state service + def _secure_state_factory(provider: IServiceProvider) -> SecureStateService: + app_state = provider.get_required_service(ApplicationStateService) + return SecureStateService(app_state) + + _add_singleton(SecureStateService, implementation_factory=_secure_state_factory) + + try: + services.add_singleton( + cast(type, ISecureStateAccess), implementation_factory=_secure_state_factory + ) # type: ignore[type-abstract] + services.add_singleton( + cast(type, ISecureStateModification), + implementation_factory=_secure_state_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register secure state interfaces: {e}") + # Continue if concrete SecureStateService is registered + + # Register secure command factory + def _secure_command_factory(provider: IServiceProvider) -> SecureCommandFactory: + secure_state = provider.get_required_service(SecureStateService) + return SecureCommandFactory( + state_reader=secure_state, state_modifier=secure_state + ) + + _add_singleton(SecureCommandFactory, implementation_factory=_secure_command_factory) + + # Register conversation fingerprint service + from src.core.services.conversation_fingerprint_service import ( + ConversationFingerprintService, + ) + + _add_singleton(ConversationFingerprintService) + + # Register session manager + def _session_manager_factory(provider: IServiceProvider) -> SessionManager: + session_service = provider.get_required_service(ISessionService) # type: ignore[type-abstract] + session_resolver = provider.get_required_service(ISessionResolver) # type: ignore[type-abstract] + # Get session repository for fingerprint tracking + session_repository = provider.get_service(cast(type, ISessionRepository)) # type: ignore[type-abstract] + fingerprint_service = provider.get_required_service( + ConversationFingerprintService + ) + return SessionManager( + session_service, + session_resolver, + session_repository=session_repository, + fingerprint_service=fingerprint_service, + ) + + _add_singleton(SessionManager, implementation_factory=_session_manager_factory) + + try: + services.add_singleton( + cast(type, ISessionManager), implementation_factory=_session_manager_factory + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register ISessionManager interface: {e}") + # Continue if concrete SessionManager is registered + + # Register agent response formatter + def _agent_response_formatter_factory( + provider: IServiceProvider, + ) -> AgentResponseFormatter: + session_service = provider.get_service(SessionService) + return AgentResponseFormatter(session_service=session_service) + + _add_singleton( + AgentResponseFormatter, implementation_factory=_agent_response_formatter_factory + ) + + try: + services.add_singleton( + cast(type, IAgentResponseFormatter), + implementation_factory=_agent_response_formatter_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IAgentResponseFormatter interface: {e}") + # Continue if concrete AgentResponseFormatter is registered + + # Register response manager + def _response_manager_factory(provider: IServiceProvider) -> ResponseManager: + agent_response_formatter = provider.get_required_service(IAgentResponseFormatter) # type: ignore[type-abstract] + session_service = provider.get_required_service(ISessionService) # type: ignore[type-abstract] + return ResponseManager(agent_response_formatter, session_service) + + _add_singleton(ResponseManager, implementation_factory=_response_manager_factory) + + try: + services.add_singleton( + cast(type, IResponseManager), + implementation_factory=_response_manager_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IResponseManager interface: {e}") + # Continue if concrete ResponseManager is registered + + # Register backend request manager + def _backend_request_manager_factory( + provider: IServiceProvider, + ) -> BackendRequestManager: + backend_processor = provider.get_required_service(IBackendProcessor) # type: ignore[type-abstract] + response_processor = provider.get_required_service(IResponseProcessor) # type: ignore[type-abstract] + wire_capture = provider.get_required_service(IWireCapture) # type: ignore[type-abstract] + return BackendRequestManager( + backend_processor, response_processor, wire_capture + ) + + _add_singleton( + BackendRequestManager, implementation_factory=_backend_request_manager_factory + ) + + try: + services.add_singleton( + cast(type, IBackendRequestManager), + implementation_factory=_backend_request_manager_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IBackendRequestManager interface: {e}") + # Continue if concrete BackendRequestManager is registered + + # Register stream normalizer + def _stream_normalizer_factory(provider: IServiceProvider) -> StreamNormalizer: + # Retrieve all stream processors in the correct order + try: + from src.core.config.app_config import AppConfig + + app_config: AppConfig = provider.get_required_service(AppConfig) + + # Optional JSON repair processor (enabled via config) + json_repair_processor = None + if getattr(app_config.session, "json_repair_enabled", False): + json_repair_processor = provider.get_required_service( + JsonRepairProcessor + ) + tool_call_repair_processor = None + if getattr(app_config.session, "tool_call_repair_enabled", True): + tool_call_repair_processor = provider.get_required_service( + ToolCallRepairProcessor + ) + loop_detection_processor = None + try: + loop_detection_processor = provider.get_required_service( + LoopDetectionProcessor + ) + logger.debug( + "LoopDetectionProcessor successfully registered for streaming" + ) + except Exception as e: + logger.warning( + f"Failed to register LoopDetectionProcessor for streaming: {e}" + ) + loop_detection_processor = None + middleware_application_processor = provider.get_required_service( + MiddlewareApplicationProcessor + ) + content_accumulation_processor = provider.get_required_service( + ContentAccumulationProcessor + ) + + processors: list[IStreamProcessor] = [] + # Prefer JSON repair first so JSON blocks are valid + if json_repair_processor is not None: + processors.append(json_repair_processor) + # Then text loop detection + if loop_detection_processor is not None: + processors.append(loop_detection_processor) + # Then tool-call repair + if tool_call_repair_processor is not None: + processors.append(tool_call_repair_processor) + # Middleware and accumulation + processors.append(middleware_application_processor) + processors.append(content_accumulation_processor) + except Exception as e: + logger.warning( + f"Error creating stream processors: {e}. Using default configuration." + ) + # Create minimal configuration with just content accumulation + # Use default 10MB buffer limit for fallback + content_accumulation_processor = ContentAccumulationProcessor( + max_buffer_bytes=10 * 1024 * 1024 + ) + processors = [content_accumulation_processor] + + return StreamNormalizer(processors) + + _add_singleton(StreamNormalizer, implementation_factory=_stream_normalizer_factory) + + try: + services.add_singleton( + cast(type, IStreamNormalizer), + implementation_factory=_stream_normalizer_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IStreamNormalizer interface: {e}") + # Continue if concrete StreamNormalizer is registered + + # Register ResponseParser + def _response_parser_factory(provider: IServiceProvider) -> ResponseParser: + + return ResponseParser() + + _add_singleton(ResponseParser, implementation_factory=_response_parser_factory) + try: + services.add_singleton( + cast(type, IResponseParser), implementation_factory=_response_parser_factory + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IResponseParser interface: {e}") + # Continue if concrete ResponseParser is registered + + # Register individual stream processors + def _loop_detection_processor_factory( + provider: IServiceProvider, + ) -> LoopDetectionProcessor: + from src.core.interfaces.loop_detector_interface import ILoopDetector + + # Create a factory function that creates new detector instances + # This ensures each session gets its own isolated detector + def create_detector() -> ILoopDetector: + return provider.get_required_service(cast(type, ILoopDetector)) + + app_config = provider.get_required_service(AppConfig) + session_ttl = getattr( + app_config.session, "loop_detection_session_ttl_seconds", 300 + ) + + return LoopDetectionProcessor( + loop_detector_factory=create_detector, + session_ttl_seconds=session_ttl, + ) + + _add_singleton( + LoopDetectionProcessor, implementation_factory=_loop_detection_processor_factory + ) + + # Register ContentAccumulationProcessor with configured buffer limit + def _content_accumulation_processor_factory( + provider: IServiceProvider, + ) -> ContentAccumulationProcessor: + from src.core.config.app_config import AppConfig + + config: AppConfig = provider.get_required_service(AppConfig) + buffer_cap = getattr( + config.session, "content_accumulation_buffer_cap_bytes", 10 * 1024 * 1024 + ) + return ContentAccumulationProcessor(max_buffer_bytes=buffer_cap) + + _add_singleton( + ContentAccumulationProcessor, + implementation_factory=_content_accumulation_processor_factory, + ) + + # Register JSON repair service and processor + def _json_repair_service_factory(provider: IServiceProvider) -> JsonRepairService: + return JsonRepairService() + + _add_singleton( + JsonRepairService, implementation_factory=_json_repair_service_factory + ) + + # Register StructuredOutputMiddleware + def _structured_output_middleware_factory( + provider: IServiceProvider, + ) -> StructuredOutputMiddleware: + json_repair_service: JsonRepairService = provider.get_required_service( + JsonRepairService + ) + return StructuredOutputMiddleware(json_repair_service) + + _add_singleton( + StructuredOutputMiddleware, + implementation_factory=_structured_output_middleware_factory, + ) + + def _json_repair_processor_factory( + provider: IServiceProvider, + ) -> JsonRepairProcessor: + from src.core.config.app_config import AppConfig + + config: AppConfig = provider.get_required_service(AppConfig) + service: JsonRepairService = provider.get_required_service(JsonRepairService) + return JsonRepairProcessor( + repair_service=service, + buffer_cap_bytes=getattr( + config.session, "json_repair_buffer_cap_bytes", 64 * 1024 + ), + strict_mode=getattr(config.session, "json_repair_strict_mode", False), + schema=getattr(config.session, "json_repair_schema", None), + enabled=getattr(config.session, "json_repair_enabled", False), + ) + + _add_singleton( + JsonRepairProcessor, implementation_factory=_json_repair_processor_factory + ) + + # Wire capture service is registered in CoreServicesStage using BufferedWireCapture. + # Intentionally avoid legacy StructuredWireCapture registration here to keep + # the active format consistent across the app. + + # Register tool call repair service (if not already registered elsewhere as a concrete type) + def _tool_call_repair_service_factory( + provider: IServiceProvider, + ) -> ToolCallRepairService: + return ToolCallRepairService() + + _add_singleton( + ToolCallRepairService, implementation_factory=_tool_call_repair_service_factory + ) + + # Register TranslationService (dependency of BackendService) + def _translation_service_factory(provider: IServiceProvider) -> TranslationService: + return TranslationService() + + _add_singleton( + TranslationService, implementation_factory=_translation_service_factory + ) + + # Register ITranslationService interface to resolve to the same singleton instance + def _translation_service_interface_factory( + provider: IServiceProvider, + ) -> TranslationService: + return provider.get_required_service(TranslationService) + + _add_singleton( + cast(type, ITranslationService), + implementation_factory=_translation_service_interface_factory, + ) + + # Register assessment services if enabled + if app_config and app_config.assessment.enabled: + logger.info( + "LLM Assessment System ACTIVATED - Monitoring conversations for unproductive patterns" + ) + + # Initialize assessment prompts first + from src.core.services.assessment_prompts import initialize_prompts + + try: + initialize_prompts() + logger.info("Assessment prompts loaded successfully") + except Exception as e: + logger.error(f"Failed to load assessment prompts: {e}") + raise + + # Import assessment services only when needed to avoid circular imports + from src.core.interfaces.assessment_service_interface import ( + IAssessmentBackendService, + IAssessmentRepository, + IAssessmentService, + ITurnCounterService, + ) + from src.core.repositories.assessment_repository import ( + InMemoryAssessmentRepository, + ) + from src.core.services.assessment_backend_service import ( + AssessmentBackendService, + ) + from src.core.services.assessment_service import AssessmentService + from src.core.services.turn_counter_service import TurnCounterService + + # Assessment repository + def _assessment_repository_factory( + provider: IServiceProvider, + ) -> InMemoryAssessmentRepository: + return InMemoryAssessmentRepository() + + _add_singleton( + IAssessmentRepository, implementation_factory=_assessment_repository_factory # type: ignore[type-abstract] + ) + + # Turn counter service + def _turn_counter_service_factory( + provider: IServiceProvider, + ) -> TurnCounterService: + repository = provider.get_required_service(IAssessmentRepository) # type: ignore[type-abstract] + config = provider.get_required_service(AppConfig).assessment + return TurnCounterService(repository, config) + + _add_singleton( + ITurnCounterService, implementation_factory=_turn_counter_service_factory # type: ignore[type-abstract] + ) + + # Assessment backend service + + def _assessment_backend_service_factory( + provider: IServiceProvider, + ) -> AssessmentBackendService: + backend_service = provider.get_required_service(IBackendService) # type: ignore[type-abstract] + config = provider.get_required_service(AppConfig).assessment + return AssessmentBackendService(backend_service, config) + + _add_singleton( + IAssessmentBackendService, + implementation_factory=_assessment_backend_service_factory, # type: ignore[type-abstract] + ) + + # Core assessment service + + def _assessment_service_factory( + provider: IServiceProvider, + ) -> AssessmentService: + backend_service = provider.get_required_service(IAssessmentBackendService) # type: ignore[type-abstract] + config = provider.get_required_service(AppConfig).assessment + return AssessmentService(backend_service, config) + + _add_singleton( + IAssessmentService, implementation_factory=_assessment_service_factory # type: ignore[type-abstract] + ) + + logger.info( + f"Assessment services registered: backend={app_config.assessment.backend}, " + f"model={app_config.assessment.model}, threshold={app_config.assessment.turn_threshold}" + ) + + try: + services.add_singleton( + cast(type, IToolCallRepairService), + implementation_factory=_tool_call_repair_service_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IToolCallRepairService interface: {e}") + # Continue if concrete ToolCallRepairService is registered + + # Register tool call repair processor + def _tool_call_repair_processor_factory( + provider: IServiceProvider, + ) -> ToolCallRepairProcessor: + tool_call_repair_service = provider.get_required_service(IToolCallRepairService) # type: ignore[type-abstract] + return ToolCallRepairProcessor(tool_call_repair_service) + + _add_singleton( + ToolCallRepairProcessor, + implementation_factory=_tool_call_repair_processor_factory, + ) + + # Register dangerous command service + def _dangerous_command_service_factory( + provider: IServiceProvider, + ) -> DangerousCommandService: + from src.core.config.app_config import AppConfig + from src.core.domain.configuration.dangerous_command_config import ( + DEFAULT_DANGEROUS_COMMAND_CONFIG, + ) + from src.core.services.dangerous_command_service import ( + DangerousCommandService, + ) + + provider.get_required_service(AppConfig) + return DangerousCommandService(DEFAULT_DANGEROUS_COMMAND_CONFIG) + + _add_singleton( + DangerousCommandService, + implementation_factory=_dangerous_command_service_factory, + ) + + # Register pytest compression service + def _pytest_compression_service_factory( + provider: IServiceProvider, + ) -> PytestCompressionService: + from src.core.services.pytest_compression_service import ( + PytestCompressionService, + ) + + provider.get_required_service(AppConfig) + return PytestCompressionService() + + _add_singleton( + PytestCompressionService, + implementation_factory=_pytest_compression_service_factory, + ) + + # Register tool access policy service + from src.core.services.tool_access_policy_service import ToolAccessPolicyService + + def _tool_access_policy_service_factory( + provider: IServiceProvider, + ) -> ToolAccessPolicyService: + from src.core.config.app_config import AppConfig + + app_config: AppConfig = provider.get_required_service(AppConfig) + reactor_config = app_config.session.tool_call_reactor + + # Get global overrides from session config (set by CLI parameters) + global_overrides = getattr( + app_config.session, "tool_access_global_overrides", None + ) + return ToolAccessPolicyService( + reactor_config, global_overrides=global_overrides + ) + + _add_singleton( + ToolAccessPolicyService, + implementation_factory=_tool_access_policy_service_factory, + ) + + # Register tool call reactor services + def _tool_call_history_tracker_factory( + provider: IServiceProvider, + ) -> InMemoryToolCallHistoryTracker: + return InMemoryToolCallHistoryTracker() + + _add_singleton( + InMemoryToolCallHistoryTracker, + implementation_factory=_tool_call_history_tracker_factory, + ) + + def _tool_call_reactor_factory( + provider: IServiceProvider, + ) -> ToolCallReactorService: + from src.core.config.app_config import AppConfig + + history_tracker = provider.get_required_service(InMemoryToolCallHistoryTracker) + reactor = ToolCallReactorService(history_tracker) + + # Get configuration + app_config: AppConfig = provider.get_required_service(AppConfig) + reactor_config = app_config.session.tool_call_reactor + + # Register default handlers if enabled + if reactor_config.enabled: + from src.core.services.tool_call_handlers.config_steering_handler import ( + ConfigSteeringHandler, + ) + from src.core.services.tool_call_handlers.dangerous_command_handler import ( + DangerousCommandHandler, + ) + from src.core.services.tool_call_handlers.pytest_full_suite_handler import ( + PytestFullSuiteHandler, + ) + + # Register config-driven steering handler (includes synthesized legacy apply_diff rule when enabled) + try: + # Build effective rules from config while avoiding expensive deep copy. + # Since steering_rules are configuration data (immutable during runtime), + # we can safely use a shallow copy for better performance. + effective_rules = ( + (reactor_config.steering_rules or []).copy() + if reactor_config.steering_rules + else [] + ) + + # Synthesize legacy apply_diff rule if enabled and missing + if getattr(reactor_config, "apply_diff_steering_enabled", True): + has_apply_rule = False + for r in effective_rules: + triggers = (r or {}).get("triggers") or {} + tnames = triggers.get("tool_names") or [] + phrases = triggers.get("phrases") or [] + if "apply_diff" in tnames or any( + isinstance(p, str) and "apply_diff" in p for p in phrases + ): + has_apply_rule = True + break + if not has_apply_rule: + effective_rules.append( + { + "name": "apply_diff_to_patch_file", + "enabled": True, + "priority": 100, + "triggers": { + "tool_names": ["apply_diff"], + "phrases": [], + }, + "message": ( + reactor_config.apply_diff_steering_message + or ( + "You tried to use apply_diff tool. Please prefer to use patch_file tool instead, " + "as it is superior to apply_diff and provides automated Python QA checks." + ) + ), + "rate_limit": { + "calls_per_window": 1, + "window_seconds": reactor_config.apply_diff_steering_rate_limit_seconds, + }, + } + ) + + if effective_rules: + config_handler = ConfigSteeringHandler(rules=effective_rules) + try: + reactor.register_handler_sync(config_handler) + except Exception as e: + logger.warning( + f"Failed to register config steering handler: {e}", + exc_info=True, + ) + except Exception as e: + logger.warning( + "Failed to register steering handlers: %s", e, exc_info=True + ) + + # Register DangerousCommandHandler if enabled in session config + try: + if getattr( + app_config.session, "dangerous_command_prevention_enabled", True + ): + dangerous_service = provider.get_required_service( + DangerousCommandService + ) + dangerous_handler = DangerousCommandHandler( + dangerous_service, + steering_message=getattr( + app_config.session, + "dangerous_command_steering_message", + None, + ), + enabled=True, + ) + try: + reactor.register_handler_sync(dangerous_handler) + except Exception as e: + logger.warning( + f"Failed to register dangerous command handler: {e}", + exc_info=True, + ) + except Exception as e: + logger.warning( + f"Failed to register DangerousCommandHandler: {e}", exc_info=True + ) + + # Register PytestFullSuiteHandler if enabled + try: + if getattr(reactor_config, "pytest_full_suite_steering_enabled", False): + steering_message = getattr( + reactor_config, "pytest_full_suite_steering_message", None + ) + pytest_full_suite_handler = PytestFullSuiteHandler( + message=steering_message, + enabled=True, + ) + try: + reactor.register_handler_sync(pytest_full_suite_handler) + except Exception as e: + logger.warning( + f"Failed to register pytest full-suite handler: {e}", + exc_info=True, + ) + except Exception as e: + logger.warning( + f"Failed to register PytestFullSuiteHandler: {e}", exc_info=True + ) + + # Register PytestContextSavingHandler if enabled + try: + if getattr(reactor_config, "pytest_context_saving_enabled", False): + from src.core.services.tool_call_handlers.pytest_context_saving_handler import ( + PytestContextSavingHandler, + ) + + context_saving_handler = PytestContextSavingHandler(enabled=True) + try: + reactor.register_handler_sync(context_saving_handler) + except Exception as e: + logger.warning( + f"Failed to register pytest context saving handler: {e}", + exc_info=True, + ) + except Exception as e: + logger.warning( + f"Failed to register PytestContextSavingHandler: {e}", exc_info=True + ) + + # Register PytestCompressionHandler if enabled in session config + try: + if getattr(app_config.session, "pytest_compression_enabled", True): + from src.core.services.tool_call_handlers.pytest_compression_handler import ( + PytestCompressionHandler, + ) + + pytest_compression_service = provider.get_required_service( + PytestCompressionService + ) + session_service = provider.get_required_service(SessionService) + pytest_handler = PytestCompressionHandler( + pytest_compression_service, + session_service, + enabled=True, + ) + try: + reactor.register_handler_sync(pytest_handler) + except Exception as e: + logger.warning( + f"Failed to register pytest compression handler: {e}", + exc_info=True, + ) + except Exception as e: + logger.warning( + f"Failed to register PytestCompressionHandler: {e}", exc_info=True + ) + + # Register ToolAccessControlHandler if access policies are configured + try: + from src.core.services.tool_access_policy_service import ( + ToolAccessPolicyService, + ) + from src.core.services.tool_call_handlers.tool_access_control_handler import ( + ToolAccessControlHandler, + ) + + # Get the policy service + policy_service = provider.get_required_service(ToolAccessPolicyService) + + # Only register if there are policies configured + if policy_service._policies: + tool_access_handler = ToolAccessControlHandler( + policy_service=policy_service, + priority=90, # After dangerous-command handler (100) + reactor_service=reactor, # Pass reactor for telemetry + ) + try: + reactor.register_handler_sync(tool_access_handler) + logger.info( + f"Registered ToolAccessControlHandler with priority 90 " + f"({len(policy_service._policies)} policies loaded)" + ) + except Exception as e: + logger.warning( + f"Failed to register tool access control handler: {e}", + exc_info=True, + ) + except Exception as e: + logger.warning( + f"Failed to register ToolAccessControlHandler: {e}", exc_info=True + ) + + return reactor + + _add_singleton( + ToolCallReactorService, + implementation_factory=_tool_call_reactor_factory, + ) + + def _tool_call_reactor_middleware_factory( + provider: IServiceProvider, + ) -> ToolCallReactorMiddleware: + from src.core.config.app_config import AppConfig + + reactor = provider.get_required_service(ToolCallReactorService) + + # Get configuration to determine if middleware should be enabled + app_config: AppConfig = provider.get_required_service(AppConfig) + enabled = app_config.session.tool_call_reactor.enabled + + return ToolCallReactorMiddleware(reactor, enabled=enabled, priority=-10) + + _add_singleton( + ToolCallReactorMiddleware, + implementation_factory=_tool_call_reactor_middleware_factory, + ) + + # Register PathValidationService + def _path_validation_service_factory( + provider: IServiceProvider, + ) -> PathValidationService: + return PathValidationService() + + _add_singleton( + PathValidationService, implementation_factory=_path_validation_service_factory + ) + _add_singleton( + IPathValidator, # type: ignore[type-abstract] + implementation_factory=lambda p: p.get_required_service(PathValidationService), + ) + + # Register FileSandboxingHandler + def _file_sandboxing_handler_factory( + provider: IServiceProvider, + ) -> FileSandboxingHandler: + config = provider.get_required_service(AppConfig) + path_validator = provider.get_required_service(IPathValidator) # type: ignore[type-abstract] + session_service = provider.get_required_service(ISessionService) # type: ignore[type-abstract] + + return FileSandboxingHandler( + config=config.sandboxing, + path_validator=path_validator, + session_service=session_service, + ) + + _add_singleton( + FileSandboxingHandler, implementation_factory=_file_sandboxing_handler_factory + ) + + # Register backend service + def _backend_service_factory(provider: IServiceProvider) -> BackendService: + # Import required modules + import httpx + + from src.core.services.backend_factory import BackendFactory + from src.core.services.rate_limiter import RateLimiter + + # Get or create dependencies + httpx_client: httpx.AsyncClient | None = provider.get_service(httpx.AsyncClient) + if httpx_client is None: + try: + httpx_client = httpx.AsyncClient( + http2=True, + timeout=httpx.Timeout( + connect=10.0, read=60.0, write=60.0, pool=60.0 + ), + limits=httpx.Limits( + max_connections=100, max_keepalive_connections=20 + ), + trust_env=False, + ) + except ImportError: + httpx_client = httpx.AsyncClient( + http2=False, + timeout=httpx.Timeout( + connect=10.0, read=60.0, write=60.0, pool=60.0 + ), + limits=httpx.Limits( + max_connections=100, max_keepalive_connections=20 + ), + trust_env=False, + ) + + # Get app config + app_config: AppConfig = provider.get_required_service(AppConfig) + + backend_factory: BackendFactory = provider.get_required_service(BackendFactory) + + # Resolve the rate limiter from the DI container when available + rate_limiter: IRateLimiter | None = provider.get_service(RateLimiter) + if rate_limiter is None: + rate_limiter = provider.get_service(cast(type, IRateLimiter)) + if rate_limiter is None: + logging.getLogger(__name__).warning( + "RateLimiter service not registered; creating transient instance" + ) + rate_limiter = RateLimiter() + + # Get application state service + app_state: IApplicationState = provider.get_required_service(IApplicationState) # type: ignore[type-abstract] + + # Get failover coordinator (optional for test environments) + failover_coordinator = None + try: + failover_coordinator = provider.get_service(IFailoverCoordinator) # type: ignore[type-abstract] + except Exception as e: + logger.debug(f"FailoverCoordinator not available: {e}") + + # Get backend config provider or create one + backend_config_provider = None + try: + backend_config_provider = provider.get_service(IBackendConfigProvider) # type: ignore[type-abstract] + except Exception as e: + logger.debug( + f"BackendConfigProvider not available, will create default: {e}" + ) + + # If not available, create one with the app config + if backend_config_provider is None: + from src.core.services.backend_config_provider import BackendConfigProvider + + backend_config_provider = BackendConfigProvider(app_config) + + # Optionally build a failover strategy based on feature flag + failover_strategy = None + try: + if ( + app_state.get_use_failover_strategy() + and failover_coordinator is not None + ): + from src.core.services.failover_strategy import DefaultFailoverStrategy + + failover_strategy = DefaultFailoverStrategy(failover_coordinator) + except (AttributeError, ImportError, TypeError) as e: + logging.getLogger(__name__).debug( + "Failed to enable failover strategy: %s", e, exc_info=True + ) + + # Return backend service + return BackendService( + backend_factory, + rate_limiter, + app_config, + session_service=provider.get_required_service(SessionService), + app_state=app_state, + backend_config_provider=backend_config_provider, + failover_coordinator=failover_coordinator, + failover_strategy=failover_strategy, + wire_capture=provider.get_required_service(IWireCapture), # type: ignore[type-abstract] + ) + + # Register backend service and bind to interface + _add_singleton(BackendService, implementation_factory=_backend_service_factory) + + try: + services.add_singleton( + cast(type, IBackendService), implementation_factory=_backend_service_factory + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IBackendService interface: {e}") + # Continue if concrete BackendService is registered + + # Register FailoverService first (dependency of FailoverCoordinator) + def _failover_service_factory(provider: IServiceProvider) -> FailoverService: + # FailoverService constructor takes failover_routes dict, defaulting to empty + return FailoverService(failover_routes={}) + + _add_singleton(FailoverService, implementation_factory=_failover_service_factory) + + # Register failover coordinator (if not already registered elsewhere as a concrete type) + def _failover_coordinator_factory( + provider: IServiceProvider, + ) -> FailoverCoordinator: + from src.core.services.failover_coordinator import FailoverCoordinator + from src.core.services.failover_service import FailoverService + + failover_service = provider.get_required_service(FailoverService) + return FailoverCoordinator(failover_service) + + from src.core.services.failover_coordinator import FailoverCoordinator + + _add_singleton( + FailoverCoordinator, implementation_factory=_failover_coordinator_factory + ) + + try: + from src.core.interfaces.failover_interface import IFailoverCoordinator + + services.add_singleton( + cast(type, IFailoverCoordinator), + implementation_factory=_failover_coordinator_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IFailoverCoordinator interface: {e}") + # Continue if concrete FailoverCoordinator is registered + + # Register request processor + def _request_processor_factory(provider: IServiceProvider) -> RequestProcessor: + # Get required services + command_processor = provider.get_required_service(ICommandProcessor) # type: ignore[type-abstract] + session_manager = provider.get_required_service(ISessionManager) # type: ignore[type-abstract] + backend_request_manager = provider.get_required_service(IBackendRequestManager) # type: ignore[type-abstract] + response_manager = provider.get_required_service(IResponseManager) # type: ignore[type-abstract] + app_state = provider.get_service(IApplicationState) # type: ignore[type-abstract] + + # Return request processor with decomposed services + return RequestProcessor( + command_processor, + session_manager, + backend_request_manager, + response_manager, + app_state=app_state, + ) + + # Register request processor and bind to interface + _add_singleton(RequestProcessor, implementation_factory=_request_processor_factory) + + try: + _add_singleton( + cast(type, IRequestProcessor), + implementation_factory=_request_processor_factory, + ) # type: ignore[type-abstract] + except Exception as e: + logger.warning(f"Failed to register IRequestProcessor interface: {e}") + # Continue if concrete RequestProcessor is registered + + +def get_service(service_type: type[T]) -> T | None: + """Get a service from the global service provider. + + Args: + service_type: The type of service to get + + Returns: + The service instance, or None if the service is not registered + """ + provider = get_or_build_service_provider() + return provider.get_service(service_type) # type: ignore + + +def get_required_service(service_type: type[T]) -> T: + """Get a required service from the global service provider. + + Args: + service_type: The type of service to get + + Returns: + The service instance + + Raises: + Exception: If the service is not registered + """ + provider = get_or_build_service_provider() + return provider.get_required_service(service_type) # type: ignore diff --git a/src/core/domain/streaming_response_processor.py b/src/core/domain/streaming_response_processor.py index c80b72bf0..b2718a262 100644 --- a/src/core/domain/streaming_response_processor.py +++ b/src/core/domain/streaming_response_processor.py @@ -1,153 +1,194 @@ -""" -Streaming response processor interfaces and utilities. - -This module provides interfaces and utilities for processing streaming -responses in a consistent way, regardless of the source or format. -""" - -from __future__ import annotations - -import logging -from collections.abc import Awaitable, Callable - -from src.core.app.constants.logging_constants import TRACE_LEVEL -from src.core.ports.streaming import IStreamProcessor, StreamingContent - -logger = logging.getLogger(__name__) - - -from src.core.interfaces.loop_detector_interface import ILoopDetector -from src.core.services.streaming.stream_utils import get_stream_id as _get_stream_id -from src.loop_detection.event import LoopDetectionEvent - - -class LoopDetectionProcessor(IStreamProcessor): - """Stream processor that checks for repetitive patterns in the content and handles API cancellation. - - This implementation uses a hash-based loop detection mechanism and integrates - with the backend's cancellation system to properly break loops with token waste prevention. - - IMPORTANT: This processor maintains per-session detector instances to ensure that - loop detection state is never shared between different sessions. - """ - - def __init__( - self, - loop_detector_factory: Callable[[], ILoopDetector], - cancel_callback: Callable[[], Awaitable[None]] | None = None, - ) -> None: - """Initialize loop detection processor. - - Args: - loop_detector_factory: Factory function to create new loop detector instances per session. - cancel_callback: Optional callback to trigger API cancellation when loop is detected. - """ - self.loop_detector_factory = loop_detector_factory - self.cancel_callback = cancel_callback - # Per-session detector instances to ensure isolation - self._session_detectors: dict[str, ILoopDetector] = {} - - def _get_detector_for_session(self, session_id: str) -> ILoopDetector: - """Get or create a loop detector for the given session. - - Args: - session_id: The session identifier - - Returns: - A loop detector instance dedicated to this session - """ - if session_id not in self._session_detectors: - detector = self.loop_detector_factory() - self._session_detectors[session_id] = detector - logger.debug(f"Created new loop detector for session {session_id}") - return self._session_detectors[session_id] - - def cleanup_session(self, session_id: str) -> None: - """Clean up detector instance for a completed session. - - Args: - session_id: The session identifier to clean up - """ - if session_id in self._session_detectors: - del self._session_detectors[session_id] - logger.debug(f"Cleaned up loop detector for session {session_id}") - - async def process(self, content: StreamingContent) -> StreamingContent: - """Process a streaming content chunk and check for loops. - - Args: - content: The content to process. - - Returns: - The processed content, with API cancellation if a loop is detected. - """ - if content.is_empty and not content.is_done: - return content - - # Ensure a stable stream identifier so metadata stays consistent across processors. - stream_id = _get_stream_id(content) - - # Prefer an explicit session identifier when provided; otherwise fall back to stream. - raw_session = content.metadata.get("session_id") or content.metadata.get("id") - session_id = str(raw_session) if raw_session else str(stream_id) - - # Get the detector instance for this specific session - loop_detector = self._get_detector_for_session(session_id) - - # Process the content for loop detection - # Ensure content is a string for loop detector - content_str = content.content - if logger.isEnabledFor(TRACE_LEVEL): - logger.log( - TRACE_LEVEL, - f"LoopDetectionProcessor processing chunk for session {session_id}: '{content_str[:50]}...' (length: {len(content_str)})", - ) - detection_event = loop_detector.process_chunk(content_str) - - # Clean up detector when stream is done - if content.is_done: - self.cleanup_session(session_id) - - if detection_event: - logger.warning( - f"Loop detected in streaming response by LoopDetectionProcessor: pattern='{detection_event.pattern[:50]}...', " - f"repetitions={detection_event.repetition_count}, total_length={detection_event.total_length}" - ) - - # Trigger API cancellation if callback is available - if self.cancel_callback is not None: - logger.info( - f"Triggering API cancellation due to loop detection: pattern='{detection_event.pattern[:50]}', repetitions={detection_event.repetition_count}" - ) - try: - await self.cancel_callback() - except Exception as e: - logger.error( - f"Failed to trigger API cancellation: {e}", exc_info=True - ) - - return self._create_cancellation_content(detection_event) - else: - # No loop detected, pass through the content - return content - - def _create_cancellation_content( - self, detection_event: LoopDetectionEvent - ) -> StreamingContent: - """Create a StreamingContent object with a cancellation message.""" - payload = ( - "[Response cancelled: Loop detected - Pattern " - f"'{detection_event.pattern[:30]}...' repeated " - f"{detection_event.repetition_count} times]" - ) - - return StreamingContent( - content=payload, - is_done=True, - is_cancellation=True, - metadata={ - "loop_detected": True, - "pattern": detection_event.pattern, - "repetition_count": detection_event.repetition_count, - }, - ) +""" +"""Streaming response processor interfaces and utilities. + +This module provides interfaces and utilities for processing streaming +responses in a consistent way, regardless of the source or format. +""" + +from __future__ import annotations + +import logging +import time +from collections.abc import Awaitable, Callable + +from src.core.app.constants.logging_constants import TRACE_LEVEL +from src.core.ports.streaming import IStreamProcessor, StreamingContent + +logger = logging.getLogger(__name__) + + +from src.core.interfaces.loop_detector_interface import ILoopDetector +from src.core.services.streaming.stream_utils import get_stream_id as _get_stream_id +from src.loop_detection.event import LoopDetectionEvent + + +class LoopDetectionProcessor(IStreamProcessor): + """Stream processor that checks for repetitive patterns in the content and handles API cancellation. + + This implementation uses a hash-based loop detection mechanism and integrates + with the backend's cancellation system to properly break loops with token waste prevention. + + IMPORTANT: This processor maintains per-session detector instances to ensure that + loop detection state is never shared between different sessions. + """ + + def __init__( + self, + loop_detector_factory: Callable[[], ILoopDetector], + cancel_callback: Callable[[], Awaitable[None]] | None = None, + *, + session_ttl_seconds: int = 300, + time_provider: Callable[[], float] | None = None, + ) -> None: + """Initialize loop detection processor. + + Args: + loop_detector_factory: Factory function to create new loop detector instances per session. + cancel_callback: Optional callback to trigger API cancellation when loop is detected. + session_ttl_seconds: Maximum idle time (in seconds) before a detector is + automatically discarded. Defaults to 300 seconds. Set to 0 to disable TTL cleanup. + time_provider: Optional callable returning the current time in seconds. Primarily + intended for tests so they can control time progression deterministically. + """ + self.loop_detector_factory = loop_detector_factory + self.cancel_callback = cancel_callback + # Per-session detector instances to ensure isolation + self._session_detectors: dict[str, ILoopDetector] = {} + self._session_last_activity: dict[str, float] = {} + self._session_ttl_seconds = max(0, session_ttl_seconds) + self._time_provider = time_provider or time.time + + def _cleanup_stale_sessions(self, current_time: float) -> None: + """Prune detector instances that have been inactive beyond the TTL.""" + + if self._session_ttl_seconds <= 0: + return + + stale_sessions = [ + session_id + for session_id, last_seen in self._session_last_activity.items() + if current_time - last_seen >= self._session_ttl_seconds + ] + + for session_id in stale_sessions: + if session_id in self._session_detectors: + logger.debug( + "Pruning stale loop detector for session %s after %ss of inactivity", + session_id, + self._session_ttl_seconds, + ) + self._session_detectors.pop(session_id, None) + self._session_last_activity.pop(session_id, None) + + def _get_detector_for_session( + self, session_id: str, current_time: float + ) -> ILoopDetector: + """Get or create a loop detector for the given session. + + Args: + session_id: The session identifier + + Returns: + A loop detector instance dedicated to this session + """ + self._cleanup_stale_sessions(current_time) + + if session_id not in self._session_detectors: + detector = self.loop_detector_factory() + self._session_detectors[session_id] = detector + logger.debug(f"Created new loop detector for session {session_id}") + self._session_last_activity[session_id] = current_time + return self._session_detectors[session_id] + + def cleanup_session(self, session_id: str) -> None: + """Clean up detector instance for a completed session. + + Args: + session_id: The session identifier to clean up + """ + if session_id in self._session_detectors: + del self._session_detectors[session_id] + logger.debug(f"Cleaned up loop detector for session {session_id}") + self._session_last_activity.pop(session_id, None) + + async def process(self, content: StreamingContent) -> StreamingContent: + """Process a streaming content chunk and check for loops. + + Args: + content: The content to process. + + Returns: + The processed content, with API cancellation if a loop is detected. + """ + if content.is_empty and not content.is_done: + return content + + # Ensure a stable stream identifier so metadata stays consistent across processors. + stream_id = _get_stream_id(content) + + # Prefer an explicit session identifier when provided; otherwise fall back to stream. + raw_session = content.metadata.get("session_id") or content.metadata.get("id") + session_id = str(raw_session) if raw_session else str(stream_id) + + current_time = self._time_provider() + + # Get the detector instance for this specific session + loop_detector = self._get_detector_for_session(session_id, current_time) + + # Process the content for loop detection + # Ensure content is a string for loop detector + content_str = content.content + if logger.isEnabledFor(TRACE_LEVEL): + logger.log( + TRACE_LEVEL, + f"LoopDetectionProcessor processing chunk for session {session_id}: '{content_str[:50]}...' (length: {len(content_str)})", + ) + detection_event = loop_detector.process_chunk(content_str) + + # Clean up detector when stream is done + if content.is_done: + self.cleanup_session(session_id) + + if detection_event: + logger.warning( + f"Loop detected in streaming response by LoopDetectionProcessor: pattern='{detection_event.pattern[:50]}...', " + f"repetitions={detection_event.repetition_count}, total_length={detection_event.total_length}" + ) + + # Trigger API cancellation if callback is available + if self.cancel_callback is not None: + logger.info( + f"Triggering API cancellation due to loop detection: pattern='{detection_event.pattern[:50]}', repetitions={detection_event.repetition_count}" + ) + try: + await self.cancel_callback() + except Exception as e: + logger.error( + f"Failed to trigger API cancellation: {e}", exc_info=True + ) + + return self._create_cancellation_content(detection_event) + else: + # No loop detected, pass through the content + return content + + def _create_cancellation_content( + self, detection_event: LoopDetectionEvent + ) -> StreamingContent: + """Create a StreamingContent object with a cancellation message.""" + payload = ( + "[Response cancelled: Loop detected - Pattern " + f"'{detection_event.pattern[:30]}...' repeated " + f"{detection_event.repetition_count} times]" + ) + + return StreamingContent( + content=payload, + is_done=True, + is_cancellation=True, + metadata={ + "loop_detected": True, + "pattern": detection_event.pattern, + "repetition_count": detection_event.repetition_count, + }, + ) diff --git a/tests/unit/loop_detection/test_session_isolation.py b/tests/unit/loop_detection/test_session_isolation.py index 6fe9d4143..d84525073 100644 --- a/tests/unit/loop_detection/test_session_isolation.py +++ b/tests/unit/loop_detection/test_session_isolation.py @@ -1,368 +1,469 @@ -""" -Test cases for loop detection session isolation. - -These tests ensure that loop detector state is never shared between different sessions, -preventing state contamination and ensuring each session has independent loop detection. -""" - -import pytest -from src.core.domain.streaming_response_processor import LoopDetectionProcessor -from src.core.ports.streaming import StreamingContent -from src.loop_detection.hybrid_detector import HybridLoopDetector - - -class TestLoopDetectionSessionIsolation: - """Test suite for verifying session isolation in loop detection.""" - - @pytest.fixture - def detector_factory(self): - """Factory function to create new detector instances.""" - - def create_detector(): - short_config = { - "content_loop_threshold": 6, - "content_chunk_size": 50, - "max_history_length": 4096, - } - return HybridLoopDetector(short_detector_config=short_config) - - return create_detector - - @pytest.fixture - def processor(self, detector_factory): - """Create a LoopDetectionProcessor with factory.""" - return LoopDetectionProcessor(loop_detector_factory=detector_factory) - - @pytest.mark.asyncio - async def test_different_sessions_have_independent_detectors( - self, processor, detector_factory - ): - """Test that different sessions get different detector instances.""" - # Create content for two different sessions - content_session_a = StreamingContent( - content="test", metadata={"session_id": "session-a"} - ) - content_session_b = StreamingContent( - content="test", metadata={"session_id": "session-b"} - ) - - # Process content for both sessions - await processor.process(content_session_a) - await processor.process(content_session_b) - - # Verify that two different detector instances were created - assert "session-a" in processor._session_detectors - assert "session-b" in processor._session_detectors - assert ( - processor._session_detectors["session-a"] - is not processor._session_detectors["session-b"] - ) - - @pytest.mark.asyncio - async def test_session_state_does_not_leak_between_sessions(self, processor): - """Test that loop detection state from one session doesn't affect another.""" - # Session A: Send repetitive content that should accumulate state - session_a_content = "AAAAAAAAAA" * 10 # 100 A's - for _ in range(5): - content = StreamingContent( - content=session_a_content, metadata={"session_id": "session-a"} - ) - await processor.process(content) - - # Session B: Send different content - should start with clean state - session_b_content = "BBBBBBBBBB" * 10 # 100 B's - content = StreamingContent( - content=session_b_content, metadata={"session_id": "session-b"} - ) - await processor.process(content) - - # Verify that session B's detector has no history from session A - detector_a = processor._session_detectors["session-a"] - detector_b = processor._session_detectors["session-b"] - - # Session A should have accumulated content - history_a = detector_a.short_detector.stream_content_history - assert "A" in history_a - assert len(history_a) > 0 - - # Session B should only have its own content, not session A's - history_b = detector_b.short_detector.stream_content_history - assert "B" in history_b - assert "A" not in history_b - - @pytest.mark.asyncio - async def test_loop_detection_in_one_session_does_not_affect_another( - self, processor - ): - """Test that detecting a loop in one session doesn't trigger in another.""" - # Session A: Send content that will trigger loop detection - loop_content = "IIIIIIII" # 8 I's - for _ in range(15): # Send enough to trigger detection - content = StreamingContent( - content=loop_content, metadata={"session_id": "session-a"} - ) - result = await processor.process(content) - if result.is_cancellation: - break - - # Session B: Send normal content - should NOT be affected by session A's loop - normal_content = "This is normal text without any loops." - content = StreamingContent( - content=normal_content, metadata={"session_id": "session-b"} - ) - result = await processor.process(content) - - # Session B should process normally, not be cancelled - assert not result.is_cancellation - assert result.content == normal_content - - @pytest.mark.asyncio - async def test_session_cleanup_removes_detector(self, processor): - """Test that detector is cleaned up when session completes.""" - session_id = "test-session" - - # Send some content - content = StreamingContent( - content="test content", metadata={"session_id": session_id} - ) - await processor.process(content) - - # Verify detector was created - assert session_id in processor._session_detectors - - # Send done marker - done_content = StreamingContent( - content="", is_done=True, metadata={"session_id": session_id} - ) - await processor.process(done_content) - - # Verify detector was cleaned up - assert session_id not in processor._session_detectors - - @pytest.mark.asyncio - async def test_concurrent_sessions_maintain_isolation(self, processor): - """Test that multiple concurrent sessions maintain independent state.""" - sessions = ["session-1", "session-2", "session-3"] - - # Send different content to each session concurrently - for i, session_id in enumerate(sessions): - # Each session gets different repeated character - char = chr(ord("A") + i) # A, B, C - content = StreamingContent( - content=char * 50, metadata={"session_id": session_id} - ) - await processor.process(content) - - # Verify each session has its own detector with its own content - for i, session_id in enumerate(sessions): - detector = processor._session_detectors[session_id] - history = detector.short_detector.stream_content_history - expected_char = chr(ord("A") + i) - - # Each session should only have its own character - assert expected_char in history - # And should not have other sessions' characters - for j, other_session in enumerate(sessions): # noqa: B007 - if i != j: - other_char = chr(ord("A") + j) - assert other_char not in history - - @pytest.mark.asyncio - async def test_same_session_reuses_detector(self, processor): - """Test that the same session reuses its detector instance.""" - session_id = "test-session" - - # Send first chunk - content1 = StreamingContent( - content="first chunk", metadata={"session_id": session_id} - ) - await processor.process(content1) - detector1 = processor._session_detectors[session_id] - - # Send second chunk - content2 = StreamingContent( - content="second chunk", metadata={"session_id": session_id} - ) - await processor.process(content2) - detector2 = processor._session_detectors[session_id] - - # Should be the same detector instance - assert detector1 is detector2 - - # And should have accumulated both chunks - history = detector1.short_detector.stream_content_history - assert "first chunk" in history - assert "second chunk" in history - - @pytest.mark.asyncio - async def test_session_without_id_uses_generated_stream_id(self, processor): - """Test that content without session_id generates a unique stream_id.""" - # Send content without session_id - content = StreamingContent(content="test content", metadata={}) - await processor.process(content) - - # Should create detector with a generated stream_id - assert len(processor._session_detectors) == 1 - # The generated stream_id should be a UUID hex string (32 characters) - session_key = next(iter(processor._session_detectors.keys())) - assert len(session_key) == 32 # UUID hex without dashes - - @pytest.mark.asyncio - async def test_stream_id_fallback_when_no_session_id(self, processor): - """Test that stream_id is used as fallback when session_id is not present.""" - stream_id = "stream-123" - - # Send content with stream_id but no session_id - content = StreamingContent( - content="test content", metadata={"stream_id": stream_id} - ) - await processor.process(content) - - # Should create detector using stream_id - assert stream_id in processor._session_detectors - - @pytest.mark.asyncio - async def test_multiple_cleanup_calls_are_safe(self, processor): - """Test that cleaning up the same session multiple times doesn't cause errors.""" - session_id = "test-session" - - # Create a detector - content = StreamingContent(content="test", metadata={"session_id": session_id}) - await processor.process(content) - assert session_id in processor._session_detectors - - # Clean up multiple times - processor.cleanup_session(session_id) - processor.cleanup_session(session_id) # Should not raise error - processor.cleanup_session(session_id) # Should not raise error - - assert session_id not in processor._session_detectors - - @pytest.mark.asyncio - async def test_detector_state_persists_within_session(self, processor): - """Test that detector state accumulates correctly within a single session.""" - session_id = "test-session" - - # Send multiple chunks of DIFFERENT content to avoid triggering loop detection - for i in range(10): - content = StreamingContent( - content=f"Chunk {i} with unique content here.", - metadata={"session_id": session_id}, - ) - await processor.process(content) - - # Verify that content accumulated in the detector - detector = processor._session_detectors[session_id] - history = detector.short_detector.stream_content_history - - # Should have accumulated all chunks - assert "Chunk 0" in history - assert "Chunk 9" in history - assert len(history) > 200 # Should have accumulated substantial content - - @pytest.mark.asyncio - async def test_factory_creates_fresh_detectors(self, detector_factory): - """Test that the factory function creates independent detector instances.""" - detector1 = detector_factory() - detector2 = detector_factory() - - # Should be different instances - assert detector1 is not detector2 - - # Should have independent state - detector1.process_chunk("test1") - detector2.process_chunk("test2") - - history1 = detector1.short_detector.stream_content_history - history2 = detector2.short_detector.stream_content_history - - assert "test1" in history1 - assert "test1" not in history2 - assert "test2" in history2 - assert "test2" not in history1 - - -class TestLoopDetectionRegressionPrevention: - """Tests to prevent regression to shared detector state.""" - - @pytest.mark.asyncio - async def test_processor_does_not_share_single_detector_instance(self): - """ - REGRESSION TEST: Ensure processor doesn't use a single shared detector. - - This test would FAIL if someone reverts to the old implementation where - a single detector instance was shared across all sessions. - """ - - # Create processor with factory - def create_detector(): - return HybridLoopDetector() - - processor = LoopDetectionProcessor(loop_detector_factory=create_detector) - - # Process content for two sessions - content_a = StreamingContent( - content="AAAA", metadata={"session_id": "session-a"} - ) - content_b = StreamingContent( - content="BBBB", metadata={"session_id": "session-b"} - ) - - await processor.process(content_a) - await processor.process(content_b) - - # CRITICAL: Must have separate detector instances - detector_a = processor._session_detectors["session-a"] - detector_b = processor._session_detectors["session-b"] - - # This assertion would FAIL if using shared detector - assert detector_a is not detector_b, ( - "REGRESSION: Detector instances are shared between sessions! " - "Each session must have its own isolated detector instance." - ) - - @pytest.mark.asyncio - async def test_detector_state_is_not_global(self): - """ - REGRESSION TEST: Ensure detector state is not stored globally. - - This test would FAIL if detector state was stored in a class variable - or module-level variable instead of per-instance. - """ - - def create_detector(): - return HybridLoopDetector() - - processor = LoopDetectionProcessor(loop_detector_factory=create_detector) - - # Session A accumulates state - for _ in range(5): - content = StreamingContent( - content="AAAA", metadata={"session_id": "session-a"} - ) - await processor.process(content) - - # Session B should start fresh - content_b = StreamingContent( - content="BBBB", metadata={"session_id": "session-b"} - ) - await processor.process(content_b) - - # Get histories - history_a = processor._session_detectors[ - "session-a" - ].short_detector.stream_content_history - history_b = processor._session_detectors[ - "session-b" - ].short_detector.stream_content_history - - # This assertion would FAIL if state was global - assert "A" not in history_b, ( - "REGRESSION: Session B's detector contains Session A's content! " - "Detector state is being shared globally instead of per-session." - ) - - assert "B" not in history_a, ( - "REGRESSION: Session A's detector contains Session B's content! " - "Detector state is being shared globally instead of per-session." - ) +""" +Test cases for loop detection session isolation. + +These tests ensure that loop detector state is never shared between different sessions, +preventing state contamination and ensuring each session has independent loop detection. +""" + +from typing import Any + +import pytest +from src.core.domain.streaming_response_processor import LoopDetectionProcessor +from src.core.interfaces.loop_detector_interface import ( + ILoopDetector, + LoopDetectionResult, +) +from src.core.ports.streaming import StreamingContent +from src.loop_detection.event import LoopDetectionEvent +from src.loop_detection.hybrid_detector import HybridLoopDetector + + +class _FakeTime: + """Deterministic time provider for TTL tests.""" + + def __init__(self, start: float = 0.0) -> None: + self._current = start + + def advance(self, seconds: float) -> None: + self._current += seconds + + def now(self) -> float: + return self._current + + +class _NoopLoopDetector(ILoopDetector): + """Simple loop detector stub for exercising processor lifecycle logic.""" + + def __init__(self) -> None: + self._seen_chunks: list[str] = [] + + def is_enabled(self) -> bool: + return True + + def process_chunk(self, chunk: str) -> LoopDetectionEvent | None: + self._seen_chunks.append(chunk) + return None + + def reset(self) -> None: + self._seen_chunks.clear() + + def get_loop_history(self) -> list[Any]: + return [] + + def get_current_state(self) -> dict[str, Any]: + return {"chunks": list(self._seen_chunks)} + + async def check_for_loops(self, content: str) -> LoopDetectionResult: + return LoopDetectionResult(has_loop=False) + + +class TestLoopDetectionSessionIsolation: + """Test suite for verifying session isolation in loop detection.""" + + @pytest.fixture + def detector_factory(self): + """Factory function to create new detector instances.""" + + def create_detector(): + short_config = { + "content_loop_threshold": 6, + "content_chunk_size": 50, + "max_history_length": 4096, + } + return HybridLoopDetector(short_detector_config=short_config) + + return create_detector + + @pytest.fixture + def processor(self, detector_factory): + """Create a LoopDetectionProcessor with factory.""" + return LoopDetectionProcessor(loop_detector_factory=detector_factory) + + @pytest.mark.asyncio + async def test_different_sessions_have_independent_detectors( + self, processor, detector_factory + ): + """Test that different sessions get different detector instances.""" + # Create content for two different sessions + content_session_a = StreamingContent( + content="test", metadata={"session_id": "session-a"} + ) + content_session_b = StreamingContent( + content="test", metadata={"session_id": "session-b"} + ) + + # Process content for both sessions + await processor.process(content_session_a) + await processor.process(content_session_b) + + # Verify that two different detector instances were created + assert "session-a" in processor._session_detectors + assert "session-b" in processor._session_detectors + assert ( + processor._session_detectors["session-a"] + is not processor._session_detectors["session-b"] + ) + + @pytest.mark.asyncio + async def test_session_state_does_not_leak_between_sessions(self, processor): + """Test that loop detection state from one session doesn't affect another.""" + # Session A: Send repetitive content that should accumulate state + session_a_content = "AAAAAAAAAA" * 10 # 100 A's + for _ in range(5): + content = StreamingContent( + content=session_a_content, metadata={"session_id": "session-a"} + ) + await processor.process(content) + + # Session B: Send different content - should start with clean state + session_b_content = "BBBBBBBBBB" * 10 # 100 B's + content = StreamingContent( + content=session_b_content, metadata={"session_id": "session-b"} + ) + await processor.process(content) + + # Verify that session B's detector has no history from session A + detector_a = processor._session_detectors["session-a"] + detector_b = processor._session_detectors["session-b"] + + # Session A should have accumulated content + history_a = detector_a.short_detector.stream_content_history + assert "A" in history_a + assert len(history_a) > 0 + + # Session B should only have its own content, not session A's + history_b = detector_b.short_detector.stream_content_history + assert "B" in history_b + assert "A" not in history_b + + @pytest.mark.asyncio + async def test_loop_detection_in_one_session_does_not_affect_another( + self, processor + ): + """Test that detecting a loop in one session doesn't trigger in another.""" + # Session A: Send content that will trigger loop detection + loop_content = "IIIIIIII" # 8 I's + for _ in range(15): # Send enough to trigger detection + content = StreamingContent( + content=loop_content, metadata={"session_id": "session-a"} + ) + result = await processor.process(content) + if result.is_cancellation: + break + + # Session B: Send normal content - should NOT be affected by session A's loop + normal_content = "This is normal text without any loops." + content = StreamingContent( + content=normal_content, metadata={"session_id": "session-b"} + ) + result = await processor.process(content) + + # Session B should process normally, not be cancelled + assert not result.is_cancellation + assert result.content == normal_content + + @pytest.mark.asyncio + async def test_session_cleanup_removes_detector(self, processor): + """Test that detector is cleaned up when session completes.""" + session_id = "test-session" + + # Send some content + content = StreamingContent( + content="test content", metadata={"session_id": session_id} + ) + await processor.process(content) + + # Verify detector was created + assert session_id in processor._session_detectors + + # Send done marker + done_content = StreamingContent( + content="", is_done=True, metadata={"session_id": session_id} + ) + await processor.process(done_content) + + # Verify detector was cleaned up + assert session_id not in processor._session_detectors + + @pytest.mark.asyncio + async def test_concurrent_sessions_maintain_isolation(self, processor): + """Test that multiple concurrent sessions maintain independent state.""" + sessions = ["session-1", "session-2", "session-3"] + + # Send different content to each session concurrently + for i, session_id in enumerate(sessions): + # Each session gets different repeated character + char = chr(ord("A") + i) # A, B, C + content = StreamingContent( + content=char * 50, metadata={"session_id": session_id} + ) + await processor.process(content) + + # Verify each session has its own detector with its own content + for i, session_id in enumerate(sessions): + detector = processor._session_detectors[session_id] + history = detector.short_detector.stream_content_history + expected_char = chr(ord("A") + i) + + # Each session should only have its own character + assert expected_char in history + # And should not have other sessions' characters + for j, other_session in enumerate(sessions): # noqa: B007 + if i != j: + other_char = chr(ord("A") + j) + assert other_char not in history + + @pytest.mark.asyncio + async def test_same_session_reuses_detector(self, processor): + """Test that the same session reuses its detector instance.""" + session_id = "test-session" + + # Send first chunk + content1 = StreamingContent( + content="first chunk", metadata={"session_id": session_id} + ) + await processor.process(content1) + detector1 = processor._session_detectors[session_id] + + # Send second chunk + content2 = StreamingContent( + content="second chunk", metadata={"session_id": session_id} + ) + await processor.process(content2) + detector2 = processor._session_detectors[session_id] + + # Should be the same detector instance + assert detector1 is detector2 + + # And should have accumulated both chunks + history = detector1.short_detector.stream_content_history + assert "first chunk" in history + assert "second chunk" in history + + @pytest.mark.asyncio + async def test_session_without_id_uses_generated_stream_id(self, processor): + """Test that content without session_id generates a unique stream_id.""" + # Send content without session_id + content = StreamingContent(content="test content", metadata={}) + await processor.process(content) + + # Should create detector with a generated stream_id + assert len(processor._session_detectors) == 1 + # The generated stream_id should be a UUID hex string (32 characters) + session_key = next(iter(processor._session_detectors.keys())) + assert len(session_key) == 32 # UUID hex without dashes + + @pytest.mark.asyncio + async def test_stream_id_fallback_when_no_session_id(self, processor): + """Test that stream_id is used as fallback when session_id is not present.""" + stream_id = "stream-123" + + # Send content with stream_id but no session_id + content = StreamingContent( + content="test content", metadata={"stream_id": stream_id} + ) + await processor.process(content) + + # Should create detector using stream_id + assert stream_id in processor._session_detectors + + @pytest.mark.asyncio + async def test_multiple_cleanup_calls_are_safe(self, processor): + """Test that cleaning up the same session multiple times doesn't cause errors.""" + session_id = "test-session" + + # Create a detector + content = StreamingContent(content="test", metadata={"session_id": session_id}) + await processor.process(content) + assert session_id in processor._session_detectors + + # Clean up multiple times + processor.cleanup_session(session_id) + processor.cleanup_session(session_id) # Should not raise error + processor.cleanup_session(session_id) # Should not raise error + + assert session_id not in processor._session_detectors + + @pytest.mark.asyncio + async def test_detector_state_persists_within_session(self, processor): + """Test that detector state accumulates correctly within a single session.""" + session_id = "test-session" + + # Send multiple chunks of DIFFERENT content to avoid triggering loop detection + for i in range(10): + content = StreamingContent( + content=f"Chunk {i} with unique content here.", + metadata={"session_id": session_id}, + ) + await processor.process(content) + + # Verify that content accumulated in the detector + detector = processor._session_detectors[session_id] + history = detector.short_detector.stream_content_history + + # Should have accumulated all chunks + assert "Chunk 0" in history + assert "Chunk 9" in history + assert len(history) > 200 # Should have accumulated substantial content + + @pytest.mark.asyncio + async def test_factory_creates_fresh_detectors(self, detector_factory): + """Test that the factory function creates independent detector instances.""" + detector1 = detector_factory() + detector2 = detector_factory() + + # Should be different instances + assert detector1 is not detector2 + + # Should have independent state + detector1.process_chunk("test1") + detector2.process_chunk("test2") + + history1 = detector1.short_detector.stream_content_history + history2 = detector2.short_detector.stream_content_history + + assert "test1" in history1 + assert "test1" not in history2 + assert "test2" in history2 + assert "test2" not in history1 + + +class TestLoopDetectionRegressionPrevention: + """Tests to prevent regression to shared detector state.""" + + @pytest.mark.asyncio + async def test_processor_does_not_share_single_detector_instance(self): + """ + REGRESSION TEST: Ensure processor doesn't use a single shared detector. + + This test would FAIL if someone reverts to the old implementation where + a single detector instance was shared across all sessions. + """ + + # Create processor with factory + def create_detector(): + return HybridLoopDetector() + + processor = LoopDetectionProcessor(loop_detector_factory=create_detector) + + # Process content for two sessions + content_a = StreamingContent( + content="AAAA", metadata={"session_id": "session-a"} + ) + content_b = StreamingContent( + content="BBBB", metadata={"session_id": "session-b"} + ) + + await processor.process(content_a) + await processor.process(content_b) + + # CRITICAL: Must have separate detector instances + detector_a = processor._session_detectors["session-a"] + detector_b = processor._session_detectors["session-b"] + + # This assertion would FAIL if using shared detector + assert detector_a is not detector_b, ( + "REGRESSION: Detector instances are shared between sessions! " + "Each session must have its own isolated detector instance." + ) + + @pytest.mark.asyncio + async def test_detector_state_is_not_global(self): + """ + REGRESSION TEST: Ensure detector state is not stored globally. + + This test would FAIL if detector state was stored in a class variable + or module-level variable instead of per-instance. + """ + + def create_detector(): + return HybridLoopDetector() + + processor = LoopDetectionProcessor(loop_detector_factory=create_detector) + + # Session A accumulates state + for _ in range(5): + content = StreamingContent( + content="AAAA", metadata={"session_id": "session-a"} + ) + await processor.process(content) + + # Session B should start fresh + content_b = StreamingContent( + content="BBBB", metadata={"session_id": "session-b"} + ) + await processor.process(content_b) + + # Get histories + history_a = processor._session_detectors[ + "session-a" + ].short_detector.stream_content_history + history_b = processor._session_detectors[ + "session-b" + ].short_detector.stream_content_history + + # This assertion would FAIL if state was global + assert "A" not in history_b, ( + "REGRESSION: Session B's detector contains Session A's content! " + "Detector state is being shared globally instead of per-session." + ) + + assert "B" not in history_a, ( + "REGRESSION: Session A's detector contains Session B's content! " + "Detector state is being shared globally instead of per-session." + ) + + +class TestLoopDetectionProcessorMemoryManagement: + """Tests covering memory-leak prevention mechanics in the processor.""" + + @pytest.mark.asyncio + async def test_stale_sessions_are_pruned_after_ttl(self) -> None: + fake_time = _FakeTime() + processor = LoopDetectionProcessor( + loop_detector_factory=_NoopLoopDetector, + session_ttl_seconds=5, + time_provider=fake_time.now, + ) + + async def send(session_id: str) -> None: + content = StreamingContent( + content="chunk", metadata={"session_id": session_id} + ) + await processor.process(content) + + await send("session-a") + fake_time.advance(1) + await send("session-b") + + assert set(processor._session_detectors) == {"session-a", "session-b"} + + # Advance time past the TTL without touching session-a + fake_time.advance(6) + await send("session-b") + + assert "session-a" not in processor._session_detectors + assert "session-a" not in processor._session_last_activity + assert "session-b" in processor._session_detectors + + @pytest.mark.asyncio + async def test_zero_ttl_disables_cleanup(self) -> None: + fake_time = _FakeTime() + processor = LoopDetectionProcessor( + loop_detector_factory=_NoopLoopDetector, + session_ttl_seconds=0, + time_provider=fake_time.now, + ) + + content = StreamingContent( + content="chunk", metadata={"session_id": "session-x"} + ) + await processor.process(content) + + # Even after a large time jump, the session should still be present + fake_time.advance(10_000) + await processor.process( + StreamingContent(content="chunk", metadata={"session_id": "session-x"}) + ) + + assert set(processor._session_detectors) == {"session-x"}