Skip to content

Commit 6b8adea

Browse files
monoxgasCopilot
andauthored
feat: Backoff hooks (#197)
* Add backoff hook for agents. Add unique session_id for an agent run session to events. * Add missing backoff.py file * Fixing start_time * Update dreadnode/agent/hooks/backoff.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b4b8fcb commit 6b8adea

File tree

5 files changed

+154
-2
lines changed

5 files changed

+154
-2
lines changed

dreadnode/agent/agent.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import rigging as rg
77
from pydantic import ConfigDict, Field, PrivateAttr, SkipValidation, field_validator
8-
from rigging.message import inject_system_content # can't access via rg
8+
from rigging.message import inject_system_content
9+
from ulid import ULID # can't access via rg
910

1011
from dreadnode.agent.error import MaxStepsError
1112
from dreadnode.agent.events import (
@@ -275,6 +276,7 @@ async def _stream( # noqa: PLR0912, PLR0915
275276
) -> t.AsyncGenerator[AgentEvent, None]:
276277
events: list[AgentEvent] = []
277278
stop_conditions = self.stop_conditions
279+
session_id = ULID()
278280

279281
# Event dispatcher
280282

@@ -368,6 +370,7 @@ async def _dispatch(event: AgentEvent) -> t.AsyncIterator[AgentEvent]:
368370
"unknown",
369371
)
370372
reacted_event = Reacted(
373+
session_id=session_id,
371374
agent=self,
372375
thread=thread,
373376
messages=messages,
@@ -395,6 +398,7 @@ async def _process_tool_call(
395398
) -> t.AsyncGenerator[AgentEvent, None]:
396399
async for event in _dispatch(
397400
ToolStart(
401+
session_id=session_id,
398402
agent=self,
399403
thread=thread,
400404
messages=messages,
@@ -416,6 +420,7 @@ async def _process_tool_call(
416420
except Exception as e:
417421
async for event in _dispatch(
418422
AgentError(
423+
session_id=session_id,
419424
agent=self,
420425
thread=thread,
421426
messages=messages,
@@ -432,6 +437,7 @@ async def _process_tool_call(
432437

433438
async for event in _dispatch(
434439
ToolEnd(
440+
session_id=session_id,
435441
agent=self,
436442
thread=thread,
437443
messages=messages,
@@ -447,6 +453,7 @@ async def _process_tool_call(
447453

448454
async for event in _dispatch(
449455
AgentStart(
456+
session_id=session_id,
450457
agent=self,
451458
thread=thread,
452459
messages=messages,
@@ -464,6 +471,7 @@ async def _process_tool_call(
464471
try:
465472
async for event in _dispatch(
466473
StepStart(
474+
session_id=session_id,
467475
agent=self,
468476
thread=thread,
469477
messages=messages,
@@ -479,6 +487,7 @@ async def _process_tool_call(
479487
if step_chat.failed and step_chat.error:
480488
async for event in _dispatch(
481489
AgentError(
490+
session_id=session_id,
482491
agent=self,
483492
thread=thread,
484493
messages=messages,
@@ -493,6 +502,7 @@ async def _process_tool_call(
493502

494503
async for event in _dispatch(
495504
GenerationEnd(
505+
session_id=session_id,
496506
agent=self,
497507
thread=thread,
498508
messages=messages,
@@ -516,6 +526,7 @@ async def _process_tool_call(
516526

517527
async for event in _dispatch(
518528
AgentStalled(
529+
session_id=session_id,
519530
agent=self,
520531
thread=thread,
521532
messages=messages,
@@ -579,6 +590,7 @@ async def _process_tool_call(
579590
thread.events.extend(events)
580591

581592
yield AgentEnd(
593+
session_id=session_id,
582594
agent=self,
583595
thread=thread,
584596
messages=messages,

dreadnode/agent/events.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from rich.rule import Rule
1010
from rich.table import Table
1111
from rich.text import Text
12+
from ulid import ULID
1213

1314
from dreadnode.agent.format import format_message
1415
from dreadnode.agent.reactions import (
@@ -39,6 +40,8 @@ class AgentEvent:
3940
)
4041
"""The timestamp of when the event occurred (UTC)."""
4142

43+
session_id: ULID = field(repr=False)
44+
"""The unique identifier for the agent run session."""
4245
agent: "Agent" = field(repr=False)
4346
"""The agent associated with this event."""
4447
thread: "Thread" = field(repr=False)

dreadnode/agent/hooks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dreadnode.agent.hooks.backoff import backoff_on_error, backoff_on_ratelimit
12
from dreadnode.agent.hooks.base import (
23
Hook,
34
retry_with_feedback,
@@ -6,6 +7,8 @@
67

78
__all__ = [
89
"Hook",
10+
"backoff_on_error",
11+
"backoff_on_ratelimit",
912
"retry_with_feedback",
1013
"summarize_when_long",
1114
]

dreadnode/agent/hooks/backoff.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import asyncio
2+
import random
3+
import time
4+
import typing as t
5+
from dataclasses import dataclass
6+
7+
from loguru import logger
8+
9+
from dreadnode.agent.events import AgentError, AgentEvent, StepStart
10+
from dreadnode.agent.reactions import Reaction, Retry
11+
12+
if t.TYPE_CHECKING:
13+
from ulid import ULID
14+
15+
from dreadnode.agent.hooks.base import Hook
16+
17+
18+
@dataclass
19+
class BackoffState:
20+
tries: int = 0
21+
start_time: float | None = None
22+
last_step_seen: int = -1
23+
24+
def reset(self, step: int = -1) -> None:
25+
self.tries = 0
26+
self.start_time = None
27+
self.last_step_seen = step
28+
29+
30+
def backoff_on_error(
31+
exception_types: type[Exception] | t.Iterable[type[Exception]],
32+
*,
33+
max_tries: int = 8,
34+
max_time: float = 300.0,
35+
base_factor: float = 1.0,
36+
jitter: bool = True,
37+
) -> "Hook":
38+
"""
39+
Creates a hook that retries with exponential backoff when specific errors occur.
40+
41+
It listens for `AgentError` events and, if the error matches, waits for an
42+
exponentially increasing duration before issuing a `Retry` reaction.
43+
44+
Args:
45+
exception_types: An exception type or iterable of types to catch.
46+
max_tries: The maximum number of retries before giving up.
47+
max_time: The maximum total time in seconds to wait before giving up.
48+
base_factor: The base duration (in seconds) for the backoff calculation.
49+
jitter: If True, adds a random jitter to the wait time to prevent synchronized retries.
50+
51+
Returns:
52+
An agent hook that implements the backoff logic.
53+
"""
54+
exceptions = (
55+
tuple(exception_types) if isinstance(exception_types, t.Iterable) else (exception_types,)
56+
)
57+
58+
session_states: dict[ULID, BackoffState] = {}
59+
60+
async def backoff_hook(event: "AgentEvent") -> "Reaction | None":
61+
state = session_states.setdefault(event.session_id, BackoffState())
62+
63+
if isinstance(event, StepStart):
64+
if event.step > state.last_step_seen:
65+
state.reset(event.step)
66+
return None
67+
68+
if not isinstance(event, AgentError) or not isinstance(event.error, exceptions):
69+
return None
70+
71+
if state.start_time is None:
72+
state.start_time = time.monotonic()
73+
74+
if state.tries >= max_tries:
75+
logger.warning(
76+
f"Backoff aborted for session {event.session_id}: maximum tries ({max_tries}) exceeded."
77+
)
78+
return None
79+
80+
if (time.monotonic() - state.start_time) >= max_time:
81+
logger.warning(
82+
f"Backoff aborted for session {event.session_id}: maximum time ({max_time:.2f}s) exceeded."
83+
)
84+
return None
85+
86+
state.tries += 1
87+
88+
seconds = base_factor * (2 ** (state.tries - 1))
89+
if jitter:
90+
seconds += random.uniform(0, base_factor) # noqa: S311 # nosec
91+
92+
logger.warning(
93+
f"Backing off for {seconds:.2f}s (try {state.tries}/{max_tries}) on session {event.session_id} due to error: {event.error}"
94+
)
95+
96+
await asyncio.sleep(seconds)
97+
return Retry()
98+
99+
return backoff_hook
100+
101+
102+
def backoff_on_ratelimit(
103+
*,
104+
max_tries: int = 8,
105+
max_time: float = 300.0,
106+
base_factor: float = 1.0,
107+
jitter: bool = True,
108+
) -> "Hook":
109+
"""
110+
A convenient default backoff hook for common, ephemeral LLM errors.
111+
112+
This hook retries on `litellm.exceptions.RateLimitError` and `litellm.exceptions.APIError`
113+
with an exponential backoff strategy for up to 5 minutes.
114+
115+
See `backoff_on_error` for more details.
116+
117+
Args:
118+
max_tries: The maximum number of retries before giving up.
119+
max_time: The maximum total time in seconds to wait before giving up.
120+
base_factor: The base duration (in seconds) for the backoff calculation.
121+
jitter: If True, adds a random jitter to the wait time to prevent synchronized retries.
122+
123+
Returns:
124+
An agent hook that implements the backoff logic.
125+
"""
126+
import litellm.exceptions
127+
128+
return backoff_on_error(
129+
(litellm.exceptions.RateLimitError, litellm.exceptions.APIError),
130+
max_time=max_time,
131+
max_tries=max_tries,
132+
base_factor=base_factor,
133+
jitter=jitter,
134+
)

dreadnode/agent/reactions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Continue(Reaction):
1919

2020
@dataclass
2121
class Retry(Reaction):
22-
messages: list[rg.Message] | None = Field(None, repr=False)
22+
messages: list[rg.Message] | None = Field(default=None, repr=False)
2323

2424

2525
@dataclass

0 commit comments

Comments
 (0)