Skip to content

Commit 54824d2

Browse files
committed
Readd env removed by gitignore
1 parent 267c93d commit 54824d2

File tree

6 files changed

+432
-1
lines changed

6 files changed

+432
-1
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ coverage.xml
5252
# Virtual environments
5353
.env
5454
.venv
55-
env/
5655
venv/
5756
ENV/
5857
env.bak/

src/core/env/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Core environment interfaces and types."""
8+
9+
from .interfaces import Environment, Transform, Tool, ToolRegistry
10+
from .types import (
11+
Action, CodeAction, Observation, CodeObservation,
12+
State, CodeState, ExecutionResult
13+
)
14+
from .base_transforms import CompositeTransform, NullTransform
15+
from .code_execution_environment import CodeExecutionEnvironment
16+
17+
__all__ = [
18+
# Core interfaces
19+
"Environment", "Transform", "Tool", "ToolRegistry",
20+
21+
# Types
22+
"Action", "CodeAction", "Observation", "CodeObservation",
23+
"State", "CodeState", "ExecutionResult",
24+
25+
# Base transforms
26+
"CompositeTransform", "NullTransform",
27+
28+
# Base environment implementation
29+
"CodeExecutionEnvironment"
30+
]

src/core/env/base_transforms.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Base transform implementations for composing environment-specific transforms."""
8+
9+
from .interfaces import Transform
10+
from .types import Observation
11+
12+
13+
class CompositeTransform(Transform):
14+
"""Combines multiple transforms into a single transform."""
15+
16+
def __init__(self, transforms: list[Transform]):
17+
self.transforms = transforms
18+
19+
def __call__(self, observation: Observation) -> Observation:
20+
for transform in self.transforms:
21+
observation = transform(observation)
22+
return observation
23+
24+
25+
class NullTransform(Transform):
26+
"""Default transform that passes through unchanged."""
27+
28+
def __call__(self, observation: Observation) -> Observation:
29+
return observation
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import uuid
9+
from typing import Any, Dict, Literal
10+
11+
from ..docker.docker_executor import DockerExecutor
12+
from .interfaces import Environment, Transform
13+
from .types import CodeAction, CodeObservation, CodeState, Action, Observation, State
14+
15+
16+
class CodeExecutionEnvironment(Environment):
17+
"""Environment for executing Python code actions using Docker."""
18+
19+
def __init__(
20+
self,
21+
transform: Transform | None = None,
22+
docker_image: str = "python:3.11-slim",
23+
timeout_seconds: int = 30
24+
):
25+
super().__init__(transform)
26+
self.docker_image = docker_image
27+
self.timeout_seconds = timeout_seconds
28+
self.executor = DockerExecutor(docker_image, timeout_seconds)
29+
self._state = CodeState()
30+
31+
def reset(self) -> Observation:
32+
"""Reset environment and start fresh Docker session."""
33+
# Stop any existing session
34+
self.executor.stop_session()
35+
36+
# Initialize fresh state
37+
self._state = CodeState(
38+
episode_id=str(uuid.uuid4()),
39+
step_count=0
40+
)
41+
42+
# Start new Docker session
43+
try:
44+
self.executor.start_session()
45+
except Exception as e:
46+
# Fail hard as requested
47+
raise RuntimeError(f"Failed to start Docker session: {e}")
48+
49+
# Return initial observation
50+
observation = CodeObservation(
51+
execution_result=None,
52+
available_tools=[] # TODO: populate from MCP registry
53+
)
54+
55+
return self._apply_transform(observation)
56+
57+
def step(self, action: Action) -> Observation:
58+
"""Execute code action and return observation."""
59+
if not isinstance(action, CodeAction):
60+
raise ValueError(f"Expected CodeAction, got {type(action)}")
61+
62+
# Execute the code
63+
execution_result = self.executor.execute_code(action.code)
64+
65+
# Update state
66+
self._state.step_count += 1
67+
self._state.action_history.append(action)
68+
self._state.result_history.append(execution_result)
69+
70+
# Create observation
71+
observation = CodeObservation(
72+
execution_result=execution_result,
73+
available_tools=[] # TODO: populate from MCP registry
74+
)
75+
76+
return self._apply_transform(observation)
77+
78+
def render(self, mode: Literal["human", "raw", "ansi"] = "human") -> Any:
79+
"""Render current environment state."""
80+
try:
81+
variables = self.executor.get_variable_dump()
82+
except Exception as e:
83+
variables = {"error": f"Failed to get variables: {e}"}
84+
85+
render_data = {
86+
"episode_id": self._state.episode_id,
87+
"step_count": self._state.step_count,
88+
"variables": variables,
89+
"last_result": self._state.result_history[-1] if self._state.result_history else None
90+
}
91+
92+
if mode == "raw":
93+
return render_data
94+
elif mode == "ansi":
95+
return self._render_ansi(render_data)
96+
else: # mode == "human"
97+
return self._render_human(render_data)
98+
99+
def close(self) -> None:
100+
"""Close environment and clean up Docker container."""
101+
self.executor.stop_session()
102+
103+
@property
104+
def state(self) -> State:
105+
"""Get current environment state."""
106+
return self._state
107+
108+
def _render_human(self, data: Dict[str, Any]) -> str:
109+
"""Render in human-readable format."""
110+
lines = []
111+
lines.append(f"=== Code Environment (Episode: {data['episode_id'][:8]}...) ===")
112+
lines.append(f"Steps: {data['step_count']}")
113+
114+
if data.get("last_result"):
115+
result = data["last_result"]
116+
lines.append(f"Last execution: {'✓ Success' if result.success else '✗ Failed'}")
117+
if result.stdout:
118+
lines.append(f"Output: {result.stdout[:100]}...")
119+
if not result.success and result.exception_message:
120+
lines.append(f"Error: {result.exception_message}")
121+
122+
lines.append("\n--- Variables ---")
123+
variables = data.get("variables", {})
124+
if "error" in variables:
125+
lines.append(f"Error getting variables: {variables['error']}")
126+
else:
127+
for name, value in sorted(variables.items()):
128+
lines.append(f"{name}: {value}")
129+
130+
return "\n".join(lines)
131+
132+
def _render_ansi(self, data: Dict[str, Any]) -> str:
133+
"""Render in ANSI terminal format with colors."""
134+
lines = []
135+
136+
# ANSI color codes
137+
BLUE = "\033[34m"
138+
GREEN = "\033[32m"
139+
RED = "\033[31m"
140+
YELLOW = "\033[33m"
141+
RESET = "\033[0m"
142+
BOLD = "\033[1m"
143+
144+
lines.append(f"{BOLD}{BLUE}=== Code Environment ==={RESET}")
145+
lines.append(f"Episode: {data['episode_id'][:8]}...")
146+
lines.append(f"Steps: {YELLOW}{data['step_count']}{RESET}")
147+
148+
if data.get("last_result"):
149+
result = data["last_result"]
150+
status_color = GREEN if result.success else RED
151+
status_text = "Success" if result.success else "Failed"
152+
lines.append(f"Last execution: {status_color}{status_text}{RESET}")
153+
154+
if result.stdout:
155+
lines.append(f"Output: {result.stdout[:100]}...")
156+
if not result.success and result.exception_message:
157+
lines.append(f"{RED}Error: {result.exception_message}{RESET}")
158+
159+
lines.append(f"\n{BOLD}--- Variables ---{RESET}")
160+
variables = data.get("variables", {})
161+
if "error" in variables:
162+
lines.append(f"{RED}Error getting variables: {variables['error']}{RESET}")
163+
else:
164+
for name, value in sorted(variables.items()):
165+
lines.append(f"{YELLOW}{name}{RESET}: {value}")
166+
167+
return "\n".join(lines)

src/core/env/interfaces.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from abc import ABC, abstractmethod
8+
from typing import Any
9+
10+
from .types import Action, Observation, State
11+
12+
13+
class Transform(ABC):
14+
"""Transform observations to add rewards, metrics, or other modifications.
15+
16+
Transforms follow the TorchRL pattern where they take an observation
17+
and return a (potentially modified) observation. This allows for
18+
flexible reward computation and observation augmentation.
19+
"""
20+
21+
@abstractmethod
22+
def __call__(self, observation: Observation) -> Observation:
23+
"""Transform an observation.
24+
25+
Args:
26+
observation: The input observation
27+
28+
Returns:
29+
The transformed observation
30+
"""
31+
pass
32+
33+
34+
class Environment(ABC):
35+
"""Base class for all environments following Gym/Gymnasium API.
36+
37+
Args:
38+
transform: Optional transform to apply to observations
39+
"""
40+
41+
def __init__(self, transform: Transform | None = None):
42+
self.transform = transform
43+
44+
@abstractmethod
45+
def reset(self) -> Observation:
46+
"""Reset the environment and return initial observation."""
47+
pass
48+
49+
@abstractmethod
50+
def step(self, action: Action) -> Observation:
51+
"""Take a step in the environment."""
52+
pass
53+
54+
@property
55+
@abstractmethod
56+
def state(self) -> State:
57+
"""Get the current environment state."""
58+
pass
59+
60+
def _apply_transform(self, observation: Observation) -> Observation:
61+
"""Apply transform if one is provided."""
62+
if self.transform is not None:
63+
return self.transform(observation)
64+
return observation
65+
66+
67+
class Tool(ABC):
68+
"""Base class for tools that can be used in code execution."""
69+
70+
@abstractmethod
71+
def __call__(self, *args, **kwargs) -> Any:
72+
"""Execute the tool."""
73+
pass
74+
75+
76+
class ToolRegistry:
77+
"""Registry for managing tools available to code execution."""
78+
79+
def __init__(self):
80+
self._tools: dict[str, Any] = {}
81+
82+
def register(self, name: str, tool: Any):
83+
"""Register a tool with a name."""
84+
self._tools[name] = tool
85+
86+
def get(self, name: str) -> Any | None:
87+
"""Get a tool by name."""
88+
return self._tools.get(name)
89+
90+
def get_all(self) -> dict[str, Any]:
91+
"""Get all registered tools."""
92+
return self._tools.copy()
93+
94+
def get_names(self) -> list[str]:
95+
"""Get all tool names."""
96+
return list(self._tools.keys())

0 commit comments

Comments
 (0)