diff --git a/newrelic/api/error_trace.py b/newrelic/api/error_trace.py index db63c54316..aaa12b50e3 100644 --- a/newrelic/api/error_trace.py +++ b/newrelic/api/error_trace.py @@ -15,6 +15,7 @@ import functools from newrelic.api.time_trace import current_trace, notice_error +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object @@ -43,17 +44,31 @@ def __exit__(self, exc, value, tb): ) -def ErrorTraceWrapper(wrapped, ignore=None, expected=None, status_code=None): - def wrapper(wrapped, instance, args, kwargs): - parent = current_trace() +def ErrorTraceWrapper(wrapped, ignore=None, expected=None, status_code=None, async_wrapper=None): + def literal_wrapper(wrapped, instance, args, kwargs): + # Determine if the wrapped function is async or sync + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) + # Sync function path + if not wrapper: + parent = current_trace() + if not parent: + # No active tracing context so just call the wrapped function directly + return wrapped(*args, **kwargs) + # Async function path + else: + # For async functions, the async wrapper will handle trace context propagation + parent = None - if parent is None: - return wrapped(*args, **kwargs) + trace = ErrorTrace(ignore, expected, status_code, parent=parent) + + if wrapper: + # The async wrapper handles the context management for us + return wrapper(wrapped, trace)(*args, **kwargs) - with ErrorTrace(ignore, expected, status_code, parent=parent): + with trace: return wrapped(*args, **kwargs) - return FunctionWrapper(wrapped, wrapper) + return FunctionWrapper(wrapped, literal_wrapper) def error_trace(ignore=None, expected=None, status_code=None): diff --git a/newrelic/common/llm_utils.py b/newrelic/common/llm_utils.py new file mode 100644 index 0000000000..eebdacfc7f --- /dev/null +++ b/newrelic/common/llm_utils.py @@ -0,0 +1,24 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def _get_llm_metadata(transaction): + # Grab LLM-related custom attributes off of the transaction to store as metadata on LLM events + custom_attrs_dict = transaction._custom_params + llm_metadata_dict = {key: value for key, value in custom_attrs_dict.items() if key.startswith("llm.")} + llm_context_attrs = getattr(transaction, "_llm_context_attrs", None) + if llm_context_attrs: + llm_metadata_dict.update(llm_context_attrs) + + return llm_metadata_dict diff --git a/newrelic/config.py b/newrelic/config.py index 21ce996f6c..ff2d85e359 100644 --- a/newrelic/config.py +++ b/newrelic/config.py @@ -2946,6 +2946,13 @@ def _process_module_builtin_defaults(): "newrelic.hooks.mlmodel_autogen", "instrument_autogen_agentchat_agents__assistant_agent", ) + _process_module_definition("strands.agent.agent", "newrelic.hooks.mlmodel_strands", "instrument_agent_agent") + _process_module_definition( + "strands.tools.executors._executor", "newrelic.hooks.mlmodel_strands", "instrument_tools_executors__executor" + ) + _process_module_definition("strands.tools.registry", "newrelic.hooks.mlmodel_strands", "instrument_tools_registry") + _process_module_definition("strands.models.bedrock", "newrelic.hooks.mlmodel_strands", "instrument_models_bedrock") + _process_module_definition("mcp.client.session", "newrelic.hooks.adapter_mcp", "instrument_mcp_client_session") _process_module_definition( "mcp.server.fastmcp.tools.tool_manager", diff --git a/newrelic/hooks/mlmodel_strands.py b/newrelic/hooks/mlmodel_strands.py new file mode 100644 index 0000000000..bf849fd717 --- /dev/null +++ b/newrelic/hooks/mlmodel_strands.py @@ -0,0 +1,492 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +import uuid + +from newrelic.api.error_trace import ErrorTraceWrapper +from newrelic.api.function_trace import FunctionTrace +from newrelic.api.time_trace import current_trace, get_trace_linking_metadata +from newrelic.api.transaction import current_transaction +from newrelic.common.llm_utils import _get_llm_metadata +from newrelic.common.object_names import callable_name +from newrelic.common.object_wrapper import ObjectProxy, wrap_function_wrapper +from newrelic.common.package_version_utils import get_package_version +from newrelic.common.signature import bind_args +from newrelic.core.config import global_settings +from newrelic.core.context import ContextOf + +_logger = logging.getLogger(__name__) +STRANDS_VERSION = get_package_version("strands-agents") + +RECORD_EVENTS_FAILURE_LOG_MESSAGE = "Exception occurred in Strands instrumentation: Failed to record LLM events. Please report this issue to New Relic Support." +TOOL_OUTPUT_FAILURE_LOG_MESSAGE = "Exception occurred in Strands instrumentation: Failed to record output of tool call. Please report this issue to New Relic Support." +AGENT_EVENT_FAILURE_LOG_MESSAGE = "Exception occurred in Strands instrumentation: Failed to record agent data. Please report this issue to New Relic Support." +TOOL_EXTRACTOR_FAILURE_LOG_MESSAGE = "Exception occurred in Strands instrumentation: Failed to extract tool information. If the issue persists, report this issue to New Relic support.\n" + + +def wrap_agent__call__(wrapped, instance, args, kwargs): + trace = current_trace() + if not trace: + return wrapped(*args, **kwargs) + + try: + bound_args = bind_args(wrapped, args, kwargs) + # Make a copy of the invocation state before we mutate it + if "invocation_state" in bound_args: + invocation_state = bound_args["invocation_state"] = dict(bound_args["invocation_state"] or {}) + + # Attempt to save the current transaction context into the invocation state dictionary + invocation_state["_nr_transaction"] = trace + except Exception: + return wrapped(*args, **kwargs) + else: + return wrapped(**bound_args) + + +async def wrap_agent_invoke_async(wrapped, instance, args, kwargs): + # If there's already a transaction, don't propagate anything here + if current_transaction(): + return await wrapped(*args, **kwargs) + + try: + # Grab the trace context we should be running under and pass it to ContextOf + bound_args = bind_args(wrapped, args, kwargs) + invocation_state = bound_args["invocation_state"] or {} + trace = invocation_state.pop("_nr_transaction", None) + except Exception: + return await wrapped(*args, **kwargs) + + # If we find a transaction to propagate, use it. Otherwise, just call wrapped. + if trace: + with ContextOf(trace=trace): + return await wrapped(*args, **kwargs) + else: + return await wrapped(*args, **kwargs) + + +def wrap_stream_async(wrapped, instance, args, kwargs): + transaction = current_transaction() + if not transaction: + return wrapped(*args, **kwargs) + + settings = transaction.settings or global_settings() + if not settings.ai_monitoring.enabled: + return wrapped(*args, **kwargs) + + # Framework metric also used for entity tagging in the UI + transaction.add_ml_model_info("Strands", STRANDS_VERSION) + transaction._add_agent_attribute("llm", True) + + func_name = callable_name(wrapped) + agent_name = getattr(instance, "name", "agent") + function_trace_name = f"{func_name}/{agent_name}" + + ft = FunctionTrace(name=function_trace_name, group="Llm/agent/Strands") + ft.__enter__() + linking_metadata = get_trace_linking_metadata() + agent_id = str(uuid.uuid4()) + + try: + return_val = wrapped(*args, **kwargs) + except Exception: + raise + + # For streaming responses, wrap with proxy and attach metadata + try: + # For streaming responses, wrap with proxy and attach metadata + proxied_return_val = AsyncGeneratorProxy( + return_val, _record_agent_event_on_stop_iteration, _handle_agent_streaming_completion_error + ) + proxied_return_val._nr_ft = ft + proxied_return_val._nr_metadata = linking_metadata + proxied_return_val._nr_strands_attrs = {"agent_name": agent_name, "agent_id": agent_id} + return proxied_return_val + except Exception: + # If proxy creation fails, clean up the function trace and return original value + ft.__exit__(*sys.exc_info()) + return return_val + + +def _record_agent_event_on_stop_iteration(self, transaction): + if hasattr(self, "_nr_ft"): + # Use saved linking metadata to maintain correct span association + linking_metadata = self._nr_metadata or get_trace_linking_metadata() + self._nr_ft.__exit__(None, None, None) + + try: + strands_attrs = getattr(self, "_nr_strands_attrs", {}) + + # If there are no strands attrs exit early as there's no data to record. + if not strands_attrs: + return + + agent_name = strands_attrs.get("agent_name", "agent") + agent_id = strands_attrs.get("agent_id") + agent_event_dict = _construct_base_agent_event_dict(agent_name, agent_id, transaction, linking_metadata) + agent_event_dict["duration"] = self._nr_ft.duration * 1000 + transaction.record_custom_event("LlmAgent", agent_event_dict) + + except Exception: + _logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True) + finally: + # Clear cached data to prevent memory leaks and duplicate reporting + if hasattr(self, "_nr_strands_attrs"): + self._nr_strands_attrs.clear() + + +def _record_tool_event_on_stop_iteration(self, transaction): + if hasattr(self, "_nr_ft"): + # Use saved linking metadata to maintain correct span association + linking_metadata = self._nr_metadata or get_trace_linking_metadata() + self._nr_ft.__exit__(None, None, None) + + try: + strands_attrs = getattr(self, "_nr_strands_attrs", {}) + + # If there are no strands attrs exit early as there's no data to record. + if not strands_attrs: + return + + try: + tool_results = strands_attrs.get("tool_results", []) + except Exception: + tool_results = None + _logger.warning(TOOL_OUTPUT_FAILURE_LOG_MESSAGE, exc_info=True) + + tool_event_dict = _construct_base_tool_event_dict( + strands_attrs, tool_results, transaction, linking_metadata + ) + tool_event_dict["duration"] = self._nr_ft.duration * 1000 + transaction.record_custom_event("LlmTool", tool_event_dict) + + except Exception: + _logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True) + finally: + # Clear cached data to prevent memory leaks and duplicate reporting + if hasattr(self, "_nr_strands_attrs"): + self._nr_strands_attrs.clear() + + +def _construct_base_tool_event_dict(strands_attrs, tool_results, transaction, linking_metadata): + try: + try: + tool_output = tool_results[-1]["content"][0] if tool_results else None + error = tool_results[-1]["status"] == "error" + except Exception: + tool_output = None + error = False + _logger.warning(TOOL_OUTPUT_FAILURE_LOG_MESSAGE, exc_info=True) + + tool_name = strands_attrs.get("tool_name", "tool") + tool_id = strands_attrs.get("tool_id") + run_id = strands_attrs.get("run_id") + tool_input = strands_attrs.get("tool_input") + agent_name = strands_attrs.get("agent_name", "agent") + settings = transaction.settings or global_settings() + + tool_event_dict = { + "id": tool_id, + "run_id": run_id, + "name": tool_name, + "span_id": linking_metadata.get("span.id"), + "trace_id": linking_metadata.get("trace.id"), + "agent_name": agent_name, + "vendor": "strands", + "ingest_source": "Python", + } + # Set error flag if the status shows an error was caught, + # it will be reported further down in the instrumentation. + if error: + tool_event_dict["error"] = True + + if settings.ai_monitoring.record_content.enabled: + tool_event_dict["input"] = tool_input + # In error cases, the output will hold the error message + tool_event_dict["output"] = tool_output + tool_event_dict.update(_get_llm_metadata(transaction)) + except Exception: + tool_event_dict = {} + _logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True) + + return tool_event_dict + + +def _construct_base_agent_event_dict(agent_name, agent_id, transaction, linking_metadata): + try: + agent_event_dict = { + "id": agent_id, + "name": agent_name, + "span_id": linking_metadata.get("span.id"), + "trace_id": linking_metadata.get("trace.id"), + "vendor": "strands", + "ingest_source": "Python", + } + agent_event_dict.update(_get_llm_metadata(transaction)) + except Exception: + _logger.warning(AGENT_EVENT_FAILURE_LOG_MESSAGE, exc_info=True) + agent_event_dict = {} + + return agent_event_dict + + +def _handle_agent_streaming_completion_error(self, transaction): + if hasattr(self, "_nr_ft"): + strands_attrs = getattr(self, "_nr_strands_attrs", {}) + + # If there are no strands attrs exit early as there's no data to record. + if not strands_attrs: + self._nr_ft.__exit__(*sys.exc_info()) + return + + # Use saved linking metadata to maintain correct span association + linking_metadata = self._nr_metadata or get_trace_linking_metadata() + + try: + agent_name = strands_attrs.get("agent_name", "agent") + agent_id = strands_attrs.get("agent_id") + + # Notice the error on the function trace + self._nr_ft.notice_error(attributes={"agent_id": agent_id}) + self._nr_ft.__exit__(*sys.exc_info()) + + # Create error event + agent_event_dict = _construct_base_agent_event_dict(agent_name, agent_id, transaction, linking_metadata) + agent_event_dict.update({"duration": self._nr_ft.duration * 1000, "error": True}) + transaction.record_custom_event("LlmAgent", agent_event_dict) + + except Exception: + _logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True) + finally: + # Clear cached data to prevent memory leaks + if hasattr(self, "_nr_strands_attrs"): + self._nr_strands_attrs.clear() + + +def _handle_tool_streaming_completion_error(self, transaction): + if hasattr(self, "_nr_ft"): + strands_attrs = getattr(self, "_nr_strands_attrs", {}) + + # If there are no strands attrs exit early as there's no data to record. + if not strands_attrs: + self._nr_ft.__exit__(*sys.exc_info()) + return + + # Use saved linking metadata to maintain correct span association + linking_metadata = self._nr_metadata or get_trace_linking_metadata() + + try: + tool_id = strands_attrs.get("tool_id") + + # We expect this to never have any output since this is an error case, + # but if it does we will report it. + try: + tool_results = strands_attrs.get("tool_results", []) + except Exception: + tool_results = None + _logger.warning(TOOL_OUTPUT_FAILURE_LOG_MESSAGE, exc_info=True) + + # Notice the error on the function trace + self._nr_ft.notice_error(attributes={"tool_id": tool_id}) + self._nr_ft.__exit__(*sys.exc_info()) + + # Create error event + tool_event_dict = _construct_base_tool_event_dict( + strands_attrs, tool_results, transaction, linking_metadata + ) + tool_event_dict["duration"] = self._nr_ft.duration * 1000 + # Ensure error flag is set to True in case the tool_results did not indicate an error + if "error" not in tool_event_dict: + tool_event_dict["error"] = True + + transaction.record_custom_event("LlmTool", tool_event_dict) + + except Exception: + _logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True) + finally: + # Clear cached data to prevent memory leaks + if hasattr(self, "_nr_strands_attrs"): + self._nr_strands_attrs.clear() + + +def wrap_tool_executor__stream(wrapped, instance, args, kwargs): + transaction = current_transaction() + if not transaction: + return wrapped(*args, **kwargs) + + settings = transaction.settings or global_settings() + if not settings.ai_monitoring.enabled: + return wrapped(*args, **kwargs) + + # Framework metric also used for entity tagging in the UI + transaction.add_ml_model_info("Strands", STRANDS_VERSION) + transaction._add_agent_attribute("llm", True) + + # Grab tool data + try: + bound_args = bind_args(wrapped, args, kwargs) + agent_name = getattr(bound_args.get("agent"), "name", "agent") + tool_use = bound_args.get("tool_use", {}) + + run_id = tool_use.get("toolUseId", "") + tool_name = tool_use.get("name", "tool") + _input = tool_use.get("input") + tool_input = str(_input) if _input else None + tool_results = bound_args.get("tool_results", []) + except Exception: + tool_name = "tool" + _logger.warning(TOOL_EXTRACTOR_FAILURE_LOG_MESSAGE, exc_info=True) + + func_name = callable_name(wrapped) + function_trace_name = f"{func_name}/{tool_name}" + + ft = FunctionTrace(name=function_trace_name, group="Llm/tool/Strands") + ft.__enter__() + linking_metadata = get_trace_linking_metadata() + tool_id = str(uuid.uuid4()) + + try: + return_val = wrapped(*args, **kwargs) + except Exception: + raise + + try: + # Wrap return value with proxy and attach metadata for later access + proxied_return_val = AsyncGeneratorProxy( + return_val, _record_tool_event_on_stop_iteration, _handle_tool_streaming_completion_error + ) + proxied_return_val._nr_ft = ft + proxied_return_val._nr_metadata = linking_metadata + proxied_return_val._nr_strands_attrs = { + "tool_results": tool_results, + "tool_name": tool_name, + "tool_id": tool_id, + "run_id": run_id, + "tool_input": tool_input, + "agent_name": agent_name, + } + return proxied_return_val + except Exception: + # If proxy creation fails, clean up the function trace and return original value + ft.__exit__(*sys.exc_info()) + return return_val + + +class AsyncGeneratorProxy(ObjectProxy): + def __init__(self, wrapped, on_stop_iteration, on_error): + super().__init__(wrapped) + self._nr_on_stop_iteration = on_stop_iteration + self._nr_on_error = on_error + + def __aiter__(self): + self._nr_wrapped_iter = self.__wrapped__.__aiter__() + return self + + async def __anext__(self): + transaction = current_transaction() + if not transaction: + return await self._nr_wrapped_iter.__anext__() + + return_val = None + try: + return_val = await self._nr_wrapped_iter.__anext__() + except StopAsyncIteration: + self._nr_on_stop_iteration(self, transaction) + raise + except Exception: + self._nr_on_error(self, transaction) + raise + return return_val + + async def aclose(self): + return await super().aclose() + + +def wrap_ToolRegister_register_tool(wrapped, instance, args, kwargs): + bound_args = bind_args(wrapped, args, kwargs) + bound_args["tool"]._tool_func = ErrorTraceWrapper(bound_args["tool"]._tool_func) + return wrapped(*args, **kwargs) + + +def wrap_bedrock_model_stream(wrapped, instance, args, kwargs): + """Stores trace context on the messages argument to be retrieved by the _stream() instrumentation.""" + trace = current_trace() + if not trace: + return wrapped(*args, **kwargs) + + settings = trace.settings or global_settings() + if not settings.ai_monitoring.enabled: + return wrapped(*args, **kwargs) + + try: + bound_args = bind_args(wrapped, args, kwargs) + except Exception: + return wrapped(*args, **kwargs) + + if "messages" in bound_args and isinstance(bound_args["messages"], list): + bound_args["messages"].append({"newrelic_trace": trace}) + + return wrapped(*args, **kwargs) + + +def wrap_bedrock_model__stream(wrapped, instance, args, kwargs): + """Retrieves trace context stored on the messages argument and propagates it to the new thread.""" + try: + bound_args = bind_args(wrapped, args, kwargs) + except Exception: + return wrapped(*args, **kwargs) + + if ( + "messages" in bound_args + and isinstance(bound_args["messages"], list) + and bound_args["messages"] # non-empty list + and "newrelic_trace" in bound_args["messages"][-1] + ): + trace_message = bound_args["messages"].pop() + with ContextOf(trace=trace_message["newrelic_trace"]): + return wrapped(*args, **kwargs) + + return wrapped(*args, **kwargs) + + +def instrument_agent_agent(module): + if hasattr(module, "Agent"): + if hasattr(module.Agent, "__call__"): # noqa: B004 + wrap_function_wrapper(module, "Agent.__call__", wrap_agent__call__) + if hasattr(module.Agent, "invoke_async"): + wrap_function_wrapper(module, "Agent.invoke_async", wrap_agent_invoke_async) + if hasattr(module.Agent, "stream_async"): + wrap_function_wrapper(module, "Agent.stream_async", wrap_stream_async) + + +def instrument_tools_executors__executor(module): + if hasattr(module, "ToolExecutor"): + if hasattr(module.ToolExecutor, "_stream"): + wrap_function_wrapper(module, "ToolExecutor._stream", wrap_tool_executor__stream) + + +def instrument_tools_registry(module): + if hasattr(module, "ToolRegistry"): + if hasattr(module.ToolRegistry, "register_tool"): + wrap_function_wrapper(module, "ToolRegistry.register_tool", wrap_ToolRegister_register_tool) + + +def instrument_models_bedrock(module): + # This instrumentation only exists to pass trace context due to bedrock models using a separate thread. + if hasattr(module, "BedrockModel"): + if hasattr(module.BedrockModel, "stream"): + wrap_function_wrapper(module, "BedrockModel.stream", wrap_bedrock_model_stream) + if hasattr(module.BedrockModel, "_stream"): + wrap_function_wrapper(module, "BedrockModel._stream", wrap_bedrock_model__stream) diff --git a/tests/mlmodel_strands/_mock_model_provider.py b/tests/mlmodel_strands/_mock_model_provider.py index e4c9e79930..ef60e13bad 100644 --- a/tests/mlmodel_strands/_mock_model_provider.py +++ b/tests/mlmodel_strands/_mock_model_provider.py @@ -41,7 +41,7 @@ def __init__(self, agent_responses): def format_chunk(self, event): return event - def format_request(self, messages, tool_specs=None, system_prompt=None): + def format_request(self, messages, tool_specs=None, system_prompt=None, **kwargs): return None def get_config(self): @@ -53,7 +53,7 @@ def update_config(self, **model_config): async def structured_output(self, output_model, prompt, system_prompt=None, **kwargs): pass - async def stream(self, messages, tool_specs=None, system_prompt=None): + async def stream(self, messages, tool_specs=None, system_prompt=None, **kwargs): events = self.map_agent_message_to_events(self.agent_responses[self.index]) for event in events: yield event diff --git a/tests/mlmodel_strands/conftest.py b/tests/mlmodel_strands/conftest.py index b810161f6a..a2ad9b8dd0 100644 --- a/tests/mlmodel_strands/conftest.py +++ b/tests/mlmodel_strands/conftest.py @@ -14,6 +14,7 @@ import pytest from _mock_model_provider import MockedModelProvider +from testing_support.fixture.event_loop import event_loop as loop from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture from testing_support.ml_testing_utils import set_trace_info @@ -50,15 +51,33 @@ def single_tool_model(): @pytest.fixture -def single_tool_model_error(): +def single_tool_model_runtime_error_coro(): model = MockedModelProvider( [ { "role": "assistant", "content": [ - {"text": "Calling add_exclamation tool"}, + {"text": "Calling throw_exception_coro tool"}, + # Set arguments to an invalid type to trigger error in tool + {"toolUse": {"name": "throw_exception_coro", "toolUseId": "123", "input": {"message": "Hello"}}}, + ], + }, + {"role": "assistant", "content": [{"text": "Success!"}]}, + ] + ) + return model + + +@pytest.fixture +def single_tool_model_runtime_error_agen(): + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + {"text": "Calling throw_exception_agen tool"}, # Set arguments to an invalid type to trigger error in tool - {"toolUse": {"name": "add_exclamation", "toolUseId": "123", "input": {"message": 12}}}, + {"toolUse": {"name": "throw_exception_agen", "toolUseId": "123", "input": {"message": "Hello"}}}, ], }, {"role": "assistant", "content": [{"text": "Success!"}]}, diff --git a/tests/mlmodel_strands/test_agent.py b/tests/mlmodel_strands/test_agent.py new file mode 100644 index 0000000000..af685668ad --- /dev/null +++ b/tests/mlmodel_strands/test_agent.py @@ -0,0 +1,427 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from strands import Agent, tool +from testing_support.fixtures import reset_core_stats_engine, validate_attributes +from testing_support.ml_testing_utils import ( + disabled_ai_monitoring_record_content_settings, + disabled_ai_monitoring_settings, + events_with_context_attrs, + tool_events_sans_content, +) +from testing_support.validators.validate_custom_event import validate_custom_event_count +from testing_support.validators.validate_custom_events import validate_custom_events +from testing_support.validators.validate_error_trace_attributes import validate_error_trace_attributes +from testing_support.validators.validate_transaction_error_event_count import validate_transaction_error_event_count +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics + +from newrelic.api.background_task import background_task +from newrelic.api.llm_custom_attributes import WithLlmCustomAttributes +from newrelic.common.object_names import callable_name +from newrelic.common.object_wrapper import transient_function_wrapper + +tool_recorded_event = [ + ( + {"type": "LlmTool"}, + { + "id": None, + "run_id": "123", + "output": "{'text': 'Hello!'}", + "name": "add_exclamation", + "agent_name": "my_agent", + "span_id": None, + "trace_id": "trace-id", + "input": "{'message': 'Hello'}", + "vendor": "strands", + "ingest_source": "Python", + "duration": None, + }, + ) +] + +tool_recorded_event_forced_internal_error = [ + ( + {"type": "LlmTool"}, + { + "id": None, + "run_id": "123", + "name": "add_exclamation", + "agent_name": "my_agent", + "span_id": None, + "trace_id": "trace-id", + "input": "{'message': 'Hello'}", + "vendor": "strands", + "ingest_source": "Python", + "duration": None, + "error": True, + }, + ) +] + +tool_recorded_event_error_coro = [ + ( + {"type": "LlmTool"}, + { + "id": None, + "run_id": "123", + "name": "throw_exception_coro", + "agent_name": "my_agent", + "span_id": None, + "trace_id": "trace-id", + "input": "{'message': 'Hello'}", + "vendor": "strands", + "ingest_source": "Python", + "error": True, + "output": "{'text': 'Error: RuntimeError - Oops'}", + "duration": None, + }, + ) +] + + +tool_recorded_event_error_agen = [ + ( + {"type": "LlmTool"}, + { + "id": None, + "run_id": "123", + "name": "throw_exception_agen", + "agent_name": "my_agent", + "span_id": None, + "trace_id": "trace-id", + "input": "{'message': 'Hello'}", + "vendor": "strands", + "ingest_source": "Python", + "error": True, + "output": "{'text': 'Error: RuntimeError - Oops'}", + "duration": None, + }, + ) +] + + +agent_recorded_event = [ + ( + {"type": "LlmAgent"}, + { + "id": None, + "name": "my_agent", + "span_id": None, + "trace_id": "trace-id", + "vendor": "strands", + "ingest_source": "Python", + "duration": None, + }, + ) +] + +agent_recorded_event_error = [ + ( + {"type": "LlmAgent"}, + { + "id": None, + "name": "my_agent", + "span_id": None, + "trace_id": "trace-id", + "vendor": "strands", + "ingest_source": "Python", + "error": True, + "duration": None, + }, + ) +] + + +# Example tool for testing purposes +@tool +async def add_exclamation(message: str) -> str: + return f"{message}!" + + +@tool +async def throw_exception_coro(message: str) -> str: + raise RuntimeError("Oops") + + +@tool +async def throw_exception_agen(message: str) -> str: + raise RuntimeError("Oops") + yield + + +@reset_core_stats_engine() +@validate_custom_events(events_with_context_attrs(tool_recorded_event)) +@validate_custom_events(events_with_context_attrs(agent_recorded_event)) +@validate_custom_event_count(count=2) +@validate_transaction_metrics( + "test_agent:test_agent_invoke", + scoped_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/add_exclamation", 1), + ], + rollup_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/add_exclamation", 1), + ], + background_task=True, +) +@validate_attributes("agent", ["llm"]) +@background_task() +def test_agent_invoke(set_trace_info, single_tool_model): + set_trace_info() + my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation]) + + with WithLlmCustomAttributes({"context": "attr"}): + response = my_agent('Add an exclamation to the word "Hello"') + assert response.message["content"][0]["text"] == "Success!" + assert response.metrics.tool_metrics["add_exclamation"].success_count == 1 + + +@reset_core_stats_engine() +@validate_custom_events(tool_recorded_event) +@validate_custom_events(agent_recorded_event) +@validate_custom_event_count(count=2) +@validate_transaction_metrics( + "test_agent:test_agent_invoke_async", + scoped_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/add_exclamation", 1), + ], + rollup_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/add_exclamation", 1), + ], + background_task=True, +) +@validate_attributes("agent", ["llm"]) +@background_task() +def test_agent_invoke_async(loop, set_trace_info, single_tool_model): + set_trace_info() + my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation]) + + async def _test(): + response = await my_agent.invoke_async('Add an exclamation to the word "Hello"') + assert response.message["content"][0]["text"] == "Success!" + assert response.metrics.tool_metrics["add_exclamation"].success_count == 1 + + loop.run_until_complete(_test()) + + +@reset_core_stats_engine() +@validate_custom_events(tool_recorded_event) +@validate_custom_events(agent_recorded_event) +@validate_custom_event_count(count=2) +@validate_transaction_metrics( + "test_agent:test_agent_stream_async", + scoped_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/add_exclamation", 1), + ], + rollup_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/add_exclamation", 1), + ], + background_task=True, +) +@validate_attributes("agent", ["llm"]) +@background_task() +def test_agent_stream_async(loop, set_trace_info, single_tool_model): + set_trace_info() + my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation]) + + async def _test(): + response = my_agent.stream_async('Add an exclamation to the word "Hello"') + messages = [event["message"]["content"] async for event in response if "message" in event] + + assert len(messages) == 3 + assert messages[0][0]["text"] == "Calling add_exclamation tool" + assert messages[0][1]["toolUse"]["name"] == "add_exclamation" + assert messages[1][0]["toolResult"]["content"][0]["text"] == "Hello!" + assert messages[2][0]["text"] == "Success!" + + loop.run_until_complete(_test()) + + +@reset_core_stats_engine() +@disabled_ai_monitoring_record_content_settings +@validate_custom_events(agent_recorded_event) +@validate_custom_events(tool_events_sans_content(tool_recorded_event)) +@validate_custom_event_count(count=2) +@validate_transaction_metrics( + "test_agent:test_agent_invoke_no_content", + scoped_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/add_exclamation", 1), + ], + rollup_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/add_exclamation", 1), + ], + background_task=True, +) +@validate_attributes("agent", ["llm"]) +@background_task() +def test_agent_invoke_no_content(set_trace_info, single_tool_model): + set_trace_info() + my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation]) + + response = my_agent('Add an exclamation to the word "Hello"') + assert response.message["content"][0]["text"] == "Success!" + assert response.metrics.tool_metrics["add_exclamation"].success_count == 1 + + +@disabled_ai_monitoring_settings +@reset_core_stats_engine() +@validate_custom_event_count(count=0) +@background_task() +def test_agent_invoke_disabled_ai_monitoring_events(set_trace_info, single_tool_model): + set_trace_info() + my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation]) + + response = my_agent('Add an exclamation to the word "Hello"') + assert response.message["content"][0]["text"] == "Success!" + assert response.metrics.tool_metrics["add_exclamation"].success_count == 1 + + +@reset_core_stats_engine() +@validate_transaction_error_event_count(1) +@validate_error_trace_attributes(callable_name(ValueError), exact_attrs={"agent": {}, "intrinsic": {}, "user": {}}) +@validate_custom_events(agent_recorded_event_error) +@validate_custom_event_count(count=1) +@validate_transaction_metrics( + "test_agent:test_agent_invoke_error", + scoped_metrics=[("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1)], + rollup_metrics=[("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1)], + background_task=True, +) +@validate_attributes("agent", ["llm"]) +@background_task() +def test_agent_invoke_error(set_trace_info, single_tool_model): + # Add a wrapper to intentionally force an error in the Agent code + @transient_function_wrapper("strands.agent.agent", "Agent._convert_prompt_to_messages") + def _wrap_convert_prompt_to_messages(wrapped, instance, args, kwargs): + raise ValueError("Oops") + + @_wrap_convert_prompt_to_messages + def _test(): + set_trace_info() + my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation]) + my_agent('Add an exclamation to the word "Hello"') # raises ValueError + + with pytest.raises(ValueError): + _test() + + +@reset_core_stats_engine() +@validate_transaction_error_event_count(1) +@validate_error_trace_attributes(callable_name(RuntimeError), exact_attrs={"agent": {}, "intrinsic": {}, "user": {}}) +@validate_custom_events(tool_recorded_event_error_coro) +@validate_custom_event_count(count=2) +@validate_transaction_metrics( + "test_agent:test_agent_invoke_tool_coro_runtime_error", + scoped_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/throw_exception_coro", 1), + ], + rollup_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/throw_exception_coro", 1), + ], + background_task=True, +) +@validate_attributes("agent", ["llm"]) +@background_task() +def test_agent_invoke_tool_coro_runtime_error(set_trace_info, single_tool_model_runtime_error_coro): + set_trace_info() + my_agent = Agent(name="my_agent", model=single_tool_model_runtime_error_coro, tools=[throw_exception_coro]) + + response = my_agent('Add an exclamation to the word "Hello"') + assert response.message["content"][0]["text"] == "Success!" + assert response.metrics.tool_metrics["throw_exception_coro"].error_count == 1 + + +@reset_core_stats_engine() +@validate_transaction_error_event_count(1) +@validate_error_trace_attributes(callable_name(RuntimeError), exact_attrs={"agent": {}, "intrinsic": {}, "user": {}}) +@validate_custom_events(tool_recorded_event_error_agen) +@validate_custom_event_count(count=2) +@validate_transaction_metrics( + "test_agent:test_agent_invoke_tool_agen_runtime_error", + scoped_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/throw_exception_agen", 1), + ], + rollup_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/throw_exception_agen", 1), + ], + background_task=True, +) +@validate_attributes("agent", ["llm"]) +@background_task() +def test_agent_invoke_tool_agen_runtime_error(set_trace_info, single_tool_model_runtime_error_agen): + set_trace_info() + my_agent = Agent(name="my_agent", model=single_tool_model_runtime_error_agen, tools=[throw_exception_agen]) + + response = my_agent('Add an exclamation to the word "Hello"') + assert response.message["content"][0]["text"] == "Success!" + assert response.metrics.tool_metrics["throw_exception_agen"].error_count == 1 + + +@reset_core_stats_engine() +@validate_transaction_error_event_count(1) +@validate_error_trace_attributes(callable_name(ValueError), exact_attrs={"agent": {}, "intrinsic": {}, "user": {}}) +@validate_custom_events(agent_recorded_event) +@validate_custom_events(tool_recorded_event_forced_internal_error) +@validate_custom_event_count(count=2) +@validate_transaction_metrics( + "test_agent:test_agent_tool_forced_exception", + scoped_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/add_exclamation", 1), + ], + rollup_metrics=[ + ("Llm/agent/Strands/strands.agent.agent:Agent.stream_async/my_agent", 1), + ("Llm/tool/Strands/strands.tools.executors._executor:ToolExecutor._stream/add_exclamation", 1), + ], + background_task=True, +) +@validate_attributes("agent", ["llm"]) +@background_task() +def test_agent_tool_forced_exception(set_trace_info, single_tool_model): + # Add a wrapper to intentionally force an error in the ToolExecutor._stream code to hit the exception path in + # the AsyncGeneratorProxy + @transient_function_wrapper("strands.hooks.events", "BeforeToolCallEvent.__init__") + def _wrap_BeforeToolCallEvent_init(wrapped, instance, args, kwargs): + raise ValueError("Oops") + + @_wrap_BeforeToolCallEvent_init + def _test(): + set_trace_info() + my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation]) + my_agent('Add an exclamation to the word "Hello"') + + # This will not explicitly raise a ValueError when running the test but we are still able to capture it in the error trace + _test() + + +@reset_core_stats_engine() +@validate_custom_event_count(count=0) +def test_agent_invoke_outside_txn(single_tool_model): + my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation]) + + response = my_agent('Add an exclamation to the word "Hello"') + assert response.message["content"][0]["text"] == "Success!" + assert response.metrics.tool_metrics["add_exclamation"].success_count == 1 diff --git a/tests/mlmodel_strands/test_simple.py b/tests/mlmodel_strands/test_simple.py deleted file mode 100644 index ae24003fab..0000000000 --- a/tests/mlmodel_strands/test_simple.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2010 New Relic, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from strands import Agent, tool - -from newrelic.api.background_task import background_task - - -# Example tool for testing purposes -@tool -def add_exclamation(message: str) -> str: - return f"{message}!" - - -# TODO: Remove this file once all real tests are in place - - -@background_task() -def test_simple_run_agent(set_trace_info, single_tool_model): - set_trace_info() - my_agent = Agent(name="my_agent", model=single_tool_model, tools=[add_exclamation]) - - response = my_agent("Run the tools.") - assert response.message["content"][0]["text"] == "Success!" - assert response.metrics.tool_metrics["add_exclamation"].success_count == 1 diff --git a/tests/testing_support/fixtures.py b/tests/testing_support/fixtures.py index 3d93e06e30..540e44f70c 100644 --- a/tests/testing_support/fixtures.py +++ b/tests/testing_support/fixtures.py @@ -797,7 +797,7 @@ def _bind_params(transaction, *args, **kwargs): transaction = _bind_params(*args, **kwargs) error_events = transaction.error_events(instance.stats_table) - assert len(error_events) == num_errors + assert len(error_events) == num_errors, f"Expected: {num_errors}, Got: {len(error_events)}" for sample in error_events: assert isinstance(sample, list) assert len(sample) == 3 diff --git a/tests/testing_support/validators/validate_custom_event.py b/tests/testing_support/validators/validate_custom_event.py index deeef7fb25..c3cf78032a 100644 --- a/tests/testing_support/validators/validate_custom_event.py +++ b/tests/testing_support/validators/validate_custom_event.py @@ -61,7 +61,9 @@ def _validate_custom_event_count(wrapped, instance, args, kwargs): raise else: stats = core_application_stats_engine(None) - assert stats.custom_events.num_samples == count + assert stats.custom_events.num_samples == count, ( + f"Expected: {count}, Got: {stats.custom_events.num_samples}. Events: {list(stats.custom_events)}" + ) return result diff --git a/tests/testing_support/validators/validate_error_event_collector_json.py b/tests/testing_support/validators/validate_error_event_collector_json.py index d1cec3a558..27ea76f3a3 100644 --- a/tests/testing_support/validators/validate_error_event_collector_json.py +++ b/tests/testing_support/validators/validate_error_event_collector_json.py @@ -52,7 +52,7 @@ def _validate_error_event_collector_json(wrapped, instance, args, kwargs): error_events = decoded_json[2] - assert len(error_events) == num_errors + assert len(error_events) == num_errors, f"Expected: {num_errors}, Got: {len(error_events)}" for event in error_events: # event is an array containing intrinsics, user-attributes, # and agent-attributes diff --git a/tests/testing_support/validators/validate_transaction_error_event_count.py b/tests/testing_support/validators/validate_transaction_error_event_count.py index b41a52330f..f5e8c0b206 100644 --- a/tests/testing_support/validators/validate_transaction_error_event_count.py +++ b/tests/testing_support/validators/validate_transaction_error_event_count.py @@ -28,7 +28,9 @@ def _validate_error_event_on_stats_engine(wrapped, instance, args, kwargs): raise else: error_events = list(instance.error_events) - assert len(error_events) == num_errors + assert len(error_events) == num_errors, ( + f"Expected: {num_errors}, Got: {len(error_events)}. Errors: {error_events}" + ) return result