|
2 | 2 |
|
3 | 3 | import pytest |
4 | 4 |
|
| 5 | +from strands import Agent, tool |
5 | 6 | from strands.agent.state import AgentState |
| 7 | +from strands.types.content import Messages |
| 8 | + |
| 9 | +from ...fixtures.mocked_model_provider import MockedModelProvider |
6 | 10 |
|
7 | 11 |
|
8 | 12 | def test_set_and_get(): |
@@ -109,3 +113,33 @@ def test_initial_state(): |
109 | 113 | assert state.get("key1") == "value1" |
110 | 114 | assert state.get("key2") == "value2" |
111 | 115 | assert state.get() == initial |
| 116 | + |
| 117 | + |
| 118 | +def test_agent_state_update_from_tool(): |
| 119 | + @tool |
| 120 | + def update_state(agent: Agent): |
| 121 | + agent.state.set("hello", "world") |
| 122 | + agent.state.set("foo", "baz") |
| 123 | + |
| 124 | + agent_messages: Messages = [ |
| 125 | + { |
| 126 | + "role": "assistant", |
| 127 | + "content": [{"toolUse": {"name": "update_state", "toolUseId": "123", "input": {}}}], |
| 128 | + }, |
| 129 | + {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, |
| 130 | + ] |
| 131 | + mocked_model_provider = MockedModelProvider(agent_messages) |
| 132 | + |
| 133 | + agent = Agent( |
| 134 | + model=mocked_model_provider, |
| 135 | + tools=[update_state], |
| 136 | + state={"foo": "bar"}, |
| 137 | + ) |
| 138 | + |
| 139 | + assert agent.state.get("hello") is None |
| 140 | + assert agent.state.get("foo") == "bar" |
| 141 | + |
| 142 | + agent("Invoke Mocked!") |
| 143 | + |
| 144 | + assert agent.state.get("hello") == "world" |
| 145 | + assert agent.state.get("foo") == "baz" |
0 commit comments