@@ -183,6 +183,7 @@ Trainer and hooks
183183 Trainer
184184 TrainerHookBase
185185 UpdateWeights
186+ TargetNetUpdaterHook
186187
187188
188189Algorithm-specific trainers (Experimental)
@@ -202,37 +203,54 @@ into complete training solutions with sensible defaults and comprehensive config
202203 :template: rl_template.rst
203204
204205 PPOTrainer
206+ SACTrainer
205207
206- PPOTrainer
207- ~~~~~~~~~~
208+ Algorithm Trainers
209+ ~~~~~~~~~~~~~~~~~~
208210
209- The :class: `~torchrl.trainers.algorithms.PPOTrainer ` provides a complete PPO training solution
210- with configurable defaults and a comprehensive configuration system built on Hydra.
211+ TorchRL provides high-level algorithm trainers that offer complete training solutions with minimal code.
212+ These trainers feature comprehensive configuration systems built on Hydra, enabling both simple usage
213+ and sophisticated customization.
214+
215+ **Currently Available: **
216+
217+ - :class: `~torchrl.trainers.algorithms.PPOTrainer ` - Proximal Policy Optimization
218+ - :class: `~torchrl.trainers.algorithms.SACTrainer ` - Soft Actor-Critic
211219
212220**Key Features: **
213221
214- - Complete training pipeline with environment setup, data collection, and optimization
215- - Extensive configuration system using dataclasses and Hydra
216- - Built-in logging for rewards , actions, and training statistics
217- - Modular design built on existing TorchRL components
218- - **Minimal code **: Complete SOTA implementation in just ~20 lines!
222+ - ** Complete pipeline **: Environment setup, data collection, and optimization
223+ - ** Hydra configuration**: Extensive dataclass-based configuration system
224+ - ** Built-in logging **: Rewards , actions, and algorithm-specific metrics
225+ - ** Modular design **: Built on existing TorchRL components
226+ - **Minimal code **: Complete SOTA implementations in ~20 lines!
219227
220228.. warning ::
221- This is an experimental feature . The API may change in future versions.
222- We welcome feedback and contributions to help improve this implementation !
229+ Algorithm trainers are experimental features . The API may change in future versions.
230+ We welcome feedback and contributions to help improve these implementations !
223231
224- **Quick Start - Command Line Interface: **
232+ Quick Start Examples
233+ ^^^^^^^^^^^^^^^^^^^^
234+
235+ **PPO Training: **
225236
226237.. code-block :: bash
227238
228- # Basic usage - train PPO on Pendulum-v1 with default settings
239+ # Train PPO on Pendulum-v1 with default settings
229240 python sota-implementations/ppo_trainer/train.py
230241
242+ **SAC Training: **
243+
244+ .. code-block :: bash
245+
246+ # Train SAC on a continuous control task
247+ python sota-implementations/sac_trainer/train.py
248+
231249 **Custom Configuration: **
232250
233251.. code-block :: bash
234252
235- # Override specific parameters via command line
253+ # Override parameters for any algorithm
236254 python sota-implementations/ppo_trainer/train.py \
237255 trainer.total_frames=2000000 \
238256 training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
@@ -243,32 +261,34 @@ with configurable defaults and a comprehensive configuration system built on Hyd
243261
244262.. code-block :: bash
245263
246- # Switch to a different environment and logger
247- python sota-implementations/ppo_trainer/train.py \
248- env=gym \
264+ # Switch environment and logger for any trainer
265+ python sota-implementations/sac_trainer/train.py \
249266 training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
250- logger=tensorboard
267+ logger=tensorboard \
268+ logger.exp_name=sac_walker2d
251269
252- **See All Options: **
270+ **View Configuration Options: **
253271
254272.. code-block :: bash
255273
256- # View all available configuration options
274+ # See all available options for any trainer
257275 python sota-implementations/ppo_trainer/train.py --help
276+ python sota-implementations/sac_trainer/train.py --help
258277
259- **Configuration Groups: **
278+ Universal Configuration System
279+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
260280
261- The PPOTrainer configuration is organized into logical groups:
281+ All algorithm trainers share a unified configuration architecture organized into logical groups:
262282
263- - **Environment **: ``env_cfg__env_name ``, ``env_cfg__backend ``, ``env_cfg__device ``
264- - **Networks **: ``actor_network__network__num_cells ``, ``critic_network__module__num_cells ``
265- - **Training **: ``total_frames ``, ``clip_norm ``, ``num_epochs ``, ``optimizer_cfg__lr ``
266- - **Logging **: ``log_rewards ``, ``log_actions ``, ``log_observations ``
283+ - **Environment **: ``training_env.create_env_fn.base_env.env_name ``, ``training_env.num_workers ``
284+ - **Networks **: ``networks.policy_network.num_cells ``, ``networks.value_network.num_cells ``
285+ - **Training **: ``trainer.total_frames ``, ``trainer.clip_norm ``, ``optimizer.lr ``
286+ - **Data **: ``collector.frames_per_batch ``, ``replay_buffer.batch_size ``, ``replay_buffer.storage.max_size ``
287+ - **Logging **: ``logger.exp_name ``, ``logger.project ``, ``trainer.log_interval ``
267288
268289**Working Example: **
269290
270- The `sota-implementations/ppo_trainer/ <https://github.com/pytorch/rl/tree/main/sota-implementations/ppo_trainer >`_
271- directory contains a complete, working PPO implementation that demonstrates the simplicity and power of the trainer system:
291+ All trainer implementations follow the same simple pattern:
272292
273293.. code-block :: python
274294
@@ -283,33 +303,57 @@ directory contains a complete, working PPO implementation that demonstrates the
283303 if __name__ == " __main__" :
284304 main()
285305
286- *Complete PPO training with full configurability in ~20 lines! *
306+ *Complete algorithm training with full configurability in ~20 lines! *
287307
288- **Configuration Classes: **
308+ Configuration Classes
309+ ^^^^^^^^^^^^^^^^^^^^^
289310
290- The PPOTrainer uses a hierarchical configuration system with these main config classes .
311+ The trainer system uses a hierarchical configuration system with shared components .
291312
292313.. note ::
293314 The configuration system requires Python 3.10+ due to its use of modern type annotation syntax.
294315
295- - **Trainer **: :class: `~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig `
316+ **Algorithm-Specific Trainers: **
317+
318+ - **PPO **: :class: `~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig `
319+ - **SAC **: :class: `~torchrl.trainers.algorithms.configs.trainers.SACTrainerConfig `
320+
321+ **Shared Configuration Components: **
322+
296323- **Environment **: :class: `~torchrl.trainers.algorithms.configs.envs_libs.GymEnvConfig `, :class: `~torchrl.trainers.algorithms.configs.envs.BatchedEnvConfig `
297324- **Networks **: :class: `~torchrl.trainers.algorithms.configs.modules.MLPConfig `, :class: `~torchrl.trainers.algorithms.configs.modules.TanhNormalModelConfig `
298325- **Data **: :class: `~torchrl.trainers.algorithms.configs.data.TensorDictReplayBufferConfig `, :class: `~torchrl.trainers.algorithms.configs.collectors.MultiaSyncDataCollectorConfig `
299- - **Objectives **: :class: `~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig `
326+ - **Objectives **: :class: `~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig `, :class: ` ~torchrl.trainers.algorithms.configs.objectives.SACLossConfig `
300327- **Optimizers **: :class: `~torchrl.trainers.algorithms.configs.utils.AdamConfig `, :class: `~torchrl.trainers.algorithms.configs.utils.AdamWConfig `
301328- **Logging **: :class: `~torchrl.trainers.algorithms.configs.logging.WandbLoggerConfig `, :class: `~torchrl.trainers.algorithms.configs.logging.TensorboardLoggerConfig `
302329
330+ Algorithm-Specific Features
331+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
332+
333+ **PPOTrainer: **
334+
335+ - On-policy learning with advantage estimation
336+ - Policy clipping and value function optimization
337+ - Configurable number of epochs per batch
338+ - Built-in GAE (Generalized Advantage Estimation)
339+
340+ **SACTrainer: **
341+
342+ - Off-policy learning with replay buffer
343+ - Entropy-regularized policy optimization
344+ - Target network soft updates
345+ - Continuous action space optimization
346+
303347**Future Development: **
304348
305- This is the first of many planned algorithm-specific trainers. Future releases will include:
349+ The trainer system is actively expanding. Upcoming features include:
306350
307- - Additional algorithms: SAC, TD3, DQN, A2C, and more
308- - Full integration of all TorchRL components within the configuration system
309- - Enhanced configuration validation and error reporting
310- - Distributed training support for high-level trainers
351+ - Additional algorithms: TD3, DQN, A2C, DDPG , and more
352+ - Enhanced distributed training support
353+ - Advanced configuration validation and error reporting
354+ - Integration with more TorchRL ecosystem components
311355
312- See the complete `configuration system documentation <https://github.com/pytorch/rl/tree/main/torchrl/trainers/algorithms/configs >`_ for all available options.
356+ See the complete `configuration system documentation <https://github.com/pytorch/rl/tree/main/torchrl/trainers/algorithms/configs >`_ for all available options and examples .
313357
314358
315359Builders
0 commit comments