diff --git a/src/llama_stack/providers/utils/inference/openai_mixin.py b/src/llama_stack/providers/utils/inference/openai_mixin.py index 559ac90ceb..30511a3419 100644 --- a/src/llama_stack/providers/utils/inference/openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/openai_mixin.py @@ -213,6 +213,19 @@ def _get_api_key_from_config_or_provider_data(self) -> str | None: return api_key + def _validate_model_allowed(self, provider_model_id: str) -> None: + """ + Validate that the model is in the allowed_models list if configured. + + :param provider_model_id: The provider-specific model ID to validate + :raises ValueError: If the model is not in the allowed_models list + """ + if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models: + raise ValueError( + f"Model '{provider_model_id}' is not in the allowed models list. " + f"Allowed models: {self.config.allowed_models}" + ) + async def _get_provider_model_id(self, model: str) -> str: """ Get the provider-specific model ID from the model store. @@ -259,8 +272,11 @@ async def openai_completion( Direct OpenAI completion API call. """ # TODO: fix openai_completion to return type compatible with OpenAI's API response + provider_model_id = await self._get_provider_model_id(params.model) + self._validate_model_allowed(provider_model_id) + completion_kwargs = await prepare_openai_completion_params( - model=await self._get_provider_model_id(params.model), + model=provider_model_id, prompt=params.prompt, best_of=params.best_of, echo=params.echo, @@ -292,6 +308,9 @@ async def openai_chat_completion( """ Direct OpenAI chat completion API call. """ + provider_model_id = await self._get_provider_model_id(params.model) + self._validate_model_allowed(provider_model_id) + messages = params.messages if self.download_images: @@ -313,7 +332,7 @@ async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam: messages = [await _localize_image_url(m) for m in messages] request_params = await prepare_openai_completion_params( - model=await self._get_provider_model_id(params.model), + model=provider_model_id, messages=messages, frequency_penalty=params.frequency_penalty, function_call=params.function_call, @@ -351,10 +370,13 @@ async def openai_embeddings( """ Direct OpenAI embeddings API call. """ + provider_model_id = await self._get_provider_model_id(params.model) + self._validate_model_allowed(provider_model_id) + # Build request params conditionally to avoid NotGiven/Omit type mismatch # The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven request_params: dict[str, Any] = { - "model": await self._get_provider_model_id(params.model), + "model": provider_model_id, "input": params.input, } if params.encoding_format is not None: diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 5b13a75f48..02d44f2ba6 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -15,7 +15,14 @@ from llama_stack.core.request_headers import request_provider_data_context from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin -from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam +from llama_stack_api import ( + Model, + ModelType, + OpenAIChatCompletionRequestWithExtraBody, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIUserMessageParam, +) class OpenAIMixinImpl(OpenAIMixin): @@ -834,3 +841,96 @@ def test_error_message_includes_correct_field_names(self, mixin_with_provider_da error_message = str(exc_info.value) assert "test_api_key" in error_message assert "x-llamastack-provider-data" in error_message + + +class TestOpenAIMixinAllowedModelsInference: + """Test cases for allowed_models enforcement during inference requests""" + + async def test_inference_with_allowed_models(self, mixin, mock_client_context): + """Test that all inference methods succeed with allowed models""" + mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"] + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + mock_client.completions.create = AsyncMock(return_value=MagicMock()) + mock_embedding_response = MagicMock() + mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5) + mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response) + + with mock_client_context(mixin, mock_client): + # Test chat completion + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + mock_client.chat.completions.create.assert_called_once() + + # Test completion + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello") + ) + mock_client.completions.create.assert_called_once() + + # Test embeddings + await mixin.openai_embeddings( + OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text") + ) + mock_client.embeddings.create.assert_called_once() + + async def test_inference_with_disallowed_models(self, mixin, mock_client_context): + """Test that all inference methods fail with disallowed models""" + mixin.config.allowed_models = ["gpt-4"] + + mock_client = MagicMock() + + with mock_client_context(mixin, mock_client): + # Test chat completion with disallowed model + with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + + # Test completion with disallowed model + with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"): + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello") + ) + + # Test embeddings with disallowed model + with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"): + await mixin.openai_embeddings( + OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text") + ) + + mock_client.chat.completions.create.assert_not_called() + mock_client.completions.create.assert_not_called() + mock_client.embeddings.create.assert_not_called() + + async def test_inference_with_no_restrictions(self, mixin, mock_client_context): + """Test that inference succeeds when allowed_models is None or empty list blocks all""" + # Test with None (no restrictions) + assert mixin.config.allowed_models is None + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + mock_client.chat.completions.create.assert_called_once() + + # Test with empty list (blocks all models) + mixin.config.allowed_models = [] + with mock_client_context(mixin, mock_client): + with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + )