Skip to content

Commit 963b417

Browse files
committed
Merge branch 'proxy' into nightly
2 parents 976d74c + 639247b commit 963b417

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+13577
-4
lines changed

examples/omni_trainer/README.md

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Omni Trainer
2+
3+
This example demonstrates how to use the Omni Trainer for reinforcement learning with language models.
4+
5+
## Prerequisites
6+
7+
### 1. Install Verl
8+
9+
Run the installation script:
10+
```bash
11+
bash scripts/install_verl.sh
12+
```
13+
14+
**Important:** Make sure to install `torch==2.6.0` when installing Verl. After `install_verl.sh` finishes, install `vllm==0.10.0`. You should see your torch version bumped to 2.7.1 after this - this is expected behavior.
15+
16+
**Troubleshooting:**
17+
- If you encounter issues with `flash_attn`, reinstall it with:
18+
```bash
19+
pip install flash-attn --no-build-isolation
20+
```
21+
22+
- If you encounter errors with Ray, try:
23+
```bash
24+
pip install ray==2.48.0
25+
```
26+
27+
### 2. Install Episodic
28+
29+
Download and install [episodic](https://github.com/agentica-org/episodic):
30+
```bash
31+
cd rllm/sdk/episodic-sdk
32+
pip install -e .
33+
```
34+
35+
### 3. Verify Dependencies
36+
37+
Check that your websocket version is >= 15.0 (version 13.x will not work).
38+
39+
## Setup
40+
41+
### 1. Launch the Context Store
42+
43+
Start the episodic context store server:
44+
45+
```bash
46+
episodic serve --db-path /tmp/episodic.db # choose a local path for better performance
47+
```
48+
49+
### 2. Deploy the LiteLLM Proxy
50+
51+
In a separate terminal, start the LiteLLM proxy server:
52+
53+
```bash
54+
#!/bin/bash
55+
56+
# Set ulimit first
57+
ulimit -n 65536
58+
59+
# Set aiohttp connection limits
60+
export AIOHTTP_CONNECTOR_LIMIT=4096
61+
export AIOHTTP_KEEPALIVE_TIMEOUT=60
62+
63+
# Verify the limits are set
64+
echo "Current ulimit -n: $(ulimit -n)"
65+
echo "AIOHTTP_CONNECTOR_LIMIT: $AIOHTTP_CONNECTOR_LIMIT"
66+
echo "AIOHTTP_KEEPALIVE_TIMEOUT: $AIOHTTP_KEEPALIVE_TIMEOUT"
67+
echo "Starting LiteLLM proxy..."
68+
69+
# Start the proxy
70+
python scripts/litellm_proxy_server.py \
71+
--config litellm_proxy_config_autogen.yaml \
72+
--host 127.0.0.1 \
73+
--port 4000 \
74+
--state-dir /tmp/litellm_proxy \
75+
--cs-endpoint http://localhost:8000 \
76+
--cs-api-key "your-api-key-here" \
77+
--project rllm-agent-omni-engine \
78+
--admin-token my-shared-secret
79+
```
80+
81+
82+
## Running the Examples
83+
84+
Once both the context store and LiteLLM proxy are running, you can execute one of the training examples:
85+
86+
### Hendrycks Math Training
87+
88+
This is the simplest example with a single agent and single turn.
89+
90+
```bash
91+
./train_hendrycks_math.sh
92+
```
93+
94+
### Solver-Judge Flow Training
95+
96+
This is a more complex example with 2 agents and more complex grouping logic.
97+
98+
```bash
99+
./train_solver_judge_flow.sh
100+
```
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import hydra
2+
3+
from rllm.data.dataset import DatasetRegistry
4+
from rllm.rewards.reward_fn import math_reward_fn
5+
from rllm.sdk.shortcuts import get_chat_client
6+
from rllm.trainer.agent_trainer import AgentTrainer
7+
from rllm.workflows.simple_workflow import SimpleWorkflow
8+
9+
10+
@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
11+
def main(config):
12+
train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
13+
test_dataset = DatasetRegistry.load_dataset("math500", "test")
14+
15+
# Define run function that recreates the client inside to avoid closure capture
16+
# This ensures the function is fully serializable for Ray
17+
def run(
18+
question: str,
19+
ground_truth: str,
20+
base_url: str = "http://localhost:4000/v1",
21+
api_key: str = "EMPTY",
22+
**kwargs,
23+
):
24+
# Recreate the client inside the function to avoid serialization issues
25+
# This ensures the function doesn't capture non-serializable objects
26+
client = get_chat_client(base_url=base_url, api_key=api_key)
27+
response = client.chat.completions.create(
28+
model="vllm/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
29+
messages=[
30+
{"role": "user", "content": question},
31+
],
32+
)
33+
response_text = response.choices[0].message.content
34+
reward = math_reward_fn({"response": response_text, "ground_truth": ground_truth}, response_text).reward
35+
return reward
36+
37+
trainer = AgentTrainer(
38+
config=config,
39+
train_dataset=train_dataset,
40+
val_dataset=test_dataset,
41+
agent_run_func=run,
42+
)
43+
trainer.train()
44+
45+
46+
if __name__ == "__main__":
47+
main()
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
set -x
2+
3+
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
4+
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
5+
export VLLM_USE_V1=1
6+
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
7+
export VLLM_ENGINE_ITERATION_TIMEOUT_S=100000000000
8+
9+
# Find the directory where rllm package is located
10+
RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
11+
12+
MODEL_PATH=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
13+
14+
python3 -m examples.omni_trainer.simple_math.train_hendrycks_math \
15+
algorithm.adv_estimator=grpo \
16+
data.train_batch_size=32 \
17+
data.val_batch_size=512 \
18+
data.max_prompt_length=2048 \
19+
data.max_response_length=2048 \
20+
actor_rollout_ref.model.path=$MODEL_PATH \
21+
actor_rollout_ref.hybrid_engine=True \
22+
actor_rollout_ref.actor.optim.lr=1e-6 \
23+
actor_rollout_ref.actor.strategy=fsdp2 \
24+
actor_rollout_ref.actor.loss_agg_mode=token-mean \
25+
actor_rollout_ref.model.use_remove_padding=True \
26+
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
27+
actor_rollout_ref.actor.use_dynamic_bsz=True \
28+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \
29+
actor_rollout_ref.actor.use_kl_loss=False \
30+
actor_rollout_ref.actor.clip_ratio_high=0.28 \
31+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
32+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
33+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
34+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
35+
actor_rollout_ref.actor.fsdp_config.param_offload=True \
36+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
37+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
38+
actor_rollout_ref.rollout.name=vllm \
39+
actor_rollout_ref.rollout.mode="async" \
40+
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
41+
actor_rollout_ref.rollout.enforce_eager=False \
42+
actor_rollout_ref.rollout.n=16 \
43+
actor_rollout_ref.rollout.temperature=0.6 \
44+
actor_rollout_ref.rollout.val_kwargs.n=1 \
45+
actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
46+
actor_rollout_ref.rollout.val_kwargs.top_p=0.9 \
47+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
48+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
49+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
50+
actor_rollout_ref.actor.entropy_coeff=0 \
51+
algorithm.kl_ctrl.kl_coef=0.001 \
52+
rllm.mask_truncated_samples=False \
53+
trainer.critic_warmup=0 \
54+
trainer.logger=['console','wandb'] \
55+
trainer.project_name='rllm-agent' \
56+
trainer.experiment_name='simple-math-simple-workflow' \
57+
trainer.val_before_train=True \
58+
trainer.n_gpus_per_node=8 \
59+
trainer.nnodes=1 \
60+
trainer.save_freq=200 \
61+
trainer.test_freq=10 \
62+
trainer.default_hdfs_dir=null \
63+
rllm.agent.max_steps=1 \
64+
rllm.stepwise_advantage.enable=False \
65+
rllm.workflow.use_workflow=True \
66+
trainer.total_epochs=100 \
67+
+rllm.proxy.host=127.0.0.1 \
68+
+rllm.proxy.port=4000 \
69+
+rllm.proxy.auto_start=False \
70+
+rllm.proxy.admin_token=my-shared-secret \
71+
+rllm.run_name=rllm-agent-omni-engine \
72+
+context_store.endpoint=http://localhost:8000 \
73+
+context_store.api_key=your-api-key-here
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import asyncio
2+
import re
3+
import uuid
4+
5+
from rllm.agents.agent import Episode, Trajectory
6+
from rllm.engine import RolloutEngine
7+
from rllm.rewards.reward_fn import RewardFunction
8+
from rllm.sdk import get_chat_client_async, session, set_reward_async
9+
from rllm.workflows.workflow import Workflow
10+
11+
12+
class Solver:
13+
def __init__(self, **kwargs):
14+
self.client = get_chat_client_async(base_url="http://localhost:4000/v1", api_key="EMPTY", model="vllm/Qwen/Qwen3-4B-Instruct-2507")
15+
16+
async def generate_solution(self, problem: str) -> Trajectory:
17+
with session(agent="solver", groupby_key=str(uuid.uuid4())):
18+
messages = [{"role": "user", "content": f"{problem}. Output the final answer within <answer>...</answer>"}]
19+
response = await self.client.chat.completions.create(
20+
messages=messages,
21+
temperature=1,
22+
max_tokens=1000,
23+
)
24+
25+
content = response.choices[0].message.content
26+
return response.id, self._parse_solver_response(content)
27+
28+
async def generate_solutions(self, problem: str, n_solutions: int = 2) -> list[Trajectory]:
29+
tasks = [asyncio.create_task(self.generate_solution(problem)) for _ in range(n_solutions)]
30+
return await asyncio.gather(*tasks)
31+
32+
def _parse_solver_response(self, response: str) -> str:
33+
answer_match = re.search(r"<answer>(.*?)</answer>", response, re.IGNORECASE | re.DOTALL)
34+
if answer_match:
35+
return f"<answer>{answer_match.group(1).strip()}</answer>"
36+
else:
37+
return "No solution found"
38+
39+
40+
class Judge:
41+
def __init__(self, **kwargs):
42+
self.client = get_chat_client_async(base_url="http://localhost:4000/v1", api_key="EMPTY", model="vllm/Qwen/Qwen3-4B-Instruct-2507")
43+
44+
async def judge_solutions(self, problem: str, solutions: list[str]) -> Trajectory:
45+
with session(agent="judge"):
46+
messages = [{"role": "user", "content": self._create_judge_prompt(problem, solutions)}]
47+
response = await self.client.chat.completions.create(
48+
messages=messages,
49+
temperature=1,
50+
max_tokens=1000,
51+
)
52+
content = response.choices[0].message.content
53+
return response.id, self._parse_judge_response(content, solutions)
54+
55+
def _parse_judge_response(self, response: str, solutions: list[str]) -> str:
56+
answer_match = re.search(r"<answer>(.*?)</answer>", response, re.IGNORECASE | re.DOTALL)
57+
if answer_match:
58+
answer_text = answer_match.group(1).strip()
59+
try:
60+
solution_index = int(answer_text)
61+
return solutions[solution_index - 1]
62+
except (ValueError, IndexError):
63+
return ""
64+
return ""
65+
66+
def _create_judge_prompt(self, problem: str, solutions: list[str]) -> str:
67+
"""Create a prompt for the judge to evaluate solutions."""
68+
prompt = f"""You are an expert verifier. Given a countdown problem and multiple solution attempts, select a correct solution.
69+
Problem:
70+
{problem}
71+
Solutions to evaluate:
72+
"""
73+
for i, solution in enumerate(solutions, 1):
74+
prompt += f"\nSolution {i}:\n{solution}\n"
75+
76+
prompt += """
77+
A correct solution must satisfy the following criteria:
78+
1. The solution uses only the given numbers.
79+
2. Each number is used exactly once.
80+
3. Only basic arithmetic operations (+, -, *, /) are used.
81+
4. The calculation results in the target number.
82+
5. The final answer is clearly marked within <answer>...</answer> tags.
83+
Output the index of your selected solution within <answer>...</answer> tags, e.g., <answer>1</answer> for the first solution, <answer>2</answer> for the second solution, etc. If multiple solutions are correct, output the index of the first correct solution."""
84+
return prompt
85+
86+
87+
class SolverJudgeWorkflow(Workflow):
88+
def __init__(self, rollout_engine: RolloutEngine, n_solutions: int = 2, reward_function: RewardFunction = None, **kwargs):
89+
super().__init__(rollout_engine, **kwargs)
90+
self.n_solutions = n_solutions
91+
self.reward_function = reward_function
92+
self.solver = Solver()
93+
self.judge = Judge()
94+
95+
async def run(self, task: dict, uid: str, **kwargs) -> Episode:
96+
self.reset(task, uid)
97+
problem = task["question"]
98+
99+
# Step 1: Solver generates multiple solutions in parallel
100+
solver_trajectories = await self.solver.generate_solutions(problem, self.n_solutions)
101+
102+
# Assign rewards to solver trajectories
103+
solutions = []
104+
for response_id, solution in solver_trajectories:
105+
solutions.append(solution)
106+
reward = self.reward_function(task, solution).reward
107+
await set_reward_async(response_id, reward=reward)
108+
109+
# Step 2: Judge selects the best solution
110+
response_id, selected_solution = await self.judge.judge_solutions(problem, solutions)
111+
112+
# Evaluate the selected solution
113+
reward_result = self.reward_function(task, selected_solution)
114+
115+
await set_reward_async(response_id, reward=reward_result.reward)
116+
return reward_result.reward
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import hydra
2+
3+
from examples.omni_trainer.solver_judge_workflow.simple_solver_judge_flow import SolverJudgeWorkflow
4+
from rllm.data.dataset import DatasetRegistry
5+
from rllm.rewards.countdown_reward import countdown_reward_fn
6+
from rllm.trainer.agent_trainer import AgentTrainer
7+
8+
9+
async def run_workflow(**kwargs) -> float:
10+
task = kwargs
11+
workflow = SolverJudgeWorkflow(rollout_engine=None, executor=None, n_solutions=2, reward_function=countdown_reward_fn)
12+
return await workflow.run(task, "")
13+
14+
15+
@hydra.main(config_path="pkg://rllm.trainer.config", config_name="agent_ppo_trainer", version_base=None)
16+
def main(config):
17+
train_dataset = DatasetRegistry.load_dataset("countdown", "train")
18+
test_dataset = DatasetRegistry.load_dataset("countdown", "test")
19+
20+
trainer = AgentTrainer(
21+
agent_run_func=run_workflow,
22+
workflow_class=SolverJudgeWorkflow,
23+
workflow_args={
24+
"n_solutions": 2,
25+
"reward_function": countdown_reward_fn,
26+
},
27+
config=config,
28+
train_dataset=train_dataset,
29+
val_dataset=test_dataset,
30+
)
31+
trainer.train()
32+
33+
34+
if __name__ == "__main__":
35+
main()

0 commit comments

Comments
 (0)