11import unittest .mock
22
3+ import pydantic
34import pytest
45
56import strands
@@ -58,6 +59,15 @@ def system_prompt():
5859 return "You are a helpful assistant"
5960
6061
62+ @pytest .fixture
63+ def test_output_model_cls ():
64+ class TestOutputModel (pydantic .BaseModel ):
65+ name : str
66+ age : int
67+
68+ return TestOutputModel
69+
70+
6171def test__init__model_configs (mistral_client , model_id , max_tokens ):
6272 _ = mistral_client
6373
@@ -440,35 +450,24 @@ def test_stream_other_error(mistral_client, model):
440450 list (model .stream ({}))
441451
442452
443- def test_structured_output_success (mistral_client , model ):
444- from pydantic import BaseModel
445-
446- class TestModel (BaseModel ):
447- name : str
448- age : int
453+ def test_structured_output_success (mistral_client , model , test_output_model_cls ):
454+ messages = [{"role" : "user" , "content" : [{"text" : "Extract data" }]}]
449455
450- # Mock successful response
451456 mock_response = unittest .mock .Mock ()
452457 mock_response .choices = [unittest .mock .Mock ()]
453458 mock_response .choices [0 ].message .tool_calls = [unittest .mock .Mock ()]
454459 mock_response .choices [0 ].message .tool_calls [0 ].function .arguments = '{"name": "John", "age": 30}'
455460
456461 mistral_client .chat .complete .return_value = mock_response
457462
458- prompt = [{"role" : "user" , "content" : [{"text" : "Extract data" }]}]
459- result = model .structured_output (TestModel , prompt )
460-
461- assert isinstance (result , TestModel )
462- assert result .name == "John"
463- assert result .age == 30
464-
463+ stream = model .structured_output (test_output_model_cls , messages )
465464
466- def test_structured_output_no_tool_calls (mistral_client , model ):
467- from pydantic import BaseModel
465+ tru_result = list (stream )[- 1 ]
466+ exp_result = {"output" : test_output_model_cls (name = "John" , age = 30 )}
467+ assert tru_result == exp_result
468468
469- class TestModel (BaseModel ):
470- name : str
471469
470+ def test_structured_output_no_tool_calls (mistral_client , model , test_output_model_cls ):
472471 mock_response = unittest .mock .Mock ()
473472 mock_response .choices = [unittest .mock .Mock ()]
474473 mock_response .choices [0 ].message .tool_calls = None
@@ -478,15 +477,11 @@ class TestModel(BaseModel):
478477 prompt = [{"role" : "user" , "content" : [{"text" : "Extract data" }]}]
479478
480479 with pytest .raises (ValueError , match = "No tool calls found in response" ):
481- model .structured_output (TestModel , prompt )
482-
480+ stream = model .structured_output (test_output_model_cls , prompt )
481+ next ( stream )
483482
484- def test_structured_output_invalid_json (mistral_client , model ):
485- from pydantic import BaseModel
486-
487- class TestModel (BaseModel ):
488- name : str
489483
484+ def test_structured_output_invalid_json (mistral_client , model , test_output_model_cls ):
490485 mock_response = unittest .mock .Mock ()
491486 mock_response .choices = [unittest .mock .Mock ()]
492487 mock_response .choices [0 ].message .tool_calls = [unittest .mock .Mock ()]
@@ -497,4 +492,5 @@ class TestModel(BaseModel):
497492 prompt = [{"role" : "user" , "content" : [{"text" : "Extract data" }]}]
498493
499494 with pytest .raises (ValueError , match = "Failed to parse tool call arguments into model" ):
500- model .structured_output (TestModel , prompt )
495+ stream = model .structured_output (test_output_model_cls , prompt )
496+ next (stream )
0 commit comments