diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 770c8ff6ca..e301e1c793 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -77,6 +77,10 @@ ) from mistralai.models.assistantmessage import AssistantMessage as MistralAssistantMessage from mistralai.models.function import Function as MistralFunction + from mistralai.models.prediction import ( + Prediction as MistralPrediction, + PredictionTypedDict as MistralPredictionTypedDict, + ) from mistralai.models.systemmessage import SystemMessage as MistralSystemMessage from mistralai.models.toolmessage import ToolMessage as MistralToolMessage from mistralai.models.usermessage import UserMessage as MistralUserMessage @@ -114,8 +118,12 @@ class MistralModelSettings(ModelSettings, total=False): """Settings used for a Mistral model request.""" # ALL FIELDS MUST BE `mistral_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. + mistral_prediction: str | MistralPrediction | MistralPredictionTypedDict | None + """Prediction content for the model to use as a prefix. See Predictive outputs. - # This class is a placeholder for any future mistral-specific settings + This feature is currently only supported for certain Mistral models. See the model cards at Models. + As of now, codestral-latest and mistral-large-2411 support [predicted outputs](https://docs.mistral.ai/capabilities/predicted_outputs). + """ @dataclass(init=False) @@ -241,6 +249,7 @@ async def _completions_create( timeout_ms=self._get_timeout_ms(model_settings.get('timeout')), random_seed=model_settings.get('seed', UNSET), stop=model_settings.get('stop_sequences', None), + prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)), http_headers={'User-Agent': get_user_agent()}, ) except SDKError as e: @@ -281,6 +290,7 @@ async def _stream_completions_create( presence_penalty=model_settings.get('presence_penalty'), frequency_penalty=model_settings.get('frequency_penalty'), stop=model_settings.get('stop_sequences', None), + prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)), http_headers={'User-Agent': get_user_agent()}, ) @@ -298,6 +308,7 @@ async def _stream_completions_create( 'type': 'json_object' }, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9 stream=True, + prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)), http_headers={'User-Agent': get_user_agent()}, ) @@ -307,6 +318,7 @@ async def _stream_completions_create( model=str(self._model_name), messages=mistral_messages, stream=True, + prediction=self._map_setting_prediction(model_settings.get('mistral_prediction', None)), http_headers={'User-Agent': get_user_agent()}, ) assert response, 'A unexpected empty response from Mistral.' @@ -427,6 +439,24 @@ def _map_tool_call(t: ToolCallPart) -> MistralToolCall: function=MistralFunctionCall(name=t.tool_name, arguments=t.args or {}), ) + @staticmethod + def _map_setting_prediction( + prediction: str | MistralPredictionTypedDict | MistralPrediction | None, + ) -> MistralPrediction | None: + """Maps various prediction input types to a MistralPrediction object.""" + if not prediction: + return None + if isinstance(prediction, MistralPrediction): + return prediction + elif isinstance(prediction, str): + return MistralPrediction(content=prediction) + elif isinstance(prediction, dict): + return MistralPrediction.model_validate(prediction) + else: + raise RuntimeError( + f'Unsupported prediction type: {type(prediction)} for MistralModelSettings. Expected str, dict, or MistralPrediction.' + ) + def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage: """Get a message with an example of the expected output format.""" examples: list[dict[str, Any]] = [] diff --git a/tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml b/tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml new file mode 100644 index 0000000000..a2ace9839c --- /dev/null +++ b/tests/models/cassettes/test_mistral/test_mistral_chat_with_prediction.yaml @@ -0,0 +1,70 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '315' + content-type: + - application/json + host: + - api.mistral.ai + method: POST + parsed_body: + messages: + - content: + - text: Correct the math, keep everything else. No explanation, no formatting. + type: text + - text: The result of 21+21=99 + type: text + role: user + model: mistral-large-2411 + n: 1 + prediction: + content: The result of 21+21=99 + type: content + stream: false + top_p: 1.0 + uri: https://api.mistral.ai/v1/chat/completions + response: + headers: + access-control-allow-origin: + - '*' + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '319' + content-type: + - application/json + mistral-correlation-id: + - 019a63b7-40ba-70cb-94d0-84f036d7c76f + strict-transport-security: + - max-age=15552000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + message: + content: The result of 21+21=42 + role: assistant + tool_calls: null + created: 1762609545 + id: 6c36e8b6c3c145bd8ada32f9bd0f6be9 + model: mistral-large-2411 + object: chat.completion + usage: + completion_tokens: 13 + prompt_tokens: 33 + total_tokens: 46 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 76ae344c5b..3d770023e2 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -53,6 +53,10 @@ SDKError, ToolCall as MistralToolCall, ) + from mistralai.models.prediction import ( + Prediction as MistralPrediction, + PredictionTypedDict as MistralPredictionTypedDict, + ) from mistralai.types.basemodel import Unset as MistralUnset from pydantic_ai.models.mistral import MistralModel, MistralStreamedResponse @@ -1677,6 +1681,51 @@ async def get_location(loc_name: str) -> str: ##################### ## Test methods ##################### +@pytest.fixture +def example_dict() -> MistralPredictionTypedDict: + """Fixture providing a typed dict for prediction.""" + return {'type': 'content', 'content': 'foo'} + + +@pytest.fixture +def example_prediction() -> MistralPrediction: + """Fixture providing a MistralPrediction object.""" + return MistralPrediction(content='bar') + + +@pytest.mark.parametrize( + 'input_value,expected_content', + [ + ('plain text', 'plain text'), + ('example_prediction', 'bar'), + ('example_dict', 'foo'), + (None, None), + ], +) +def test_map_setting_prediction_valid(request: pytest.FixtureRequest, input_value: str, expected_content: str | None): + """ + Accepted input types (str, dict, MistralPrediction, None) must be correctlyconverted to a MistralPrediction or None. + """ + # If the parameter is a fixture name, resolve it using request + resolved_value: str | MistralPredictionTypedDict | MistralPrediction | None = input_value + if isinstance(input_value, str) and input_value in {'example_dict', 'example_prediction'}: + resolved_value = request.getfixturevalue(input_value) + + result = MistralModel._map_setting_prediction(resolved_value) # pyright: ignore[reportPrivateUsage] + + if resolved_value is None: + assert result is None + else: + assert isinstance(result, MistralPrediction) + assert result.content == expected_content + + +def test_map_setting_prediction_unsupported_type(): + """Test that _map_setting_prediction raises RuntimeError for unsupported types.""" + with pytest.raises( + RuntimeError, match='Unsupported prediction type.*int.*Expected str, dict, or MistralPrediction' + ): + MistralModel._map_setting_prediction(123) # pyright: ignore[reportPrivateUsage, reportArgumentType] def test_generate_user_output_format_complex(mistral_api_key: str): @@ -2263,3 +2312,20 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist ), ] ) + + +@pytest.mark.vcr() +async def test_mistral_chat_with_prediction(allow_model_requests: None, mistral_api_key: str): + """Test chat completion with prediction parameter using a math query.""" + from pydantic_ai.models.mistral import MistralModelSettings + + model = MistralModel('mistral-large-2411', provider=MistralProvider(api_key=mistral_api_key)) + prediction = 'The result of 21+21=99' + settings = MistralModelSettings(mistral_prediction=prediction) + agent = Agent(model, model_settings=settings) + + result = await agent.run(['Correct the math, keep everything else. No explanation, no formatting.', prediction]) + + # Verify that the response uses the expected prediction + assert '42' in result.output + assert 'The result of 21+21=' in result.output