Skip to content

Commit 4a74399

Browse files
committed
Fix pytest
1 parent cc48d9d commit 4a74399

File tree

8 files changed

+402
-25
lines changed

8 files changed

+402
-25
lines changed

dreadnode/agent/agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
tools_to_json_with_tag_transform,
2121
)
2222

23-
from dreadnode import log_inputs, log_metric, log_outputs, score, task_span
2423
from dreadnode.agent.error import MaxStepsError
2524
from dreadnode.agent.events import (
2625
AgentEnd,
@@ -606,6 +605,8 @@ async def _process_tool_call(
606605
)
607606

608607
def _log_event_metrics(self, event: AgentEvent) -> None:
608+
from dreadnode import log_metric
609+
609610
if isinstance(event, AgentEnd):
610611
log_metric("steps_taken", min(0, event.result.steps - 1))
611612
log_metric(f"stop_{event.stop_reason}", 1)
@@ -642,6 +643,8 @@ async def _stream_in_task(
642643
*,
643644
commit: CommitBehavior = "on-success",
644645
) -> t.AsyncGenerator[AgentEvent, None]:
646+
from dreadnode import log_inputs, log_outputs, score, task_span
647+
645648
hooks = self._get_hooks()
646649
tool_names = [t.name for t in self.all_tools]
647650
stop_names = [s.name for s in self.stop_conditions]

dreadnode/agent/tools/tasking.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from loguru import logger
22

3-
from dreadnode import log_metric
43
from dreadnode.agent.reactions import Fail, Finish
54
from dreadnode.agent.tools.base import tool
65

@@ -28,6 +27,7 @@ async def finish_task(success: bool, summary: str) -> None: # noqa: ARG001
2827
* **Honest Status**: Accurately report the success or failure of the overall task. If any part of the task failed or was not completed, `success` should be `False`.
2928
* **Comprehensive Summary**: The `summary` should be a complete and detailed markdown-formatted report of everything you did, including steps taken, tools used, and the final outcome. This is your final report to the user.
3029
"""
30+
from dreadnode import log_metric
3131

3232
log_func = logger.success if success else logger.warning
3333
log_func(f"Agent finished the task (success={success})")
@@ -42,6 +42,7 @@ async def give_up_on_task(reason: str) -> None: # noqa: ARG001
4242
"""
4343
Give up on your task.
4444
"""
45+
from dreadnode import log_metric
4546

4647
logger.info("Agent gave up on the task")
4748
log_metric("task_give_up", 1)

dreadnode/optimization/study.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from pydantic import ConfigDict, FilePath, PrivateAttr
66

7-
from dreadnode import log_inputs, log_metric, log_outputs, log_params, run, task_span
87
from dreadnode.eval import Eval
98
from dreadnode.eval.result import EvalResult
109
from dreadnode.eval.sample import InputDataset
@@ -233,6 +232,8 @@ async def _stream(self) -> t.AsyncGenerator[StudyEvent[CandidateT], None]: # no
233232
)
234233

235234
def _log_event_metrics(self, event: StudyEvent[t.Any]) -> None:
235+
from dreadnode import log_metric
236+
236237
if isinstance(event, TrialComplete):
237238
trial = event.trial
238239
if trial.status == "success":
@@ -246,6 +247,8 @@ def _log_event_metrics(self, event: StudyEvent[t.Any]) -> None:
246247
log_metric("best_score", event.trial.score, step=event.trial.step)
247248

248249
async def _stream_traced(self) -> t.AsyncGenerator[StudyEvent[CandidateT], None]:
250+
from dreadnode import log_inputs, log_outputs, log_params, run, task_span
251+
249252
objective_name = (
250253
self.objective
251254
if isinstance(self.objective, str)

dreadnode/scorers/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import typing as t
22

3-
from transforms import pipeline # type: ignore[import-not-found]
3+
from transformers import pipeline # type: ignore[import-not-found]
44

55
from dreadnode.meta import Config
66
from dreadnode.metric import Metric

dreadnode/scorers/crucible.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import aiohttp
55

6-
from dreadnode import tag
76
from dreadnode.metric import Metric
87
from dreadnode.scorers import Scorer
98

@@ -31,6 +30,8 @@ def contains_crucible_flag(
3130
platform_url: str = "https://platform.dreadnode.io",
3231
name: str = "contains_crucible_flag",
3332
) -> Scorer[t.Any]:
33+
from dreadnode import tag
34+
3435
async def evaluate(
3536
obj: t.Any,
3637
*,

dreadnode/task.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
import typing_extensions as te
77
from opentelemetry.trace import Tracer
88

9-
from dreadnode import score
10-
from dreadnode.airt.target import CustomTarget
11-
from dreadnode.eval.eval import Eval
129
from dreadnode.meta.context import Context
1310
from dreadnode.meta.types import Component, ConfigInfo
1411
from dreadnode.scorers.base import Scorer, ScorerCallable, ScorersLike
@@ -23,7 +20,9 @@
2320
)
2421

2522
if t.TYPE_CHECKING:
23+
from dreadnode.airt.target.custom import CustomTarget
2624
from dreadnode.eval.eval import (
25+
Eval,
2726
InputDataset,
2827
InputDatasetProcessor,
2928
)
@@ -379,6 +378,8 @@ def as_eval(
379378
scorers: "ScorersLike[R] | None" = None,
380379
assert_scores: list[str] | t.Literal[True] | None = None,
381380
) -> "Eval[t.Any, R]":
381+
from dreadnode.eval.eval import Eval
382+
382383
if isinstance(dataset, str):
383384
dataset = Path(dataset)
384385

@@ -402,6 +403,8 @@ def as_target(
402403
self,
403404
input_param_name: str | None = None,
404405
) -> "CustomTarget[R]":
406+
from dreadnode.airt.target.custom import CustomTarget
407+
405408
return CustomTarget(task=self, input_param_name=input_param_name)
406409

407410
async def run_always(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
@@ -417,6 +420,7 @@ async def run_always(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
417420
Returns:
418421
The span associated with task execution.
419422
"""
423+
from dreadnode import score
420424

421425
run = current_run_span.get()
422426

0 commit comments

Comments
 (0)