88from strands .models .mistral import MistralModel
99
1010
11- @pytest .fixture
11+ @pytest .fixture ( scope = "module" )
1212def streaming_model ():
1313 return MistralModel (
1414 model_id = "mistral-medium-latest" ,
@@ -20,7 +20,7 @@ def streaming_model():
2020 )
2121
2222
23- @pytest .fixture
23+ @pytest .fixture ( scope = "module" )
2424def non_streaming_model ():
2525 return MistralModel (
2626 model_id = "mistral-medium-latest" ,
@@ -32,126 +32,101 @@ def non_streaming_model():
3232 )
3333
3434
35- @pytest .fixture
35+ @pytest .fixture ( scope = "module" )
3636def system_prompt ():
3737 return "You are an AI assistant that provides helpful and accurate information."
3838
3939
40- @pytest .fixture
41- def calculator_tool ():
42- @strands .tool
43- def calculator (expression : str ) -> float :
44- """Calculate the result of a mathematical expression."""
45- return eval (expression )
46-
47- return calculator
48-
49-
50- @pytest .fixture
51- def weather_tools ():
40+ @pytest .fixture (scope = "module" )
41+ def tools ():
5242 @strands .tool
5343 def tool_time () -> str :
54- """Get the current time."""
5544 return "12:00"
5645
5746 @strands .tool
5847 def tool_weather () -> str :
59- """Get the current weather."""
6048 return "sunny"
6149
6250 return [tool_time , tool_weather ]
6351
6452
65- @pytest .fixture
66- def streaming_agent (streaming_model ):
67- return Agent (model = streaming_model )
53+ @pytest .fixture ( scope = "module" )
54+ def streaming_agent (streaming_model , tools ):
55+ return Agent (model = streaming_model , tools = tools )
6856
6957
70- @pytest .fixture
71- def non_streaming_agent (non_streaming_model ):
72- return Agent (model = non_streaming_model )
58+ @pytest .fixture ( scope = "module" )
59+ def non_streaming_agent (non_streaming_model , tools ):
60+ return Agent (model = non_streaming_model , tools = tools )
7361
7462
75- @pytest .mark .skipif ("MISTRAL_API_KEY" not in os .environ , reason = "MISTRAL_API_KEY environment variable missing" )
76- def test_streaming_agent_basic (streaming_agent ):
77- """Test basic streaming agent functionality."""
78- result = streaming_agent ("Tell me about Agentic AI in one sentence." )
63+ @pytest .fixture (params = ["streaming_agent" , "non_streaming_agent" ])
64+ def agent (request ):
65+ return request .getfixturevalue (request .param )
7966
80- assert len (str (result )) > 0
81- assert hasattr (result , "message" )
82- assert "content" in result .message
8367
68+ @pytest .fixture (scope = "module" )
69+ def weather ():
70+ class Weather (BaseModel ):
71+ """Extracts the time and weather from the user's message with the exact strings."""
8472
85- @pytest .mark .skipif ("MISTRAL_API_KEY" not in os .environ , reason = "MISTRAL_API_KEY environment variable missing" )
86- def test_non_streaming_agent_basic (non_streaming_agent ):
87- """Test basic non-streaming agent functionality."""
88- result = non_streaming_agent ("Tell me about Agentic AI in one sentence." )
73+ time : str
74+ weather : str
8975
90- assert len (str (result )) > 0
91- assert hasattr (result , "message" )
92- assert "content" in result .message
76+ return Weather (time = "12:00" , weather = "sunny" )
9377
9478
9579@pytest .mark .skipif ("MISTRAL_API_KEY" not in os .environ , reason = "MISTRAL_API_KEY environment variable missing" )
96- def test_tool_use_streaming (streaming_model ):
97- """Test tool use with streaming model."""
98-
99- @strands .tool
100- def calculator (expression : str ) -> float :
101- """Calculate the result of a mathematical expression."""
102- return eval (expression )
103-
104- agent = Agent (model = streaming_model , tools = [calculator ])
105- result = agent ("What is the square root of 1764" )
80+ def test_agent_invoke (agent ):
81+ # TODO: https://github.com/strands-agents/sdk-python/issues/374
82+ # result = streaming_agent("What is the time and weather in New York?")
83+ result = agent ("What is the time in New York?" )
84+ text = result .message ["content" ][0 ]["text" ].lower ()
10685
107- # Verify the result contains the calculation
108- text_content = str (result ).lower ()
109- assert "42" in text_content
86+ # assert all(string in text for string in ["12:00", "sunny"])
87+ assert all (string in text for string in ["12:00" ])
11088
11189
11290@pytest .mark .skipif ("MISTRAL_API_KEY" not in os .environ , reason = "MISTRAL_API_KEY environment variable missing" )
113- def test_tool_use_non_streaming (non_streaming_model ):
114- """Test tool use with non-streaming model."""
91+ @pytest .mark .asyncio
92+ async def test_agent_invoke_async (agent ):
93+ # TODO: https://github.com/strands-agents/sdk-python/issues/374
94+ # result = await streaming_agent.invoke_async("What is the time and weather in New York?")
95+ result = await agent .invoke_async ("What is the time in New York?" )
96+ text = result .message ["content" ][0 ]["text" ].lower ()
11597
116- @strands .tool
117- def calculator (expression : str ) -> float :
118- """Calculate the result of a mathematical expression."""
119- return eval (expression )
120-
121- agent = Agent (model = non_streaming_model , tools = [calculator ], load_tools_from_directory = False )
122- result = agent ("What is the square root of 1764" )
123-
124- text_content = str (result ).lower ()
125- assert "42" in text_content
98+ # assert all(string in text for string in ["12:00", "sunny"])
99+ assert all (string in text for string in ["12:00" ])
126100
127101
128102@pytest .mark .skipif ("MISTRAL_API_KEY" not in os .environ , reason = "MISTRAL_API_KEY environment variable missing" )
129- def test_structured_output_streaming (streaming_model ):
130- """Test structured output with streaming model."""
131-
132- class Weather (BaseModel ):
133- time : str
134- weather : str
103+ @pytest .mark .asyncio
104+ async def test_agent_stream_async (agent ):
105+ # TODO: https://github.com/strands-agents/sdk-python/issues/374
106+ # stream = streaming_agent.stream_async("What is the time and weather in New York?")
107+ stream = agent .stream_async ("What is the time in New York?" )
108+ async for event in stream :
109+ _ = event
135110
136- agent = Agent ( model = streaming_model )
137- result = agent . structured_output ( Weather , "The time is 12:00 and the weather is sunny" )
111+ result = event [ "result" ]
112+ text = result . message [ "content" ][ 0 ][ "text" ]. lower ( )
138113
139- assert isinstance (result , Weather )
140- assert result .time == "12:00"
141- assert result .weather == "sunny"
114+ # assert all(string in text for string in ["12:00", "sunny"])
115+ assert all (string in text for string in ["12:00" ])
142116
143117
144118@pytest .mark .skipif ("MISTRAL_API_KEY" not in os .environ , reason = "MISTRAL_API_KEY environment variable missing" )
145- def test_structured_output_non_streaming (non_streaming_model ):
146- """Test structured output with non-streaming model."""
119+ def test_agent_structured_output (non_streaming_agent , weather ):
120+ tru_weather = non_streaming_agent .structured_output (type (weather ), "The time is 12:00 and the weather is sunny" )
121+ exp_weather = weather
122+ assert tru_weather == exp_weather
147123
148- class Weather (BaseModel ):
149- time : str
150- weather : str
151124
152- agent = Agent (model = non_streaming_model )
153- result = agent .structured_output (Weather , "The time is 12:00 and the weather is sunny" )
154-
155- assert isinstance (result , Weather )
156- assert result .time == "12:00"
157- assert result .weather == "sunny"
125+ @pytest .mark .skipif ("MISTRAL_API_KEY" not in os .environ , reason = "MISTRAL_API_KEY environment variable missing" )
126+ @pytest .mark .asyncio
127+ async def test_agent_structured_output_async (non_streaming_agent , weather ):
128+ tru_weather = await non_streaming_agent .structured_output_async (
129+ type (weather ), "The time is 12:00 and the weather is sunny"
130+ )
131+ exp_weather = weather
132+ assert tru_weather == exp_weather
0 commit comments