From 4fc72f24409607f5faf77f3114c6e61085d505c3 Mon Sep 17 00:00:00 2001 From: openhands Date: Fri, 7 Nov 2025 20:59:55 +0000 Subject: [PATCH 1/2] Add FallbackRouter for LLM failover support This commit implements a FallbackRouter that provides automatic failover between multiple LLM models when the primary model fails. Key features: - Automatically falls back to secondary models on errors (rate limits, connection failures, service unavailable, etc.) - Supports multiple fallback models in a chain - Preserves telemetry and metrics from the active model - Includes comprehensive logging of failover attempts Implementation: - New FallbackRouter class extending RouterLLM - Overrides completion() to implement fallback logic - Validates that 'primary' key exists in llms_for_routing - Tracks active_llm for telemetry purposes Tests: - 8 comprehensive unit tests covering all scenarios - Mocked LLM responses to avoid actual API calls - Tests for successful completion, fallback scenarios, and error cases Example: - examples/01_standalone_sdk/27_llm_fallback.py demonstrates usage - Shows how to configure primary and fallback models - Includes logging setup to observe failover behavior Co-authored-by: openhands --- examples/01_standalone_sdk/27_llm_fallback.py | 114 +++++++ .../openhands/sdk/llm/router/__init__.py | 2 + .../openhands/sdk/llm/router/impl/fallback.py | 115 +++++++ tests/sdk/llm/test_fallback_router.py | 293 ++++++++++++++++++ 4 files changed, 524 insertions(+) create mode 100644 examples/01_standalone_sdk/27_llm_fallback.py create mode 100644 openhands-sdk/openhands/sdk/llm/router/impl/fallback.py create mode 100644 tests/sdk/llm/test_fallback_router.py diff --git a/examples/01_standalone_sdk/27_llm_fallback.py b/examples/01_standalone_sdk/27_llm_fallback.py new file mode 100644 index 0000000000..9ef6a91336 --- /dev/null +++ b/examples/01_standalone_sdk/27_llm_fallback.py @@ -0,0 +1,114 @@ +""" +Example demonstrating LLM fallback functionality using FallbackRouter. + +This example shows how to configure multiple language models with automatic +fallback capability. If the primary model fails (due to rate limits, timeouts, +or service unavailability), the system automatically falls back to secondary +models. + +Use cases: +- High availability: Ensure your application continues working even if one + provider has an outage +- Rate limit handling: Automatically switch to a backup model when you hit + rate limits +- Cost optimization: Use expensive models as primary but have cheaper backups +""" + +import os + +from pydantic import SecretStr + +from openhands.sdk import ( + LLM, + Agent, + Conversation, + Message, + TextContent, + get_logger, +) +from openhands.sdk.llm.router import FallbackRouter +from openhands.tools.preset.default import get_default_tools + + +logger = get_logger(__name__) + +# Configure API credentials +api_key = os.getenv("LLM_API_KEY") +assert api_key is not None, "LLM_API_KEY environment variable is not set." +model = os.getenv("LLM_MODEL", "claude-sonnet-4-20250514") +base_url = os.getenv("LLM_BASE_URL") + +# Configure primary and fallback LLMs +# Primary: A powerful but potentially rate-limited model +primary_llm = LLM( + usage_id="primary", + model=model, + base_url=base_url, + api_key=SecretStr(api_key), +) + +# Fallback 1: A reliable alternative model +# In a real scenario, this might be a different provider or cheaper model +fallback_llm = LLM( + usage_id="fallback", + model="openhands/devstral-small-2507", + base_url=base_url, + api_key=SecretStr(api_key), +) + +# Create FallbackRouter +# Models will be tried in the order they appear in the dictionary +# Note: The first model must have key "primary" +fallback_router = FallbackRouter( + usage_id="fallback-router", + llms_for_routing={ + "primary": primary_llm, + "fallback": fallback_llm, + }, +) + +# Configure agent with fallback router +tools = get_default_tools() +agent = Agent(llm=fallback_router, tools=tools) + +# Create conversation +conversation = Conversation(agent=agent, workspace=os.getcwd()) + +# Send a message - the router will automatically try primary first, +# then fall back if needed +conversation.send_message( + message=Message( + role="user", + content=[ + TextContent( + text=( + "Hello! Can you tell me what the current date is? " + "You can use the bash tool to run the 'date' command." + ) + ) + ], + ) +) + +# Run the conversation +conversation.run() + +# Display results +print("=" * 100) +print("Conversation completed successfully!") +if fallback_router.active_llm: + print(f"Active model used: {fallback_router.active_llm.model}") +else: + print("No active model (no completions made)") + +# Report costs +metrics = conversation.conversation_stats.get_combined_metrics() +print(f"Total cost: ${metrics.accumulated_cost:.4f}") +print(f"Total tokens: {metrics.accumulated_token_usage}") + +print("\n" + "=" * 100) +print("Key features demonstrated:") +print("1. Automatic fallback when primary model fails") +print("2. Transparent switching between models") +print("3. Cost and usage tracking across all models") +print("4. Works seamlessly with agents and tools") diff --git a/openhands-sdk/openhands/sdk/llm/router/__init__.py b/openhands-sdk/openhands/sdk/llm/router/__init__.py index 37e7baca4a..425b2a8f9b 100644 --- a/openhands-sdk/openhands/sdk/llm/router/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/router/__init__.py @@ -1,4 +1,5 @@ from openhands.sdk.llm.router.base import RouterLLM +from openhands.sdk.llm.router.impl.fallback import FallbackRouter from openhands.sdk.llm.router.impl.multimodal import MultimodalRouter from openhands.sdk.llm.router.impl.random import RandomRouter @@ -7,4 +8,5 @@ "RouterLLM", "RandomRouter", "MultimodalRouter", + "FallbackRouter", ] diff --git a/openhands-sdk/openhands/sdk/llm/router/impl/fallback.py b/openhands-sdk/openhands/sdk/llm/router/impl/fallback.py new file mode 100644 index 0000000000..9ba627016b --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/router/impl/fallback.py @@ -0,0 +1,115 @@ +from collections.abc import Sequence +from typing import ClassVar + +from pydantic import model_validator + +from openhands.sdk.llm.llm_response import LLMResponse +from openhands.sdk.llm.message import Message +from openhands.sdk.llm.router.base import RouterLLM +from openhands.sdk.logger import get_logger +from openhands.sdk.tool.tool import ToolDefinition + + +logger = get_logger(__name__) + + +class FallbackRouter(RouterLLM): + """ + A RouterLLM implementation that provides fallback capability across multiple + language models. When the primary model fails due to rate limits, timeouts, + or service unavailability, it automatically falls back to secondary models. + + Models are tried in order: primary -> fallback1 -> fallback2 -> ... + If all models fail, the exception from the last model is raised. + + Example: + >>> primary = LLM(model="gpt-4", usage_id="primary") + >>> fallback = LLM(model="gpt-3.5-turbo", usage_id="fallback") + >>> router = FallbackRouter( + ... usage_id="fallback-router", + ... llms_for_routing={"primary": primary, "fallback": fallback} + ... ) + >>> # Will try primary first, then fallback if primary fails + >>> response = router.completion(messages) + """ + + router_name: str = "fallback_router" + + PRIMARY_MODEL_KEY: ClassVar[str] = "primary" + + def select_llm(self, messages: list[Message]) -> str: # noqa: ARG002 + """ + For fallback router, we always start with the primary model. + The fallback logic is implemented in the completion() method. + """ + return self.PRIMARY_MODEL_KEY + + def completion( + self, + messages: list[Message], + tools: Sequence[ToolDefinition] | None = None, + return_metrics: bool = False, + add_security_risk_prediction: bool = False, + **kwargs, + ) -> LLMResponse: + """ + Try models in order until one succeeds. Falls back to next model + on retry-able exceptions (rate limits, timeouts, service errors). + """ + # Get ordered list of model keys + model_keys = list(self.llms_for_routing.keys()) + last_exception = None + + for i, model_key in enumerate(model_keys): + llm = self.llms_for_routing[model_key] + is_last_model = i == len(model_keys) - 1 + + try: + logger.info( + f"FallbackRouter: Attempting completion with model " + f"'{model_key}' ({llm.model})" + ) + self.active_llm = llm + + response = llm.completion( + messages=messages, + tools=tools, + _return_metrics=return_metrics, + add_security_risk_prediction=add_security_risk_prediction, + **kwargs, + ) + + logger.info( + f"FallbackRouter: Successfully completed with model '{model_key}'" + ) + return response + + except Exception as e: + last_exception = e + logger.warning( + f"FallbackRouter: Model '{model_key}' failed with " + f"{type(e).__name__}: {str(e)}" + ) + + if is_last_model: + logger.error( + "FallbackRouter: All models failed. Raising last exception." + ) + raise + else: + next_model = model_keys[i + 1] + logger.info(f"FallbackRouter: Falling back to '{next_model}'...") + + # This should never happen, but satisfy type checker + assert last_exception is not None + raise last_exception + + @model_validator(mode="after") + def _validate_llms_for_routing(self) -> "FallbackRouter": + """Ensure required primary model is present in llms_for_routing.""" + if self.PRIMARY_MODEL_KEY not in self.llms_for_routing: + raise ValueError( + f"Primary LLM key '{self.PRIMARY_MODEL_KEY}' not found " + "in llms_for_routing." + ) + return self diff --git a/tests/sdk/llm/test_fallback_router.py b/tests/sdk/llm/test_fallback_router.py new file mode 100644 index 0000000000..9bef96e71b --- /dev/null +++ b/tests/sdk/llm/test_fallback_router.py @@ -0,0 +1,293 @@ +"""Tests for FallbackRouter functionality.""" + +from unittest.mock import patch + +import pytest +from litellm.exceptions import ( + APIConnectionError, + RateLimitError, + ServiceUnavailableError, +) +from litellm.types.utils import ( + Choices, + Message as LiteLLMMessage, + ModelResponse, + Usage, +) +from pydantic import SecretStr + +from openhands.sdk.llm import LLM, Message, TextContent +from openhands.sdk.llm.exceptions import LLMServiceUnavailableError +from openhands.sdk.llm.router import FallbackRouter + + +def create_mock_response(content: str = "Test response", model: str = "test-model"): + """Helper function to create properly structured mock responses.""" + return ModelResponse( + id="test-id", + choices=[ + Choices( + finish_reason="stop", + index=0, + message=LiteLLMMessage(content=content, role="assistant"), + ) + ], + created=1234567890, + model=model, + object="chat.completion", + system_fingerprint="test", + usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + +@pytest.fixture +def primary_llm(): + """Create a primary LLM for testing.""" + return LLM( + model="gpt-4", + api_key=SecretStr("test-key"), + usage_id="primary", + num_retries=0, # Disable retries for faster tests + ) + + +@pytest.fixture +def fallback_llm(): + """Create a fallback LLM for testing.""" + return LLM( + model="gpt-3.5-turbo", + api_key=SecretStr("test-key"), + usage_id="fallback", + num_retries=0, # Disable retries for faster tests + ) + + +@pytest.fixture +def test_messages(): + """Create test messages.""" + return [Message(role="user", content=[TextContent(text="Hello")])] + + +def test_fallback_router_creation(primary_llm, fallback_llm): + """Test that FallbackRouter can be created with primary and fallback models.""" + router = FallbackRouter( + usage_id="test-router", + llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + ) + assert router.router_name == "fallback_router" + assert len(router.llms_for_routing) == 2 + assert "primary" in router.llms_for_routing + assert "fallback" in router.llms_for_routing + + +def test_fallback_router_requires_primary(fallback_llm): + """Test that FallbackRouter requires a 'primary' model.""" + with pytest.raises(ValueError, match="Primary LLM key 'primary' not found"): + FallbackRouter( + usage_id="test-router", + llms_for_routing={"fallback": fallback_llm}, + ) + + +def test_fallback_router_success_with_primary(primary_llm, fallback_llm, test_messages): + """Test that router uses primary model when it succeeds.""" + router = FallbackRouter( + usage_id="test-router", + llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + ) + + mock_response = create_mock_response(content="Primary response", model="gpt-4") + + with patch.object( + primary_llm, "_transport_call", return_value=mock_response + ) as mock_primary: + response = router.completion(messages=test_messages) + + # Verify primary was called + mock_primary.assert_called_once() + + # Verify response is from primary + assert isinstance(response.message.content[0], TextContent) + assert response.message.content[0].text == "Primary response" + assert router.active_llm == primary_llm + + +def test_fallback_router_falls_back_on_rate_limit( + primary_llm, fallback_llm, test_messages +): + """Test that router falls back to secondary model on rate limit error.""" + router = FallbackRouter( + usage_id="test-router", + llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + ) + + mock_fallback_response = create_mock_response( + content="Fallback response", model="gpt-3.5-turbo" + ) + + with ( + patch.object( + primary_llm, + "_transport_call", + side_effect=RateLimitError( + message="Rate limit exceeded", + model="gpt-4", + llm_provider="openai", + ), + ) as mock_primary, + patch.object( + fallback_llm, "_transport_call", return_value=mock_fallback_response + ) as mock_fallback, + ): + response = router.completion(messages=test_messages) + + # Verify both were called + mock_primary.assert_called_once() + mock_fallback.assert_called_once() + + # Verify response is from fallback + assert isinstance(response.message.content[0], TextContent) + assert response.message.content[0].text == "Fallback response" + assert router.active_llm == fallback_llm + + +def test_fallback_router_falls_back_on_connection_error( + primary_llm, fallback_llm, test_messages +): + """Test that router falls back on API connection error.""" + router = FallbackRouter( + usage_id="test-router", + llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + ) + + mock_fallback_response = create_mock_response( + content="Fallback response", model="gpt-3.5-turbo" + ) + + with ( + patch.object( + primary_llm, + "_transport_call", + side_effect=APIConnectionError( + message="Connection failed", + model="gpt-4", + llm_provider="openai", + ), + ), + patch.object( + fallback_llm, "_transport_call", return_value=mock_fallback_response + ), + ): + response = router.completion(messages=test_messages) + assert isinstance(response.message.content[0], TextContent) + assert response.message.content[0].text == "Fallback response" + assert router.active_llm == fallback_llm + + +def test_fallback_router_raises_when_all_fail(primary_llm, fallback_llm, test_messages): + """Test that router raises exception when all models fail.""" + router = FallbackRouter( + usage_id="test-router", + llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + ) + + with ( + patch.object( + primary_llm, + "_transport_call", + side_effect=ServiceUnavailableError( + message="Service unavailable", + model="gpt-4", + llm_provider="openai", + ), + ), + patch.object( + fallback_llm, + "_transport_call", + side_effect=ServiceUnavailableError( + message="Service unavailable", + model="gpt-3.5-turbo", + llm_provider="openai", + ), + ), + ): + with pytest.raises(LLMServiceUnavailableError): + router.completion(messages=test_messages) + + +def test_fallback_router_with_multiple_fallbacks(test_messages): + """Test router with multiple fallback models.""" + primary = LLM( + model="gpt-4", + api_key=SecretStr("test-key"), + usage_id="primary", + num_retries=0, + ) + fallback1 = LLM( + model="gpt-3.5-turbo", + api_key=SecretStr("test-key"), + usage_id="fallback1", + num_retries=0, + ) + fallback2 = LLM( + model="gpt-3.5-turbo-16k", + api_key=SecretStr("test-key"), + usage_id="fallback2", + num_retries=0, + ) + + router = FallbackRouter( + usage_id="test-router", + llms_for_routing={ + "primary": primary, + "fallback1": fallback1, + "fallback2": fallback2, + }, + ) + + mock_response = create_mock_response( + content="Fallback2 response", model="gpt-3.5-turbo-16k" + ) + + with ( + patch.object( + primary, + "_transport_call", + side_effect=RateLimitError( + message="Rate limit", model="gpt-4", llm_provider="openai" + ), + ) as mock_primary, + patch.object( + fallback1, + "_transport_call", + side_effect=RateLimitError( + message="Rate limit", model="gpt-3.5-turbo", llm_provider="openai" + ), + ) as mock_fallback1, + patch.object( + fallback2, "_transport_call", return_value=mock_response + ) as mock_fallback2, + ): + response = router.completion(messages=test_messages) + + # Verify all three were tried + mock_primary.assert_called_once() + mock_fallback1.assert_called_once() + mock_fallback2.assert_called_once() + + # Verify response is from fallback2 + assert isinstance(response.message.content[0], TextContent) + assert response.message.content[0].text == "Fallback2 response" + assert router.active_llm == fallback2 + + +def test_fallback_router_select_llm_returns_primary(primary_llm, fallback_llm): + """Test that select_llm always returns primary key.""" + router = FallbackRouter( + usage_id="test-router", + llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + ) + + messages = [Message(role="user", content=[TextContent(text="Test")])] + selected = router.select_llm(messages) + assert selected == "primary" From e0019c8e3e7b7ad95019e3b6ee5c4f9ecfceb32b Mon Sep 17 00:00:00 2001 From: openhands Date: Fri, 7 Nov 2025 22:56:22 +0000 Subject: [PATCH 2/2] Refactor FallbackRouter to use list-based approach - Change llms parameter from dictionary to list for simpler API - Models are tried in list order (similar to litellm's approach) - Internally converts list to dict for base class compatibility - Update validator to check for empty list instead of missing 'primary' key - Update logging to show model index (1/N, 2/N, etc.) - Update example and tests to use new list-based API - Update documentation to reflect list-based approach This makes the API more intuitive and consistent with litellm's pattern. Co-authored-by: openhands --- examples/01_standalone_sdk/27_llm_fallback.py | 21 +++--- .../openhands/sdk/llm/router/impl/fallback.py | 70 ++++++++++--------- tests/sdk/llm/test_fallback_router.py | 40 +++++------ 3 files changed, 65 insertions(+), 66 deletions(-) diff --git a/examples/01_standalone_sdk/27_llm_fallback.py b/examples/01_standalone_sdk/27_llm_fallback.py index 9ef6a91336..d3025668fb 100644 --- a/examples/01_standalone_sdk/27_llm_fallback.py +++ b/examples/01_standalone_sdk/27_llm_fallback.py @@ -38,8 +38,8 @@ model = os.getenv("LLM_MODEL", "claude-sonnet-4-20250514") base_url = os.getenv("LLM_BASE_URL") -# Configure primary and fallback LLMs -# Primary: A powerful but potentially rate-limited model +# Configure LLMs for fallback +# First model: A powerful but potentially rate-limited model primary_llm = LLM( usage_id="primary", model=model, @@ -47,7 +47,7 @@ api_key=SecretStr(api_key), ) -# Fallback 1: A reliable alternative model +# Second model: A reliable alternative model # In a real scenario, this might be a different provider or cheaper model fallback_llm = LLM( usage_id="fallback", @@ -56,15 +56,12 @@ api_key=SecretStr(api_key), ) -# Create FallbackRouter -# Models will be tried in the order they appear in the dictionary -# Note: The first model must have key "primary" +# Create FallbackRouter with a list of LLMs +# Models will be tried in the order they appear in the list +# Similar to how litellm handles fallbacks fallback_router = FallbackRouter( usage_id="fallback-router", - llms_for_routing={ - "primary": primary_llm, - "fallback": fallback_llm, - }, + llms=[primary_llm, fallback_llm], ) # Configure agent with fallback router @@ -74,8 +71,8 @@ # Create conversation conversation = Conversation(agent=agent, workspace=os.getcwd()) -# Send a message - the router will automatically try primary first, -# then fall back if needed +# Send a message - the router will automatically try models in order, +# falling back if one fails conversation.send_message( message=Message( role="user", diff --git a/openhands-sdk/openhands/sdk/llm/router/impl/fallback.py b/openhands-sdk/openhands/sdk/llm/router/impl/fallback.py index 9ba627016b..be02124366 100644 --- a/openhands-sdk/openhands/sdk/llm/router/impl/fallback.py +++ b/openhands-sdk/openhands/sdk/llm/router/impl/fallback.py @@ -1,8 +1,8 @@ from collections.abc import Sequence -from typing import ClassVar -from pydantic import model_validator +from pydantic import field_validator, model_validator +from openhands.sdk.llm.llm import LLM from openhands.sdk.llm.llm_response import LLMResponse from openhands.sdk.llm.message import Message from openhands.sdk.llm.router.base import RouterLLM @@ -16,10 +16,10 @@ class FallbackRouter(RouterLLM): """ A RouterLLM implementation that provides fallback capability across multiple - language models. When the primary model fails due to rate limits, timeouts, - or service unavailability, it automatically falls back to secondary models. + language models. When the first model fails due to rate limits, timeouts, + or service unavailability, it automatically falls back to subsequent models. - Models are tried in order: primary -> fallback1 -> fallback2 -> ... + Similar to litellm's fallback approach, models are tried in the order provided. If all models fail, the exception from the last model is raised. Example: @@ -27,22 +27,38 @@ class FallbackRouter(RouterLLM): >>> fallback = LLM(model="gpt-3.5-turbo", usage_id="fallback") >>> router = FallbackRouter( ... usage_id="fallback-router", - ... llms_for_routing={"primary": primary, "fallback": fallback} + ... llms=[primary, fallback] ... ) - >>> # Will try primary first, then fallback if primary fails + >>> # Will try models in order until one succeeds >>> response = router.completion(messages) """ router_name: str = "fallback_router" - - PRIMARY_MODEL_KEY: ClassVar[str] = "primary" + llms: list[LLM] + + @model_validator(mode="before") + @classmethod + def _convert_llms_to_routing(cls, values: dict) -> dict: + """Convert llms list to llms_for_routing dict for base class compatibility.""" + if "llms" in values and "llms_for_routing" not in values: + llms = values["llms"] + values["llms_for_routing"] = {f"llm_{i}": llm for i, llm in enumerate(llms)} + return values + + @field_validator("llms") + @classmethod + def _validate_llms(cls, llms: list[LLM]) -> list[LLM]: + """Ensure at least one LLM is provided.""" + if not llms: + raise ValueError("FallbackRouter requires at least one LLM") + return llms def select_llm(self, messages: list[Message]) -> str: # noqa: ARG002 """ - For fallback router, we always start with the primary model. + For fallback router, we always start with the first model. The fallback logic is implemented in the completion() method. """ - return self.PRIMARY_MODEL_KEY + return "llm_0" def completion( self, @@ -56,18 +72,15 @@ def completion( Try models in order until one succeeds. Falls back to next model on retry-able exceptions (rate limits, timeouts, service errors). """ - # Get ordered list of model keys - model_keys = list(self.llms_for_routing.keys()) last_exception = None - for i, model_key in enumerate(model_keys): - llm = self.llms_for_routing[model_key] - is_last_model = i == len(model_keys) - 1 + for i, llm in enumerate(self.llms): + is_last_model = i == len(self.llms) - 1 try: logger.info( f"FallbackRouter: Attempting completion with model " - f"'{model_key}' ({llm.model})" + f"{i + 1}/{len(self.llms)} ({llm.model}, usage_id={llm.usage_id})" ) self.active_llm = llm @@ -80,15 +93,16 @@ def completion( ) logger.info( - f"FallbackRouter: Successfully completed with model '{model_key}'" + f"FallbackRouter: Successfully completed with model " + f"{llm.model} (usage_id={llm.usage_id})" ) return response except Exception as e: last_exception = e logger.warning( - f"FallbackRouter: Model '{model_key}' failed with " - f"{type(e).__name__}: {str(e)}" + f"FallbackRouter: Model {llm.model} (usage_id={llm.usage_id}) " + f"failed with {type(e).__name__}: {str(e)}" ) if is_last_model: @@ -97,19 +111,11 @@ def completion( ) raise else: - next_model = model_keys[i + 1] - logger.info(f"FallbackRouter: Falling back to '{next_model}'...") + logger.info( + "FallbackRouter: Falling back to model " + f"{i + 2}/{len(self.llms)}..." + ) # This should never happen, but satisfy type checker assert last_exception is not None raise last_exception - - @model_validator(mode="after") - def _validate_llms_for_routing(self) -> "FallbackRouter": - """Ensure required primary model is present in llms_for_routing.""" - if self.PRIMARY_MODEL_KEY not in self.llms_for_routing: - raise ValueError( - f"Primary LLM key '{self.PRIMARY_MODEL_KEY}' not found " - "in llms_for_routing." - ) - return self diff --git a/tests/sdk/llm/test_fallback_router.py b/tests/sdk/llm/test_fallback_router.py index 9bef96e71b..c9ea5c5c5a 100644 --- a/tests/sdk/llm/test_fallback_router.py +++ b/tests/sdk/llm/test_fallback_router.py @@ -69,23 +69,23 @@ def test_messages(): def test_fallback_router_creation(primary_llm, fallback_llm): - """Test that FallbackRouter can be created with primary and fallback models.""" + """Test that FallbackRouter can be created with a list of models.""" router = FallbackRouter( usage_id="test-router", - llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + llms=[primary_llm, fallback_llm], ) assert router.router_name == "fallback_router" - assert len(router.llms_for_routing) == 2 - assert "primary" in router.llms_for_routing - assert "fallback" in router.llms_for_routing + assert len(router.llms) == 2 + assert router.llms[0] == primary_llm + assert router.llms[1] == fallback_llm -def test_fallback_router_requires_primary(fallback_llm): - """Test that FallbackRouter requires a 'primary' model.""" - with pytest.raises(ValueError, match="Primary LLM key 'primary' not found"): +def test_fallback_router_requires_at_least_one_llm(): + """Test that FallbackRouter requires at least one LLM.""" + with pytest.raises(ValueError, match="at least one LLM"): FallbackRouter( usage_id="test-router", - llms_for_routing={"fallback": fallback_llm}, + llms=[], ) @@ -93,7 +93,7 @@ def test_fallback_router_success_with_primary(primary_llm, fallback_llm, test_me """Test that router uses primary model when it succeeds.""" router = FallbackRouter( usage_id="test-router", - llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + llms=[primary_llm, fallback_llm], ) mock_response = create_mock_response(content="Primary response", model="gpt-4") @@ -118,7 +118,7 @@ def test_fallback_router_falls_back_on_rate_limit( """Test that router falls back to secondary model on rate limit error.""" router = FallbackRouter( usage_id="test-router", - llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + llms=[primary_llm, fallback_llm], ) mock_fallback_response = create_mock_response( @@ -157,7 +157,7 @@ def test_fallback_router_falls_back_on_connection_error( """Test that router falls back on API connection error.""" router = FallbackRouter( usage_id="test-router", - llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + llms=[primary_llm, fallback_llm], ) mock_fallback_response = create_mock_response( @@ -188,7 +188,7 @@ def test_fallback_router_raises_when_all_fail(primary_llm, fallback_llm, test_me """Test that router raises exception when all models fail.""" router = FallbackRouter( usage_id="test-router", - llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + llms=[primary_llm, fallback_llm], ) with ( @@ -238,11 +238,7 @@ def test_fallback_router_with_multiple_fallbacks(test_messages): router = FallbackRouter( usage_id="test-router", - llms_for_routing={ - "primary": primary, - "fallback1": fallback1, - "fallback2": fallback2, - }, + llms=[primary, fallback1, fallback2], ) mock_response = create_mock_response( @@ -281,13 +277,13 @@ def test_fallback_router_with_multiple_fallbacks(test_messages): assert router.active_llm == fallback2 -def test_fallback_router_select_llm_returns_primary(primary_llm, fallback_llm): - """Test that select_llm always returns primary key.""" +def test_fallback_router_select_llm_returns_first(primary_llm, fallback_llm): + """Test that select_llm always returns the first LLM's key.""" router = FallbackRouter( usage_id="test-router", - llms_for_routing={"primary": primary_llm, "fallback": fallback_llm}, + llms=[primary_llm, fallback_llm], ) messages = [Message(role="user", content=[TextContent(text="Test")])] selected = router.select_llm(messages) - assert selected == "primary" + assert selected == "llm_0"