Skip to content

Commit f56cb74

Browse files
committed
add per-episode logging and timing metrics to workflow trainer
1 parent 2956f86 commit f56cb74

13 files changed

+506
-18
lines changed

rllm/engine/agent_workflow_engine.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
class AgentWorkflowEngine:
25-
def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, **kwargs):
25+
def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, episode_logger=None, **kwargs):
2626
"""Initialize the AgentWorkflowEngine.
2727
2828
Args:
@@ -33,6 +33,7 @@ def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_en
3333
n_parallel_tasks: Number of parallel workflow instances to maintain.
3434
retry_limit: Maximum number of retry attempts for failed tasks.
3535
raise_on_error: Whether to raise exceptions on permanent failures.
36+
episode_logger: Optional logger for saving episode data to files.
3637
**kwargs: Additional keyword arguments.
3738
"""
3839
self.workflow_cls = workflow_cls
@@ -48,7 +49,25 @@ def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_en
4849
self.n_parallel_tasks = n_parallel_tasks
4950
self.executor = ThreadPoolExecutor(max_workers=self.n_parallel_tasks)
5051
self.workflow_queue = None
51-
52+
53+
# Episode logging support
54+
self.episode_logger = episode_logger
55+
self.current_step = 0
56+
self.current_epoch = 0
57+
self.current_mode = "train" # "train" or "val"
58+
59+
def set_training_step(self, step: int, mode: str = "train", epoch: int = 0):
60+
"""Set current training step for episode logging.
61+
62+
Args:
63+
step: Current training step number
64+
mode: Mode identifier ('train' or 'val'), defaults to 'train'
65+
epoch: Current epoch number, defaults to 0
66+
"""
67+
self.current_step = step
68+
self.current_mode = mode
69+
self.current_epoch = epoch
70+
5271
async def initialize_pool(self):
5372
"""Initialize the workflow pool with parallel workflow instances.
5473
@@ -154,6 +173,18 @@ async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = No
154173
sorted_tasks = sorted(task_states.keys(), key=lambda task_id: task_states[task_id]["idx"])
155174
for task_id in sorted_tasks:
156175
results.extend(task_states[task_id]["episodes"])
176+
177+
# Log episodes if logger is provided
178+
if self.episode_logger is not None:
179+
try:
180+
logger.info(f"Logging {len(results)} episodes to step={self.current_step}, mode={self.current_mode}, epoch={self.current_epoch}")
181+
self.episode_logger.log_episodes_batch(results, self.current_step, self.current_mode, self.current_epoch)
182+
logger.info(f"Successfully logged {len(results)} episodes")
183+
except Exception as e:
184+
logger.error(f"Failed to log episodes: {e}")
185+
import traceback
186+
traceback.print_exc()
187+
157188
return results
158189

159190
async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto":
@@ -167,12 +198,17 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto":
167198
DataProto: Transformed results compatible with Verl training.
168199
"""
169200
self.rollout_engine.wake_up()
170-
if batch.meta_info.get("validate", False):
201+
is_validation = batch.meta_info.get("validate", False)
202+
if is_validation:
171203
self.rollout_engine.validate = True
204+
self.current_mode = "val"
205+
else:
206+
self.current_mode = "train"
172207
tasks = batch.non_tensor_batch["extra_info"].tolist()
173208
task_ids = batch.non_tensor_batch["task_ids"].tolist()
174209
results = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes
175210
self.rollout_engine.validate = False
211+
self.current_mode = "train"
176212
self.rollout_engine.sleep()
177213
return self.transform_results_for_verl(results, task_ids)
178214

rllm/trainer/config/_generated_agent_ppo_trainer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ trainer:
201201
val_before_train: true
202202
val_only: false
203203
test_freq: -1
204+
log_episodes: false
205+
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}
204206
critic_warmup: 0
205207
default_hdfs_dir: null
206208
del_local_ckpt_after_load: false

rllm/trainer/config/agent_ppo_trainer.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,8 @@ rllm:
6464
fireworks:
6565
deployment_id: null
6666
model_id_prefix: test-model
67-
concurrency: 32
67+
concurrency: 32
68+
69+
trainer:
70+
log_episodes: false
71+
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}

rllm/trainer/config/agent_ppo_trainer_megatron.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,8 @@ rllm:
5858
mask_timeout: True
5959
rejection_sample:
6060
enable: False
61-
multiplier: 1
61+
multiplier: 1
62+
63+
trainer:
64+
log_episodes: false
65+
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}

rllm/trainer/config/agent_sft_trainer.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,8 @@ defaults:
99

1010
data:
1111
rllm:
12-
tokenize_and_mask_method: cumulative
12+
tokenize_and_mask_method: cumulative
13+
14+
trainer:
15+
log_episodes: false
16+
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}

rllm/trainer/verl/agent_workflow_trainer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
1414
from rllm.engine.rollout.verl_engine import VerlEngine
15+
from rllm.utils.episode_logger import EpisodeLogger
1516
from rllm.workflows.workflow import TerminationReason
1617
from verl import DataProto
1718
from verl.protocol import pad_dataproto_to_divisor
@@ -76,13 +77,24 @@ def init_workers(self):
7677
tokenizer=self.tokenizer,
7778
)
7879

80+
# Create episode logger if enabled in config
81+
episode_logger = None
82+
if self.config.trainer.get('log_episodes', False):
83+
# Get episode log directory from config, default to "logs/my_project/my_experiment"
84+
episode_log_dir = self.config.trainer.get('episode_log_dir', f'logs/{self.config.trainer.project_name}/{self.config.trainer.experiment_name}')
85+
episode_logger = EpisodeLogger(
86+
base_dir=episode_log_dir,
87+
subdirectory="episodes"
88+
)
89+
7990
self.agent_execution_engine = AgentWorkflowEngine(
8091
workflow_cls=self.workflow_class,
8192
workflow_args=self.workflow_args,
8293
rollout_engine=rollout_engine,
8394
config=self.config,
8495
n_parallel_tasks=self.config.rllm.workflow.n_parallel_tasks,
8596
retry_limit=self.config.rllm.workflow.retry_limit,
97+
episode_logger=episode_logger,
8698
)
8799

88100
# init workflow workers
@@ -111,6 +123,7 @@ def fit_agent(self):
111123

112124
start_time = time.time()
113125
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
126+
self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=0)
114127
val_metrics = self._validate_agent()
115128
pprint(f"Initial validation metrics: {val_metrics}")
116129
logger.log(data=val_metrics, step=self.global_steps)
@@ -145,6 +158,9 @@ def fit_agent(self):
145158

146159
new_batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"], non_tensor_batch_keys=["raw_prompt_ids"])
147160

161+
# Update training step in engine for episode logging
162+
self.agent_execution_engine.set_training_step(self.global_steps, mode="train", epoch=epoch)
163+
148164
with marked_timer("step", timing_raw):
149165
# generate trajectories
150166
final_gen_batch_output = self.generate_trajectories(batch=new_batch, timing_raw=timing_raw)
@@ -391,6 +407,7 @@ def fit_agent(self):
391407
# validate
392408
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and self.global_steps % self.config.trainer.test_freq == 0:
393409
with marked_timer("testing", timing_raw, color="green"):
410+
self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=epoch)
394411
val_metrics: dict = self._validate_agent()
395412
metrics.update(val_metrics)
396413

@@ -455,6 +472,7 @@ def fit_agent(self):
455472
if self.global_steps >= self.total_training_steps:
456473
# perform validation after training
457474
if self.val_reward_fn is not None:
475+
self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=epoch)
458476
val_metrics = self._validate_agent()
459477
pprint(f"Final validation metrics: {val_metrics}")
460478
logger.log(data=val_metrics, step=self.global_steps)

rllm/utils/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Utilities for the rllm package."""
2+
3+
from rllm.utils.episode_logger import EpisodeLogger
4+
5+
__all__ = ["EpisodeLogger"]
6+

rllm/utils/episode_logger.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""Episode JSON Logger for saving detailed episode information."""
2+
import json
3+
import hashlib
4+
from pathlib import Path
5+
from typing import Any
6+
from rllm.agents.agent import Episode
7+
8+
9+
class EpisodeLogger:
10+
"""Logger to save episodes to individual JSON files with step and data hash."""
11+
12+
def __init__(self, base_dir: str, subdirectory: str = "episodes"):
13+
"""Initialize the episode logger.
14+
15+
Args:
16+
base_dir: Base directory for episode logs. Can be configured via
17+
config.trainer.episode_log_dir
18+
(default: "logs/${trainer.project_name}/${trainer.experiment_name}")
19+
subdirectory: Subdirectory within base_dir for episodes (default: "episodes")
20+
Final path will be: {base_dir}/{subdirectory}/
21+
"""
22+
self.log_dir = Path(base_dir) / subdirectory
23+
self.log_dir.mkdir(parents=True, exist_ok=True)
24+
25+
@staticmethod
26+
def compute_task_hash(task: Any, length: int = 8) -> str:
27+
"""Compute a hash from the task data.
28+
29+
Args:
30+
task: The task dictionary or data
31+
length: Length of the hash to use (default 8 chars)
32+
33+
Returns:
34+
Hash string
35+
"""
36+
# Convert task to a stable string representation
37+
task_str = json.dumps(task, sort_keys=True, default=str)
38+
# Compute SHA256 hash
39+
hash_obj = hashlib.sha256(task_str.encode('utf-8'))
40+
# Return first `length` characters of hex digest
41+
return hash_obj.hexdigest()[:length]
42+
43+
def get_step_dir(self, step: int, mode: str = "train", epoch: int = 0) -> Path:
44+
"""Get the directory path for a specific training or validation step.
45+
46+
Args:
47+
step: Current training/validation step
48+
mode: Mode identifier ('train' or 'val'), defaults to 'train'
49+
epoch: Current epoch number, defaults to 0
50+
51+
Returns:
52+
Path object for the step directory
53+
"""
54+
step_dir = self.log_dir / f"{mode}_step_{step}_epoch_{epoch}"
55+
step_dir.mkdir(parents=True, exist_ok=True)
56+
return step_dir
57+
58+
def get_episode_filename(self, episode: Episode, step: int) -> str:
59+
"""Generate filename for an episode.
60+
61+
Format: episode_hash{task_hash}_id{episode_id}.json
62+
63+
Args:
64+
episode: The episode to save
65+
step: Current training step (not used in filename, but kept for compatibility)
66+
67+
Returns:
68+
Filename string
69+
"""
70+
task_hash = self.compute_task_hash(episode.task)
71+
# Clean episode_id to make it filesystem-safe
72+
episode_id_safe = str(episode.id).replace(':', '_').replace('/', '_')
73+
74+
filename = f"episode_hash{task_hash}_id{episode_id_safe}.json"
75+
return filename
76+
77+
def log_episode(self, episode: Episode, step: int, mode: str = "train", epoch: int = 0):
78+
"""Log a single episode to its own JSON file in a step-specific directory.
79+
80+
Args:
81+
episode: The episode to log
82+
step: Current training/validation step
83+
mode: Mode identifier ('train' or 'val'), defaults to 'train'
84+
epoch: Current epoch number, defaults to 0
85+
"""
86+
episode_data = {
87+
'training_step': step,
88+
'epoch': epoch,
89+
'episode_id': episode.id,
90+
'task': episode.task,
91+
'task_hash': self.compute_task_hash(episode.task),
92+
'is_correct': episode.is_correct,
93+
'termination_reason': episode.termination_reason.value if episode.termination_reason else None,
94+
'metrics': episode.metrics,
95+
'timing': episode.info.get('timing', {}),
96+
'trajectories': []
97+
}
98+
99+
for traj in episode.trajectories:
100+
traj_data = {
101+
'name': traj.name,
102+
'uid': traj.uid,
103+
'reward': traj.reward,
104+
'num_steps': len(traj.steps),
105+
'timing': traj.info.get('timing', {}),
106+
'steps': [
107+
{
108+
'observation': step.observation,
109+
'thought': step.thought,
110+
'action': step.action,
111+
'reward': step.reward,
112+
'done': step.done,
113+
'model_response': step.model_response,
114+
'chat_completions': step.chat_completions,
115+
'timing': step.info.get('timing', {}), # Add step-level timing
116+
}
117+
for step in traj.steps
118+
]
119+
}
120+
episode_data['trajectories'].append(traj_data)
121+
122+
# Write to individual file in step-specific directory
123+
step_dir = self.get_step_dir(step, mode, epoch)
124+
filename = self.get_episode_filename(episode, step)
125+
filepath = step_dir / filename
126+
127+
try:
128+
with open(filepath, 'w') as f:
129+
json_str = json.dumps(episode_data, indent=2, default=str)
130+
f.write(json_str + '\n')
131+
f.flush() # Ensure data is written to disk
132+
except Exception as e:
133+
print(f"Error writing episode to {filepath}: {e}")
134+
import traceback
135+
traceback.print_exc()
136+
raise
137+
138+
def log_episodes(self, episodes: list[Episode], step: int, mode: str = "train", epoch: int = 0):
139+
"""Log multiple episodes, each to its own file.
140+
141+
Args:
142+
episodes: List of episodes to log
143+
step: Current training/validation step
144+
mode: Mode identifier ('train' or 'val'), defaults to 'train'
145+
epoch: Current epoch number, defaults to 0
146+
"""
147+
print(f"[EpisodeLogger] Logging {len(episodes)} episodes for step={step}, mode={mode}, epoch={epoch}")
148+
for i, episode in enumerate(episodes):
149+
try:
150+
self.log_episode(episode, step, mode, epoch)
151+
print(f"[EpisodeLogger] Successfully logged episode {i+1}/{len(episodes)}: {episode.id}")
152+
except Exception as e:
153+
print(f"[EpisodeLogger] Failed to log episode {i+1}/{len(episodes)}: {e}")
154+
raise
155+
156+
def log_episodes_batch(self, episodes: list[Episode], step: int, mode: str = "train", epoch: int = 0, batch_summary: bool = True):
157+
"""Log multiple episodes and optionally create a batch summary in step-specific directory.
158+
159+
Args:
160+
episodes: List of episodes to log
161+
step: Current training/validation step
162+
mode: Mode identifier ('train' or 'val'), defaults to 'train'
163+
epoch: Current epoch number, defaults to 0
164+
batch_summary: Whether to create a summary file for the batch
165+
"""
166+
# Log individual episodes
167+
self.log_episodes(episodes, step, mode, epoch)
168+
169+
# Optionally create batch summary in step-specific directory
170+
if batch_summary and episodes:
171+
summary_data = {
172+
'training_step': step,
173+
'epoch': epoch,
174+
'mode': mode,
175+
'num_episodes': len(episodes),
176+
'episode_files': [
177+
self.get_episode_filename(ep, step) for ep in episodes
178+
],
179+
'summary_stats': {
180+
'total_correct': sum(1 for ep in episodes if ep.is_correct),
181+
'total_incorrect': sum(1 for ep in episodes if not ep.is_correct),
182+
'accuracy': sum(1 for ep in episodes if ep.is_correct) / len(episodes) if episodes else 0,
183+
'avg_trajectories_per_episode': sum(len(ep.trajectories) for ep in episodes) / len(episodes) if episodes else 0,
184+
}
185+
}
186+
187+
step_dir = self.get_step_dir(step, mode, epoch)
188+
summary_file = step_dir / "batch_summary.json"
189+
with open(summary_file, 'w') as f:
190+
json.dump(summary_data, f, indent=2)
191+
192+

0 commit comments

Comments
 (0)