Skip to content

Commit 5eb0e22

Browse files
authored
Merge pull request #68 from meta-pytorch/mortimer/resetcode
[BUG] reset coding env properly
2 parents e8acf40 + d6036ad commit 5eb0e22

File tree

3 files changed

+169
-2
lines changed

3 files changed

+169
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
55
[project]
66
name = "openenv"
77
version = "0.1.0"
8-
requires-python = ">=3.8"
8+
requires-python = ">=3.10"
99
dependencies = [
1010
"torch>=1.9.0",
1111
"numpy>=1.19.0",

src/envs/coding_env/server/python_codeact_env.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313

1414
import uuid
1515

16-
from core.env_server import Action, Environment, Observation, Transform
16+
from core.env_server import Action, Environment, Observation
1717
from core.tools import PyExecutor
1818

1919
from ..models import CodeAction, CodeObservation, CodeState
2020
from .transforms import create_safe_coding_transform
2121

22+
2223
class PythonCodeActEnv(Environment):
2324
"""
2425
Python Code Action Environment for executing code and tracking state.
@@ -61,6 +62,12 @@ def reset(self) -> Observation:
6162
# Add last_exit_code to state
6263
self._state.last_exit_code = 0
6364

65+
# Reset executor to clear any previously defined variables/functions
66+
self._executor = PyExecutor()
67+
68+
# Reset transform to clear any accumulated state
69+
self.transform = create_safe_coding_transform()
70+
6471
# Return initial observation
6572
observation = CodeObservation(
6673
stdout="",
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Test that PythonCodeActEnv.reset() properly resets executor state."""
8+
9+
import sys
10+
from pathlib import Path
11+
12+
from envs.coding_env.models import CodeAction
13+
from envs.coding_env.server.python_codeact_env import PythonCodeActEnv
14+
15+
# Add src to path
16+
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
17+
18+
19+
def test_reset_clears_executor_state():
20+
"""Test that reset() clears functions and variables defined in
21+
previous execution."""
22+
env = PythonCodeActEnv()
23+
24+
# Initial reset
25+
obs = env.reset()
26+
assert obs.exit_code == 0
27+
assert env.state.step_count == 0
28+
29+
# Define a function in the executor
30+
action1 = CodeAction(code="def my_function():\n return 'Hello from function'\n")
31+
obs1 = env.step(action1)
32+
assert obs1.exit_code == 0
33+
34+
# Call the function to verify it exists
35+
action2 = CodeAction(code="result = my_function()\nprint(result)")
36+
obs2 = env.step(action2)
37+
assert obs2.exit_code == 0
38+
assert "Hello from function" in obs2.stdout
39+
40+
# Reset the environment
41+
obs_reset = env.reset()
42+
assert obs_reset.exit_code == 0
43+
assert env.state.step_count == 0
44+
45+
# Try to call the function again - should fail because executor was reset
46+
action3 = CodeAction(code="result = my_function()\nprint(result)")
47+
obs3 = env.step(action3)
48+
49+
# Should get an error because my_function is no longer defined
50+
assert obs3.exit_code == 1 # Error exit code
51+
assert "my_function" in obs3.stderr or "NameError" in obs3.stderr
52+
53+
54+
def test_reset_clears_variables():
55+
"""Test that reset() clears variables defined in previous execution."""
56+
env = PythonCodeActEnv()
57+
58+
# Initial reset
59+
env.reset()
60+
61+
# Define a variable
62+
action1 = CodeAction(code="my_variable = 42\n")
63+
obs1 = env.step(action1)
64+
assert obs1.exit_code == 0
65+
66+
# Use the variable to verify it exists
67+
action2 = CodeAction(code="print(my_variable)")
68+
obs2 = env.step(action2)
69+
assert obs2.exit_code == 0
70+
assert "42" in obs2.stdout
71+
72+
# Reset the environment
73+
env.reset()
74+
75+
# Try to use the variable again - should fail
76+
action3 = CodeAction(code="print(my_variable)")
77+
obs3 = env.step(action3)
78+
79+
# Should get an error because my_variable is no longer defined
80+
assert obs3.exit_code == 1
81+
assert "my_variable" in obs3.stderr or "NameError" in obs3.stderr
82+
83+
84+
def test_reset_clears_imports():
85+
"""Test that reset() clears imported modules from previous execution."""
86+
env = PythonCodeActEnv()
87+
88+
# Initial reset
89+
env.reset()
90+
91+
# Import a module and define an alias
92+
action1 = CodeAction(code="import math as m\n")
93+
obs1 = env.step(action1)
94+
assert obs1.exit_code == 0
95+
96+
# Use the alias to verify it exists
97+
action2 = CodeAction(code="print(m.pi)")
98+
obs2 = env.step(action2)
99+
assert obs2.exit_code == 0
100+
assert "3.14" in obs2.stdout
101+
102+
# Reset the environment
103+
env.reset()
104+
105+
# Try to use the alias again - should fail
106+
action3 = CodeAction(code="print(m.pi)")
107+
obs3 = env.step(action3)
108+
109+
# Should get an error because 'm' is no longer defined
110+
assert obs3.exit_code == 1
111+
assert (
112+
"NameError" in obs3.stderr
113+
or "'m'" in obs3.stderr
114+
or "variable `m` is not defined" in obs3.stderr
115+
)
116+
117+
118+
def test_reset_preserves_step_count_reset():
119+
"""Test that reset() properly resets step count."""
120+
env = PythonCodeActEnv()
121+
122+
# Initial reset
123+
env.reset()
124+
assert env.state.step_count == 0
125+
126+
# Execute some steps
127+
for i in range(5):
128+
action = CodeAction(code=f"print({i})")
129+
env.step(action)
130+
131+
assert env.state.step_count == 5
132+
133+
# Reset should reset step count
134+
env.reset()
135+
assert env.state.step_count == 0
136+
137+
# Execute another step
138+
action = CodeAction(code="print('test')")
139+
env.step(action)
140+
assert env.state.step_count == 1
141+
142+
143+
def test_reset_changes_episode_id():
144+
"""Test that reset() generates a new episode ID."""
145+
env = PythonCodeActEnv()
146+
147+
# Initial reset
148+
env.reset()
149+
episode_id_1 = env.state.episode_id
150+
151+
# Execute some steps
152+
action = CodeAction(code="print('test')")
153+
env.step(action)
154+
155+
# Reset and get new episode ID
156+
env.reset()
157+
episode_id_2 = env.state.episode_id
158+
159+
# Episode IDs should be different
160+
assert episode_id_1 != episode_id_2

0 commit comments

Comments
 (0)