Skip to content

Commit c636400

Browse files
remove verl dependencies for Tinker trainer
1 parent f019950 commit c636400

File tree

6 files changed

+638
-61
lines changed

6 files changed

+638
-61
lines changed

examples/solver_judge_tinker/train_solver_judge_flow_tinker.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ set -x
22

33
MODEL_PATH=Qwen/Qwen3-4B-Instruct-2507
44

5-
python3 -m examples.solver_judge_tinker.train_solver_judge_flow_tinker \
5+
python -m examples.solver_judge_tinker.train_solver_judge_flow_tinker \
66
model.name=$MODEL_PATH \
77
model.lora_rank=32 \
88
training.group_size=4 \

rllm/engine/agent_workflow_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111

1212
from rllm.agents.agent import Episode
1313
from rllm.engine.rollout import ModelOutput, RolloutEngine
14-
from rllm.engine.rollout.verl_engine import VerlEngine
1514
from rllm.misc import colorful_print
1615
from rllm.workflows.workflow import TerminationReason, Workflow
1716

1817
# Avoid hard dependency on verl at import time; only for typing
1918
if TYPE_CHECKING:
19+
from rllm.engine.rollout.verl_engine import VerlEngine
2020
from verl import DataProto
2121

2222
logger = logging.getLogger(__name__)

rllm/trainer/agent_trainer.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import ray
44

55
from rllm.data import Dataset
6-
from rllm.trainer.verl.ray_runtime_env import get_ppo_ray_runtime_env
7-
from rllm.trainer.verl.train_agent_ppo import TaskRunner
86

97

108
class AgentTrainer:
@@ -42,13 +40,21 @@ def __init__(
4240
"""
4341
if workflow_class is not None:
4442
if agent_class is not None:
45-
raise ValueError("agent_class is not supported when using workflow, instead use workflow_args['agent_cls']")
43+
raise ValueError(
44+
"agent_class is not supported when using workflow, instead use workflow_args['agent_cls']"
45+
)
4646
if agent_args is not None:
47-
raise ValueError("agent_args is not supported when using workflow, instead use workflow_args['agent_args']")
47+
raise ValueError(
48+
"agent_args is not supported when using workflow, instead use workflow_args['agent_args']"
49+
)
4850
if env_class is not None:
49-
raise ValueError("env_class is not supported when using workflow, instead use workflow_args['env_cls']")
51+
raise ValueError(
52+
"env_class is not supported when using workflow, instead use workflow_args['env_cls']"
53+
)
5054
if env_args is not None:
51-
raise ValueError("env_args is not supported when using workflow, instead use workflow_args['env_args']")
55+
raise ValueError(
56+
"env_args is not supported when using workflow, instead use workflow_args['env_args']"
57+
)
5258

5359
self.workflow_class = workflow_class
5460
self.workflow_args = workflow_args or {}
@@ -63,11 +69,15 @@ def __init__(
6369
self.val_dataset = val_dataset
6470
self.backend = backend
6571

66-
assert self.backend in ["verl", "tinker"], f"Unsupported backend: {self.backend}, must be one of ['verl', 'tinker']"
72+
assert self.backend in [
73+
"verl", "tinker"
74+
], f"Unsupported backend: {self.backend}, must be one of ['verl', 'tinker']"
6775

68-
if train_dataset is not None and self.config is not None and hasattr(self.config, "data"):
76+
if train_dataset is not None and self.config is not None and hasattr(
77+
self.config, "data"):
6978
self.config.data.train_files = train_dataset.get_verl_data_path()
70-
if val_dataset is not None and self.config is not None and hasattr(self.config, "data"):
79+
if val_dataset is not None and self.config is not None and hasattr(
80+
self.config, "data"):
7181
self.config.data.val_files = val_dataset.get_verl_data_path()
7282

7383
def train(self):
@@ -101,14 +111,20 @@ def _train_tinker(self):
101111
trainer.fit_agent()
102112

103113
def _train_verl(self):
114+
from rllm.trainer.verl.ray_runtime_env import get_ppo_ray_runtime_env
115+
from rllm.trainer.verl.train_agent_ppo import TaskRunner
104116
# Check if Ray is not initialized
105117
if not ray.is_initialized():
106118
# read off all the `ray_init` settings from the config
107119
if self.config is not None and hasattr(self.config, "ray_init"):
108-
ray_init_settings = {k: v for k, v in self.config.ray_init.items() if v is not None}
120+
ray_init_settings = {
121+
k: v
122+
for k, v in self.config.ray_init.items() if v is not None
123+
}
109124
else:
110125
ray_init_settings = {}
111-
ray.init(runtime_env=get_ppo_ray_runtime_env(), **ray_init_settings)
126+
ray.init(runtime_env=get_ppo_ray_runtime_env(),
127+
**ray_init_settings)
112128

113129
runner = TaskRunner.remote()
114130

@@ -121,5 +137,4 @@ def _train_verl(self):
121137
env_class=self.env_class,
122138
agent_args=self.agent_args,
123139
env_args=self.env_args,
124-
)
125-
)
140+
))

0 commit comments

Comments
 (0)