Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions src/llama_stack/providers/utils/inference/openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,19 @@ def _get_api_key_from_config_or_provider_data(self) -> str | None:

return api_key

def _validate_model_allowed(self, provider_model_id: str) -> None:
"""
Validate that the model is in the allowed_models list if configured.

:param provider_model_id: The provider-specific model ID to validate
:raises ValueError: If the model is not in the allowed_models list
"""
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
raise ValueError(
f"Model '{provider_model_id}' is not in the allowed models list. "
f"Allowed models: {self.config.allowed_models}"
)

async def _get_provider_model_id(self, model: str) -> str:
"""
Get the provider-specific model ID from the model store.
Expand Down Expand Up @@ -259,8 +272,11 @@ async def openai_completion(
Direct OpenAI completion API call.
"""
# TODO: fix openai_completion to return type compatible with OpenAI's API response
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)

completion_kwargs = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model),
model=provider_model_id,
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
Expand Down Expand Up @@ -292,6 +308,9 @@ async def openai_chat_completion(
"""
Direct OpenAI chat completion API call.
"""
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)

messages = params.messages

if self.download_images:
Expand All @@ -313,7 +332,7 @@ async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
messages = [await _localize_image_url(m) for m in messages]

request_params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model),
model=provider_model_id,
messages=messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
Expand Down Expand Up @@ -351,10 +370,13 @@ async def openai_embeddings(
"""
Direct OpenAI embeddings API call.
"""
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)

# Build request params conditionally to avoid NotGiven/Omit type mismatch
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
request_params: dict[str, Any] = {
"model": await self._get_provider_model_id(params.model),
"model": provider_model_id,
"input": params.input,
}
if params.encoding_format is not None:
Expand Down
102 changes: 101 additions & 1 deletion tests/unit/providers/utils/inference/test_openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack_api import (
Model,
ModelType,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIUserMessageParam,
)


class OpenAIMixinImpl(OpenAIMixin):
Expand Down Expand Up @@ -834,3 +841,96 @@ def test_error_message_includes_correct_field_names(self, mixin_with_provider_da
error_message = str(exc_info.value)
assert "test_api_key" in error_message
assert "x-llamastack-provider-data" in error_message


class TestOpenAIMixinAllowedModelsInference:
"""Test cases for allowed_models enforcement during inference requests"""

async def test_inference_with_allowed_models(self, mixin, mock_client_context):
"""Test that all inference methods succeed with allowed models"""
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]

mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
mock_client.completions.create = AsyncMock(return_value=MagicMock())
mock_embedding_response = MagicMock()
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)

with mock_client_context(mixin, mock_client):
# Test chat completion
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()

# Test completion
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
)
mock_client.completions.create.assert_called_once()

# Test embeddings
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
)
mock_client.embeddings.create.assert_called_once()

async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
"""Test that all inference methods fail with disallowed models"""
mixin.config.allowed_models = ["gpt-4"]

mock_client = MagicMock()

with mock_client_context(mixin, mock_client):
# Test chat completion with disallowed model
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)

# Test completion with disallowed model
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
)

# Test embeddings with disallowed model
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
)

mock_client.chat.completions.create.assert_not_called()
mock_client.completions.create.assert_not_called()
mock_client.embeddings.create.assert_not_called()

async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
# Test with None (no restrictions)
assert mixin.config.allowed_models is None
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())

with mock_client_context(mixin, mock_client):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()

# Test with empty list (blocks all models)
mixin.config.allowed_models = []
with mock_client_context(mixin, mock_client):
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
Loading