Skip to content

Commit b18104d

Browse files
authored
Strands Mock Model (#1551)
* Add strands to tox.ini * Add mock models for strands testing * Add simple test file to validate strands mocking
1 parent 59633ac commit b18104d

File tree

4 files changed

+287
-4
lines changed

4 files changed

+287
-4
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2010 New Relic, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Test setup derived from: https://github.com/strands-agents/sdk-python/blob/main/tests/fixtures/mocked_model_provider.py
16+
# strands Apache 2.0 license: https://github.com/strands-agents/sdk-python/blob/main/LICENSE
17+
18+
import json
19+
from typing import TypedDict
20+
21+
from strands.models import Model
22+
23+
24+
class RedactionMessage(TypedDict):
25+
redactedUserContent: str
26+
redactedAssistantContent: str
27+
28+
29+
class MockedModelProvider(Model):
30+
"""A mock implementation of the Model interface for testing purposes.
31+
32+
This class simulates a model provider by returning pre-defined agent responses
33+
in sequence. It implements the Model interface methods and provides functionality
34+
to stream mock responses as events.
35+
"""
36+
37+
def __init__(self, agent_responses):
38+
self.agent_responses = agent_responses
39+
self.index = 0
40+
41+
def format_chunk(self, event):
42+
return event
43+
44+
def format_request(self, messages, tool_specs=None, system_prompt=None):
45+
return None
46+
47+
def get_config(self):
48+
pass
49+
50+
def update_config(self, **model_config):
51+
pass
52+
53+
async def structured_output(self, output_model, prompt, system_prompt=None, **kwargs):
54+
pass
55+
56+
async def stream(self, messages, tool_specs=None, system_prompt=None):
57+
events = self.map_agent_message_to_events(self.agent_responses[self.index])
58+
for event in events:
59+
yield event
60+
61+
self.index += 1
62+
63+
def map_agent_message_to_events(self, agent_message):
64+
stop_reason = "end_turn"
65+
yield {"messageStart": {"role": "assistant"}}
66+
if agent_message.get("redactedAssistantContent"):
67+
yield {"redactContent": {"redactUserContentMessage": agent_message["redactedUserContent"]}}
68+
yield {"contentBlockStart": {"start": {}}}
69+
yield {"contentBlockDelta": {"delta": {"text": agent_message["redactedAssistantContent"]}}}
70+
yield {"contentBlockStop": {}}
71+
stop_reason = "guardrail_intervened"
72+
else:
73+
for content in agent_message["content"]:
74+
if "reasoningContent" in content:
75+
yield {"contentBlockStart": {"start": {}}}
76+
yield {"contentBlockDelta": {"delta": {"reasoningContent": content["reasoningContent"]}}}
77+
yield {"contentBlockStop": {}}
78+
if "text" in content:
79+
yield {"contentBlockStart": {"start": {}}}
80+
yield {"contentBlockDelta": {"delta": {"text": content["text"]}}}
81+
yield {"contentBlockStop": {}}
82+
if "toolUse" in content:
83+
stop_reason = "tool_use"
84+
yield {
85+
"contentBlockStart": {
86+
"start": {
87+
"toolUse": {
88+
"name": content["toolUse"]["name"],
89+
"toolUseId": content["toolUse"]["toolUseId"],
90+
}
91+
}
92+
}
93+
}
94+
yield {
95+
"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(content["toolUse"]["input"])}}}
96+
}
97+
yield {"contentBlockStop": {}}
98+
99+
yield {"messageStop": {"stopReason": stop_reason}}

tests/mlmodel_strands/conftest.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2010 New Relic, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from _mock_model_provider import MockedModelProvider
17+
from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture
18+
from testing_support.ml_testing_utils import set_trace_info
19+
20+
_default_settings = {
21+
"package_reporting.enabled": False, # Turn off package reporting for testing as it causes slowdowns.
22+
"transaction_tracer.explain_threshold": 0.0,
23+
"transaction_tracer.transaction_threshold": 0.0,
24+
"transaction_tracer.stack_trace_threshold": 0.0,
25+
"debug.log_data_collector_payloads": True,
26+
"debug.record_transaction_failure": True,
27+
"ai_monitoring.enabled": True,
28+
}
29+
30+
collector_agent_registration = collector_agent_registration_fixture(
31+
app_name="Python Agent Test (mlmodel_strands)", default_settings=_default_settings
32+
)
33+
34+
35+
@pytest.fixture
36+
def single_tool_model():
37+
model = MockedModelProvider(
38+
[
39+
{
40+
"role": "assistant",
41+
"content": [
42+
{"text": "Calling add_exclamation tool"},
43+
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Hello"}}},
44+
],
45+
},
46+
{"role": "assistant", "content": [{"text": "Success!"}]},
47+
]
48+
)
49+
return model
50+
51+
52+
@pytest.fixture
53+
def single_tool_model_error():
54+
model = MockedModelProvider(
55+
[
56+
{
57+
"role": "assistant",
58+
"content": [
59+
{"text": "Calling add_exclamation tool"},
60+
# Set arguments to an invalid type to trigger error in tool
61+
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": 12}}},
62+
],
63+
},
64+
{"role": "assistant", "content": [{"text": "Success!"}]},
65+
]
66+
)
67+
return model
68+
69+
70+
@pytest.fixture
71+
def multi_tool_model():
72+
model = MockedModelProvider(
73+
[
74+
{
75+
"role": "assistant",
76+
"content": [
77+
{"text": "Calling add_exclamation tool"},
78+
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Hello"}}},
79+
],
80+
},
81+
{
82+
"role": "assistant",
83+
"content": [
84+
{"text": "Calling compute_sum tool"},
85+
{"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 5, "b": 3}}},
86+
],
87+
},
88+
{
89+
"role": "assistant",
90+
"content": [
91+
{"text": "Calling add_exclamation tool"},
92+
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Goodbye"}}},
93+
],
94+
},
95+
{
96+
"role": "assistant",
97+
"content": [
98+
{"text": "Calling compute_sum tool"},
99+
{"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 123, "b": 2}}},
100+
],
101+
},
102+
{"role": "assistant", "content": [{"text": "Success!"}]},
103+
]
104+
)
105+
return model
106+
107+
108+
@pytest.fixture
109+
def multi_tool_model_error():
110+
model = MockedModelProvider(
111+
[
112+
{
113+
"role": "assistant",
114+
"content": [
115+
{"text": "Calling add_exclamation tool"},
116+
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Hello"}}},
117+
],
118+
},
119+
{
120+
"role": "assistant",
121+
"content": [
122+
{"text": "Calling compute_sum tool"},
123+
{"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 5, "b": 3}}},
124+
],
125+
},
126+
{
127+
"role": "assistant",
128+
"content": [
129+
{"text": "Calling add_exclamation tool"},
130+
{"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": "Goodbye"}}},
131+
],
132+
},
133+
{
134+
"role": "assistant",
135+
"content": [
136+
{"text": "Calling compute_sum tool"},
137+
# Set insufficient arguments to trigger error in tool
138+
{"toolUse": {"name": "compute_sum", "toolUseId": "123", "input": {"a": 123}}},
139+
],
140+
},
141+
{"role": "assistant", "content": [{"text": "Success!"}]},
142+
]
143+
)
144+
return model
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2010 New Relic, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from strands import Agent, tool
16+
17+
from newrelic.api.background_task import background_task
18+
19+
20+
# Example tool for testing purposes
21+
@tool
22+
def add_exclamation(message: str) -> str:
23+
return f"{message}!"
24+
25+
26+
# TODO: Remove this file once all real tests are in place
27+
28+
29+
@background_task()
30+
def test_simple_run_agent(set_trace_info, single_tool_model):
31+
set_trace_info()
32+
my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation])
33+
34+
response = my_agent("Run the tools.")
35+
assert response.message["content"][0]["text"] == "Success!"
36+
assert response.metrics.tool_metrics["add_exclamation"].success_count == 1

tox.ini

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ envlist =
182182
python-logger_structlog-{py38,py39,py310,py311,py312,py313,py314,pypy311}-structloglatest,
183183
python-mlmodel_autogen-{py310,py311,py312,py313,py314,pypy311}-autogen061,
184184
python-mlmodel_autogen-{py310,py311,py312,py313,py314,pypy311}-autogenlatest,
185+
python-mlmodel_strands-{py310,py311,py312,py313}-strandslatest,
185186
python-mlmodel_gemini-{py39,py310,py311,py312,py313,py314},
186187
python-mlmodel_langchain-{py39,py310,py311,py312,py313},
187188
;; Package not ready for Python 3.14 (type annotations not updated)
@@ -440,6 +441,8 @@ deps =
440441
mlmodel_langchain: faiss-cpu
441442
mlmodel_langchain: mock
442443
mlmodel_langchain: asyncio
444+
mlmodel_strands: strands-agents[openai]
445+
mlmodel_strands: strands-agents-tools
443446
logger_loguru-logurulatest: loguru
444447
logger_structlog-structloglatest: structlog
445448
messagebroker_pika-pikalatest: pika
@@ -510,6 +513,7 @@ changedir =
510513
application_celery: tests/application_celery
511514
component_djangorestframework: tests/component_djangorestframework
512515
component_flask_rest: tests/component_flask_rest
516+
component_graphenedjango: tests/component_graphenedjango
513517
component_graphqlserver: tests/component_graphqlserver
514518
component_tastypie: tests/component_tastypie
515519
coroutines_asyncio: tests/coroutines_asyncio
@@ -521,26 +525,26 @@ changedir =
521525
datastore_cassandradriver: tests/datastore_cassandradriver
522526
datastore_elasticsearch: tests/datastore_elasticsearch
523527
datastore_firestore: tests/datastore_firestore
524-
datastore_oracledb: tests/datastore_oracledb
525528
datastore_memcache: tests/datastore_memcache
529+
datastore_motor: tests/datastore_motor
526530
datastore_mysql: tests/datastore_mysql
527531
datastore_mysqldb: tests/datastore_mysqldb
532+
datastore_oracledb: tests/datastore_oracledb
528533
datastore_postgresql: tests/datastore_postgresql
529534
datastore_psycopg: tests/datastore_psycopg
530535
datastore_psycopg2: tests/datastore_psycopg2
531536
datastore_psycopg2cffi: tests/datastore_psycopg2cffi
532537
datastore_pylibmc: tests/datastore_pylibmc
533538
datastore_pymemcache: tests/datastore_pymemcache
534-
datastore_motor: tests/datastore_motor
535539
datastore_pymongo: tests/datastore_pymongo
536540
datastore_pymssql: tests/datastore_pymssql
537541
datastore_pymysql: tests/datastore_pymysql
538542
datastore_pyodbc: tests/datastore_pyodbc
539543
datastore_pysolr: tests/datastore_pysolr
540544
datastore_redis: tests/datastore_redis
541545
datastore_rediscluster: tests/datastore_rediscluster
542-
datastore_valkey: tests/datastore_valkey
543546
datastore_sqlite: tests/datastore_sqlite
547+
datastore_valkey: tests/datastore_valkey
544548
external_aiobotocore: tests/external_aiobotocore
545549
external_botocore: tests/external_botocore
546550
external_feedparser: tests/external_feedparser
@@ -561,7 +565,6 @@ changedir =
561565
framework_fastapi: tests/framework_fastapi
562566
framework_flask: tests/framework_flask
563567
framework_graphene: tests/framework_graphene
564-
component_graphenedjango: tests/component_graphenedjango
565568
framework_graphql: tests/framework_graphql
566569
framework_grpc: tests/framework_grpc
567570
framework_pyramid: tests/framework_pyramid
@@ -581,6 +584,7 @@ changedir =
581584
mlmodel_langchain: tests/mlmodel_langchain
582585
mlmodel_openai: tests/mlmodel_openai
583586
mlmodel_sklearn: tests/mlmodel_sklearn
587+
mlmodel_strands: tests/mlmodel_strands
584588
template_genshi: tests/template_genshi
585589
template_jinja2: tests/template_jinja2
586590
template_mako: tests/template_mako

0 commit comments

Comments
 (0)