Skip to content

Commit d649c36

Browse files
authored
fix: enforce allowed_models during inference requests (#4197)
The `allowed_models` configuration was only being applied when listing models via the `/v1/models` endpoint, but the actual inference requests weren't checking this restriction. This meant users could directly request any model the provider supports by specifying it in their inference call, completely bypassing the intended cost controls. The fix adds validation to all three inference methods (chat completions, completions, and embeddings) that checks the requested model against the allowed_models list before making the provider API call. ### Test plan Added unit tests
1 parent b6ce242 commit d649c36

File tree

2 files changed

+126
-4
lines changed

2 files changed

+126
-4
lines changed

src/llama_stack/providers/utils/inference/openai_mixin.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,19 @@ def _get_api_key_from_config_or_provider_data(self) -> str | None:
213213

214214
return api_key
215215

216+
def _validate_model_allowed(self, provider_model_id: str) -> None:
217+
"""
218+
Validate that the model is in the allowed_models list if configured.
219+
220+
:param provider_model_id: The provider-specific model ID to validate
221+
:raises ValueError: If the model is not in the allowed_models list
222+
"""
223+
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
224+
raise ValueError(
225+
f"Model '{provider_model_id}' is not in the allowed models list. "
226+
f"Allowed models: {self.config.allowed_models}"
227+
)
228+
216229
async def _get_provider_model_id(self, model: str) -> str:
217230
"""
218231
Get the provider-specific model ID from the model store.
@@ -259,8 +272,11 @@ async def openai_completion(
259272
Direct OpenAI completion API call.
260273
"""
261274
# TODO: fix openai_completion to return type compatible with OpenAI's API response
275+
provider_model_id = await self._get_provider_model_id(params.model)
276+
self._validate_model_allowed(provider_model_id)
277+
262278
completion_kwargs = await prepare_openai_completion_params(
263-
model=await self._get_provider_model_id(params.model),
279+
model=provider_model_id,
264280
prompt=params.prompt,
265281
best_of=params.best_of,
266282
echo=params.echo,
@@ -292,6 +308,9 @@ async def openai_chat_completion(
292308
"""
293309
Direct OpenAI chat completion API call.
294310
"""
311+
provider_model_id = await self._get_provider_model_id(params.model)
312+
self._validate_model_allowed(provider_model_id)
313+
295314
messages = params.messages
296315

297316
if self.download_images:
@@ -313,7 +332,7 @@ async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
313332
messages = [await _localize_image_url(m) for m in messages]
314333

315334
request_params = await prepare_openai_completion_params(
316-
model=await self._get_provider_model_id(params.model),
335+
model=provider_model_id,
317336
messages=messages,
318337
frequency_penalty=params.frequency_penalty,
319338
function_call=params.function_call,
@@ -351,10 +370,13 @@ async def openai_embeddings(
351370
"""
352371
Direct OpenAI embeddings API call.
353372
"""
373+
provider_model_id = await self._get_provider_model_id(params.model)
374+
self._validate_model_allowed(provider_model_id)
375+
354376
# Build request params conditionally to avoid NotGiven/Omit type mismatch
355377
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
356378
request_params: dict[str, Any] = {
357-
"model": await self._get_provider_model_id(params.model),
379+
"model": provider_model_id,
358380
"input": params.input,
359381
}
360382
if params.encoding_format is not None:

tests/unit/providers/utils/inference/test_openai_mixin.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515
from llama_stack.core.request_headers import request_provider_data_context
1616
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
1717
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
18-
from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
18+
from llama_stack_api import (
19+
Model,
20+
ModelType,
21+
OpenAIChatCompletionRequestWithExtraBody,
22+
OpenAICompletionRequestWithExtraBody,
23+
OpenAIEmbeddingsRequestWithExtraBody,
24+
OpenAIUserMessageParam,
25+
)
1926

2027

2128
class OpenAIMixinImpl(OpenAIMixin):
@@ -834,3 +841,96 @@ def test_error_message_includes_correct_field_names(self, mixin_with_provider_da
834841
error_message = str(exc_info.value)
835842
assert "test_api_key" in error_message
836843
assert "x-llamastack-provider-data" in error_message
844+
845+
846+
class TestOpenAIMixinAllowedModelsInference:
847+
"""Test cases for allowed_models enforcement during inference requests"""
848+
849+
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
850+
"""Test that all inference methods succeed with allowed models"""
851+
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
852+
853+
mock_client = MagicMock()
854+
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
855+
mock_client.completions.create = AsyncMock(return_value=MagicMock())
856+
mock_embedding_response = MagicMock()
857+
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
858+
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
859+
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
860+
861+
with mock_client_context(mixin, mock_client):
862+
# Test chat completion
863+
await mixin.openai_chat_completion(
864+
OpenAIChatCompletionRequestWithExtraBody(
865+
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
866+
)
867+
)
868+
mock_client.chat.completions.create.assert_called_once()
869+
870+
# Test completion
871+
await mixin.openai_completion(
872+
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
873+
)
874+
mock_client.completions.create.assert_called_once()
875+
876+
# Test embeddings
877+
await mixin.openai_embeddings(
878+
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
879+
)
880+
mock_client.embeddings.create.assert_called_once()
881+
882+
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
883+
"""Test that all inference methods fail with disallowed models"""
884+
mixin.config.allowed_models = ["gpt-4"]
885+
886+
mock_client = MagicMock()
887+
888+
with mock_client_context(mixin, mock_client):
889+
# Test chat completion with disallowed model
890+
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
891+
await mixin.openai_chat_completion(
892+
OpenAIChatCompletionRequestWithExtraBody(
893+
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
894+
)
895+
)
896+
897+
# Test completion with disallowed model
898+
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
899+
await mixin.openai_completion(
900+
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
901+
)
902+
903+
# Test embeddings with disallowed model
904+
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
905+
await mixin.openai_embeddings(
906+
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
907+
)
908+
909+
mock_client.chat.completions.create.assert_not_called()
910+
mock_client.completions.create.assert_not_called()
911+
mock_client.embeddings.create.assert_not_called()
912+
913+
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
914+
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
915+
# Test with None (no restrictions)
916+
assert mixin.config.allowed_models is None
917+
mock_client = MagicMock()
918+
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
919+
920+
with mock_client_context(mixin, mock_client):
921+
await mixin.openai_chat_completion(
922+
OpenAIChatCompletionRequestWithExtraBody(
923+
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
924+
)
925+
)
926+
mock_client.chat.completions.create.assert_called_once()
927+
928+
# Test with empty list (blocks all models)
929+
mixin.config.allowed_models = []
930+
with mock_client_context(mixin, mock_client):
931+
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
932+
await mixin.openai_chat_completion(
933+
OpenAIChatCompletionRequestWithExtraBody(
934+
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
935+
)
936+
)

0 commit comments

Comments
 (0)