Skip to content

Commit 387ed89

Browse files
authored
fix: prevent UnboundLocalError for output_object_hash in task output linking (#224)
**Added:** - Added comprehensive tests for Task output_object_hash linking logic, covering cases where logging is enabled or disabled, async tasks, exception handling, complex and None outputs, inherited log settings, and entrypoint tasks (`tests/test_task_output_linking.py`) **Changed:** - Initialize `output_object_hash = None` before conditional logic to prevent referencing it before assignment, ensuring robust task output linking logic (`dreadnode/task.py`) - Renamed parameter `x` to `_x` in `failing_task` within exception handling test to indicate it is unused and improve code clarity in `tests/test_task_output_linking.py`
1 parent 70dcb0a commit 387ed89

File tree

2 files changed

+147
-6
lines changed

2 files changed

+147
-6
lines changed

dreadnode/task.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ async def run_always(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: #
538538

539539
# Log the output
540540

541+
output_object_hash = None
541542
if log_output and (
542543
not isinstance(self.log_inputs, Inherited) or seems_useful_to_serialize(output)
543544
):
@@ -546,13 +547,12 @@ async def run_always(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: #
546547
output,
547548
attributes={"auto": True},
548549
)
549-
elif run is not None:
550550
# Link the output to the inputs
551-
for input_object_hash in input_object_hashes:
552-
run.link_objects(output_object_hash, input_object_hash)
553-
554-
if create_run:
555-
run.log_output("output", output, attributes={"auto": True})
551+
if run is not None:
552+
for input_object_hash in input_object_hashes:
553+
run.link_objects(output_object_hash, input_object_hash)
554+
elif run is not None and create_run:
555+
run.log_output("output", output, attributes={"auto": True})
556556

557557
# Score and check assertions
558558

tests/test_task_output_linking.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""
2+
Tests for Task output object hash linking logic.
3+
4+
This test module covers the fix for ENG-3549, which addresses a scope issue
5+
where output_object_hash could be referenced before definition.
6+
7+
The bug: output_object_hash was only defined inside the `if log_output and ...`
8+
block, but was referenced in the subsequent `elif` block, causing UnboundLocalError.
9+
10+
The fix: Initialize output_object_hash = None before use.
11+
"""
12+
13+
import pytest
14+
15+
from dreadnode import task
16+
17+
18+
@pytest.mark.asyncio
19+
async def test_task_with_log_output_true() -> None:
20+
"""Test that a task with log_output=True executes without errors."""
21+
22+
@task(log_inputs=True, log_output=True)
23+
def sample_task(x: int) -> int:
24+
return x * 2
25+
26+
result = await sample_task.run_always(5)
27+
assert result.output == 10
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_task_with_log_output_false() -> None:
32+
"""Edge case where output_object_hash would not be defined in buggy code."""
33+
34+
@task(log_inputs=True, log_output=False)
35+
def sample_task(x: int) -> int:
36+
return x * 2
37+
38+
result = await sample_task.run_always(5)
39+
assert result.output == 10
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_task_with_no_logging() -> None:
44+
"""
45+
Core bug scenario: no logging means output_object_hash would be
46+
referenced before definition in the original buggy code.
47+
"""
48+
49+
@task(log_inputs=False, log_output=False)
50+
def sample_task(x: int) -> int:
51+
return x * 2
52+
53+
result = await sample_task.run_always(5)
54+
assert result.output == 10
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_task_with_multiple_inputs() -> None:
59+
"""Test that linking logic handles multiple input hashes properly."""
60+
61+
@task(log_inputs=True, log_output=True)
62+
def sample_task(x: int, y: int, z: int) -> int:
63+
return x + y + z
64+
65+
result = await sample_task.run_always(1, 2, 3)
66+
assert result.output == 6
67+
68+
69+
@pytest.mark.asyncio
70+
async def test_async_task_execution() -> None:
71+
"""Test that the fix works correctly for async tasks."""
72+
73+
@task(log_inputs=True, log_output=True)
74+
async def async_sample_task(x: int) -> int:
75+
return x * 2
76+
77+
result = await async_sample_task.run_always(5)
78+
assert result.output == 10
79+
80+
81+
@pytest.mark.asyncio
82+
async def test_task_with_inherited_log_settings() -> None:
83+
"""Test inherited logging settings (the default, most common usage)."""
84+
85+
@task
86+
def sample_task(x: int) -> int:
87+
return x * 2
88+
89+
result = await sample_task.run_always(5)
90+
assert result.output == 10
91+
92+
93+
@pytest.mark.asyncio
94+
async def test_task_exception_handling() -> None:
95+
"""Test that exceptions don't cause issues with output_object_hash logic."""
96+
97+
@task(log_inputs=True, log_output=True)
98+
def failing_task(_x: int) -> int:
99+
raise ValueError("Intentional test error")
100+
101+
result = await failing_task.run_always(5)
102+
103+
assert result.exception is not None
104+
assert isinstance(result.exception, ValueError)
105+
assert "Intentional test error" in str(result.exception)
106+
107+
108+
@pytest.mark.asyncio
109+
async def test_task_with_complex_output() -> None:
110+
"""Test that tasks returning complex types work correctly."""
111+
112+
@task(log_inputs=True, log_output=True)
113+
def complex_task(x: int) -> dict[str, int]:
114+
return {"result": x * 2, "input": x}
115+
116+
result = await complex_task.run_always(5)
117+
assert result.output == {"result": 10, "input": 5}
118+
119+
120+
@pytest.mark.asyncio
121+
async def test_task_with_none_output() -> None:
122+
"""Test None outputs (may be handled differently in serialization)."""
123+
124+
@task(log_inputs=True, log_output=True)
125+
def none_task(x: int) -> None:
126+
pass
127+
128+
result = await none_task.run_always(5)
129+
assert result.output is None
130+
131+
132+
@pytest.mark.asyncio
133+
async def test_task_entrypoint_behavior() -> None:
134+
"""Test entrypoint tasks (create_run=True path) with log_output=False."""
135+
136+
@task(log_inputs=True, log_output=False, entrypoint=True)
137+
def entrypoint_task(x: int) -> int:
138+
return x * 2
139+
140+
result = await entrypoint_task.run_always(5)
141+
assert result.output == 10

0 commit comments

Comments
 (0)