From db6488b379b93f71cca0a2b9595531eaccc69045 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 19 Nov 2025 12:12:28 -0800 Subject: [PATCH] fix: enforce allowed_models during inference requests The `allowed_models` configuration was only filtering the model list endpoint but not enforcing restrictions during actual inference requests. This allowed users to bypass the restriction by directly requesting models not in the allowed list, potentially accessing expensive models when only cheaper ones were intended. This change adds validation to all inference methods (`openai_chat_completion`, `openai_completion`, `openai_embeddings`) to reject requests for disallowed models with a clear error message. **Implementation:** - Added `_validate_model_allowed()` helper method that checks if a model is in the `allowed_models` list - Called validation in all three inference methods before making API requests - Validation occurs after resolving the provider model ID to ensure consistency **Test Plan:** - Added unit tests verifying all inference methods respect `allowed_models` - Tests cover allowed models (success), disallowed models (rejection), and no restrictions (None allows all, empty list blocks all) - All existing tests continue to pass Fixes GHSA-5rjj-4jp6-fw39 --- .../providers/utils/inference/openai_mixin.py | 28 ++++- .../utils/inference/test_openai_mixin.py | 102 +++++++++++++++++- 2 files changed, 126 insertions(+), 4 deletions(-) diff --git a/src/llama_stack/providers/utils/inference/openai_mixin.py b/src/llama_stack/providers/utils/inference/openai_mixin.py index 559ac90ceb..30511a3419 100644 --- a/src/llama_stack/providers/utils/inference/openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/openai_mixin.py @@ -213,6 +213,19 @@ def _get_api_key_from_config_or_provider_data(self) -> str | None: return api_key + def _validate_model_allowed(self, provider_model_id: str) -> None: + """ + Validate that the model is in the allowed_models list if configured. + + :param provider_model_id: The provider-specific model ID to validate + :raises ValueError: If the model is not in the allowed_models list + """ + if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models: + raise ValueError( + f"Model '{provider_model_id}' is not in the allowed models list. " + f"Allowed models: {self.config.allowed_models}" + ) + async def _get_provider_model_id(self, model: str) -> str: """ Get the provider-specific model ID from the model store. @@ -259,8 +272,11 @@ async def openai_completion( Direct OpenAI completion API call. """ # TODO: fix openai_completion to return type compatible with OpenAI's API response + provider_model_id = await self._get_provider_model_id(params.model) + self._validate_model_allowed(provider_model_id) + completion_kwargs = await prepare_openai_completion_params( - model=await self._get_provider_model_id(params.model), + model=provider_model_id, prompt=params.prompt, best_of=params.best_of, echo=params.echo, @@ -292,6 +308,9 @@ async def openai_chat_completion( """ Direct OpenAI chat completion API call. """ + provider_model_id = await self._get_provider_model_id(params.model) + self._validate_model_allowed(provider_model_id) + messages = params.messages if self.download_images: @@ -313,7 +332,7 @@ async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam: messages = [await _localize_image_url(m) for m in messages] request_params = await prepare_openai_completion_params( - model=await self._get_provider_model_id(params.model), + model=provider_model_id, messages=messages, frequency_penalty=params.frequency_penalty, function_call=params.function_call, @@ -351,10 +370,13 @@ async def openai_embeddings( """ Direct OpenAI embeddings API call. """ + provider_model_id = await self._get_provider_model_id(params.model) + self._validate_model_allowed(provider_model_id) + # Build request params conditionally to avoid NotGiven/Omit type mismatch # The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven request_params: dict[str, Any] = { - "model": await self._get_provider_model_id(params.model), + "model": provider_model_id, "input": params.input, } if params.encoding_format is not None: diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 5b13a75f48..02d44f2ba6 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -15,7 +15,14 @@ from llama_stack.core.request_headers import request_provider_data_context from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin -from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam +from llama_stack_api import ( + Model, + ModelType, + OpenAIChatCompletionRequestWithExtraBody, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIUserMessageParam, +) class OpenAIMixinImpl(OpenAIMixin): @@ -834,3 +841,96 @@ def test_error_message_includes_correct_field_names(self, mixin_with_provider_da error_message = str(exc_info.value) assert "test_api_key" in error_message assert "x-llamastack-provider-data" in error_message + + +class TestOpenAIMixinAllowedModelsInference: + """Test cases for allowed_models enforcement during inference requests""" + + async def test_inference_with_allowed_models(self, mixin, mock_client_context): + """Test that all inference methods succeed with allowed models""" + mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"] + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + mock_client.completions.create = AsyncMock(return_value=MagicMock()) + mock_embedding_response = MagicMock() + mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5) + mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response) + + with mock_client_context(mixin, mock_client): + # Test chat completion + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + mock_client.chat.completions.create.assert_called_once() + + # Test completion + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello") + ) + mock_client.completions.create.assert_called_once() + + # Test embeddings + await mixin.openai_embeddings( + OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text") + ) + mock_client.embeddings.create.assert_called_once() + + async def test_inference_with_disallowed_models(self, mixin, mock_client_context): + """Test that all inference methods fail with disallowed models""" + mixin.config.allowed_models = ["gpt-4"] + + mock_client = MagicMock() + + with mock_client_context(mixin, mock_client): + # Test chat completion with disallowed model + with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + + # Test completion with disallowed model + with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"): + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello") + ) + + # Test embeddings with disallowed model + with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"): + await mixin.openai_embeddings( + OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text") + ) + + mock_client.chat.completions.create.assert_not_called() + mock_client.completions.create.assert_not_called() + mock_client.embeddings.create.assert_not_called() + + async def test_inference_with_no_restrictions(self, mixin, mock_client_context): + """Test that inference succeeds when allowed_models is None or empty list blocks all""" + # Test with None (no restrictions) + assert mixin.config.allowed_models is None + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + mock_client.chat.completions.create.assert_called_once() + + # Test with empty list (blocks all models) + mixin.config.allowed_models = [] + with mock_client_context(mixin, mock_client): + with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + )