Skip to content

Commit 37b8a67

Browse files
pgrayyWorkshop Participant
authored andcommitted
models - anthropic - async (#371)
1 parent fb40573 commit 37b8a67

File tree

3 files changed

+61
-26
lines changed

3 files changed

+61
-26
lines changed

src/strands/models/anthropic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_conf
7272
logger.debug("config=<%s> | initializing", self.config)
7373

7474
client_args = client_args or {}
75-
self.client = anthropic.Anthropic(**client_args)
75+
self.client = anthropic.AsyncAnthropic(**client_args)
7676

7777
@override
7878
def update_config(self, **model_config: Unpack[AnthropicConfig]) -> None: # type: ignore[override]
@@ -360,8 +360,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
360360
ModelThrottledException: If the request is throttled by Anthropic.
361361
"""
362362
try:
363-
with self.client.messages.stream(**request) as stream:
364-
for event in stream:
363+
async with self.client.messages.stream(**request) as stream:
364+
async for event in stream:
365365
if event.type in AnthropicModel.EVENT_TYPES:
366366
yield event.model_dump()
367367

tests-integ/test_model_anthropic.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from strands.models.anthropic import AnthropicModel
99

1010

11-
@pytest.fixture
11+
@pytest.fixture(scope="module")
1212
def model():
1313
return AnthropicModel(
1414
client_args={
@@ -19,7 +19,7 @@ def model():
1919
)
2020

2121

22-
@pytest.fixture
22+
@pytest.fixture(scope="module")
2323
def tools():
2424
@strands.tool
2525
def tool_time() -> str:
@@ -32,32 +32,67 @@ def tool_weather() -> str:
3232
return [tool_time, tool_weather]
3333

3434

35-
@pytest.fixture
35+
@pytest.fixture(scope="module")
3636
def system_prompt():
3737
return "You are an AI assistant."
3838

3939

40-
@pytest.fixture
40+
@pytest.fixture(scope="module")
4141
def agent(model, tools, system_prompt):
4242
return Agent(model=model, tools=tools, system_prompt=system_prompt)
4343

4444

45+
@pytest.fixture(scope="module")
46+
def weather():
47+
class Weather(BaseModel):
48+
"""Extracts the time and weather from the user's message with the exact strings."""
49+
50+
time: str
51+
weather: str
52+
53+
return Weather(time="12:00", weather="sunny")
54+
55+
4556
@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
46-
def test_agent(agent):
57+
def test_agent_invoke(agent):
4758
result = agent("What is the time and weather in New York?")
4859
text = result.message["content"][0]["text"].lower()
4960

5061
assert all(string in text for string in ["12:00", "sunny"])
5162

5263

5364
@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
54-
def test_structured_output(model):
55-
class Weather(BaseModel):
56-
time: str
57-
weather: str
65+
@pytest.mark.asyncio
66+
async def test_agent_invoke_async(agent):
67+
result = await agent.invoke_async("What is the time and weather in New York?")
68+
text = result.message["content"][0]["text"].lower()
69+
70+
assert all(string in text for string in ["12:00", "sunny"])
5871

59-
agent = Agent(model=model)
60-
result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
61-
assert isinstance(result, Weather)
62-
assert result.time == "12:00"
63-
assert result.weather == "sunny"
72+
73+
@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
74+
@pytest.mark.asyncio
75+
async def test_agent_stream_async(agent):
76+
stream = agent.stream_async("What is the time and weather in New York?")
77+
async for event in stream:
78+
_ = event
79+
80+
result = event["result"]
81+
text = result.message["content"][0]["text"].lower()
82+
83+
assert all(string in text for string in ["12:00", "sunny"])
84+
85+
86+
@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
87+
def test_structured_output(agent, weather):
88+
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
89+
exp_weather = weather
90+
assert tru_weather == exp_weather
91+
92+
93+
@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
94+
@pytest.mark.asyncio
95+
async def test_agent_structured_output_async(agent, weather):
96+
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
97+
exp_weather = weather
98+
assert tru_weather == exp_weather

tests/strands/models/test_anthropic.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
@pytest.fixture
1313
def anthropic_client():
14-
with unittest.mock.patch.object(strands.models.anthropic.anthropic, "Anthropic") as mock_client_cls:
14+
with unittest.mock.patch.object(strands.models.anthropic.anthropic, "AsyncAnthropic") as mock_client_cls:
1515
yield mock_client_cls.return_value
1616

1717

@@ -632,7 +632,7 @@ def test_format_chunk_unknown(model):
632632

633633

634634
@pytest.mark.asyncio
635-
async def test_stream(anthropic_client, model, alist):
635+
async def test_stream(anthropic_client, model, agenerator, alist):
636636
mock_event_1 = unittest.mock.Mock(
637637
type="message_start",
638638
dict=lambda: {"type": "message_start"},
@@ -653,9 +653,9 @@ async def test_stream(anthropic_client, model, alist):
653653
),
654654
)
655655

656-
mock_stream = unittest.mock.MagicMock()
657-
mock_stream.__iter__.return_value = iter([mock_event_1, mock_event_2, mock_event_3])
658-
anthropic_client.messages.stream.return_value.__enter__.return_value = mock_stream
656+
mock_context = unittest.mock.AsyncMock()
657+
mock_context.__aenter__.return_value = agenerator([mock_event_1, mock_event_2, mock_event_3])
658+
anthropic_client.messages.stream.return_value = mock_context
659659

660660
request = {"model": "m1"}
661661
response = model.stream(request)
@@ -712,7 +712,7 @@ async def test_stream_bad_request_error(anthropic_client, model):
712712

713713

714714
@pytest.mark.asyncio
715-
async def test_structured_output(anthropic_client, model, test_output_model_cls, alist):
715+
async def test_structured_output(anthropic_client, model, test_output_model_cls, agenerator, alist):
716716
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
717717

718718
events = [
@@ -756,9 +756,9 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls,
756756
),
757757
]
758758

759-
mock_stream = unittest.mock.MagicMock()
760-
mock_stream.__iter__.return_value = iter(events)
761-
anthropic_client.messages.stream.return_value.__enter__.return_value = mock_stream
759+
mock_context = unittest.mock.AsyncMock()
760+
mock_context.__aenter__.return_value = agenerator(events)
761+
anthropic_client.messages.stream.return_value = mock_context
762762

763763
stream = model.structured_output(test_output_model_cls, messages)
764764
events = await alist(stream)

0 commit comments

Comments
 (0)