From 52420be67923c89d76720bcdefaa143907182f59 Mon Sep 17 00:00:00 2001 From: Tim Date: Sat, 8 Nov 2025 14:16:53 +0100 Subject: [PATCH 1/6] Add prediction support for MistralModel and associated tests. --- .../pydantic_ai/models/mistral.py | 33 ++++++++- .../test_mistral_chat_with_prediction.yaml | 67 +++++++++++++++++++ tests/models/test_mistral.py | 59 ++++++++++++++++ 3 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 770c8ff6ca..0837536de5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -77,6 +77,10 @@ ) from mistralai.models.assistantmessage import AssistantMessage as MistralAssistantMessage from mistralai.models.function import Function as MistralFunction + from mistralai.models.prediction import ( + Prediction as MistralPrediction, + PredictionTypedDict as MistralPredictionTypedDict, + ) from mistralai.models.systemmessage import SystemMessage as MistralSystemMessage from mistralai.models.toolmessage import ToolMessage as MistralToolMessage from mistralai.models.usermessage import UserMessage as MistralUserMessage @@ -114,8 +118,13 @@ class MistralModelSettings(ModelSettings, total=False): """Settings used for a Mistral model request.""" # ALL FIELDS MUST BE `mistral_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. + mistral_prediction: str | MistralPrediction | MistralPredictionTypedDict | None + """Prediction content for the model to use as a prefix. See Predictive outputs. - # This class is a placeholder for any future mistral-specific settings + This feature is currently only supported for certain Mistral models. See the model cards at Models. + For example, it is supported for the latest Mistral Serie Large (> 2), Medium (> 3), Small (> 3) and Pixtral models, + but not for reasoning or coding models yet. + """ @dataclass(init=False) @@ -241,6 +250,7 @@ async def _completions_create( timeout_ms=self._get_timeout_ms(model_settings.get('timeout')), random_seed=model_settings.get('seed', UNSET), stop=model_settings.get('stop_sequences', None), + prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)), http_headers={'User-Agent': get_user_agent()}, ) except SDKError as e: @@ -281,6 +291,7 @@ async def _stream_completions_create( presence_penalty=model_settings.get('presence_penalty'), frequency_penalty=model_settings.get('frequency_penalty'), stop=model_settings.get('stop_sequences', None), + prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)), http_headers={'User-Agent': get_user_agent()}, ) @@ -298,6 +309,7 @@ async def _stream_completions_create( 'type': 'json_object' }, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9 stream=True, + prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)), http_headers={'User-Agent': get_user_agent()}, ) @@ -307,6 +319,7 @@ async def _stream_completions_create( model=str(self._model_name), messages=mistral_messages, stream=True, + prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)), http_headers={'User-Agent': get_user_agent()}, ) assert response, 'A unexpected empty response from Mistral.' @@ -427,6 +440,24 @@ def _map_tool_call(t: ToolCallPart) -> MistralToolCall: function=MistralFunctionCall(name=t.tool_name, arguments=t.args or {}), ) + @staticmethod + def _map_setting_prediction( + prediction: str | MistralPredictionTypedDict | MistralPrediction | None, + ) -> MistralPrediction | None: + """Maps various prediction input types to a MistralPrediction object.""" + if not prediction: + return None + if isinstance(prediction, MistralPrediction): + return prediction + elif isinstance(prediction, str): + return MistralPrediction(content=prediction) + elif isinstance(prediction, dict): + return MistralPrediction.model_validate(prediction) + else: + raise RuntimeError( + f'Unsupported prediction type: {type(prediction)} for MistralModelSettings. Expected str, dict, or MistralPrediction.' + ) + def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage: """Get a message with an example of the expected output format.""" examples: list[dict[str, Any]] = [] diff --git a/tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml b/tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml new file mode 100644 index 0000000000..c3e106b036 --- /dev/null +++ b/tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml @@ -0,0 +1,67 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '246' + content-type: + - application/json + host: + - api.mistral.ai + method: POST + parsed_body: + messages: + - content: + - text: Correct only the math, respond with no explanation, no formatting. + type: text + - text: The result of 21+21=99 + type: text + role: user + model: mistral-small-latest + n: 1 + stream: false + top_p: 1.0 + uri: https://api.mistral.ai/v1/chat/completions + response: + headers: + access-control-allow-origin: + - '*' + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '321' + content-type: + - application/json + mistral-correlation-id: + - 019a639b-7bf4-7481-96af-078cd1a7d277 + strict-transport-security: + - max-age=15552000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + message: + content: The result of 21+21=42 + role: assistant + tool_calls: null + created: 1762607725 + id: a7952046ef794d1697627b54231df17a + model: mistral-small-latest + object: chat.completion + usage: + completion_tokens: 13 + prompt_tokens: 28 + total_tokens: 41 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 76ae344c5b..138d3e25cb 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -54,6 +54,10 @@ ToolCall as MistralToolCall, ) from mistralai.types.basemodel import Unset as MistralUnset + from mistralai.models.prediction import ( + Prediction as MistralPrediction, + PredictionTypedDict as MistralPredictionTypedDict, + ) from pydantic_ai.models.mistral import MistralModel, MistralStreamedResponse from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings @@ -1677,8 +1681,44 @@ async def get_location(loc_name: str) -> str: ##################### ## Test methods ##################### +# --- _map_setting_prediction -------------------------------------------------- +@pytest.fixture +def example_dict() -> MistralPredictionTypedDict: + """Fixture providing a typed dict for prediction.""" + return {"type": "content", "content": "foo"} + +@pytest.fixture +def example_prediction() -> MistralPrediction: + """Fixture providing a MistralPrediction object.""" + return MistralPrediction(content="bar") + +@pytest.mark.parametrize( + "input_value,expected_content", + [ + ("plain text", "plain text"), + ("example_prediction", "bar"), + ("example_dict", "foo"), + (None, None), + ], +) +def test_map_setting_prediction_valid(request, input_value, expected_content): + """ + Accepted input types (str, dict, MistralPrediction, None) must be correctlyconverted to a MistralPrediction or None. + """ + # If the parameter is a fixture name, resolve it using request + if isinstance(input_value, str) and input_value in {"example_dict", "example_prediction"}: + input_value = request.getfixturevalue(input_value) + + result = MistralModel._map_setting_prediction(input_value) # pyright: ignore[reportPrivateUsage] + + if input_value is None: + assert result is None + else: + assert isinstance(result, MistralPrediction) + assert result.content == expected_content +# ----------------------------------------------------- def test_generate_user_output_format_complex(mistral_api_key: str): """ Single test that includes properties exercising every branch @@ -2263,3 +2303,22 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist ), ] ) + +@pytest.mark.vcr() +async def test_mistral_chat_with_prediction(allow_model_requests: None, mistral_api_key: str): + """Test chat completion with prediction parameter using a math query.""" + from pydantic_ai.models.mistral import MistralModelSettings + + model = MistralModel( + 'mistral-small-latest', + provider=MistralProvider(api_key=mistral_api_key) + ) + prediction = "The result of 21+21=99" + settings = MistralModelSettings(prediction=prediction) + agent = Agent(model, model_settings=settings) + + result = await agent.run(['Correct only the math, respond with no explanation, no formatting.',"The result of 21+21=99"]) + + # Verify that the response uses the expected prediction + assert 'The result of 21+21=' in result.output + assert '42' in result.output From 71e3b824f13676d55839d7ff25c34e14b6cd6644 Mon Sep 17 00:00:00 2001 From: Tim Date: Sat, 8 Nov 2025 14:55:34 +0100 Subject: [PATCH 2/6] Update tests and model documentation for revised Mistral prediction handling. --- .../pydantic_ai/models/mistral.py | 3 +- .../test_mistral_chat_with_prediction.yaml | 23 +++++----- tests/models/test_mistral.py | 43 +++++++++---------- 3 files changed, 35 insertions(+), 34 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 0837536de5..9a125f5645 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -122,8 +122,7 @@ class MistralModelSettings(ModelSettings, total=False): """Prediction content for the model to use as a prefix. See Predictive outputs. This feature is currently only supported for certain Mistral models. See the model cards at Models. - For example, it is supported for the latest Mistral Serie Large (> 2), Medium (> 3), Small (> 3) and Pixtral models, - but not for reasoning or coding models yet. + As of now, codestral-latest and mistral-large-2411 support [predicted outputs](https://docs.mistral.ai/capabilities/predicted_outputs). """ diff --git a/tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml b/tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml index c3e106b036..a2ace9839c 100644 --- a/tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml +++ b/tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml @@ -8,7 +8,7 @@ interactions: connection: - keep-alive content-length: - - '246' + - '315' content-type: - application/json host: @@ -17,13 +17,16 @@ interactions: parsed_body: messages: - content: - - text: Correct only the math, respond with no explanation, no formatting. + - text: Correct the math, keep everything else. No explanation, no formatting. type: text - text: The result of 21+21=99 type: text role: user - model: mistral-small-latest + model: mistral-large-2411 n: 1 + prediction: + content: The result of 21+21=99 + type: content stream: false top_p: 1.0 uri: https://api.mistral.ai/v1/chat/completions @@ -36,11 +39,11 @@ interactions: connection: - keep-alive content-length: - - '321' + - '319' content-type: - application/json mistral-correlation-id: - - 019a639b-7bf4-7481-96af-078cd1a7d277 + - 019a63b7-40ba-70cb-94d0-84f036d7c76f strict-transport-security: - max-age=15552000; includeSubDomains; preload transfer-encoding: @@ -53,14 +56,14 @@ interactions: content: The result of 21+21=42 role: assistant tool_calls: null - created: 1762607725 - id: a7952046ef794d1697627b54231df17a - model: mistral-small-latest + created: 1762609545 + id: 6c36e8b6c3c145bd8ada32f9bd0f6be9 + model: mistral-large-2411 object: chat.completion usage: completion_tokens: 13 - prompt_tokens: 28 - total_tokens: 41 + prompt_tokens: 33 + total_tokens: 46 status: code: 200 message: OK diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 138d3e25cb..4f2a9d7a2a 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -53,11 +53,11 @@ SDKError, ToolCall as MistralToolCall, ) - from mistralai.types.basemodel import Unset as MistralUnset from mistralai.models.prediction import ( Prediction as MistralPrediction, PredictionTypedDict as MistralPredictionTypedDict, ) + from mistralai.types.basemodel import Unset as MistralUnset from pydantic_ai.models.mistral import MistralModel, MistralStreamedResponse from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings @@ -1681,44 +1681,45 @@ async def get_location(loc_name: str) -> str: ##################### ## Test methods ##################### -# --- _map_setting_prediction -------------------------------------------------- @pytest.fixture def example_dict() -> MistralPredictionTypedDict: """Fixture providing a typed dict for prediction.""" - return {"type": "content", "content": "foo"} + return {'type': 'content', 'content': 'foo'} @pytest.fixture def example_prediction() -> MistralPrediction: """Fixture providing a MistralPrediction object.""" - return MistralPrediction(content="bar") + return MistralPrediction(content='bar') @pytest.mark.parametrize( - "input_value,expected_content", + 'input_value,expected_content', [ - ("plain text", "plain text"), - ("example_prediction", "bar"), - ("example_dict", "foo"), + ('plain text', 'plain text'), + ('example_prediction', 'bar'), + ('example_dict', 'foo'), (None, None), ], ) -def test_map_setting_prediction_valid(request, input_value, expected_content): +def test_map_setting_prediction_valid(request: pytest.FixtureRequest, input_value: str, expected_content: str | None): """ Accepted input types (str, dict, MistralPrediction, None) must be correctlyconverted to a MistralPrediction or None. """ # If the parameter is a fixture name, resolve it using request - if isinstance(input_value, str) and input_value in {"example_dict", "example_prediction"}: - input_value = request.getfixturevalue(input_value) + resolved_value: str | MistralPredictionTypedDict | MistralPrediction | None = input_value + if isinstance(input_value, str) and input_value in {'example_dict', 'example_prediction'}: + resolved_value = request.getfixturevalue(input_value) - result = MistralModel._map_setting_prediction(input_value) # pyright: ignore[reportPrivateUsage] + result = MistralModel._map_setting_prediction(resolved_value) # pyright: ignore[reportPrivateUsage] - if input_value is None: + if resolved_value is None: assert result is None else: assert isinstance(result, MistralPrediction) assert result.content == expected_content -# ----------------------------------------------------- + + def test_generate_user_output_format_complex(mistral_api_key: str): """ Single test that includes properties exercising every branch @@ -2304,21 +2305,19 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist ] ) + @pytest.mark.vcr() async def test_mistral_chat_with_prediction(allow_model_requests: None, mistral_api_key: str): """Test chat completion with prediction parameter using a math query.""" from pydantic_ai.models.mistral import MistralModelSettings - model = MistralModel( - 'mistral-small-latest', - provider=MistralProvider(api_key=mistral_api_key) - ) - prediction = "The result of 21+21=99" - settings = MistralModelSettings(prediction=prediction) + model = MistralModel('mistral-large-2411', provider=MistralProvider(api_key=mistral_api_key)) + prediction = 'The result of 21+21=99' + settings = MistralModelSettings(mistral_prediction=prediction) agent = Agent(model, model_settings=settings) - result = await agent.run(['Correct only the math, respond with no explanation, no formatting.',"The result of 21+21=99"]) + result = await agent.run(['Correct the math, keep everything else. No explanation, no formatting.', prediction]) # Verify that the response uses the expected prediction - assert 'The result of 21+21=' in result.output assert '42' in result.output + assert 'The result of 21+21=' in result.output From c2f2e3febcd0bd63e14e706c447fce333a4ae30c Mon Sep 17 00:00:00 2001 From: Tim Date: Sun, 9 Nov 2025 13:37:30 +0100 Subject: [PATCH 3/6] Add test for unsupported prediction types in MistralModel and fix minor doc typo --- pydantic_ai_slim/pydantic_ai/models/mistral.py | 2 +- tests/models/test_mistral.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 9a125f5645..e301e1c793 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -122,7 +122,7 @@ class MistralModelSettings(ModelSettings, total=False): """Prediction content for the model to use as a prefix. See Predictive outputs. This feature is currently only supported for certain Mistral models. See the model cards at Models. - As of now, codestral-latest and mistral-large-2411 support [predicted outputs](https://docs.mistral.ai/capabilities/predicted_outputs). + As of now, codestral-latest and mistral-large-2411 support [predicted outputs](https://docs.mistral.ai/capabilities/predicted_outputs). """ diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 4f2a9d7a2a..39d31e99d5 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -1719,6 +1719,13 @@ def test_map_setting_prediction_valid(request: pytest.FixtureRequest, input_valu assert isinstance(result, MistralPrediction) assert result.content == expected_content + def test_map_setting_prediction_unsupported_type(): + """Test that _map_setting_prediction raises RuntimeError for unsupported types.""" + with pytest.raises( + RuntimeError, match='Unsupported prediction type.*int.*Expected str, dict, or MistralPrediction' + ): + MistralModel._map_setting_prediction(123) # pyright: ignore[reportPrivateUsage] + def test_generate_user_output_format_complex(mistral_api_key: str): """ From 7dc5f87eb035b513d75d0131d84dac3c2a6775d9 Mon Sep 17 00:00:00 2001 From: Tim Date: Sun, 9 Nov 2025 13:48:27 +0100 Subject: [PATCH 4/6] Update test to expand pyright ignore rules for argument type enforcement for Unit-Test Method --- tests/models/test_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 39d31e99d5..6baa6302b4 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -1724,7 +1724,7 @@ def test_map_setting_prediction_unsupported_type(): with pytest.raises( RuntimeError, match='Unsupported prediction type.*int.*Expected str, dict, or MistralPrediction' ): - MistralModel._map_setting_prediction(123) # pyright: ignore[reportPrivateUsage] + MistralModel._map_setting_prediction(123) # pyright: ignore[reportPrivateUsage, reportArgumentType] def test_generate_user_output_format_complex(mistral_api_key: str): From 2f6da178393b069756ed8e0636aa09b7ec47e947 Mon Sep 17 00:00:00 2001 From: Tim Date: Sun, 9 Nov 2025 14:06:31 +0100 Subject: [PATCH 5/6] Fix indentation in test for unsupported prediction types in MistralModel --- tests/models/test_mistral.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 6baa6302b4..32041ab726 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -1719,12 +1719,12 @@ def test_map_setting_prediction_valid(request: pytest.FixtureRequest, input_valu assert isinstance(result, MistralPrediction) assert result.content == expected_content - def test_map_setting_prediction_unsupported_type(): - """Test that _map_setting_prediction raises RuntimeError for unsupported types.""" - with pytest.raises( - RuntimeError, match='Unsupported prediction type.*int.*Expected str, dict, or MistralPrediction' - ): - MistralModel._map_setting_prediction(123) # pyright: ignore[reportPrivateUsage, reportArgumentType] +def test_map_setting_prediction_unsupported_type(): + """Test that _map_setting_prediction raises RuntimeError for unsupported types.""" + with pytest.raises( + RuntimeError, match='Unsupported prediction type.*int.*Expected str, dict, or MistralPrediction' + ): + MistralModel._map_setting_prediction(123) # pyright: ignore[reportPrivateUsage, reportArgumentType] def test_generate_user_output_format_complex(mistral_api_key: str): From 2820f673e9109ba1885928b26d6deb264a3fea90 Mon Sep 17 00:00:00 2001 From: Tim Date: Sun, 9 Nov 2025 14:24:20 +0100 Subject: [PATCH 6/6] Add missing newline in test_mistral.py for readability in tests --- tests/models/test_mistral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 32041ab726..3d770023e2 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -1719,6 +1719,7 @@ def test_map_setting_prediction_valid(request: pytest.FixtureRequest, input_valu assert isinstance(result, MistralPrediction) assert result.content == expected_content + def test_map_setting_prediction_unsupported_type(): """Test that _map_setting_prediction raises RuntimeError for unsupported types.""" with pytest.raises(