Skip to content

Commit f4b5d77

Browse files
Merge pull request #288 from thwu1/nightly
[feat] Tinker Workflow Trainer
2 parents 8e51f12 + 18a2c72 commit f4b5d77

File tree

9 files changed

+557
-33
lines changed

9 files changed

+557
-33
lines changed

examples/math_tinker/train_math_tinker.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,6 @@
1616
from rllm.trainer.tinker.tinker_agent_trainer import TinkerAgentTrainer
1717

1818

19-
class SimpleDataLoader:
20-
"""Simple reusable dataloader."""
21-
22-
def __init__(self, dataset, batch_size):
23-
self.dataset = dataset
24-
self.batch_size = batch_size
25-
26-
def __iter__(self):
27-
for i in range(0, len(self.dataset), self.batch_size):
28-
yield self.dataset[i : i + self.batch_size]
29-
30-
31-
def create_dataloader(dataset, batch_size):
32-
"""Create a simple reusable dataloader from dataset."""
33-
return SimpleDataLoader(dataset, batch_size)
34-
35-
3619
@hydra.main(version_base=None, config_path="../../rllm/trainer/config", config_name="tinker_agent_trainer")
3720
def main(config: DictConfig):
3821
"""
@@ -48,19 +31,15 @@ def main(config: DictConfig):
4831
if train_dataset is None or test_dataset is None:
4932
raise ValueError("Datasets not found! Please run prepare_tinker_math_dataset.py first:\n python -m examples.math_tinker.prepare_tinker_math_dataset")
5033

51-
# Create dataloaders
52-
train_dataloader = create_dataloader(train_dataset, config.data.train_batch_size)
53-
test_dataloader = create_dataloader(test_dataset, config.data.val_batch_size)
54-
5534
# Create trainer (uses separated components internally)
5635
trainer = TinkerAgentTrainer(
5736
config=config,
5837
agent_class=MathAgentWithFewshot,
5938
env_class=SingleTurnEnvironment,
6039
agent_args={"use_fewshot": True},
6140
env_args={"reward_fn": math_reward_fn},
62-
train_dataloader=train_dataloader,
63-
val_dataloader=test_dataloader,
41+
train_dataset=train_dataset,
42+
val_dataset=test_dataset,
6443
)
6544

6645
# Train (all orchestration handled internally by TinkerAgentTrainer)

examples/solver_judge/train_solver_judge_flow.sh

100644100755
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ python3 -m examples.solver_judge.train_solver_judge_flow \
1010
data.train_batch_size=64 \
1111
data.max_prompt_length=2048 \
1212
data.max_response_length=1024 \
13-
actor_rollout_ref.model.path=Qwen/Qwen3-0.6B \
13+
actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507 \
1414
actor_rollout_ref.actor.optim.lr=1e-6 \
1515
actor_rollout_ref.model.use_remove_padding=True \
1616
actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
@@ -31,12 +31,13 @@ python3 -m examples.solver_judge.train_solver_judge_flow \
3131
actor_rollout_ref.rollout.name=vllm \
3232
actor_rollout_ref.rollout.mode="async" \
3333
actor_rollout_ref.rollout.enforce_eager=False \
34-
actor_rollout_ref.rollout.temperature=0.6 \
34+
actor_rollout_ref.rollout.temperature=1.0 \
35+
actor_rollout_ref.rollout.top_p=1.0 \
3536
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
3637
actor_rollout_ref.rollout.n=4 \
3738
actor_rollout_ref.rollout.val_kwargs.n=1 \
38-
actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
39-
actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
39+
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
40+
actor_rollout_ref.rollout.val_kwargs.top_p=1.0 \
4041
actor_rollout_ref.ref.fsdp_config.param_offload=True \
4142
algorithm.adv_estimator=grpo \
4243
rllm.compact_filtering.enable=False \
@@ -59,6 +60,7 @@ python3 -m examples.solver_judge.train_solver_judge_flow \
5960
trainer.test_freq=10 \
6061
trainer.default_hdfs_dir=null \
6162
trainer.total_epochs=100 \
62-
rllm.workflow.use_workflow=True
63+
rllm.workflow.use_workflow=True \
64+
+ray_init._temp_dir=/home/tianhao/tmp
6365

6466
pkill -9 -f 'ray::WorkerDict'
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import asyncio
2+
import json
3+
import os
4+
5+
# Import countdown-specific modules
6+
import sys
7+
from copy import deepcopy
8+
9+
import tinker
10+
from solver_judge_flow import SolverJudgeWorkflow
11+
from transformers import AutoTokenizer
12+
13+
from rllm.data.dataset import DatasetRegistry
14+
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
15+
from rllm.engine.rollout.tinker_engine import TinkerEngine
16+
from rllm.rewards.countdown_reward import countdown_reward_fn
17+
18+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "countdown"))
19+
20+
21+
def load_data(n=1):
22+
"""Load countdown data using the Dataset interface."""
23+
dataset = DatasetRegistry.load_dataset("countdown", "test")
24+
if dataset is None:
25+
print("Dataset not found, preparing dataset...")
26+
from prepare_countdown_data import prepare_countdown_data
27+
28+
_, dataset, _, _ = prepare_countdown_data()
29+
30+
data = []
31+
for idx, example in enumerate(dataset):
32+
processed = process_countdown_fn(example, idx)
33+
for i in range(n):
34+
data.append(deepcopy(processed))
35+
return data
36+
37+
38+
def process_countdown_fn(example, idx):
39+
"""Process countdown example into the expected format."""
40+
question = example["question"]
41+
target = example["target"]
42+
nums = example["nums"]
43+
44+
# Create ground truth in the format expected by countdown_reward_fn
45+
ground_truth = {"target": target, "numbers": nums}
46+
47+
task = {"question": question, "ground_truth": ground_truth, "idx": idx, "data_source": "countdown", "target": target, "nums": nums}
48+
return task
49+
50+
51+
def evaluate_results(results):
52+
"""Evaluate the results and compute pass@k metrics."""
53+
from collections import defaultdict
54+
55+
# Create a map to store correct answers per problem
56+
problem_correct_map = defaultdict(int)
57+
problem_total_map = defaultdict(int)
58+
59+
# Count correct answers for each problem
60+
for episode in results:
61+
problem = episode.task["question"]
62+
63+
# Use the episode-level is_correct flag set by the workflow
64+
is_correct = episode.is_correct
65+
66+
problem_correct_map[problem] += int(is_correct)
67+
problem_total_map[problem] += 1
68+
69+
# Calculate pass@1 and pass@k
70+
k = max(problem_total_map.values()) if problem_total_map else 1
71+
total_problems = len(problem_correct_map)
72+
73+
if total_problems > 0:
74+
pass_at_1 = sum(problem_correct_map.values()) / sum(problem_total_map.values())
75+
pass_at_k = sum(1 for problem, correct in problem_correct_map.items() if correct > 0) / total_problems
76+
else:
77+
pass_at_1 = 0.0
78+
pass_at_k = 0.0
79+
80+
print("Total unique problems:", total_problems)
81+
print("Average Pass@1 Accuracy:", pass_at_1)
82+
print(f"Average Pass@{k} Accuracy:", pass_at_k)
83+
84+
85+
if __name__ == "__main__":
86+
import os
87+
88+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
89+
90+
# Configuration
91+
n_parallel_tasks = 4
92+
n_solutions = 2 # Number of solutions to generate per problem
93+
94+
model_name = "Qwen/Qwen3-8B"
95+
service_client = tinker.ServiceClient(base_url=None)
96+
tokenizer = AutoTokenizer.from_pretrained(model_name)
97+
rollout_engine = TinkerEngine(
98+
base_url=None,
99+
model_name=model_name,
100+
tokenizer=tokenizer,
101+
service_client=service_client,
102+
max_prompt_length=2048,
103+
max_response_length=1024,
104+
sampling_params={"temperature": 0.6, "top_p": 0.95},
105+
)
106+
training_client = service_client.create_lora_training_client(
107+
base_model=model_name,
108+
rank=4,
109+
)
110+
sampler_future = training_client.save_weights_for_sampler(name="000000")
111+
sampler_result = sampler_future.result()
112+
sampling_client = training_client.create_sampling_client(sampler_result.path)
113+
114+
rollout_engine.set_sampling_client(sampling_client)
115+
116+
engine = AgentWorkflowEngine(
117+
workflow_cls=SolverJudgeWorkflow,
118+
workflow_args={
119+
"n_solutions": n_solutions,
120+
"reward_function": countdown_reward_fn,
121+
},
122+
rollout_engine=rollout_engine,
123+
config=None,
124+
n_parallel_tasks=n_parallel_tasks,
125+
retry_limit=1,
126+
)
127+
128+
# Load countdown tasks
129+
tasks = load_data(n=1)
130+
print(f"Loaded {len(tasks)} countdown tasks")
131+
tasks = tasks[:4]
132+
133+
results = asyncio.run(engine.execute_tasks(tasks))
134+
import pdb
135+
136+
pdb.set_trace()
137+
138+
print(results[1])
139+
140+
# Evaluate results (rewards are already assigned in the workflow)
141+
print("Evaluating results...")
142+
evaluate_results(results)
143+
144+
# Save results
145+
os.makedirs("logs", exist_ok=True)
146+
with open("logs/solver_judge_countdown.json", "w") as f:
147+
json.dump([episode.to_dict() for episode in results], f, indent=4)
148+
149+
print("\nResults saved to logs/solver_judge_countdown.json")
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import hydra
2+
3+
from examples.solver_judge.solver_judge_flow import SolverJudgeWorkflow
4+
from rllm.data.dataset import DatasetRegistry
5+
from rllm.rewards.countdown_reward import countdown_reward_fn
6+
from rllm.trainer.tinker.tinker_workflow_trainer import TinkerWorkflowTrainer
7+
8+
9+
@hydra.main(config_path="pkg://rllm.trainer.config", config_name="tinker_workflow_trainer", version_base=None)
10+
def main(config):
11+
train_dataset = DatasetRegistry.load_dataset("countdown", "train")
12+
test_dataset = DatasetRegistry.load_dataset("countdown", "test")
13+
14+
trainer = TinkerWorkflowTrainer(
15+
workflow_class=SolverJudgeWorkflow,
16+
workflow_args={
17+
"n_solutions": 2,
18+
"reward_function": countdown_reward_fn,
19+
},
20+
config=config,
21+
train_dataset=train_dataset,
22+
val_dataset=test_dataset,
23+
)
24+
trainer.fit_agent()
25+
26+
27+
if __name__ == "__main__":
28+
main()
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
set -x
2+
3+
MODEL_PATH=Qwen/Qwen3-4B-Instruct-2507
4+
5+
python3 -m examples.solver_judge_tinker.train_solver_judge_flow_tinker \
6+
model.name=$MODEL_PATH \
7+
model.lora_rank=32 \
8+
training.group_size=4 \
9+
training.learning_rate=4e-5 \
10+
sampling.temperature=1.0 \
11+
sampling.top_p=1.0 \
12+
algorithm.adv_estimator=grpo \
13+
algorithm.norm_adv_by_std_in_grpo=true \
14+
data.max_prompt_length=2048 \
15+
data.max_response_length=1024 \
16+
data.train_batch_size=64 \
17+
data.val_batch_size=512 \
18+
trainer.total_epochs=100 \
19+
trainer.logger=['wandb'] \
20+
trainer.project_name='solver-judge-workflow' \
21+
trainer.experiment_name='countdown-solver-judge-tinker-norm-by-std' \
22+
trainer.val_before_train=False \
23+
trainer.test_freq=10 \
24+
trainer.save_freq=20 \
25+
trainer.default_local_dir='/tmp/countdown-solver-judge-tinker-norm-by-std'
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Tinker Backend Configuration for rLLM
2+
# This config is used when training agents with Tinker backend
3+
# Default settings match tinker_cookbook.recipes.math_rl for MATH dataset
4+
5+
# Tinker-specific settings
6+
tinker_base_url: null # Tinker service URL (null for local)
7+
8+
# Model Configuration
9+
model:
10+
name: "Qwen/Qwen3-8B" # Default model for MATH dataset
11+
lora_rank: 32
12+
train_unembed: true # Train LoRA on output embedding layer (set to false for Fireworks compatibility)
13+
train_attn: true # Train LoRA on attention layers
14+
train_mlp: true # Train LoRA on MLP layers
15+
16+
# Training Configuration
17+
training:
18+
group_size: 16 # Number of rollouts per prompt (for GRPO)
19+
learning_rate: 2e-5 # 2e-5 for MATH dataset
20+
beta1: 0.9
21+
beta2: 0.95
22+
eps: 1e-8
23+
max_length: 32768
24+
num_minibatches: 1
25+
26+
# Sampling Configuration
27+
sampling:
28+
temperature: 0.6
29+
top_p: 0.95
30+
31+
# Algorithm Configuration (compatible with verl)
32+
algorithm:
33+
adv_estimator: grpo # REINFORCE, GRPO
34+
gamma: 1.0
35+
lam: 0.95
36+
norm_adv_by_std_in_grpo: false # math_rl doesn't normalize by std
37+
38+
workflow:
39+
n_parallel_tasks: 256
40+
retry_limit: 3
41+
42+
# Data Configuration
43+
data:
44+
train_files: null
45+
val_files: null
46+
max_prompt_length: 2048
47+
max_response_length: 2048
48+
train_batch_size: 64
49+
val_batch_size: 32
50+
51+
# Trainer Configuration
52+
trainer:
53+
total_epochs: 10
54+
logger: ['console'] # Options: 'console', 'wandb', 'tensorboard'
55+
project_name: 'rllm-tinker'
56+
experiment_name: 'default'
57+
test_freq: 5
58+
save_freq:
59+
reward_broadcast: 'step'
60+
val_before_train: true
61+
default_local_dir: '/tmp/rllm-tinker-checkpoints'
62+
63+
# Hydra configuration
64+
hydra:
65+
run:
66+
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}

0 commit comments

Comments
 (0)