Skip to content

Commit 38dda57

Browse files
committed
fix the format
1 parent 22e4b11 commit 38dda57

File tree

5 files changed

+171
-187
lines changed

5 files changed

+171
-187
lines changed

rllm/engine/agent_workflow_engine.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_en
4949
self.n_parallel_tasks = n_parallel_tasks
5050
self.executor = ThreadPoolExecutor(max_workers=self.n_parallel_tasks)
5151
self.workflow_queue = None
52-
52+
5353
# Episode logging support
5454
self.episode_logger = episode_logger
5555
self.current_step = 0
@@ -58,7 +58,7 @@ def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_en
5858

5959
def set_training_step(self, step: int, mode: str = "train", epoch: int = 0):
6060
"""Set current training step for episode logging.
61-
61+
6262
Args:
6363
step: Current training step number
6464
mode: Mode identifier ('train' or 'val'), defaults to 'train'
@@ -67,7 +67,7 @@ def set_training_step(self, step: int, mode: str = "train", epoch: int = 0):
6767
self.current_step = step
6868
self.current_mode = mode
6969
self.current_epoch = epoch
70-
70+
7171
async def initialize_pool(self):
7272
"""Initialize the workflow pool with parallel workflow instances.
7373
@@ -173,7 +173,7 @@ async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = No
173173
sorted_tasks = sorted(task_states.keys(), key=lambda task_id: task_states[task_id]["idx"])
174174
for task_id in sorted_tasks:
175175
results.extend(task_states[task_id]["episodes"])
176-
176+
177177
# Log episodes if logger is provided
178178
if self.episode_logger is not None:
179179
try:
@@ -182,8 +182,9 @@ async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = No
182182
except Exception as e:
183183
logger.error(f"Failed to log episodes: {e}")
184184
import traceback
185+
185186
traceback.print_exc()
186-
187+
187188
return results
188189

189190
async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto":

rllm/trainer/verl/agent_workflow_trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,10 @@ def init_workers(self):
7979

8080
# Create episode logger if enabled in config
8181
episode_logger = None
82-
if self.config.trainer.get('log_episodes', False):
82+
if self.config.trainer.get("log_episodes", False):
8383
# Get episode log directory from config, default to "logs/my_project/my_experiment"
84-
episode_log_dir = self.config.trainer.get('episode_log_dir', f'logs/{self.config.trainer.project_name}/{self.config.trainer.experiment_name}')
85-
episode_logger = EpisodeLogger(
86-
base_dir=episode_log_dir,
87-
subdirectory="episodes"
88-
)
84+
episode_log_dir = self.config.trainer.get("episode_log_dir", f"logs/{self.config.trainer.project_name}/{self.config.trainer.experiment_name}")
85+
episode_logger = EpisodeLogger(base_dir=episode_log_dir, subdirectory="episodes")
8986

9087
self.agent_execution_engine = AgentWorkflowEngine(
9188
workflow_cls=self.workflow_class,

rllm/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,3 @@
33
from rllm.utils.episode_logger import EpisodeLogger
44

55
__all__ = ["EpisodeLogger"]
6-

rllm/utils/episode_logger.py

Lines changed: 65 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,143 +1,135 @@
11
"""Episode JSON Logger for saving detailed episode information."""
2-
import json
2+
33
import hashlib
4+
import json
45
from pathlib import Path
56
from typing import Any
7+
68
from rllm.agents.agent import Episode
79

810

911
class EpisodeLogger:
1012
"""Logger to save episodes to individual JSON files with step and data hash."""
11-
13+
1214
def __init__(self, base_dir: str, subdirectory: str = "episodes"):
1315
"""Initialize the episode logger.
14-
16+
1517
Args:
16-
base_dir: Base directory for episode logs. Can be configured via
17-
config.trainer.episode_log_dir
18+
base_dir: Base directory for episode logs. Can be configured via
19+
config.trainer.episode_log_dir
1820
(default: "logs/${trainer.project_name}/${trainer.experiment_name}")
1921
subdirectory: Subdirectory within base_dir for episodes (default: "episodes")
2022
Final path will be: {base_dir}/{subdirectory}/
2123
"""
2224
self.log_dir = Path(base_dir) / subdirectory
2325
self.log_dir.mkdir(parents=True, exist_ok=True)
24-
26+
2527
@staticmethod
2628
def compute_task_hash(task: Any, length: int = 8) -> str:
2729
"""Compute a hash from the task data.
28-
30+
2931
Args:
3032
task: The task dictionary or data
3133
length: Length of the hash to use (default 8 chars)
32-
34+
3335
Returns:
3436
Hash string
3537
"""
3638
# Convert task to a stable string representation
3739
task_str = json.dumps(task, sort_keys=True, default=str)
3840
# Compute SHA256 hash
39-
hash_obj = hashlib.sha256(task_str.encode('utf-8'))
41+
hash_obj = hashlib.sha256(task_str.encode("utf-8"))
4042
# Return first `length` characters of hex digest
4143
return hash_obj.hexdigest()[:length]
42-
44+
4345
def get_step_dir(self, step: int, mode: str = "train", epoch: int = 0) -> Path:
4446
"""Get the directory path for a specific training or validation step.
45-
47+
4648
Args:
4749
step: Current training/validation step
4850
mode: Mode identifier ('train' or 'val'), defaults to 'train'
4951
epoch: Current epoch number, defaults to 0
50-
52+
5153
Returns:
5254
Path object for the step directory
5355
"""
5456
step_dir = self.log_dir / f"{mode}_step_{step}_epoch_{epoch}"
5557
step_dir.mkdir(parents=True, exist_ok=True)
5658
return step_dir
57-
59+
5860
def get_episode_filename(self, episode: Episode, step: int) -> str:
5961
"""Generate filename for an episode.
60-
62+
6163
Format: episode_hash{task_hash}_id{episode_id}.json
62-
64+
6365
Args:
6466
episode: The episode to save
6567
step: Current training step (not used in filename, but kept for compatibility)
66-
68+
6769
Returns:
6870
Filename string
6971
"""
7072
task_hash = self.compute_task_hash(episode.task)
7173
# Clean episode_id to make it filesystem-safe
72-
episode_id_safe = str(episode.id).replace(':', '_').replace('/', '_')
73-
74+
episode_id_safe = str(episode.id).replace(":", "_").replace("/", "_")
75+
7476
filename = f"episode_hash{task_hash}_id{episode_id_safe}.json"
7577
return filename
76-
78+
7779
def log_episode(self, episode: Episode, step: int, mode: str = "train", epoch: int = 0):
7880
"""Log a single episode to its own JSON file in a step-specific directory.
79-
81+
8082
Args:
8183
episode: The episode to log
8284
step: Current training/validation step
8385
mode: Mode identifier ('train' or 'val'), defaults to 'train'
8486
epoch: Current epoch number, defaults to 0
8587
"""
86-
episode_data = {
87-
'training_step': step,
88-
'epoch': epoch,
89-
'episode_id': episode.id,
90-
'task': episode.task,
91-
'task_hash': self.compute_task_hash(episode.task),
92-
'is_correct': episode.is_correct,
93-
'termination_reason': episode.termination_reason.value if episode.termination_reason else None,
94-
'metrics': episode.metrics,
95-
'timing': episode.info.get('timing', {}),
96-
'trajectories': []
97-
}
98-
88+
episode_data = {"training_step": step, "epoch": epoch, "episode_id": episode.id, "task": episode.task, "task_hash": self.compute_task_hash(episode.task), "is_correct": episode.is_correct, "termination_reason": episode.termination_reason.value if episode.termination_reason else None, "metrics": episode.metrics, "timing": episode.info.get("timing", {}), "trajectories": []}
89+
9990
for traj in episode.trajectories:
10091
traj_data = {
101-
'name': traj.name,
102-
'uid': traj.uid,
103-
'reward': traj.reward,
104-
'num_steps': len(traj.steps),
105-
'timing': traj.info.get('timing', {}),
106-
'steps': [
92+
"name": traj.name,
93+
"uid": traj.uid,
94+
"reward": traj.reward,
95+
"num_steps": len(traj.steps),
96+
"timing": traj.info.get("timing", {}),
97+
"steps": [
10798
{
108-
'observation': step.observation,
109-
'thought': step.thought,
110-
'action': step.action,
111-
'reward': step.reward,
112-
'done': step.done,
113-
'model_response': step.model_response,
114-
'chat_completions': step.chat_completions,
115-
'timing': step.info.get('timing', {}), # Add step-level timing
99+
"observation": step.observation,
100+
"thought": step.thought,
101+
"action": step.action,
102+
"reward": step.reward,
103+
"done": step.done,
104+
"model_response": step.model_response,
105+
"chat_completions": step.chat_completions,
106+
"timing": step.info.get("timing", {}), # Add step-level timing
116107
}
117108
for step in traj.steps
118-
]
109+
],
119110
}
120-
episode_data['trajectories'].append(traj_data)
121-
111+
episode_data["trajectories"].append(traj_data)
112+
122113
# Write to individual file in step-specific directory
123114
step_dir = self.get_step_dir(step, mode, epoch)
124115
filename = self.get_episode_filename(episode, step)
125116
filepath = step_dir / filename
126-
117+
127118
try:
128-
with open(filepath, 'w') as f:
119+
with open(filepath, "w") as f:
129120
json_str = json.dumps(episode_data, indent=2, default=str)
130-
f.write(json_str + '\n')
121+
f.write(json_str + "\n")
131122
f.flush() # Ensure data is written to disk
132123
except Exception as e:
133124
print(f"Error writing episode to {filepath}: {e}")
134125
import traceback
126+
135127
traceback.print_exc()
136128
raise
137-
129+
138130
def log_episodes(self, episodes: list[Episode], step: int, mode: str = "train", epoch: int = 0):
139131
"""Log multiple episodes, each to its own file.
140-
132+
141133
Args:
142134
episodes: List of episodes to log
143135
step: Current training/validation step
@@ -148,14 +140,14 @@ def log_episodes(self, episodes: list[Episode], step: int, mode: str = "train",
148140
for i, episode in enumerate(episodes):
149141
try:
150142
self.log_episode(episode, step, mode, epoch)
151-
print(f"[EpisodeLogger] Successfully logged episode {i+1}/{len(episodes)}: {episode.id}")
143+
print(f"[EpisodeLogger] Successfully logged episode {i + 1}/{len(episodes)}: {episode.id}")
152144
except Exception as e:
153-
print(f"[EpisodeLogger] Failed to log episode {i+1}/{len(episodes)}: {e}")
145+
print(f"[EpisodeLogger] Failed to log episode {i + 1}/{len(episodes)}: {e}")
154146
raise
155-
147+
156148
def log_episodes_batch(self, episodes: list[Episode], step: int, mode: str = "train", epoch: int = 0, batch_summary: bool = True):
157149
"""Log multiple episodes and optionally create a batch summary in step-specific directory.
158-
150+
159151
Args:
160152
episodes: List of episodes to log
161153
step: Current training/validation step
@@ -165,28 +157,24 @@ def log_episodes_batch(self, episodes: list[Episode], step: int, mode: str = "tr
165157
"""
166158
# Log individual episodes
167159
self.log_episodes(episodes, step, mode, epoch)
168-
160+
169161
# Optionally create batch summary in step-specific directory
170162
if batch_summary and episodes:
171163
summary_data = {
172-
'training_step': step,
173-
'epoch': epoch,
174-
'mode': mode,
175-
'num_episodes': len(episodes),
176-
'episode_files': [
177-
self.get_episode_filename(ep, step) for ep in episodes
178-
],
179-
'summary_stats': {
180-
'total_correct': sum(1 for ep in episodes if ep.is_correct),
181-
'total_incorrect': sum(1 for ep in episodes if not ep.is_correct),
182-
'accuracy': sum(1 for ep in episodes if ep.is_correct) / len(episodes) if episodes else 0,
183-
'avg_trajectories_per_episode': sum(len(ep.trajectories) for ep in episodes) / len(episodes) if episodes else 0,
184-
}
164+
"training_step": step,
165+
"epoch": epoch,
166+
"mode": mode,
167+
"num_episodes": len(episodes),
168+
"episode_files": [self.get_episode_filename(ep, step) for ep in episodes],
169+
"summary_stats": {
170+
"total_correct": sum(1 for ep in episodes if ep.is_correct),
171+
"total_incorrect": sum(1 for ep in episodes if not ep.is_correct),
172+
"accuracy": sum(1 for ep in episodes if ep.is_correct) / len(episodes) if episodes else 0,
173+
"avg_trajectories_per_episode": sum(len(ep.trajectories) for ep in episodes) / len(episodes) if episodes else 0,
174+
},
185175
}
186-
176+
187177
step_dir = self.get_step_dir(step, mode, epoch)
188178
summary_file = step_dir / "batch_summary.json"
189-
with open(summary_file, 'w') as f:
179+
with open(summary_file, "w") as f:
190180
json.dump(summary_data, f, indent=2)
191-
192-

0 commit comments

Comments
 (0)