Skip to content

Commit e8d2be1

Browse files
merge AgentTrainer with PipelineAgentTrainer
1 parent f974973 commit e8d2be1

File tree

6 files changed

+89
-79
lines changed

6 files changed

+89
-79
lines changed

examples/eval_protocol/train_frozen_lake_flow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
from examples.eval_protocol.frozen_lake_flow import FrozenLakeWorkflow
44
from rllm.data.dataset import DatasetRegistry
5-
from rllm.trainer.pipeline_agent_trainer import PipelineAgentTrainer
5+
from rllm.trainer.agent_trainer import AgentTrainer
66

77

88
@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
99
def main(config):
1010
train_dataset = DatasetRegistry.load_dataset("frozen_lake_eval_protocol", "train")
1111
test_dataset = DatasetRegistry.load_dataset("frozen_lake_eval_protocol", "test")
1212

13-
trainer = PipelineAgentTrainer(
13+
trainer = AgentTrainer(
1414
workflow_class=FrozenLakeWorkflow,
1515
workflow_args={
1616
"lite_llm_prefix": "fireworks_ai/",
@@ -21,6 +21,7 @@ def main(config):
2121
config=config,
2222
train_dataset=train_dataset,
2323
val_dataset=test_dataset,
24+
backend="fireworks",
2425
)
2526
trainer.train()
2627

examples/fireworks_math/README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
## Before Running Your Training Job
22

3+
First, install Fireworks SDK and export your FIREWORKS_API_KEY
4+
5+
```bash
6+
pip install fireworks-ai
7+
```
8+
9+
```bash
10+
export FIREWORKS_API_KEY=<YOUR_FIREWORKS_API_KEY>
11+
```
12+
313
Before starting your training, create a **Fireworks deployment**.
414

5-
I recommend installing **firectl** by following the guide here:
15+
We recommend installing **firectl** by following the guide here:
616
[firectl Documentation](https://docs.fireworks.ai/tools-sdks/firectl/firectl)
717

818
Then, create your deployment:
919

1020
```bash
11-
firectl create deployment accounts/fireworks/models/qwen3-30b-a3b-instruct-2507 --enable-hot-reload-latest-addon --deployment-id <YOUR_CUSTOM_DEPLOYMENT_ID> --accelerator-type NVIDIA_H100_80GB
21+
firectl create deployment accounts/fireworks/models/accounts/fireworks/models/qwen3-4b --enable-hot-reload-latest-addon --deployment-id <YOUR_CUSTOM_DEPLOYMENT_ID> --accelerator-type NVIDIA_H100_80GB
1222
```
1323

1424
---

examples/fireworks_math/train_fireworks_math.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
from rllm.engine.rollout.rollout_engine import ModelOutput
66
from rllm.rewards.reward_fn import math_reward_fn
77
from rllm.rewards.reward_types import RewardOutput
8-
from rllm.trainer.pipeline_agent_trainer import PipelineAgentTrainer
9-
from rllm.workflows.single_turn_workflow import SingleTurnWorkflow
8+
from rllm.trainer.agent_trainer import AgentTrainer
9+
from rllm.workflows.simple_workflow import SimpleWorkflow
10+
11+
# from rllm.agents.math_agent import MathAgent
12+
# from rllm.environments.base.single_turn_env import SingleTurnEnvironment
1013

1114

1215
def math_workflow_reward_fn(task_info: dict, action: str) -> RewardOutput:
@@ -22,8 +25,8 @@ def main(config):
2225
train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
2326
test_dataset = DatasetRegistry.load_dataset("math500", "test")
2427

25-
trainer = PipelineAgentTrainer(
26-
workflow_class=SingleTurnWorkflow,
28+
trainer = AgentTrainer(
29+
workflow_class=SimpleWorkflow,
2730
workflow_args={
2831
"reward_function": math_workflow_reward_fn,
2932
"max_prompt_length": config.data.max_prompt_length,
@@ -32,6 +35,7 @@ def main(config):
3235
config=config,
3336
train_dataset=train_dataset,
3437
val_dataset=test_dataset,
38+
backend="fireworks",
3539
)
3640
trainer.train()
3741

examples/fireworks_math/train_fireworks_math.sh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ python3 -m examples.fireworks_math.train_fireworks_math \
1616
data.train_batch_size=8 \
1717
data.val_batch_size=512 \
1818
data.max_prompt_length=4096 \
19-
data.max_response_length=16384 \
19+
data.max_response_length=2048 \
2020
actor_rollout_ref.model.lora_rank=32 \
2121
actor_rollout_ref.model.lora_alpha=32 \
2222
actor_rollout_ref.rollout.load_format=safetensors \
@@ -54,7 +54,7 @@ python3 -m examples.fireworks_math.train_fireworks_math \
5454
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
5555
actor_rollout_ref.actor.entropy_coeff=0 \
5656
algorithm.kl_ctrl.kl_coef=0.001 \
57-
rllm.compact_filtering.enable=True \
57+
rllm.compact_filtering.enable=False \
5858
rllm.compact_filtering.mask_max_prompt_length_exceeded=True \
5959
rllm.compact_filtering.mask_max_response_length_exceeded=True \
6060
rllm.compact_filtering.mask_max_turns_exceeded=False \
@@ -68,14 +68,14 @@ python3 -m examples.fireworks_math.train_fireworks_math \
6868
trainer.project_name='rllm-fireworks-workflow' \
6969
trainer.experiment_name='fireworks-hendrycks-math-4b' \
7070
trainer.max_actor_ckpt_to_keep=2 \
71-
trainer.val_before_train=True \
72-
trainer.n_gpus_per_node=8 \
73-
+trainer.n_training_gpus_per_node=8 \
71+
trainer.val_before_train=False \
72+
trainer.n_gpus_per_node=2 \
73+
+trainer.n_training_gpus_per_node=2 \
7474
trainer.nnodes=1 \
7575
trainer.save_freq=1 \
7676
trainer.test_freq=10 \
7777
trainer.default_hdfs_dir=null \
7878
trainer.total_epochs=100 \
7979
rllm.workflow.use_workflow=True \
80-
fireworks.deployment_id=qwen3-4b-3 \
81-
fireworks.model_id_prefix=test-math-qwen3-4b-3
80+
fireworks.deployment_id=wtk15cs9 \
81+
fireworks.model_id_prefix=qwen3-4b-math

rllm/trainer/agent_trainer.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1-
from typing import Any
1+
from typing import Any, Literal
22

33
import ray
44

55
from rllm.data import Dataset
66
from rllm.trainer.verl.ray_runtime_env import get_ppo_ray_runtime_env
77
from rllm.trainer.verl.train_agent_ppo import TaskRunner
8+
from rllm.trainer.verl.train_workflow_pipeline import PipelineTaskRunner
9+
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env as get_fireworks_ray_runtime_env
810

911

1012
class AgentTrainer:
1113
"""
1214
A wrapper class that allows users to easily train custom agents with custom environments
1315
without having to directly interact with the underlying training infrastructure.
16+
17+
Supports two backends:
18+
- 'verl' (default): Standard training backend supporting both workflow and agent/env classes
19+
- 'fireworks': Pipeline-based training backend optimized for workflow-based training
1420
"""
1521

1622
def __init__(
@@ -24,23 +30,39 @@ def __init__(
2430
config: dict[str, Any] | list[str] | None = None,
2531
train_dataset: Dataset | None = None,
2632
val_dataset: Dataset | None = None,
33+
backend: Literal["verl", "fireworks"] = "verl",
2734
):
2835
"""
2936
Initialize the AgentTrainer.
3037
3138
Args:
39+
workflow_class: The workflow class to use for training
40+
workflow_args: Optional arguments to pass to the workflow class
3241
agent_class: The custom agent class to use for training
3342
env_class: The custom environment class to use for training
43+
agent_args: Optional arguments to pass to the agent class
44+
env_args: Optional arguments to pass to the environment class
3445
config: Configuration overrides to apply to the default config
3546
Can be a dictionary with dot notation keys (e.g., {"data.train_batch_size": 8})
3647
or a list of strings in the format "key=value" (e.g., ["data.train_batch_size=8"])
3748
train_dataset: Optional train dataset to use
3849
val_dataset: Optional validation dataset to use
39-
agent_args: Optional arguments to pass to the agent class
40-
env_args: Optional arguments to pass to the environment class
50+
backend: Training backend to use ('verl' or 'fireworks'). Default is 'verl'
4151
"""
52+
# Validate backend
53+
if backend not in ["verl", "fireworks"]:
54+
raise ValueError(f"backend must be either 'verl' or 'fireworks', got '{backend}'")
55+
56+
self.backend = backend
57+
58+
# Validate backend-specific requirements
59+
if backend == "fireworks":
60+
if agent_class is not None or env_class is not None:
61+
raise ValueError("The 'fireworks' backend only supports workflow_class. agent_class and env_class are not supported. Use workflow_args to configure agent and environment.")
62+
if agent_args is not None or env_args is not None:
63+
raise ValueError("The 'fireworks' backend does not support agent_args or env_args. Use workflow_args to configure the workflow.")
4264

43-
if workflow_class is not None and config.rllm.workflow.use_workflow:
65+
if workflow_class is not None and config is not None and hasattr(config, "rllm") and hasattr(config.rllm, "workflow") and config.rllm.workflow.use_workflow:
4466
if agent_class is not None:
4567
raise ValueError("agent_class is not supported when using workflow, instead use workflow_args['agent_cls']")
4668
if agent_args is not None:
@@ -66,6 +88,21 @@ def __init__(
6688
self.config.data.val_files = val_dataset.get_verl_data_path()
6789

6890
def train(self):
91+
"""
92+
Start the training process using the specified backend.
93+
"""
94+
if self.backend == "verl":
95+
self._train_with_verl()
96+
elif self.backend == "fireworks":
97+
self._train_with_fireworks()
98+
else:
99+
raise ValueError(f"Unknown backend: {self.backend}")
100+
101+
def _train_with_verl(self):
102+
"""
103+
Train using the standard verl backend.
104+
Supports both workflow-based and agent/env-based training.
105+
"""
69106
# Check if Ray is not initialized
70107
if not ray.is_initialized():
71108
# read off all the `ray_init` settings from the config
@@ -88,3 +125,21 @@ def train(self):
88125
env_args=self.env_args,
89126
)
90127
)
128+
129+
def _train_with_fireworks(self):
130+
"""
131+
Train using the fireworks (pipeline) backend.
132+
Optimized for workflow-based training with the Fireworks API.
133+
"""
134+
if not ray.is_initialized():
135+
ray.init(runtime_env=get_fireworks_ray_runtime_env(), num_cpus=self.config.ray_init.num_cpus)
136+
137+
runner = PipelineTaskRunner.remote()
138+
139+
ray.get(
140+
runner.run.remote(
141+
config=self.config,
142+
workflow_class=self.workflow_class,
143+
workflow_args=self.workflow_args,
144+
)
145+
)

rllm/trainer/pipeline_agent_trainer.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

0 commit comments

Comments
 (0)