11"""Episode JSON Logger for saving detailed episode information."""
2- import json
2+
33import hashlib
4+ import json
45from pathlib import Path
56from typing import Any
7+
68from rllm .agents .agent import Episode
79
810
911class EpisodeLogger :
1012 """Logger to save episodes to individual JSON files with step and data hash."""
11-
13+
1214 def __init__ (self , base_dir : str , subdirectory : str = "episodes" ):
1315 """Initialize the episode logger.
14-
16+
1517 Args:
16- base_dir: Base directory for episode logs. Can be configured via
17- config.trainer.episode_log_dir
18+ base_dir: Base directory for episode logs. Can be configured via
19+ config.trainer.episode_log_dir
1820 (default: "logs/${trainer.project_name}/${trainer.experiment_name}")
1921 subdirectory: Subdirectory within base_dir for episodes (default: "episodes")
2022 Final path will be: {base_dir}/{subdirectory}/
2123 """
2224 self .log_dir = Path (base_dir ) / subdirectory
2325 self .log_dir .mkdir (parents = True , exist_ok = True )
24-
26+
2527 @staticmethod
2628 def compute_task_hash (task : Any , length : int = 8 ) -> str :
2729 """Compute a hash from the task data.
28-
30+
2931 Args:
3032 task: The task dictionary or data
3133 length: Length of the hash to use (default 8 chars)
32-
34+
3335 Returns:
3436 Hash string
3537 """
3638 # Convert task to a stable string representation
3739 task_str = json .dumps (task , sort_keys = True , default = str )
3840 # Compute SHA256 hash
39- hash_obj = hashlib .sha256 (task_str .encode (' utf-8' ))
41+ hash_obj = hashlib .sha256 (task_str .encode (" utf-8" ))
4042 # Return first `length` characters of hex digest
4143 return hash_obj .hexdigest ()[:length ]
42-
44+
4345 def get_step_dir (self , step : int , mode : str = "train" , epoch : int = 0 ) -> Path :
4446 """Get the directory path for a specific training or validation step.
45-
47+
4648 Args:
4749 step: Current training/validation step
4850 mode: Mode identifier ('train' or 'val'), defaults to 'train'
4951 epoch: Current epoch number, defaults to 0
50-
52+
5153 Returns:
5254 Path object for the step directory
5355 """
5456 step_dir = self .log_dir / f"{ mode } _step_{ step } _epoch_{ epoch } "
5557 step_dir .mkdir (parents = True , exist_ok = True )
5658 return step_dir
57-
59+
5860 def get_episode_filename (self , episode : Episode , step : int ) -> str :
5961 """Generate filename for an episode.
60-
62+
6163 Format: episode_hash{task_hash}_id{episode_id}.json
62-
64+
6365 Args:
6466 episode: The episode to save
6567 step: Current training step (not used in filename, but kept for compatibility)
66-
68+
6769 Returns:
6870 Filename string
6971 """
7072 task_hash = self .compute_task_hash (episode .task )
7173 # Clean episode_id to make it filesystem-safe
72- episode_id_safe = str (episode .id ).replace (':' , '_' ).replace ('/' , '_' )
73-
74+ episode_id_safe = str (episode .id ).replace (":" , "_" ).replace ("/" , "_" )
75+
7476 filename = f"episode_hash{ task_hash } _id{ episode_id_safe } .json"
7577 return filename
76-
78+
7779 def log_episode (self , episode : Episode , step : int , mode : str = "train" , epoch : int = 0 ):
7880 """Log a single episode to its own JSON file in a step-specific directory.
79-
81+
8082 Args:
8183 episode: The episode to log
8284 step: Current training/validation step
8385 mode: Mode identifier ('train' or 'val'), defaults to 'train'
8486 epoch: Current epoch number, defaults to 0
8587 """
86- episode_data = {
87- 'training_step' : step ,
88- 'epoch' : epoch ,
89- 'episode_id' : episode .id ,
90- 'task' : episode .task ,
91- 'task_hash' : self .compute_task_hash (episode .task ),
92- 'is_correct' : episode .is_correct ,
93- 'termination_reason' : episode .termination_reason .value if episode .termination_reason else None ,
94- 'metrics' : episode .metrics ,
95- 'timing' : episode .info .get ('timing' , {}),
96- 'trajectories' : []
97- }
98-
88+ episode_data = {"training_step" : step , "epoch" : epoch , "episode_id" : episode .id , "task" : episode .task , "task_hash" : self .compute_task_hash (episode .task ), "is_correct" : episode .is_correct , "termination_reason" : episode .termination_reason .value if episode .termination_reason else None , "metrics" : episode .metrics , "timing" : episode .info .get ("timing" , {}), "trajectories" : []}
89+
9990 for traj in episode .trajectories :
10091 traj_data = {
101- ' name' : traj .name ,
102- ' uid' : traj .uid ,
103- ' reward' : traj .reward ,
104- ' num_steps' : len (traj .steps ),
105- ' timing' : traj .info .get (' timing' , {}),
106- ' steps' : [
92+ " name" : traj .name ,
93+ " uid" : traj .uid ,
94+ " reward" : traj .reward ,
95+ " num_steps" : len (traj .steps ),
96+ " timing" : traj .info .get (" timing" , {}),
97+ " steps" : [
10798 {
108- ' observation' : step .observation ,
109- ' thought' : step .thought ,
110- ' action' : step .action ,
111- ' reward' : step .reward ,
112- ' done' : step .done ,
113- ' model_response' : step .model_response ,
114- ' chat_completions' : step .chat_completions ,
115- ' timing' : step .info .get (' timing' , {}), # Add step-level timing
99+ " observation" : step .observation ,
100+ " thought" : step .thought ,
101+ " action" : step .action ,
102+ " reward" : step .reward ,
103+ " done" : step .done ,
104+ " model_response" : step .model_response ,
105+ " chat_completions" : step .chat_completions ,
106+ " timing" : step .info .get (" timing" , {}), # Add step-level timing
116107 }
117108 for step in traj .steps
118- ]
109+ ],
119110 }
120- episode_data [' trajectories' ].append (traj_data )
121-
111+ episode_data [" trajectories" ].append (traj_data )
112+
122113 # Write to individual file in step-specific directory
123114 step_dir = self .get_step_dir (step , mode , epoch )
124115 filename = self .get_episode_filename (episode , step )
125116 filepath = step_dir / filename
126-
117+
127118 try :
128- with open (filepath , 'w' ) as f :
119+ with open (filepath , "w" ) as f :
129120 json_str = json .dumps (episode_data , indent = 2 , default = str )
130- f .write (json_str + ' \n ' )
121+ f .write (json_str + " \n " )
131122 f .flush () # Ensure data is written to disk
132123 except Exception as e :
133124 print (f"Error writing episode to { filepath } : { e } " )
134125 import traceback
126+
135127 traceback .print_exc ()
136128 raise
137-
129+
138130 def log_episodes (self , episodes : list [Episode ], step : int , mode : str = "train" , epoch : int = 0 ):
139131 """Log multiple episodes, each to its own file.
140-
132+
141133 Args:
142134 episodes: List of episodes to log
143135 step: Current training/validation step
@@ -148,14 +140,14 @@ def log_episodes(self, episodes: list[Episode], step: int, mode: str = "train",
148140 for i , episode in enumerate (episodes ):
149141 try :
150142 self .log_episode (episode , step , mode , epoch )
151- print (f"[EpisodeLogger] Successfully logged episode { i + 1 } /{ len (episodes )} : { episode .id } " )
143+ print (f"[EpisodeLogger] Successfully logged episode { i + 1 } /{ len (episodes )} : { episode .id } " )
152144 except Exception as e :
153- print (f"[EpisodeLogger] Failed to log episode { i + 1 } /{ len (episodes )} : { e } " )
145+ print (f"[EpisodeLogger] Failed to log episode { i + 1 } /{ len (episodes )} : { e } " )
154146 raise
155-
147+
156148 def log_episodes_batch (self , episodes : list [Episode ], step : int , mode : str = "train" , epoch : int = 0 , batch_summary : bool = True ):
157149 """Log multiple episodes and optionally create a batch summary in step-specific directory.
158-
150+
159151 Args:
160152 episodes: List of episodes to log
161153 step: Current training/validation step
@@ -165,28 +157,24 @@ def log_episodes_batch(self, episodes: list[Episode], step: int, mode: str = "tr
165157 """
166158 # Log individual episodes
167159 self .log_episodes (episodes , step , mode , epoch )
168-
160+
169161 # Optionally create batch summary in step-specific directory
170162 if batch_summary and episodes :
171163 summary_data = {
172- 'training_step' : step ,
173- 'epoch' : epoch ,
174- 'mode' : mode ,
175- 'num_episodes' : len (episodes ),
176- 'episode_files' : [
177- self .get_episode_filename (ep , step ) for ep in episodes
178- ],
179- 'summary_stats' : {
180- 'total_correct' : sum (1 for ep in episodes if ep .is_correct ),
181- 'total_incorrect' : sum (1 for ep in episodes if not ep .is_correct ),
182- 'accuracy' : sum (1 for ep in episodes if ep .is_correct ) / len (episodes ) if episodes else 0 ,
183- 'avg_trajectories_per_episode' : sum (len (ep .trajectories ) for ep in episodes ) / len (episodes ) if episodes else 0 ,
184- }
164+ "training_step" : step ,
165+ "epoch" : epoch ,
166+ "mode" : mode ,
167+ "num_episodes" : len (episodes ),
168+ "episode_files" : [self .get_episode_filename (ep , step ) for ep in episodes ],
169+ "summary_stats" : {
170+ "total_correct" : sum (1 for ep in episodes if ep .is_correct ),
171+ "total_incorrect" : sum (1 for ep in episodes if not ep .is_correct ),
172+ "accuracy" : sum (1 for ep in episodes if ep .is_correct ) / len (episodes ) if episodes else 0 ,
173+ "avg_trajectories_per_episode" : sum (len (ep .trajectories ) for ep in episodes ) / len (episodes ) if episodes else 0 ,
174+ },
185175 }
186-
176+
187177 step_dir = self .get_step_dir (step , mode , epoch )
188178 summary_file = step_dir / "batch_summary.json"
189- with open (summary_file , 'w' ) as f :
179+ with open (summary_file , "w" ) as f :
190180 json .dump (summary_data , f , indent = 2 )
191-
192-
0 commit comments