|
| 1 | +# Training LLMs to Play BlackJack with GRPO + OpenEnv |
| 2 | + |
| 3 | +This example demonstrates how to train language models to play BlackJack using **GRPO (Group Relative Policy Optimization)** and **OpenEnv**. |
| 4 | + |
| 5 | +## 🎯 What This Example Shows |
| 6 | + |
| 7 | +- **OpenEnv**: Universal RL environment interface for 70+ environments |
| 8 | +- **GRPO**: Efficient RL algorithm (used by DeepSeek R1) that only needs 2 models instead of 3 |
| 9 | +- **Forge**: PyTorch-native agentic RL library for production training |
| 10 | +- **End-to-End Training**: From random policy (~35% win rate) to trained agent |
| 11 | + |
| 12 | +## 📁 Files |
| 13 | + |
| 14 | +- `grpo_blackjack_tutorial.ipynb` - Interactive tutorial notebook (recommended starting point) |
| 15 | +- `grpo_utils.py` - Production GRPO utilities and helper functions |
| 16 | +- `blackjack.yaml` - Training configuration file |
| 17 | +- `README.md` - This file |
| 18 | + |
| 19 | +## 🚀 Quick Start |
| 20 | + |
| 21 | +### Prerequisites |
| 22 | + |
| 23 | +1. **Install OpenEnv**: |
| 24 | + ```bash |
| 25 | + # Clone OpenEnv repo |
| 26 | + git clone https://github.com/meta-pytorch/OpenEnv.git |
| 27 | + cd OpenEnv |
| 28 | + pip install -e . |
| 29 | + ``` |
| 30 | + |
| 31 | +2. **Install Forge** (PyTorch's agentic RL library): |
| 32 | + ```bash |
| 33 | + git clone https://github.com/meta-pytorch/torchforge.git |
| 34 | + cd torchforge |
| 35 | + pip install -e . |
| 36 | + ``` |
| 37 | + |
| 38 | +3. **Start OpenEnv BlackJack Server**: |
| 39 | + ```bash |
| 40 | + # In a separate terminal |
| 41 | + export OPENENV_PATH="/path/to/OpenEnv/src" |
| 42 | + export PYTHONPATH="${OPENENV_PATH}:${PYTHONPATH}" |
| 43 | + |
| 44 | + OPENSPIEL_GAME=blackjack python -m envs.openspiel_env.server.app --port 8004 |
| 45 | + ``` |
| 46 | + |
| 47 | +### Run the Tutorial |
| 48 | + |
| 49 | +Open the Jupyter notebook: |
| 50 | +```bash |
| 51 | +jupyter notebook grpo_blackjack_tutorial.ipynb |
| 52 | +``` |
| 53 | + |
| 54 | +Follow the cells to: |
| 55 | +1. **Explore OpenEnv** - Connect to BlackJack environment |
| 56 | +2. **Benchmark baseline** - Test random policy performance |
| 57 | +3. **Learn about GRPO** - Understand the training algorithm |
| 58 | +4. **Train with Forge** - Run production GRPO training |
| 59 | +5. **Switch environments** - See how to train on other games |
| 60 | + |
| 61 | +## 📚 What You'll Learn |
| 62 | + |
| 63 | +### OpenEnv: Universal RL Environment Spec |
| 64 | + |
| 65 | +OpenEnv is **not a game engine** - it's a **specification** that wraps ANY RL environment: |
| 66 | + |
| 67 | +```python |
| 68 | +# Same interface works for 70+ environments |
| 69 | +result = env.reset() # Start episode |
| 70 | +result = env.step(action) # Take action |
| 71 | +state = env.state() # Get state |
| 72 | +env.close() # Cleanup |
| 73 | +``` |
| 74 | + |
| 75 | +Change one environment variable → train on different games! |
| 76 | + |
| 77 | +### Forge: PyTorch-Native Agentic RL |
| 78 | + |
| 79 | +Forge handles all distributed systems complexity: |
| 80 | +- **Generator (vLLM)**: Fast LLM inference |
| 81 | +- **RLTrainer**: Distributed training with FSDP |
| 82 | +- **ReplayBuffer**: Off-policy learning |
| 83 | +- **ReferenceModel**: KL penalty computation |
| 84 | +- **Torchstore**: Distributed weight management |
| 85 | + |
| 86 | +You just write: |
| 87 | +```python |
| 88 | +trainer = await setup_forge_training("blackjack.yaml") |
| 89 | +await trainer.run(steps=100) |
| 90 | +``` |
| 91 | + |
| 92 | +Everything else is automated! |
| 93 | + |
| 94 | +## 🎓 Educational Resources |
| 95 | + |
| 96 | +This tutorial is inspired by the excellent [Unsloth RL Guide](https://docs.unsloth.ai/get-started/reinforcement-learning-rl-guide). We highly recommend reading it for deeper insights! |
| 97 | + |
| 98 | +### Further Reading |
| 99 | + |
| 100 | +- **OpenEnv**: [GitHub](https://github.com/meta-pytorch/OpenEnv) |
| 101 | +- **GRPO Paper**: [arXiv:2402.03300](https://arxiv.org/abs/2402.03300) |
| 102 | +- **Forge**: [GitHub](https://github.com/meta-pytorch/torchforge) | [Docs](https://meta-pytorch.org/torchforge/) |
| 103 | +- **Unsloth RL Guide**: [docs.unsloth.ai](https://docs.unsloth.ai/get-started/reinforcement-learning-rl-guide) |
| 104 | + |
| 105 | +## 💡 Key Concepts |
| 106 | + |
| 107 | +### "Patience Is All You Need" for RL |
| 108 | + |
| 109 | +RL works by patience: if the correct answer has *any* non-zero probability, we'll eventually find it through sampling. While waiting: |
| 110 | +1. Learn from **bad answers** → decrease their probability |
| 111 | +2. When finding **good answers** → increase their probability |
| 112 | + |
| 113 | +Over time, the model learns not just *what* to do, but *why* (reasoning process). |
| 114 | + |
| 115 | +### Reward Functions |
| 116 | + |
| 117 | +Reward functions tell the model what's good/bad. For BlackJack: |
| 118 | + |
| 119 | +```python |
| 120 | +def evaluate_response(prompt, response, game_reward): |
| 121 | + reward = float(game_reward) # +1 (win), -1 (loss), 0 (push) |
| 122 | + |
| 123 | + # Reward shaping |
| 124 | + if game_reward > 0: |
| 125 | + reward = 2.0 # Wins more valuable |
| 126 | + elif game_reward == 0: |
| 127 | + reward = 0.5 # Pushes better than losses |
| 128 | + |
| 129 | + return reward |
| 130 | +``` |
| 131 | + |
| 132 | +The key: **Reward functions must be verifiable**. You can verify "is the answer correct?" but not "is this creative?" |
| 133 | + |
| 134 | +## 🔄 Switching to Other Games |
| 135 | + |
| 136 | +The beauty of OpenEnv: **same code works for any environment!** |
| 137 | + |
| 138 | +### Try Tic-Tac-Toe |
| 139 | +```bash |
| 140 | +OPENSPIEL_GAME=tic_tac_toe python -m envs.openspiel_env.server.app --port 8005 |
| 141 | +``` |
| 142 | +Update config: `server_url = "http://localhost:8005"` |
| 143 | + |
| 144 | +### Try Chess |
| 145 | +```bash |
| 146 | +OPENSPIEL_GAME=chess python -m envs.openspiel_env.server.app --port 8006 |
| 147 | +``` |
| 148 | + |
| 149 | +### Try Atari |
| 150 | +```bash |
| 151 | +python -m envs.atari_env.server.app --game pong --port 8007 |
| 152 | +``` |
| 153 | + |
| 154 | +Everything else stays the same! Same GRPO code, same Forge infrastructure. |
| 155 | + |
| 156 | +## 🛠️ Customization |
| 157 | + |
| 158 | +All code is in `grpo_utils.py`: |
| 159 | +- Modify `BlackJackReward.evaluate_response()` for reward shaping |
| 160 | +- Adjust `ComputeAdvantages.compute()` for advantage computation |
| 161 | +- Tweak `simple_grpo_loss()` for KL penalty (beta parameter) |
| 162 | +- Change `format_prompt()` for different prompt templates |
| 163 | + |
| 164 | +Edit `blackjack.yaml` for: |
| 165 | +- Different model sizes (1B to 70B+) |
| 166 | +- More training steps |
| 167 | +- Larger group sizes |
| 168 | +- Parallel rollout collection |
| 169 | + |
| 170 | +## 📊 Expected Results |
| 171 | + |
| 172 | +- **Random policy**: ~35% win rate |
| 173 | +- **After GRPO training**: Improves toward optimal BlackJack strategy (~43% win rate) |
| 174 | +- **Training time**: Varies based on model size and training steps |
| 175 | + |
| 176 | +The model learns both strategy AND reasoning process (similar to DeepSeek R1's `<think>` tokens). |
| 177 | + |
| 178 | +## 🤝 Credits |
| 179 | + |
| 180 | +- **OpenEnv**: Meta PyTorch team |
| 181 | +- **Forge**: Meta PyTorch team |
| 182 | +- **GRPO**: DeepSeek research team |
| 183 | +- **Tutorial inspiration**: Unsloth team |
| 184 | + |
| 185 | +## 📝 License |
| 186 | + |
| 187 | +This example follows the same license as the parent OpenEnv repository. |
| 188 | + |
| 189 | +## 🙏 Acknowledgments |
| 190 | + |
| 191 | +Big thanks to the **Unsloth team** for their educational approach to RL! This tutorial's GRPO section is heavily inspired by their excellent guide. |
0 commit comments