|
15 | 15 | from llama_stack.core.request_headers import request_provider_data_context |
16 | 16 | from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig |
17 | 17 | from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin |
18 | | -from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam |
| 18 | +from llama_stack_api import ( |
| 19 | + Model, |
| 20 | + ModelType, |
| 21 | + OpenAIChatCompletionRequestWithExtraBody, |
| 22 | + OpenAICompletionRequestWithExtraBody, |
| 23 | + OpenAIEmbeddingsRequestWithExtraBody, |
| 24 | + OpenAIUserMessageParam, |
| 25 | +) |
19 | 26 |
|
20 | 27 |
|
21 | 28 | class OpenAIMixinImpl(OpenAIMixin): |
@@ -834,3 +841,96 @@ def test_error_message_includes_correct_field_names(self, mixin_with_provider_da |
834 | 841 | error_message = str(exc_info.value) |
835 | 842 | assert "test_api_key" in error_message |
836 | 843 | assert "x-llamastack-provider-data" in error_message |
| 844 | + |
| 845 | + |
| 846 | +class TestOpenAIMixinAllowedModelsInference: |
| 847 | + """Test cases for allowed_models enforcement during inference requests""" |
| 848 | + |
| 849 | + async def test_inference_with_allowed_models(self, mixin, mock_client_context): |
| 850 | + """Test that all inference methods succeed with allowed models""" |
| 851 | + mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"] |
| 852 | + |
| 853 | + mock_client = MagicMock() |
| 854 | + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) |
| 855 | + mock_client.completions.create = AsyncMock(return_value=MagicMock()) |
| 856 | + mock_embedding_response = MagicMock() |
| 857 | + mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] |
| 858 | + mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5) |
| 859 | + mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response) |
| 860 | + |
| 861 | + with mock_client_context(mixin, mock_client): |
| 862 | + # Test chat completion |
| 863 | + await mixin.openai_chat_completion( |
| 864 | + OpenAIChatCompletionRequestWithExtraBody( |
| 865 | + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] |
| 866 | + ) |
| 867 | + ) |
| 868 | + mock_client.chat.completions.create.assert_called_once() |
| 869 | + |
| 870 | + # Test completion |
| 871 | + await mixin.openai_completion( |
| 872 | + OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello") |
| 873 | + ) |
| 874 | + mock_client.completions.create.assert_called_once() |
| 875 | + |
| 876 | + # Test embeddings |
| 877 | + await mixin.openai_embeddings( |
| 878 | + OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text") |
| 879 | + ) |
| 880 | + mock_client.embeddings.create.assert_called_once() |
| 881 | + |
| 882 | + async def test_inference_with_disallowed_models(self, mixin, mock_client_context): |
| 883 | + """Test that all inference methods fail with disallowed models""" |
| 884 | + mixin.config.allowed_models = ["gpt-4"] |
| 885 | + |
| 886 | + mock_client = MagicMock() |
| 887 | + |
| 888 | + with mock_client_context(mixin, mock_client): |
| 889 | + # Test chat completion with disallowed model |
| 890 | + with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"): |
| 891 | + await mixin.openai_chat_completion( |
| 892 | + OpenAIChatCompletionRequestWithExtraBody( |
| 893 | + model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")] |
| 894 | + ) |
| 895 | + ) |
| 896 | + |
| 897 | + # Test completion with disallowed model |
| 898 | + with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"): |
| 899 | + await mixin.openai_completion( |
| 900 | + OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello") |
| 901 | + ) |
| 902 | + |
| 903 | + # Test embeddings with disallowed model |
| 904 | + with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"): |
| 905 | + await mixin.openai_embeddings( |
| 906 | + OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text") |
| 907 | + ) |
| 908 | + |
| 909 | + mock_client.chat.completions.create.assert_not_called() |
| 910 | + mock_client.completions.create.assert_not_called() |
| 911 | + mock_client.embeddings.create.assert_not_called() |
| 912 | + |
| 913 | + async def test_inference_with_no_restrictions(self, mixin, mock_client_context): |
| 914 | + """Test that inference succeeds when allowed_models is None or empty list blocks all""" |
| 915 | + # Test with None (no restrictions) |
| 916 | + assert mixin.config.allowed_models is None |
| 917 | + mock_client = MagicMock() |
| 918 | + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) |
| 919 | + |
| 920 | + with mock_client_context(mixin, mock_client): |
| 921 | + await mixin.openai_chat_completion( |
| 922 | + OpenAIChatCompletionRequestWithExtraBody( |
| 923 | + model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")] |
| 924 | + ) |
| 925 | + ) |
| 926 | + mock_client.chat.completions.create.assert_called_once() |
| 927 | + |
| 928 | + # Test with empty list (blocks all models) |
| 929 | + mixin.config.allowed_models = [] |
| 930 | + with mock_client_context(mixin, mock_client): |
| 931 | + with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"): |
| 932 | + await mixin.openai_chat_completion( |
| 933 | + OpenAIChatCompletionRequestWithExtraBody( |
| 934 | + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] |
| 935 | + ) |
| 936 | + ) |
0 commit comments