Skip to content

Commit 9c5b66c

Browse files
committed
use agent trainer
1 parent a2e375b commit 9c5b66c

File tree

5 files changed

+52
-6
lines changed

5 files changed

+52
-6
lines changed

examples/math_tinker/train_math_tinker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from examples.math_tinker.math_reward import math_reward_fn
1414
from rllm.data.dataset import DatasetRegistry
1515
from rllm.environments.base.single_turn_env import SingleTurnEnvironment
16-
from rllm.trainer.tinker.tinker_agent_trainer import TinkerAgentTrainer
16+
from rllm.trainer import AgentTrainer
1717

1818

1919
@hydra.main(version_base=None, config_path="../../rllm/trainer/config", config_name="tinker_agent_trainer")
@@ -32,18 +32,19 @@ def main(config: DictConfig):
3232
raise ValueError("Datasets not found! Please run prepare_tinker_math_dataset.py first:\n python -m examples.math_tinker.prepare_tinker_math_dataset")
3333

3434
# Create trainer (uses separated components internally)
35-
trainer = TinkerAgentTrainer(
35+
trainer = AgentTrainer(
3636
config=config,
3737
agent_class=MathAgentWithFewshot,
3838
env_class=SingleTurnEnvironment,
3939
agent_args={"use_fewshot": True},
4040
env_args={"reward_fn": math_reward_fn},
4141
train_dataset=train_dataset,
4242
val_dataset=test_dataset,
43+
backend="tinker",
4344
)
4445

4546
# Train (all orchestration handled internally by TinkerAgentTrainer)
46-
trainer.fit_agent()
47+
trainer.train()
4748

4849

4950
if __name__ == "__main__":

examples/solver_judge_tinker/train_solver_judge_flow_tinker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
from examples.solver_judge.solver_judge_flow import SolverJudgeWorkflow
44
from rllm.data.dataset import DatasetRegistry
55
from rllm.rewards.countdown_reward import countdown_reward_fn
6-
from rllm.trainer.tinker.tinker_workflow_trainer import TinkerWorkflowTrainer
6+
from rllm.trainer import AgentTrainer
77

88

99
@hydra.main(config_path="pkg://rllm.trainer.config", config_name="tinker_workflow_trainer", version_base=None)
1010
def main(config):
1111
train_dataset = DatasetRegistry.load_dataset("countdown", "train")
1212
test_dataset = DatasetRegistry.load_dataset("countdown", "test")
1313

14-
trainer = TinkerWorkflowTrainer(
14+
trainer = AgentTrainer(
1515
workflow_class=SolverJudgeWorkflow,
1616
workflow_args={
1717
"n_solutions": 2,
@@ -20,8 +20,9 @@ def main(config):
2020
config=config,
2121
train_dataset=train_dataset,
2222
val_dataset=test_dataset,
23+
backend="tinker",
2324
)
24-
trainer.fit_agent()
25+
trainer.train()
2526

2627

2728
if __name__ == "__main__":

rllm/trainer/agent_trainer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
config: dict[str, Any] | list[str] | None = None,
2525
train_dataset: Dataset | None = None,
2626
val_dataset: Dataset | None = None,
27+
backend: str = "verl",
2728
):
2829
"""
2930
Initialize the AgentTrainer.
@@ -59,13 +60,48 @@ def __init__(
5960
self.env_args = env_args or {}
6061

6162
self.config = config
63+
self.train_dataset = train_dataset
64+
self.val_dataset = val_dataset
65+
self.backend = backend
66+
67+
assert self.backend in ["verl", "tinker"], f"Unsupported backend: {self.backend}, must be one of ['verl', 'tinker']"
6268

6369
if train_dataset is not None and self.config is not None and hasattr(self.config, "data"):
6470
self.config.data.train_files = train_dataset.get_verl_data_path()
6571
if val_dataset is not None and self.config is not None and hasattr(self.config, "data"):
6672
self.config.data.val_files = val_dataset.get_verl_data_path()
6773

6874
def train(self):
75+
if self.backend == "verl":
76+
self._train_verl()
77+
elif self.backend == "tinker":
78+
self._train_tinker()
79+
80+
def _train_tinker(self):
81+
from rllm.trainer.tinker.tinker_agent_trainer import TinkerAgentTrainer
82+
from rllm.trainer.tinker.tinker_workflow_trainer import TinkerWorkflowTrainer
83+
84+
if self.config.rllm.workflow.use_workflow:
85+
trainer = TinkerWorkflowTrainer(
86+
config=self.config,
87+
workflow_class=self.workflow_class,
88+
workflow_args=self.workflow_args,
89+
train_dataset=self.train_dataset,
90+
val_dataset=self.val_dataset,
91+
)
92+
else:
93+
trainer = TinkerAgentTrainer(
94+
config=self.config,
95+
agent_class=self.agent_class,
96+
env_class=self.env_class,
97+
agent_args=self.agent_args,
98+
env_args=self.env_args,
99+
train_dataset=self.train_dataset,
100+
val_dataset=self.val_dataset,
101+
)
102+
trainer.fit_agent()
103+
104+
def _train_verl(self):
69105
# Check if Ray is not initialized
70106
if not ray.is_initialized():
71107
# read off all the `ray_init` settings from the config

rllm/trainer/config/tinker_agent_trainer.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ trainer:
6363
val_before_train: true
6464
default_local_dir: '/tmp/rllm-tinker-checkpoints'
6565

66+
rllm:
67+
workflow:
68+
use_workflow: false
69+
6670
# Hydra configuration
6771
hydra:
6872
run:

rllm/trainer/config/tinker_workflow_trainer.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ trainer:
6060
val_before_train: true
6161
default_local_dir: '/tmp/rllm-tinker-checkpoints'
6262

63+
rllm:
64+
workflow:
65+
use_workflow: true
66+
6367
# Hydra configuration
6468
hydra:
6569
run:

0 commit comments

Comments
 (0)