Skip to content

Commit c82fe60

Browse files
authored
Merge pull request #84 from meta-pytorch/OpenEnv-Forge
[Torchforge] Open env forge example
2 parents 53fa6a9 + dbac0f0 commit c82fe60

File tree

5 files changed

+1701
-2
lines changed

5 files changed

+1701
-2
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ An e2e framework for creating, deploying and using isolated execution environmen
66
[![Discord](https://img.shields.io/badge/Discord-OpenEnv-7289da?style=flat&logo=discord&logoColor=white)](https://discord.gg/YsTYBh6PD9)
77
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-pytorch/OpenEnv/blob/main/examples/OpenEnv_Tutorial.ipynb) **← Try the Interactive Tutorial!**
88

9+
---
10+
11+
**🚀 Featured Example:** Train LLMs to play BlackJack using [torchforge](https://github.com/meta-pytorch/torchforge) (PyTorch's agentic RL framework): [`examples/grpo_blackjack/`](examples/grpo_blackjack/)
12+
913
## OpenEnv on partner platforms:
1014

1115
- [Lightning AI Studio](https://lightning.ai/environments?section=featured)
@@ -178,10 +182,10 @@ client.close() # Stops and removes container
178182
- smolagents (for coding environment)
179183

180184
## Supported RL Tools
181-
The goal of this project is to support a broad set of open and closed tools to help standardize the agentic RL community. If you have a project that supports OpenEnv environments, please put up a PR to add your tool name along with a link to your documentation.
185+
The goal of this project is to support a broad set of open and closed tools to help standardize the agentic RL community. If you have a project that supports OpenEnv environments, please put up a PR to add your tool name along with a link to your documentation.
182186

183187
### torchforge
184-
(coming soon)
188+
See GRPO BlackJack training example: [`examples/grpo_blackjack/`](examples/grpo_blackjack/)
185189

186190
### TRL
187191
(coming soon}

examples/grpo_blackjack/README.md

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Training LLMs to Play BlackJack with GRPO + OpenEnv
2+
3+
This example demonstrates how to train language models to play BlackJack using **GRPO (Group Relative Policy Optimization)** and **OpenEnv**.
4+
5+
## 🎯 What This Example Shows
6+
7+
- **OpenEnv**: Universal RL environment interface for 70+ environments
8+
- **GRPO**: Efficient RL algorithm (used by DeepSeek R1) that only needs 2 models instead of 3
9+
- **Forge**: PyTorch-native agentic RL library for production training
10+
- **End-to-End Training**: From random policy (~35% win rate) to trained agent
11+
12+
## 📁 Files
13+
14+
- `grpo_blackjack_tutorial.ipynb` - Interactive tutorial notebook (recommended starting point)
15+
- `grpo_utils.py` - Production GRPO utilities and helper functions
16+
- `blackjack.yaml` - Training configuration file
17+
- `README.md` - This file
18+
19+
## 🚀 Quick Start
20+
21+
### Prerequisites
22+
23+
1. **Install OpenEnv**:
24+
```bash
25+
# Clone OpenEnv repo
26+
git clone https://github.com/meta-pytorch/OpenEnv.git
27+
cd OpenEnv
28+
pip install -e .
29+
```
30+
31+
2. **Install Forge** (PyTorch's agentic RL library):
32+
```bash
33+
git clone https://github.com/meta-pytorch/torchforge.git
34+
cd torchforge
35+
pip install -e .
36+
```
37+
38+
3. **Start OpenEnv BlackJack Server**:
39+
```bash
40+
# In a separate terminal
41+
export OPENENV_PATH="/path/to/OpenEnv/src"
42+
export PYTHONPATH="${OPENENV_PATH}:${PYTHONPATH}"
43+
44+
OPENSPIEL_GAME=blackjack python -m envs.openspiel_env.server.app --port 8004
45+
```
46+
47+
### Run the Tutorial
48+
49+
Open the Jupyter notebook:
50+
```bash
51+
jupyter notebook grpo_blackjack_tutorial.ipynb
52+
```
53+
54+
Follow the cells to:
55+
1. **Explore OpenEnv** - Connect to BlackJack environment
56+
2. **Benchmark baseline** - Test random policy performance
57+
3. **Learn about GRPO** - Understand the training algorithm
58+
4. **Train with Forge** - Run production GRPO training
59+
5. **Switch environments** - See how to train on other games
60+
61+
## 📚 What You'll Learn
62+
63+
### OpenEnv: Universal RL Environment Spec
64+
65+
OpenEnv is **not a game engine** - it's a **specification** that wraps ANY RL environment:
66+
67+
```python
68+
# Same interface works for 70+ environments
69+
result = env.reset() # Start episode
70+
result = env.step(action) # Take action
71+
state = env.state() # Get state
72+
env.close() # Cleanup
73+
```
74+
75+
Change one environment variable → train on different games!
76+
77+
### Forge: PyTorch-Native Agentic RL
78+
79+
Forge handles all distributed systems complexity:
80+
- **Generator (vLLM)**: Fast LLM inference
81+
- **RLTrainer**: Distributed training with FSDP
82+
- **ReplayBuffer**: Off-policy learning
83+
- **ReferenceModel**: KL penalty computation
84+
- **Torchstore**: Distributed weight management
85+
86+
You just write:
87+
```python
88+
trainer = await setup_forge_training("blackjack.yaml")
89+
await trainer.run(steps=100)
90+
```
91+
92+
Everything else is automated!
93+
94+
## 🎓 Educational Resources
95+
96+
This tutorial is inspired by the excellent [Unsloth RL Guide](https://docs.unsloth.ai/get-started/reinforcement-learning-rl-guide). We highly recommend reading it for deeper insights!
97+
98+
### Further Reading
99+
100+
- **OpenEnv**: [GitHub](https://github.com/meta-pytorch/OpenEnv)
101+
- **GRPO Paper**: [arXiv:2402.03300](https://arxiv.org/abs/2402.03300)
102+
- **Forge**: [GitHub](https://github.com/meta-pytorch/torchforge) | [Docs](https://meta-pytorch.org/torchforge/)
103+
- **Unsloth RL Guide**: [docs.unsloth.ai](https://docs.unsloth.ai/get-started/reinforcement-learning-rl-guide)
104+
105+
## 💡 Key Concepts
106+
107+
### "Patience Is All You Need" for RL
108+
109+
RL works by patience: if the correct answer has *any* non-zero probability, we'll eventually find it through sampling. While waiting:
110+
1. Learn from **bad answers** → decrease their probability
111+
2. When finding **good answers** → increase their probability
112+
113+
Over time, the model learns not just *what* to do, but *why* (reasoning process).
114+
115+
### Reward Functions
116+
117+
Reward functions tell the model what's good/bad. For BlackJack:
118+
119+
```python
120+
def evaluate_response(prompt, response, game_reward):
121+
reward = float(game_reward) # +1 (win), -1 (loss), 0 (push)
122+
123+
# Reward shaping
124+
if game_reward > 0:
125+
reward = 2.0 # Wins more valuable
126+
elif game_reward == 0:
127+
reward = 0.5 # Pushes better than losses
128+
129+
return reward
130+
```
131+
132+
The key: **Reward functions must be verifiable**. You can verify "is the answer correct?" but not "is this creative?"
133+
134+
## 🔄 Switching to Other Games
135+
136+
The beauty of OpenEnv: **same code works for any environment!**
137+
138+
### Try Tic-Tac-Toe
139+
```bash
140+
OPENSPIEL_GAME=tic_tac_toe python -m envs.openspiel_env.server.app --port 8005
141+
```
142+
Update config: `server_url = "http://localhost:8005"`
143+
144+
### Try Chess
145+
```bash
146+
OPENSPIEL_GAME=chess python -m envs.openspiel_env.server.app --port 8006
147+
```
148+
149+
### Try Atari
150+
```bash
151+
python -m envs.atari_env.server.app --game pong --port 8007
152+
```
153+
154+
Everything else stays the same! Same GRPO code, same Forge infrastructure.
155+
156+
## 🛠️ Customization
157+
158+
All code is in `grpo_utils.py`:
159+
- Modify `BlackJackReward.evaluate_response()` for reward shaping
160+
- Adjust `ComputeAdvantages.compute()` for advantage computation
161+
- Tweak `simple_grpo_loss()` for KL penalty (beta parameter)
162+
- Change `format_prompt()` for different prompt templates
163+
164+
Edit `blackjack.yaml` for:
165+
- Different model sizes (1B to 70B+)
166+
- More training steps
167+
- Larger group sizes
168+
- Parallel rollout collection
169+
170+
## 📊 Expected Results
171+
172+
- **Random policy**: ~35% win rate
173+
- **After GRPO training**: Improves toward optimal BlackJack strategy (~43% win rate)
174+
- **Training time**: Varies based on model size and training steps
175+
176+
The model learns both strategy AND reasoning process (similar to DeepSeek R1's `<think>` tokens).
177+
178+
## 🤝 Credits
179+
180+
- **OpenEnv**: Meta PyTorch team
181+
- **Forge**: Meta PyTorch team
182+
- **GRPO**: DeepSeek research team
183+
- **Tutorial inspiration**: Unsloth team
184+
185+
## 📝 License
186+
187+
This example follows the same license as the parent OpenEnv repository.
188+
189+
## 🙏 Acknowledgments
190+
191+
Big thanks to the **Unsloth team** for their educational approach to RL! This tutorial's GRPO section is heavily inspired by their excellent guide.
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# BlackJack GRPO Training Configuration
2+
# >>> python -m apps.grpo.blackjack_main --config apps/grpo/blackjack.yaml
3+
#
4+
# Prerequisites:
5+
# 1. Start BlackJack server:
6+
# cd /Users/sanyambhutani/OpenEnv/OpenEnv
7+
# export PYTHONPATH="/Users/sanyambhutani/OpenEnv/OpenEnv/src:${PYTHONPATH}"
8+
# OPENSPIEL_GAME=blackjack python -m envs.openspiel_env.server.app
9+
#
10+
# 2. Run training:
11+
# python -m apps.grpo.blackjack_main --config apps/grpo/blackjack.yaml
12+
13+
# Global configuration
14+
group_size: 4 # Number of parallel games per rollout
15+
local_batch_size: 8 # Per-device batch size
16+
max_req_tokens: 512 # Max tokens for prompt (BlackJack prompts are ~200-300 tokens)
17+
max_res_tokens: 32 # Max tokens for response (just "HIT" or "STAND" + thinking)
18+
model: "Qwen/Qwen3-1.7B"
19+
off_by_n: 1 # Off-policy tolerance
20+
21+
# Main loop configuration
22+
rollout_threads: 1 # Number of parallel rollout threads
23+
24+
# Observability configuration
25+
metric_logging:
26+
wandb:
27+
project: "blackjack-grpo-tutorial"
28+
group: "blackjack_exp_${oc.env:USER}"
29+
reduce_across_ranks: True
30+
console:
31+
reduce_across_ranks: True
32+
33+
# BlackJack environment configuration
34+
blackjack_env:
35+
server_url: "http://localhost:8004"
36+
model: ${model}
37+
38+
# Policy configuration (generator)
39+
policy:
40+
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
41+
model: ${model}
42+
tensor_parallel_size: 1
43+
pipeline_parallel_size: 1
44+
enforce_eager: false
45+
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
46+
n: 1 # Generate 1 response per game state (not group_size, since we play full games)
47+
max_tokens: ${max_res_tokens}
48+
temperature: 1.0
49+
top_p: 1.0
50+
51+
# Trainer configuration
52+
trainer:
53+
model:
54+
name: qwen3
55+
flavor: 1.7B
56+
hf_assets_path: hf://${model}
57+
optimizer:
58+
name: AdamW
59+
lr: 1e-5
60+
eps: 1e-8
61+
lr_scheduler:
62+
warmup_steps: 1
63+
training:
64+
local_batch_size: ${local_batch_size}
65+
seq_len: 1024 # Shorter than GSM8K since BlackJack episodes are shorter
66+
max_norm: 1.0
67+
steps: 1000 # Tutorial: 1000 steps (increase for production)
68+
dtype: bfloat16
69+
gc_freq: 1
70+
compile:
71+
enable: false
72+
parallelism:
73+
data_parallel_replicate_degree: 1
74+
data_parallel_shard_degree: 1
75+
tensor_parallel_degree: 1
76+
pipeline_parallel_degree: 1
77+
context_parallel_degree: 1
78+
expert_parallel_degree: 1
79+
disable_loss_parallel: true
80+
checkpoint:
81+
enable: true
82+
initial_load_path: hf://${model}
83+
initial_load_in_hf: true
84+
last_save_in_hf: true
85+
interval: 500
86+
async_mode: "disabled"
87+
activation_checkpoint:
88+
mode: selective
89+
selective_ac_option: op
90+
91+
# Replay buffer configuration
92+
replay_buffer:
93+
batch_size: ${local_batch_size}
94+
max_policy_age: ${off_by_n}
95+
dp_size: ${trainer.parallelism.data_parallel_shard_degree}
96+
97+
# Reference model configuration
98+
ref_model:
99+
model:
100+
name: qwen3
101+
flavor: 1.7B
102+
hf_assets_path: hf://${model}
103+
training:
104+
seq_len: ${trainer.training.seq_len}
105+
dtype: bfloat16
106+
gc_freq: 1
107+
compile:
108+
enable: false
109+
parallelism:
110+
data_parallel_replicate_degree: 1
111+
data_parallel_shard_degree: 1
112+
tensor_parallel_degree: 1
113+
pipeline_parallel_degree: 1
114+
context_parallel_degree: 1
115+
expert_parallel_degree: 1
116+
checkpoint:
117+
enable: true
118+
initial_load_path: hf://${model}
119+
initial_load_in_hf: true
120+
121+
# All resource allocations
122+
services:
123+
policy:
124+
procs: ${policy.engine_args.tensor_parallel_size}
125+
num_replicas: 1
126+
mesh_name: policy
127+
with_gpus: true
128+
ref_model:
129+
procs: 1
130+
num_replicas: 1
131+
mesh_name: ref_model
132+
with_gpus: true
133+
reward_actor:
134+
procs: 1
135+
num_replicas: 1
136+
mesh_name: reward_actor
137+
with_gpus: false
138+
139+
actors:
140+
blackjack_env:
141+
procs: 1
142+
with_gpus: false
143+
mesh_name: blackjack_env
144+
trainer:
145+
procs: 1
146+
with_gpus: true
147+
mesh_name: trainer
148+
replay_buffer:
149+
procs: 1
150+
with_gpus: false
151+
mesh_name: replay_buffer
152+
compute_advantages:
153+
procs: 1
154+
with_gpus: false
155+
mesh_name: compute_advantages

0 commit comments

Comments
 (0)