33import ray
44
55from rllm .data import Dataset
6- from rllm .trainer .verl .ray_runtime_env import get_ppo_ray_runtime_env
7- from rllm .trainer .verl .train_agent_ppo import TaskRunner
86
97
108class AgentTrainer :
@@ -42,13 +40,21 @@ def __init__(
4240 """
4341 if workflow_class is not None :
4442 if agent_class is not None :
45- raise ValueError ("agent_class is not supported when using workflow, instead use workflow_args['agent_cls']" )
43+ raise ValueError (
44+ "agent_class is not supported when using workflow, instead use workflow_args['agent_cls']"
45+ )
4646 if agent_args is not None :
47- raise ValueError ("agent_args is not supported when using workflow, instead use workflow_args['agent_args']" )
47+ raise ValueError (
48+ "agent_args is not supported when using workflow, instead use workflow_args['agent_args']"
49+ )
4850 if env_class is not None :
49- raise ValueError ("env_class is not supported when using workflow, instead use workflow_args['env_cls']" )
51+ raise ValueError (
52+ "env_class is not supported when using workflow, instead use workflow_args['env_cls']"
53+ )
5054 if env_args is not None :
51- raise ValueError ("env_args is not supported when using workflow, instead use workflow_args['env_args']" )
55+ raise ValueError (
56+ "env_args is not supported when using workflow, instead use workflow_args['env_args']"
57+ )
5258
5359 self .workflow_class = workflow_class
5460 self .workflow_args = workflow_args or {}
@@ -63,11 +69,15 @@ def __init__(
6369 self .val_dataset = val_dataset
6470 self .backend = backend
6571
66- assert self .backend in ["verl" , "tinker" ], f"Unsupported backend: { self .backend } , must be one of ['verl', 'tinker']"
72+ assert self .backend in [
73+ "verl" , "tinker"
74+ ], f"Unsupported backend: { self .backend } , must be one of ['verl', 'tinker']"
6775
68- if train_dataset is not None and self .config is not None and hasattr (self .config , "data" ):
76+ if train_dataset is not None and self .config is not None and hasattr (
77+ self .config , "data" ):
6978 self .config .data .train_files = train_dataset .get_verl_data_path ()
70- if val_dataset is not None and self .config is not None and hasattr (self .config , "data" ):
79+ if val_dataset is not None and self .config is not None and hasattr (
80+ self .config , "data" ):
7181 self .config .data .val_files = val_dataset .get_verl_data_path ()
7282
7383 def train (self ):
@@ -101,14 +111,20 @@ def _train_tinker(self):
101111 trainer .fit_agent ()
102112
103113 def _train_verl (self ):
114+ from rllm .trainer .verl .ray_runtime_env import get_ppo_ray_runtime_env
115+ from rllm .trainer .verl .train_agent_ppo import TaskRunner
104116 # Check if Ray is not initialized
105117 if not ray .is_initialized ():
106118 # read off all the `ray_init` settings from the config
107119 if self .config is not None and hasattr (self .config , "ray_init" ):
108- ray_init_settings = {k : v for k , v in self .config .ray_init .items () if v is not None }
120+ ray_init_settings = {
121+ k : v
122+ for k , v in self .config .ray_init .items () if v is not None
123+ }
109124 else :
110125 ray_init_settings = {}
111- ray .init (runtime_env = get_ppo_ray_runtime_env (), ** ray_init_settings )
126+ ray .init (runtime_env = get_ppo_ray_runtime_env (),
127+ ** ray_init_settings )
112128
113129 runner = TaskRunner .remote ()
114130
@@ -121,5 +137,4 @@ def _train_verl(self):
121137 env_class = self .env_class ,
122138 agent_args = self .agent_args ,
123139 env_args = self .env_args ,
124- )
125- )
140+ ))
0 commit comments