Skip to content

Commit 729c1c0

Browse files
authored
feat: Toolset Context Managers (#228)
* Add tool lifecycle management for agents * Fixing docs that broke during rebase * Some logic cleanup
1 parent ccac0d7 commit 729c1c0

File tree

4 files changed

+319
-5
lines changed

4 files changed

+319
-5
lines changed

docs/sdk/agent_tools.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Inheriting from this class provides:
2121
- Pydantic's declarative syntax for defining state (fields).
2222
- Automatic application of the `@configurable` decorator.
2323
- A `get_tools` method for discovering methods decorated with `@dreadnode.tool_method`.
24+
- Support for async context management, with automatic re-entrancy handling.
2425

2526
### name
2627

dreadnode/agent/agent.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import re
44
import typing as t
5-
from contextlib import aclosing, asynccontextmanager
5+
from contextlib import AsyncExitStack, aclosing, asynccontextmanager
66
from copy import deepcopy
77
from textwrap import dedent
88

@@ -875,8 +875,16 @@ async def stream(
875875
commit: CommitBehavior = "always",
876876
) -> t.AsyncIterator[t.AsyncGenerator[AgentEvent, None]]:
877877
thread = thread or self.thread
878-
async with aclosing(self._stream_traced(thread, user_input, commit=commit)) as stream:
879-
yield stream
878+
879+
async with AsyncExitStack() as stack:
880+
# Ensure all tools are properly entered if they
881+
# are context managers before we start using them
882+
for tool_container in self.tools:
883+
if hasattr(tool_container, "__aenter__") and hasattr(tool_container, "__aexit__"):
884+
await stack.enter_async_context(tool_container)
885+
886+
async with aclosing(self._stream_traced(thread, user_input, commit=commit)) as stream:
887+
yield stream
880888

881889
async def run(
882890
self,

dreadnode/agent/tools/base.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
2+
import functools
13
import typing as t
24

3-
from pydantic import ConfigDict
5+
from pydantic import ConfigDict, PrivateAttr
46
from rigging import tools
57
from rigging.tools.base import ToolMethod as RiggingToolMethod
68

@@ -171,18 +173,82 @@ class Toolset(Model):
171173
- Pydantic's declarative syntax for defining state (fields).
172174
- Automatic application of the `@configurable` decorator.
173175
- A `get_tools` method for discovering methods decorated with `@dreadnode.tool_method`.
176+
- Support for async context management, with automatic re-entrancy handling.
174177
"""
175178

179+
model_config = ConfigDict(arbitrary_types_allowed=True, use_attribute_docstrings=True)
180+
176181
variant: str | None = None
177182
"""The variant for filtering tools available in this toolset."""
178183

179-
model_config = ConfigDict(arbitrary_types_allowed=True, use_attribute_docstrings=True)
184+
# Context manager magic
185+
_entry_ref_count: int = PrivateAttr(default=0)
186+
_context_handle: object = PrivateAttr(default=None)
187+
_entry_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
180188

181189
@property
182190
def name(self) -> str:
183191
"""The name of the toolset, derived from the class name."""
184192
return self.__class__.__name__
185193

194+
def __init_subclass__(cls, **kwargs: t.Any) -> None:
195+
super().__init_subclass__(**kwargs)
196+
197+
# This essentially ensures that if the Toolset is any kind of context manager,
198+
# it will be re-entrant, and only actually enter/exit once. This means we can
199+
# safely build auto-entry/exit logic into our Agent class without worrying about
200+
# breaking the code if the user happens to enter a toolset manually before using
201+
# it in an agent.
202+
203+
original_aenter = cls.__dict__.get("__aenter__")
204+
original_enter = cls.__dict__.get("__enter__")
205+
original_aexit = cls.__dict__.get("__aexit__")
206+
original_exit = cls.__dict__.get("__exit__")
207+
208+
has_enter = callable(original_aenter) or callable(original_enter)
209+
has_exit = callable(original_aexit) or callable(original_exit)
210+
211+
if has_enter and not has_exit:
212+
raise TypeError(
213+
f"{cls.__name__} defining __aenter__ or __enter__ must also define __aexit__ or __exit__"
214+
)
215+
if has_exit and not has_enter:
216+
raise TypeError(
217+
f"{cls.__name__} defining __aexit__ or __exit__ must also define __aenter__ or __enter__"
218+
)
219+
if original_aenter and original_enter:
220+
raise TypeError(f"{cls.__name__} cannot define both __aenter__ and __enter__")
221+
if original_aexit and original_exit:
222+
raise TypeError(f"{cls.__name__} cannot define both __aexit__ and __exit__")
223+
224+
@functools.wraps(original_aenter or original_enter) # type: ignore[arg-type]
225+
async def aenter_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.Any:
226+
async with self._entry_lock:
227+
if self._entry_ref_count == 0:
228+
handle = None
229+
if original_aenter:
230+
handle = await original_aenter(self, *args, **kwargs)
231+
elif original_enter:
232+
handle = original_enter(self, *args, **kwargs)
233+
self._context_handle = handle if handle is not None else self
234+
self._entry_ref_count += 1
235+
return self._context_handle
236+
237+
cls.__aenter__ = aenter_wrapper # type: ignore[attr-defined]
238+
239+
@functools.wraps(original_aexit or original_exit) # type: ignore[arg-type]
240+
async def aexit_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.Any:
241+
async with self._entry_lock:
242+
self._entry_ref_count -= 1
243+
if self._entry_ref_count == 0:
244+
if original_aexit:
245+
await original_aexit(self, *args, **kwargs)
246+
elif original_exit:
247+
original_exit(self, *args, **kwargs)
248+
self._context_handle = None
249+
250+
cls.__aexit__ = aexit_wrapper # type: ignore[attr-defined]
251+
186252
def get_tools(self, *, variant: str | None = None) -> list[AnyTool]:
187253
variant = variant or self.variant
188254

tests/test_agent_lifecycle.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
import asyncio
2+
import inspect
3+
import typing as t
4+
5+
import pytest
6+
7+
from dreadnode.agent import Agent
8+
from dreadnode.agent.tools import Toolset, tool, tool_method
9+
10+
if t.TYPE_CHECKING:
11+
from dreadnode.agent.tools.base import AnyTool
12+
13+
# This is the state tracker that will record the order of events.
14+
event_log: list[str] = []
15+
16+
17+
class AsyncCMToolSet(Toolset):
18+
"""
19+
Scenario 1: A standard, async-native Toolset.
20+
Tests that __aenter__/__aexit__ are called correctly and only once.
21+
"""
22+
23+
enter_count: int = 0
24+
exit_count: int = 0
25+
26+
async def __aenter__(self) -> "AsyncCMToolSet":
27+
event_log.append("async_tool_enter_start")
28+
await asyncio.sleep(0.01) # Simulate async work
29+
self.enter_count += 1
30+
event_log.append("async_tool_enter_end")
31+
return self
32+
33+
async def __aexit__(self, *args: object) -> None:
34+
event_log.append("async_tool_exit_start")
35+
await asyncio.sleep(0.01)
36+
self.exit_count += 1
37+
event_log.append("async_tool_exit_end")
38+
39+
@tool_method
40+
async def do_work(self) -> str:
41+
"""A sample method for the agent to call."""
42+
event_log.append("async_tool_method_called")
43+
return "async work done"
44+
45+
46+
class SyncCMToolSet(Toolset):
47+
"""
48+
Scenario 2: A Toolset using synchronous __enter__/__exit__.
49+
Tests that our magic bridge correctly calls them in order.
50+
"""
51+
52+
def __enter__(self) -> "SyncCMToolSet":
53+
event_log.append("sync_tool_enter")
54+
return self
55+
56+
def __exit__(self, *args: object) -> None:
57+
event_log.append("sync_tool_exit")
58+
59+
@tool_method
60+
def do_blocking_work(self) -> str:
61+
"""A sample sync method."""
62+
event_log.append("sync_tool_method_called")
63+
return "sync work done"
64+
65+
66+
@tool
67+
async def stateless_tool() -> str:
68+
"""
69+
Scenario 3: A simple, stateless tool.
70+
Tests that the lifecycle manager ignores it.
71+
"""
72+
event_log.append("stateless_tool_called")
73+
return "stateless work done"
74+
75+
76+
class StandaloneCMToolset(Toolset):
77+
"""
78+
Scenario 4 (Revised): A Toolset that acts as a standalone context manager.
79+
It must inherit from Toolset to be preserved by the Agent's validator.
80+
"""
81+
82+
async def __aenter__(self) -> "StandaloneCMToolset":
83+
event_log.append("standalone_cm_enter")
84+
return self
85+
86+
async def __aexit__(self, *args: object) -> None:
87+
event_log.append("standalone_cm_exit")
88+
89+
@tool_method
90+
async def do_standalone_work(self) -> str:
91+
"""The callable part of the tool."""
92+
event_log.append("standalone_cm_called")
93+
return "standalone work done"
94+
95+
96+
class ReturnValueToolSet(Toolset):
97+
"""
98+
Scenario 5: A Toolset whose __aenter__ returns a different object.
99+
Tests that the `as` clause contract is honored.
100+
"""
101+
102+
class Handle:
103+
def __init__(self, message: str) -> None:
104+
self.message = message
105+
106+
async def __aenter__(self) -> "ReturnValueToolSet.Handle":
107+
event_log.append("return_value_tool_enter")
108+
# Return a handle object, NOT self
109+
return self.Handle("special handle")
110+
111+
async def __aexit__(self, *args: object) -> None:
112+
event_log.append("return_value_tool_exit")
113+
114+
115+
# --- Mock Agent to Control Execution ---
116+
117+
118+
class MockAgent(Agent):
119+
"""
120+
An agent override that doesn't call an LLM. Instead, it simulates
121+
a run where it calls every available tool once.
122+
"""
123+
124+
async def _stream_traced( # type: ignore[override]
125+
self,
126+
thread: object, # noqa: ARG002
127+
user_input: str, # noqa: ARG002
128+
*,
129+
commit: bool = True, # noqa: ARG002
130+
) -> t.AsyncIterator[str]:
131+
event_log.append("agent_run_start")
132+
# Simulate calling each tool the agent knows about
133+
for tool_ in self.all_tools:
134+
result = tool_()
135+
if inspect.isawaitable(result):
136+
await result
137+
event_log.append("agent_run_end")
138+
# Yield a dummy event to satisfy the stream consumer
139+
yield "dummy_event"
140+
141+
142+
# --- The Tests ---
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_agent_manages_all_lifecycle_scenarios() -> None:
147+
"""
148+
Main integration test. Verifies that the Agent correctly manages setup,
149+
execution, and teardown for a mix of tool types in the correct order.
150+
"""
151+
event_log.clear()
152+
153+
# 1. Setup our collection of tools
154+
async_tool = AsyncCMToolSet()
155+
sync_tool = SyncCMToolSet()
156+
standalone_toolset = StandaloneCMToolset()
157+
158+
# The list passed to the Agent contains the containers
159+
agent_tools: list[AnyTool | Toolset] = [
160+
async_tool,
161+
sync_tool,
162+
stateless_tool,
163+
standalone_toolset,
164+
]
165+
166+
agent = MockAgent(name="test_agent", tools=agent_tools)
167+
168+
# 2. Execute the agent run within its stream context
169+
async with agent.stream("test input") as stream:
170+
event_log.append("stream_context_active")
171+
async for _ in stream:
172+
pass # Consume the stream to trigger the run
173+
174+
# 3. Assert the order of events
175+
expected_order = [
176+
# Setup phase (order of entry is guaranteed by list order)
177+
"async_tool_enter_start",
178+
"async_tool_enter_end",
179+
"sync_tool_enter",
180+
"standalone_cm_enter",
181+
# Agent execution phase
182+
"stream_context_active",
183+
"agent_run_start",
184+
"agent_run_end",
185+
# Teardown phase (must be LIFO)
186+
"standalone_cm_exit",
187+
"sync_tool_exit",
188+
"async_tool_exit_start",
189+
"async_tool_exit_end",
190+
]
191+
192+
# Extract the tool call events to check for presence separately
193+
run_events = [e for e in event_log if e.endswith("_called")]
194+
actual_order_without_run_events = [e for e in event_log if not e.endswith("_called")]
195+
196+
assert actual_order_without_run_events == expected_order
197+
assert sorted(run_events) == sorted(
198+
[
199+
"async_tool_method_called",
200+
"sync_tool_method_called",
201+
"stateless_tool_called",
202+
"standalone_cm_called",
203+
]
204+
)
205+
206+
# 4. Assert idempotency (enter/exit should have been called only once)
207+
assert async_tool.enter_count == 1
208+
assert async_tool.exit_count == 1
209+
210+
211+
@pytest.mark.asyncio
212+
async def test_toolset_idempotency_wrapper() -> None:
213+
"""
214+
A tight unit test to verify that our wrapper magic correctly
215+
prevents a toolset from being entered more than once.
216+
"""
217+
tool = AsyncCMToolSet()
218+
219+
# Nesting the context manager simulates the agent entering it
220+
# after the user might have manually (and incorrectly) entered it.
221+
async with tool as outer_handle:
222+
assert tool.enter_count == 1
223+
async with tool as inner_handle:
224+
assert tool.enter_count == 1 # Should NOT have increased
225+
assert inner_handle is outer_handle # Should return the same handle
226+
227+
assert tool.exit_count == 1 # Exit logic should only have run once
228+
229+
230+
@pytest.mark.asyncio
231+
async def test_toolset_return_value_is_honored() -> None:
232+
"""
233+
Verifies that the handle returned by a custom __aenter__ is preserved.
234+
"""
235+
tool = ReturnValueToolSet()
236+
237+
async with tool as handle:
238+
assert isinstance(handle, tool.Handle)
239+
assert handle.message == "special handle"

0 commit comments

Comments
 (0)