diff --git a/examples/01_standalone_sdk/27_llm_fallback.py b/examples/01_standalone_sdk/27_llm_fallback.py new file mode 100644 index 0000000000..d3025668fb --- /dev/null +++ b/examples/01_standalone_sdk/27_llm_fallback.py @@ -0,0 +1,111 @@ +""" +Example demonstrating LLM fallback functionality using FallbackRouter. + +This example shows how to configure multiple language models with automatic +fallback capability. If the primary model fails (due to rate limits, timeouts, +or service unavailability), the system automatically falls back to secondary +models. + +Use cases: +- High availability: Ensure your application continues working even if one + provider has an outage +- Rate limit handling: Automatically switch to a backup model when you hit + rate limits +- Cost optimization: Use expensive models as primary but have cheaper backups +""" + +import os + +from pydantic import SecretStr + +from openhands.sdk import ( + LLM, + Agent, + Conversation, + Message, + TextContent, + get_logger, +) +from openhands.sdk.llm.router import FallbackRouter +from openhands.tools.preset.default import get_default_tools + + +logger = get_logger(__name__) + +# Configure API credentials +api_key = os.getenv("LLM_API_KEY") +assert api_key is not None, "LLM_API_KEY environment variable is not set." +model = os.getenv("LLM_MODEL", "claude-sonnet-4-20250514") +base_url = os.getenv("LLM_BASE_URL") + +# Configure LLMs for fallback +# First model: A powerful but potentially rate-limited model +primary_llm = LLM( + usage_id="primary", + model=model, + base_url=base_url, + api_key=SecretStr(api_key), +) + +# Second model: A reliable alternative model +# In a real scenario, this might be a different provider or cheaper model +fallback_llm = LLM( + usage_id="fallback", + model="openhands/devstral-small-2507", + base_url=base_url, + api_key=SecretStr(api_key), +) + +# Create FallbackRouter with a list of LLMs +# Models will be tried in the order they appear in the list +# Similar to how litellm handles fallbacks +fallback_router = FallbackRouter( + usage_id="fallback-router", + llms=[primary_llm, fallback_llm], +) + +# Configure agent with fallback router +tools = get_default_tools() +agent = Agent(llm=fallback_router, tools=tools) + +# Create conversation +conversation = Conversation(agent=agent, workspace=os.getcwd()) + +# Send a message - the router will automatically try models in order, +# falling back if one fails +conversation.send_message( + message=Message( + role="user", + content=[ + TextContent( + text=( + "Hello! Can you tell me what the current date is? " + "You can use the bash tool to run the 'date' command." + ) + ) + ], + ) +) + +# Run the conversation +conversation.run() + +# Display results +print("=" * 100) +print("Conversation completed successfully!") +if fallback_router.active_llm: + print(f"Active model used: {fallback_router.active_llm.model}") +else: + print("No active model (no completions made)") + +# Report costs +metrics = conversation.conversation_stats.get_combined_metrics() +print(f"Total cost: ${metrics.accumulated_cost:.4f}") +print(f"Total tokens: {metrics.accumulated_token_usage}") + +print("\n" + "=" * 100) +print("Key features demonstrated:") +print("1. Automatic fallback when primary model fails") +print("2. Transparent switching between models") +print("3. Cost and usage tracking across all models") +print("4. Works seamlessly with agents and tools") diff --git a/openhands-sdk/openhands/sdk/llm/router/__init__.py b/openhands-sdk/openhands/sdk/llm/router/__init__.py index 37e7baca4a..425b2a8f9b 100644 --- a/openhands-sdk/openhands/sdk/llm/router/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/router/__init__.py @@ -1,4 +1,5 @@ from openhands.sdk.llm.router.base import RouterLLM +from openhands.sdk.llm.router.impl.fallback import FallbackRouter from openhands.sdk.llm.router.impl.multimodal import MultimodalRouter from openhands.sdk.llm.router.impl.random import RandomRouter @@ -7,4 +8,5 @@ "RouterLLM", "RandomRouter", "MultimodalRouter", + "FallbackRouter", ] diff --git a/openhands-sdk/openhands/sdk/llm/router/impl/fallback.py b/openhands-sdk/openhands/sdk/llm/router/impl/fallback.py new file mode 100644 index 0000000000..be02124366 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/router/impl/fallback.py @@ -0,0 +1,121 @@ +from collections.abc import Sequence + +from pydantic import field_validator, model_validator + +from openhands.sdk.llm.llm import LLM +from openhands.sdk.llm.llm_response import LLMResponse +from openhands.sdk.llm.message import Message +from openhands.sdk.llm.router.base import RouterLLM +from openhands.sdk.logger import get_logger +from openhands.sdk.tool.tool import ToolDefinition + + +logger = get_logger(__name__) + + +class FallbackRouter(RouterLLM): + """ + A RouterLLM implementation that provides fallback capability across multiple + language models. When the first model fails due to rate limits, timeouts, + or service unavailability, it automatically falls back to subsequent models. + + Similar to litellm's fallback approach, models are tried in the order provided. + If all models fail, the exception from the last model is raised. + + Example: + >>> primary = LLM(model="gpt-4", usage_id="primary") + >>> fallback = LLM(model="gpt-3.5-turbo", usage_id="fallback") + >>> router = FallbackRouter( + ... usage_id="fallback-router", + ... llms=[primary, fallback] + ... ) + >>> # Will try models in order until one succeeds + >>> response = router.completion(messages) + """ + + router_name: str = "fallback_router" + llms: list[LLM] + + @model_validator(mode="before") + @classmethod + def _convert_llms_to_routing(cls, values: dict) -> dict: + """Convert llms list to llms_for_routing dict for base class compatibility.""" + if "llms" in values and "llms_for_routing" not in values: + llms = values["llms"] + values["llms_for_routing"] = {f"llm_{i}": llm for i, llm in enumerate(llms)} + return values + + @field_validator("llms") + @classmethod + def _validate_llms(cls, llms: list[LLM]) -> list[LLM]: + """Ensure at least one LLM is provided.""" + if not llms: + raise ValueError("FallbackRouter requires at least one LLM") + return llms + + def select_llm(self, messages: list[Message]) -> str: # noqa: ARG002 + """ + For fallback router, we always start with the first model. + The fallback logic is implemented in the completion() method. + """ + return "llm_0" + + def completion( + self, + messages: list[Message], + tools: Sequence[ToolDefinition] | None = None, + return_metrics: bool = False, + add_security_risk_prediction: bool = False, + **kwargs, + ) -> LLMResponse: + """ + Try models in order until one succeeds. Falls back to next model + on retry-able exceptions (rate limits, timeouts, service errors). + """ + last_exception = None + + for i, llm in enumerate(self.llms): + is_last_model = i == len(self.llms) - 1 + + try: + logger.info( + f"FallbackRouter: Attempting completion with model " + f"{i + 1}/{len(self.llms)} ({llm.model}, usage_id={llm.usage_id})" + ) + self.active_llm = llm + + response = llm.completion( + messages=messages, + tools=tools, + _return_metrics=return_metrics, + add_security_risk_prediction=add_security_risk_prediction, + **kwargs, + ) + + logger.info( + f"FallbackRouter: Successfully completed with model " + f"{llm.model} (usage_id={llm.usage_id})" + ) + return response + + except Exception as e: + last_exception = e + logger.warning( + f"FallbackRouter: Model {llm.model} (usage_id={llm.usage_id}) " + f"failed with {type(e).__name__}: {str(e)}" + ) + + if is_last_model: + logger.error( + "FallbackRouter: All models failed. Raising last exception." + ) + raise + else: + logger.info( + "FallbackRouter: Falling back to model " + f"{i + 2}/{len(self.llms)}..." + ) + + # This should never happen, but satisfy type checker + assert last_exception is not None + raise last_exception diff --git a/tests/sdk/llm/test_fallback_router.py b/tests/sdk/llm/test_fallback_router.py new file mode 100644 index 0000000000..c9ea5c5c5a --- /dev/null +++ b/tests/sdk/llm/test_fallback_router.py @@ -0,0 +1,289 @@ +"""Tests for FallbackRouter functionality.""" + +from unittest.mock import patch + +import pytest +from litellm.exceptions import ( + APIConnectionError, + RateLimitError, + ServiceUnavailableError, +) +from litellm.types.utils import ( + Choices, + Message as LiteLLMMessage, + ModelResponse, + Usage, +) +from pydantic import SecretStr + +from openhands.sdk.llm import LLM, Message, TextContent +from openhands.sdk.llm.exceptions import LLMServiceUnavailableError +from openhands.sdk.llm.router import FallbackRouter + + +def create_mock_response(content: str = "Test response", model: str = "test-model"): + """Helper function to create properly structured mock responses.""" + return ModelResponse( + id="test-id", + choices=[ + Choices( + finish_reason="stop", + index=0, + message=LiteLLMMessage(content=content, role="assistant"), + ) + ], + created=1234567890, + model=model, + object="chat.completion", + system_fingerprint="test", + usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + +@pytest.fixture +def primary_llm(): + """Create a primary LLM for testing.""" + return LLM( + model="gpt-4", + api_key=SecretStr("test-key"), + usage_id="primary", + num_retries=0, # Disable retries for faster tests + ) + + +@pytest.fixture +def fallback_llm(): + """Create a fallback LLM for testing.""" + return LLM( + model="gpt-3.5-turbo", + api_key=SecretStr("test-key"), + usage_id="fallback", + num_retries=0, # Disable retries for faster tests + ) + + +@pytest.fixture +def test_messages(): + """Create test messages.""" + return [Message(role="user", content=[TextContent(text="Hello")])] + + +def test_fallback_router_creation(primary_llm, fallback_llm): + """Test that FallbackRouter can be created with a list of models.""" + router = FallbackRouter( + usage_id="test-router", + llms=[primary_llm, fallback_llm], + ) + assert router.router_name == "fallback_router" + assert len(router.llms) == 2 + assert router.llms[0] == primary_llm + assert router.llms[1] == fallback_llm + + +def test_fallback_router_requires_at_least_one_llm(): + """Test that FallbackRouter requires at least one LLM.""" + with pytest.raises(ValueError, match="at least one LLM"): + FallbackRouter( + usage_id="test-router", + llms=[], + ) + + +def test_fallback_router_success_with_primary(primary_llm, fallback_llm, test_messages): + """Test that router uses primary model when it succeeds.""" + router = FallbackRouter( + usage_id="test-router", + llms=[primary_llm, fallback_llm], + ) + + mock_response = create_mock_response(content="Primary response", model="gpt-4") + + with patch.object( + primary_llm, "_transport_call", return_value=mock_response + ) as mock_primary: + response = router.completion(messages=test_messages) + + # Verify primary was called + mock_primary.assert_called_once() + + # Verify response is from primary + assert isinstance(response.message.content[0], TextContent) + assert response.message.content[0].text == "Primary response" + assert router.active_llm == primary_llm + + +def test_fallback_router_falls_back_on_rate_limit( + primary_llm, fallback_llm, test_messages +): + """Test that router falls back to secondary model on rate limit error.""" + router = FallbackRouter( + usage_id="test-router", + llms=[primary_llm, fallback_llm], + ) + + mock_fallback_response = create_mock_response( + content="Fallback response", model="gpt-3.5-turbo" + ) + + with ( + patch.object( + primary_llm, + "_transport_call", + side_effect=RateLimitError( + message="Rate limit exceeded", + model="gpt-4", + llm_provider="openai", + ), + ) as mock_primary, + patch.object( + fallback_llm, "_transport_call", return_value=mock_fallback_response + ) as mock_fallback, + ): + response = router.completion(messages=test_messages) + + # Verify both were called + mock_primary.assert_called_once() + mock_fallback.assert_called_once() + + # Verify response is from fallback + assert isinstance(response.message.content[0], TextContent) + assert response.message.content[0].text == "Fallback response" + assert router.active_llm == fallback_llm + + +def test_fallback_router_falls_back_on_connection_error( + primary_llm, fallback_llm, test_messages +): + """Test that router falls back on API connection error.""" + router = FallbackRouter( + usage_id="test-router", + llms=[primary_llm, fallback_llm], + ) + + mock_fallback_response = create_mock_response( + content="Fallback response", model="gpt-3.5-turbo" + ) + + with ( + patch.object( + primary_llm, + "_transport_call", + side_effect=APIConnectionError( + message="Connection failed", + model="gpt-4", + llm_provider="openai", + ), + ), + patch.object( + fallback_llm, "_transport_call", return_value=mock_fallback_response + ), + ): + response = router.completion(messages=test_messages) + assert isinstance(response.message.content[0], TextContent) + assert response.message.content[0].text == "Fallback response" + assert router.active_llm == fallback_llm + + +def test_fallback_router_raises_when_all_fail(primary_llm, fallback_llm, test_messages): + """Test that router raises exception when all models fail.""" + router = FallbackRouter( + usage_id="test-router", + llms=[primary_llm, fallback_llm], + ) + + with ( + patch.object( + primary_llm, + "_transport_call", + side_effect=ServiceUnavailableError( + message="Service unavailable", + model="gpt-4", + llm_provider="openai", + ), + ), + patch.object( + fallback_llm, + "_transport_call", + side_effect=ServiceUnavailableError( + message="Service unavailable", + model="gpt-3.5-turbo", + llm_provider="openai", + ), + ), + ): + with pytest.raises(LLMServiceUnavailableError): + router.completion(messages=test_messages) + + +def test_fallback_router_with_multiple_fallbacks(test_messages): + """Test router with multiple fallback models.""" + primary = LLM( + model="gpt-4", + api_key=SecretStr("test-key"), + usage_id="primary", + num_retries=0, + ) + fallback1 = LLM( + model="gpt-3.5-turbo", + api_key=SecretStr("test-key"), + usage_id="fallback1", + num_retries=0, + ) + fallback2 = LLM( + model="gpt-3.5-turbo-16k", + api_key=SecretStr("test-key"), + usage_id="fallback2", + num_retries=0, + ) + + router = FallbackRouter( + usage_id="test-router", + llms=[primary, fallback1, fallback2], + ) + + mock_response = create_mock_response( + content="Fallback2 response", model="gpt-3.5-turbo-16k" + ) + + with ( + patch.object( + primary, + "_transport_call", + side_effect=RateLimitError( + message="Rate limit", model="gpt-4", llm_provider="openai" + ), + ) as mock_primary, + patch.object( + fallback1, + "_transport_call", + side_effect=RateLimitError( + message="Rate limit", model="gpt-3.5-turbo", llm_provider="openai" + ), + ) as mock_fallback1, + patch.object( + fallback2, "_transport_call", return_value=mock_response + ) as mock_fallback2, + ): + response = router.completion(messages=test_messages) + + # Verify all three were tried + mock_primary.assert_called_once() + mock_fallback1.assert_called_once() + mock_fallback2.assert_called_once() + + # Verify response is from fallback2 + assert isinstance(response.message.content[0], TextContent) + assert response.message.content[0].text == "Fallback2 response" + assert router.active_llm == fallback2 + + +def test_fallback_router_select_llm_returns_first(primary_llm, fallback_llm): + """Test that select_llm always returns the first LLM's key.""" + router = FallbackRouter( + usage_id="test-router", + llms=[primary_llm, fallback_llm], + ) + + messages = [Message(role="user", content=[TextContent(text="Test")])] + selected = router.select_llm(messages) + assert selected == "llm_0"