Skip to content

Commit 02d4bfd

Browse files
committed
[Trainers] SAC Trainer and algorithms
ghstack-source-id: 54b9450 Pull-Request: #3172
1 parent fbdbb61 commit 02d4bfd

File tree

12 files changed

+786
-49
lines changed

12 files changed

+786
-49
lines changed

docs/source/reference/trainers.rst

Lines changed: 83 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ Trainer and hooks
183183
Trainer
184184
TrainerHookBase
185185
UpdateWeights
186+
TargetNetUpdaterHook
186187

187188

188189
Algorithm-specific trainers (Experimental)
@@ -202,37 +203,54 @@ into complete training solutions with sensible defaults and comprehensive config
202203
:template: rl_template.rst
203204

204205
PPOTrainer
206+
SACTrainer
205207

206-
PPOTrainer
207-
~~~~~~~~~~
208+
Algorithm Trainers
209+
~~~~~~~~~~~~~~~~~~
208210

209-
The :class:`~torchrl.trainers.algorithms.PPOTrainer` provides a complete PPO training solution
210-
with configurable defaults and a comprehensive configuration system built on Hydra.
211+
TorchRL provides high-level algorithm trainers that offer complete training solutions with minimal code.
212+
These trainers feature comprehensive configuration systems built on Hydra, enabling both simple usage
213+
and sophisticated customization.
214+
215+
**Currently Available:**
216+
217+
- :class:`~torchrl.trainers.algorithms.PPOTrainer` - Proximal Policy Optimization
218+
- :class:`~torchrl.trainers.algorithms.SACTrainer` - Soft Actor-Critic
211219

212220
**Key Features:**
213221

214-
- Complete training pipeline with environment setup, data collection, and optimization
215-
- Extensive configuration system using dataclasses and Hydra
216-
- Built-in logging for rewards, actions, and training statistics
217-
- Modular design built on existing TorchRL components
218-
- **Minimal code**: Complete SOTA implementation in just ~20 lines!
222+
- **Complete pipeline**: Environment setup, data collection, and optimization
223+
- **Hydra configuration**: Extensive dataclass-based configuration system
224+
- **Built-in logging**: Rewards, actions, and algorithm-specific metrics
225+
- **Modular design**: Built on existing TorchRL components
226+
- **Minimal code**: Complete SOTA implementations in ~20 lines!
219227

220228
.. warning::
221-
This is an experimental feature. The API may change in future versions.
222-
We welcome feedback and contributions to help improve this implementation!
229+
Algorithm trainers are experimental features. The API may change in future versions.
230+
We welcome feedback and contributions to help improve these implementations!
223231

224-
**Quick Start - Command Line Interface:**
232+
Quick Start Examples
233+
^^^^^^^^^^^^^^^^^^^^
234+
235+
**PPO Training:**
225236

226237
.. code-block:: bash
227238
228-
# Basic usage - train PPO on Pendulum-v1 with default settings
239+
# Train PPO on Pendulum-v1 with default settings
229240
python sota-implementations/ppo_trainer/train.py
230241
242+
**SAC Training:**
243+
244+
.. code-block:: bash
245+
246+
# Train SAC on a continuous control task
247+
python sota-implementations/sac_trainer/train.py
248+
231249
**Custom Configuration:**
232250

233251
.. code-block:: bash
234252
235-
# Override specific parameters via command line
253+
# Override parameters for any algorithm
236254
python sota-implementations/ppo_trainer/train.py \
237255
trainer.total_frames=2000000 \
238256
training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
@@ -243,32 +261,34 @@ with configurable defaults and a comprehensive configuration system built on Hyd
243261

244262
.. code-block:: bash
245263
246-
# Switch to a different environment and logger
247-
python sota-implementations/ppo_trainer/train.py \
248-
env=gym \
264+
# Switch environment and logger for any trainer
265+
python sota-implementations/sac_trainer/train.py \
249266
training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
250-
logger=tensorboard
267+
logger=tensorboard \
268+
logger.exp_name=sac_walker2d
251269
252-
**See All Options:**
270+
**View Configuration Options:**
253271

254272
.. code-block:: bash
255273
256-
# View all available configuration options
274+
# See all available options for any trainer
257275
python sota-implementations/ppo_trainer/train.py --help
276+
python sota-implementations/sac_trainer/train.py --help
258277
259-
**Configuration Groups:**
278+
Universal Configuration System
279+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
260280

261-
The PPOTrainer configuration is organized into logical groups:
281+
All algorithm trainers share a unified configuration architecture organized into logical groups:
262282

263-
- **Environment**: ``env_cfg__env_name``, ``env_cfg__backend``, ``env_cfg__device``
264-
- **Networks**: ``actor_network__network__num_cells``, ``critic_network__module__num_cells``
265-
- **Training**: ``total_frames``, ``clip_norm``, ``num_epochs``, ``optimizer_cfg__lr``
266-
- **Logging**: ``log_rewards``, ``log_actions``, ``log_observations``
283+
- **Environment**: ``training_env.create_env_fn.base_env.env_name``, ``training_env.num_workers``
284+
- **Networks**: ``networks.policy_network.num_cells``, ``networks.value_network.num_cells``
285+
- **Training**: ``trainer.total_frames``, ``trainer.clip_norm``, ``optimizer.lr``
286+
- **Data**: ``collector.frames_per_batch``, ``replay_buffer.batch_size``, ``replay_buffer.storage.max_size``
287+
- **Logging**: ``logger.exp_name``, ``logger.project``, ``trainer.log_interval``
267288

268289
**Working Example:**
269290

270-
The `sota-implementations/ppo_trainer/ <https://github.com/pytorch/rl/tree/main/sota-implementations/ppo_trainer>`_
271-
directory contains a complete, working PPO implementation that demonstrates the simplicity and power of the trainer system:
291+
All trainer implementations follow the same simple pattern:
272292

273293
.. code-block:: python
274294
@@ -283,33 +303,57 @@ directory contains a complete, working PPO implementation that demonstrates the
283303
if __name__ == "__main__":
284304
main()
285305
286-
*Complete PPO training with full configurability in ~20 lines!*
306+
*Complete algorithm training with full configurability in ~20 lines!*
287307

288-
**Configuration Classes:**
308+
Configuration Classes
309+
^^^^^^^^^^^^^^^^^^^^^
289310

290-
The PPOTrainer uses a hierarchical configuration system with these main config classes.
311+
The trainer system uses a hierarchical configuration system with shared components.
291312

292313
.. note::
293314
The configuration system requires Python 3.10+ due to its use of modern type annotation syntax.
294315

295-
- **Trainer**: :class:`~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig`
316+
**Algorithm-Specific Trainers:**
317+
318+
- **PPO**: :class:`~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig`
319+
- **SAC**: :class:`~torchrl.trainers.algorithms.configs.trainers.SACTrainerConfig`
320+
321+
**Shared Configuration Components:**
322+
296323
- **Environment**: :class:`~torchrl.trainers.algorithms.configs.envs_libs.GymEnvConfig`, :class:`~torchrl.trainers.algorithms.configs.envs.BatchedEnvConfig`
297324
- **Networks**: :class:`~torchrl.trainers.algorithms.configs.modules.MLPConfig`, :class:`~torchrl.trainers.algorithms.configs.modules.TanhNormalModelConfig`
298325
- **Data**: :class:`~torchrl.trainers.algorithms.configs.data.TensorDictReplayBufferConfig`, :class:`~torchrl.trainers.algorithms.configs.collectors.MultiaSyncDataCollectorConfig`
299-
- **Objectives**: :class:`~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig`
326+
- **Objectives**: :class:`~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig`, :class:`~torchrl.trainers.algorithms.configs.objectives.SACLossConfig`
300327
- **Optimizers**: :class:`~torchrl.trainers.algorithms.configs.utils.AdamConfig`, :class:`~torchrl.trainers.algorithms.configs.utils.AdamWConfig`
301328
- **Logging**: :class:`~torchrl.trainers.algorithms.configs.logging.WandbLoggerConfig`, :class:`~torchrl.trainers.algorithms.configs.logging.TensorboardLoggerConfig`
302329

330+
Algorithm-Specific Features
331+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
332+
333+
**PPOTrainer:**
334+
335+
- On-policy learning with advantage estimation
336+
- Policy clipping and value function optimization
337+
- Configurable number of epochs per batch
338+
- Built-in GAE (Generalized Advantage Estimation)
339+
340+
**SACTrainer:**
341+
342+
- Off-policy learning with replay buffer
343+
- Entropy-regularized policy optimization
344+
- Target network soft updates
345+
- Continuous action space optimization
346+
303347
**Future Development:**
304348

305-
This is the first of many planned algorithm-specific trainers. Future releases will include:
349+
The trainer system is actively expanding. Upcoming features include:
306350

307-
- Additional algorithms: SAC, TD3, DQN, A2C, and more
308-
- Full integration of all TorchRL components within the configuration system
309-
- Enhanced configuration validation and error reporting
310-
- Distributed training support for high-level trainers
351+
- Additional algorithms: TD3, DQN, A2C, DDPG, and more
352+
- Enhanced distributed training support
353+
- Advanced configuration validation and error reporting
354+
- Integration with more TorchRL ecosystem components
311355

312-
See the complete `configuration system documentation <https://github.com/pytorch/rl/tree/main/torchrl/trainers/algorithms/configs>`_ for all available options.
356+
See the complete `configuration system documentation <https://github.com/pytorch/rl/tree/main/torchrl/trainers/algorithms/configs>`_ for all available options and examples.
313357

314358

315359
Builders
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# SAC Trainer Configuration for HalfCheetah-v4
2+
# This configuration uses the new configurable trainer system and matches SOTA SAC implementation
3+
4+
defaults:
5+
6+
- transform@transform0: step_counter
7+
- transform@transform1: double_to_float
8+
9+
- env@training_env: batched_env
10+
- env@training_env.create_env_fn: transformed_env
11+
- env@training_env.create_env_fn.base_env: gym
12+
- transform@training_env.create_env_fn.transform: compose
13+
14+
- model@models.policy_model: tanh_normal
15+
- model@models.value_model: value
16+
- model@models.qvalue_model: value
17+
18+
- network@networks.policy_network: mlp
19+
- network@networks.value_network: mlp
20+
- network@networks.qvalue_network: mlp
21+
22+
- collector@collector: multi_async
23+
24+
- replay_buffer@replay_buffer: base
25+
- storage@replay_buffer.storage: lazy_tensor
26+
- writer@replay_buffer.writer: round_robin
27+
- sampler@replay_buffer.sampler: random
28+
- trainer@trainer: sac
29+
- optimizer@optimizer: adam
30+
- loss@loss: sac
31+
- target_net_updater@target_net_updater: soft
32+
- logger@logger: wandb
33+
- _self_
34+
35+
# Network configurations
36+
networks:
37+
policy_network:
38+
out_features: 12 # HalfCheetah action space is 6-dimensional (loc + scale)
39+
in_features: 17 # HalfCheetah observation space is 17-dimensional
40+
num_cells: [256, 256]
41+
42+
value_network:
43+
out_features: 1 # Value output
44+
in_features: 17 # HalfCheetah observation space
45+
num_cells: [256, 256]
46+
47+
qvalue_network:
48+
out_features: 1 # Q-value output
49+
in_features: 23 # HalfCheetah observation space (17) + action space (6)
50+
num_cells: [256, 256]
51+
52+
# Model configurations
53+
models:
54+
policy_model:
55+
return_log_prob: true
56+
in_keys: ["observation"]
57+
param_keys: ["loc", "scale"]
58+
out_keys: ["action"]
59+
network: ${networks.policy_network}
60+
61+
qvalue_model:
62+
in_keys: ["observation", "action"]
63+
out_keys: ["state_action_value"]
64+
network: ${networks.qvalue_network}
65+
66+
transform0:
67+
max_steps: 1000
68+
step_count_key: "step_count"
69+
70+
transform1:
71+
# DoubleToFloatTransform - converts double precision to float to fix dtype mismatch
72+
in_keys: null
73+
out_keys: null
74+
75+
training_env:
76+
num_workers: 4
77+
create_env_fn:
78+
base_env:
79+
env_name: HalfCheetah-v4
80+
transform:
81+
transforms:
82+
- ${transform0}
83+
- ${transform1}
84+
_partial_: true
85+
86+
# Loss configuration
87+
loss:
88+
actor_network: ${models.policy_model}
89+
qvalue_network: ${models.qvalue_model}
90+
target_entropy: "auto"
91+
loss_function: l2
92+
alpha_init: 1.0
93+
delay_qvalue: true
94+
num_qvalue_nets: 2
95+
96+
target_net_updater:
97+
tau: 0.001
98+
99+
# Optimizer configuration
100+
optimizer:
101+
lr: 3.0e-4
102+
103+
# Collector configuration
104+
collector:
105+
create_env_fn: ${training_env}
106+
policy: ${models.policy_model}
107+
total_frames: 1_000_000
108+
frames_per_batch: 1000
109+
num_workers: 4
110+
init_random_frames: 25000
111+
track_policy_version: true
112+
113+
# Replay buffer configuration
114+
replay_buffer:
115+
storage:
116+
max_size: 1_000_000
117+
device: cpu
118+
ndim: 1
119+
sampler:
120+
writer:
121+
compilable: false
122+
batch_size: 256
123+
124+
logger:
125+
exp_name: sac_halfcheetah_v4
126+
offline: false
127+
project: torchrl-sota-implementations
128+
129+
# Trainer configuration
130+
trainer:
131+
collector: ${collector}
132+
optimizer: ${optimizer}
133+
replay_buffer: ${replay_buffer}
134+
target_net_updater: ${target_net_updater}
135+
loss_module: ${loss}
136+
logger: ${logger}
137+
total_frames: 1_000_000
138+
frame_skip: 1
139+
clip_grad_norm: false # SAC typically doesn't use gradient clipping
140+
clip_norm: null
141+
progress_bar: true
142+
seed: 42
143+
save_trainer_interval: 25000 # Match SOTA eval_iter
144+
log_interval: 25000
145+
save_trainer_file: null
146+
optim_steps_per_batch: 64 # Match SOTA utd_ratio
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This source code is licensed under the MIT license found in the
3+
# LICENSE file in the root directory of this source tree.
4+
5+
import hydra
6+
import torchrl
7+
from torchrl.trainers.algorithms.configs import * # noqa: F401, F403
8+
9+
10+
@hydra.main(config_path="config", config_name="config", version_base="1.1")
11+
def main(cfg):
12+
def print_reward(td):
13+
torchrl.logger.info(f"reward: {td['next', 'reward'].mean(): 4.4f}")
14+
15+
trainer = hydra.utils.instantiate(cfg.trainer)
16+
trainer.register_op(dest="batch_process", op=print_reward)
17+
trainer.train()
18+
19+
20+
if __name__ == "__main__":
21+
main()

torchrl/trainers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ReplayBufferTrainer,
1717
RewardNormalizer,
1818
SelectKeys,
19+
TargetNetUpdaterHook,
1920
Trainer,
2021
TrainerHookBase,
2122
UpdateWeights,
@@ -37,4 +38,5 @@
3738
"Trainer",
3839
"TrainerHookBase",
3940
"UpdateWeights",
41+
"TargetNetUpdaterHook",
4042
]

torchrl/trainers/algorithms/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
from __future__ import annotations
77

88
from .ppo import PPOTrainer
9+
from .sac import SACTrainer
910

10-
__all__ = ["PPOTrainer"]
11+
__all__ = ["PPOTrainer", "SACTrainer"]

0 commit comments

Comments
 (0)