Skip to content

Commit f974973

Browse files
Merge pull request #282 from togethercomputer/qingyang/per-episode-logging
Per Episode Logging Feature
2 parents 2956f86 + 38dda57 commit f974973

13 files changed

+488
-17
lines changed

rllm/engine/agent_workflow_engine.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
class AgentWorkflowEngine:
25-
def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, **kwargs):
25+
def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, episode_logger=None, **kwargs):
2626
"""Initialize the AgentWorkflowEngine.
2727
2828
Args:
@@ -33,6 +33,7 @@ def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_en
3333
n_parallel_tasks: Number of parallel workflow instances to maintain.
3434
retry_limit: Maximum number of retry attempts for failed tasks.
3535
raise_on_error: Whether to raise exceptions on permanent failures.
36+
episode_logger: Optional logger for saving episode data to files.
3637
**kwargs: Additional keyword arguments.
3738
"""
3839
self.workflow_cls = workflow_cls
@@ -49,6 +50,24 @@ def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_en
4950
self.executor = ThreadPoolExecutor(max_workers=self.n_parallel_tasks)
5051
self.workflow_queue = None
5152

53+
# Episode logging support
54+
self.episode_logger = episode_logger
55+
self.current_step = 0
56+
self.current_epoch = 0
57+
self.current_mode = "train" # "train" or "val"
58+
59+
def set_training_step(self, step: int, mode: str = "train", epoch: int = 0):
60+
"""Set current training step for episode logging.
61+
62+
Args:
63+
step: Current training step number
64+
mode: Mode identifier ('train' or 'val'), defaults to 'train'
65+
epoch: Current epoch number, defaults to 0
66+
"""
67+
self.current_step = step
68+
self.current_mode = mode
69+
self.current_epoch = epoch
70+
5271
async def initialize_pool(self):
5372
"""Initialize the workflow pool with parallel workflow instances.
5473
@@ -154,6 +173,18 @@ async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = No
154173
sorted_tasks = sorted(task_states.keys(), key=lambda task_id: task_states[task_id]["idx"])
155174
for task_id in sorted_tasks:
156175
results.extend(task_states[task_id]["episodes"])
176+
177+
# Log episodes if logger is provided
178+
if self.episode_logger is not None:
179+
try:
180+
logger.info(f"Logging {len(results)} episodes to step={self.current_step}, mode={self.current_mode}, epoch={self.current_epoch}")
181+
self.episode_logger.log_episodes_batch(results, self.current_step, self.current_mode, self.current_epoch)
182+
except Exception as e:
183+
logger.error(f"Failed to log episodes: {e}")
184+
import traceback
185+
186+
traceback.print_exc()
187+
157188
return results
158189

159190
async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto":
@@ -167,12 +198,17 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto":
167198
DataProto: Transformed results compatible with Verl training.
168199
"""
169200
self.rollout_engine.wake_up()
170-
if batch.meta_info.get("validate", False):
201+
is_validation = batch.meta_info.get("validate", False)
202+
if is_validation:
171203
self.rollout_engine.validate = True
204+
self.current_mode = "val"
205+
else:
206+
self.current_mode = "train"
172207
tasks = batch.non_tensor_batch["extra_info"].tolist()
173208
task_ids = batch.non_tensor_batch["task_ids"].tolist()
174209
results = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes
175210
self.rollout_engine.validate = False
211+
self.current_mode = "train"
176212
self.rollout_engine.sleep()
177213
return self.transform_results_for_verl(results, task_ids)
178214

rllm/trainer/config/_generated_agent_ppo_trainer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ trainer:
201201
val_before_train: true
202202
val_only: false
203203
test_freq: -1
204+
log_episodes: false
205+
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}
204206
critic_warmup: 0
205207
default_hdfs_dir: null
206208
del_local_ckpt_after_load: false

rllm/trainer/config/agent_ppo_trainer.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,8 @@ rllm:
6464
fireworks:
6565
deployment_id: null
6666
model_id_prefix: test-model
67-
concurrency: 32
67+
concurrency: 32
68+
69+
trainer:
70+
log_episodes: false
71+
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}

rllm/trainer/config/agent_ppo_trainer_megatron.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,8 @@ rllm:
5858
mask_timeout: True
5959
rejection_sample:
6060
enable: False
61-
multiplier: 1
61+
multiplier: 1
62+
63+
trainer:
64+
log_episodes: false
65+
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}

rllm/trainer/config/agent_sft_trainer.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,8 @@ defaults:
99

1010
data:
1111
rllm:
12-
tokenize_and_mask_method: cumulative
12+
tokenize_and_mask_method: cumulative
13+
14+
trainer:
15+
log_episodes: false
16+
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}

rllm/trainer/verl/agent_workflow_trainer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
1414
from rllm.engine.rollout.verl_engine import VerlEngine
15+
from rllm.utils.episode_logger import EpisodeLogger
1516
from rllm.workflows.workflow import TerminationReason
1617
from verl import DataProto
1718
from verl.protocol import pad_dataproto_to_divisor
@@ -76,13 +77,21 @@ def init_workers(self):
7677
tokenizer=self.tokenizer,
7778
)
7879

80+
# Create episode logger if enabled in config
81+
episode_logger = None
82+
if self.config.trainer.get("log_episodes", False):
83+
# Get episode log directory from config, default to "logs/my_project/my_experiment"
84+
episode_log_dir = self.config.trainer.get("episode_log_dir", f"logs/{self.config.trainer.project_name}/{self.config.trainer.experiment_name}")
85+
episode_logger = EpisodeLogger(base_dir=episode_log_dir, subdirectory="episodes")
86+
7987
self.agent_execution_engine = AgentWorkflowEngine(
8088
workflow_cls=self.workflow_class,
8189
workflow_args=self.workflow_args,
8290
rollout_engine=rollout_engine,
8391
config=self.config,
8492
n_parallel_tasks=self.config.rllm.workflow.n_parallel_tasks,
8593
retry_limit=self.config.rllm.workflow.retry_limit,
94+
episode_logger=episode_logger,
8695
)
8796

8897
# init workflow workers
@@ -111,6 +120,7 @@ def fit_agent(self):
111120

112121
start_time = time.time()
113122
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
123+
self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=0)
114124
val_metrics = self._validate_agent()
115125
pprint(f"Initial validation metrics: {val_metrics}")
116126
logger.log(data=val_metrics, step=self.global_steps)
@@ -145,6 +155,9 @@ def fit_agent(self):
145155

146156
new_batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"], non_tensor_batch_keys=["raw_prompt_ids"])
147157

158+
# Update training step in engine for episode logging
159+
self.agent_execution_engine.set_training_step(self.global_steps, mode="train", epoch=epoch)
160+
148161
with marked_timer("step", timing_raw):
149162
# generate trajectories
150163
final_gen_batch_output = self.generate_trajectories(batch=new_batch, timing_raw=timing_raw)
@@ -391,6 +404,7 @@ def fit_agent(self):
391404
# validate
392405
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and self.global_steps % self.config.trainer.test_freq == 0:
393406
with marked_timer("testing", timing_raw, color="green"):
407+
self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=epoch)
394408
val_metrics: dict = self._validate_agent()
395409
metrics.update(val_metrics)
396410

@@ -455,6 +469,7 @@ def fit_agent(self):
455469
if self.global_steps >= self.total_training_steps:
456470
# perform validation after training
457471
if self.val_reward_fn is not None:
472+
self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=epoch)
458473
val_metrics = self._validate_agent()
459474
pprint(f"Final validation metrics: {val_metrics}")
460475
logger.log(data=val_metrics, step=self.global_steps)

rllm/utils/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Utilities for the rllm package."""
2+
3+
from rllm.utils.episode_logger import EpisodeLogger
4+
5+
__all__ = ["EpisodeLogger"]

rllm/utils/episode_logger.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""Episode JSON Logger for saving detailed episode information."""
2+
3+
import hashlib
4+
import json
5+
from pathlib import Path
6+
from typing import Any
7+
8+
from rllm.agents.agent import Episode
9+
10+
11+
class EpisodeLogger:
12+
"""Logger to save episodes to individual JSON files with step and data hash."""
13+
14+
def __init__(self, base_dir: str, subdirectory: str = "episodes"):
15+
"""Initialize the episode logger.
16+
17+
Args:
18+
base_dir: Base directory for episode logs. Can be configured via
19+
config.trainer.episode_log_dir
20+
(default: "logs/${trainer.project_name}/${trainer.experiment_name}")
21+
subdirectory: Subdirectory within base_dir for episodes (default: "episodes")
22+
Final path will be: {base_dir}/{subdirectory}/
23+
"""
24+
self.log_dir = Path(base_dir) / subdirectory
25+
self.log_dir.mkdir(parents=True, exist_ok=True)
26+
27+
@staticmethod
28+
def compute_task_hash(task: Any, length: int = 8) -> str:
29+
"""Compute a hash from the task data.
30+
31+
Args:
32+
task: The task dictionary or data
33+
length: Length of the hash to use (default 8 chars)
34+
35+
Returns:
36+
Hash string
37+
"""
38+
# Convert task to a stable string representation
39+
task_str = json.dumps(task, sort_keys=True, default=str)
40+
# Compute SHA256 hash
41+
hash_obj = hashlib.sha256(task_str.encode("utf-8"))
42+
# Return first `length` characters of hex digest
43+
return hash_obj.hexdigest()[:length]
44+
45+
def get_step_dir(self, step: int, mode: str = "train", epoch: int = 0) -> Path:
46+
"""Get the directory path for a specific training or validation step.
47+
48+
Args:
49+
step: Current training/validation step
50+
mode: Mode identifier ('train' or 'val'), defaults to 'train'
51+
epoch: Current epoch number, defaults to 0
52+
53+
Returns:
54+
Path object for the step directory
55+
"""
56+
step_dir = self.log_dir / f"{mode}_step_{step}_epoch_{epoch}"
57+
step_dir.mkdir(parents=True, exist_ok=True)
58+
return step_dir
59+
60+
def get_episode_filename(self, episode: Episode, step: int) -> str:
61+
"""Generate filename for an episode.
62+
63+
Format: episode_hash{task_hash}_id{episode_id}.json
64+
65+
Args:
66+
episode: The episode to save
67+
step: Current training step (not used in filename, but kept for compatibility)
68+
69+
Returns:
70+
Filename string
71+
"""
72+
task_hash = self.compute_task_hash(episode.task)
73+
# Clean episode_id to make it filesystem-safe
74+
episode_id_safe = str(episode.id).replace(":", "_").replace("/", "_")
75+
76+
filename = f"episode_hash{task_hash}_id{episode_id_safe}.json"
77+
return filename
78+
79+
def log_episode(self, episode: Episode, step: int, mode: str = "train", epoch: int = 0):
80+
"""Log a single episode to its own JSON file in a step-specific directory.
81+
82+
Args:
83+
episode: The episode to log
84+
step: Current training/validation step
85+
mode: Mode identifier ('train' or 'val'), defaults to 'train'
86+
epoch: Current epoch number, defaults to 0
87+
"""
88+
episode_data = {"training_step": step, "epoch": epoch, "episode_id": episode.id, "task": episode.task, "task_hash": self.compute_task_hash(episode.task), "is_correct": episode.is_correct, "termination_reason": episode.termination_reason.value if episode.termination_reason else None, "metrics": episode.metrics, "timing": episode.info.get("timing", {}), "trajectories": []}
89+
90+
for traj in episode.trajectories:
91+
traj_data = {
92+
"name": traj.name,
93+
"uid": traj.uid,
94+
"reward": traj.reward,
95+
"num_steps": len(traj.steps),
96+
"timing": traj.info.get("timing", {}),
97+
"steps": [
98+
{
99+
"observation": step.observation,
100+
"thought": step.thought,
101+
"action": step.action,
102+
"reward": step.reward,
103+
"done": step.done,
104+
"model_response": step.model_response,
105+
"chat_completions": step.chat_completions,
106+
"timing": step.info.get("timing", {}), # Add step-level timing
107+
}
108+
for step in traj.steps
109+
],
110+
}
111+
episode_data["trajectories"].append(traj_data)
112+
113+
# Write to individual file in step-specific directory
114+
step_dir = self.get_step_dir(step, mode, epoch)
115+
filename = self.get_episode_filename(episode, step)
116+
filepath = step_dir / filename
117+
118+
try:
119+
with open(filepath, "w") as f:
120+
json_str = json.dumps(episode_data, indent=2, default=str)
121+
f.write(json_str + "\n")
122+
f.flush() # Ensure data is written to disk
123+
except Exception as e:
124+
print(f"Error writing episode to {filepath}: {e}")
125+
import traceback
126+
127+
traceback.print_exc()
128+
raise
129+
130+
def log_episodes(self, episodes: list[Episode], step: int, mode: str = "train", epoch: int = 0):
131+
"""Log multiple episodes, each to its own file.
132+
133+
Args:
134+
episodes: List of episodes to log
135+
step: Current training/validation step
136+
mode: Mode identifier ('train' or 'val'), defaults to 'train'
137+
epoch: Current epoch number, defaults to 0
138+
"""
139+
print(f"[EpisodeLogger] Logging {len(episodes)} episodes for step={step}, mode={mode}, epoch={epoch}")
140+
for i, episode in enumerate(episodes):
141+
try:
142+
self.log_episode(episode, step, mode, epoch)
143+
print(f"[EpisodeLogger] Successfully logged episode {i + 1}/{len(episodes)}: {episode.id}")
144+
except Exception as e:
145+
print(f"[EpisodeLogger] Failed to log episode {i + 1}/{len(episodes)}: {e}")
146+
raise
147+
148+
def log_episodes_batch(self, episodes: list[Episode], step: int, mode: str = "train", epoch: int = 0, batch_summary: bool = True):
149+
"""Log multiple episodes and optionally create a batch summary in step-specific directory.
150+
151+
Args:
152+
episodes: List of episodes to log
153+
step: Current training/validation step
154+
mode: Mode identifier ('train' or 'val'), defaults to 'train'
155+
epoch: Current epoch number, defaults to 0
156+
batch_summary: Whether to create a summary file for the batch
157+
"""
158+
# Log individual episodes
159+
self.log_episodes(episodes, step, mode, epoch)
160+
161+
# Optionally create batch summary in step-specific directory
162+
if batch_summary and episodes:
163+
summary_data = {
164+
"training_step": step,
165+
"epoch": epoch,
166+
"mode": mode,
167+
"num_episodes": len(episodes),
168+
"episode_files": [self.get_episode_filename(ep, step) for ep in episodes],
169+
"summary_stats": {
170+
"total_correct": sum(1 for ep in episodes if ep.is_correct),
171+
"total_incorrect": sum(1 for ep in episodes if not ep.is_correct),
172+
"accuracy": sum(1 for ep in episodes if ep.is_correct) / len(episodes) if episodes else 0,
173+
"avg_trajectories_per_episode": sum(len(ep.trajectories) for ep in episodes) / len(episodes) if episodes else 0,
174+
},
175+
}
176+
177+
step_dir = self.get_step_dir(step, mode, epoch)
178+
summary_file = step_dir / "batch_summary.json"
179+
with open(summary_file, "w") as f:
180+
json.dump(summary_data, f, indent=2)

rllm/workflows/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
"TerminationEvent",
1212
"SingleTurnWorkflow",
1313
"MultiTurnWorkflow",
14+
"CumulativeWorkflow",
15+
"TimingTrackingMixin",
1416
]
1517

1618

@@ -23,4 +25,12 @@ def __getattr__(name):
2325
from .multi_turn_workflow import MultiTurnWorkflow as _Multi
2426

2527
return _Multi
28+
if name == "CumulativeWorkflow":
29+
from .cumulative_workflow import CumulativeWorkflow as _Cumulative
30+
31+
return _Cumulative
32+
if name == "TimingTrackingMixin":
33+
from .timing_mixin import TimingTrackingMixin as _Mixin
34+
35+
return _Mixin
2636
raise AttributeError(name)

0 commit comments

Comments
 (0)