Skip to content

Commit 5f1eb2c

Browse files
committed
[Feature] Async collection within trainers
ghstack-source-id: f9b4d79 Pull-Request: #3173
1 parent 02d4bfd commit 5f1eb2c

File tree

19 files changed

+586
-51
lines changed

19 files changed

+586
-51
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,6 +1123,7 @@ to be able to create this other composition:
11231123
ExcludeTransform
11241124
FiniteTensorDictCheck
11251125
FlattenObservation
1126+
FlattenTensorDict
11261127
FrameSkipTransform
11271128
GrayScale
11281129
Hash

docs/source/reference/trainers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ Trainer and hooks
184184
TrainerHookBase
185185
UpdateWeights
186186
TargetNetUpdaterHook
187+
UTDRHook
187188

188189

189190
Algorithm-specific trainers (Experimental)

sota-implementations/ppo_trainer/config/config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ defaults:
55

66
- transform@transform0: noop_reset
77
- transform@transform1: step_counter
8+
- transform@transform2: reward_sum
89

910
- env@training_env: batched_env
1011
- env@training_env.create_env_fn: transformed_env
@@ -64,6 +65,10 @@ transform1:
6465
max_steps: 200
6566
step_count_key: "step_count"
6667

68+
transform2:
69+
in_keys: ["reward"]
70+
out_keys: ["reward_sum"]
71+
6772
training_env:
6873
num_workers: 1
6974
create_env_fn:
@@ -73,6 +78,7 @@ training_env:
7378
transforms:
7479
- ${transform0}
7580
- ${transform1}
81+
- ${transform2}
7682
_partial_: true
7783

7884
# Loss configuration
@@ -92,6 +98,7 @@ collector:
9298
total_frames: 1_000_000
9399
frames_per_batch: 1024
94100
num_workers: 2
101+
_partial_: true
95102

96103
# Replay buffer configuration
97104
replay_buffer:
@@ -129,3 +136,4 @@ trainer:
129136
save_trainer_file: null
130137
optim_steps_per_batch: null
131138
num_epochs: 2
139+
async_collection: false

sota-implementations/sac_trainer/config/config.yaml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ defaults:
55

66
- transform@transform0: step_counter
77
- transform@transform1: double_to_float
8+
- transform@transform2: reward_sum
89

910
- env@training_env: batched_env
1011
- env@training_env.create_env_fn: transformed_env
@@ -72,6 +73,11 @@ transform1:
7273
in_keys: null
7374
out_keys: null
7475

76+
transform2:
77+
# RewardSumTransform - sums up the rewards
78+
in_keys: ["reward"]
79+
out_keys: ["reward_sum"]
80+
7581
training_env:
7682
num_workers: 4
7783
create_env_fn:
@@ -81,6 +87,7 @@ training_env:
8187
transforms:
8288
- ${transform0}
8389
- ${transform1}
90+
- ${transform2}
8491
_partial_: true
8592

8693
# Loss configuration
@@ -107,19 +114,21 @@ collector:
107114
total_frames: 1_000_000
108115
frames_per_batch: 1000
109116
num_workers: 4
110-
init_random_frames: 25000
117+
init_random_frames: 2500
111118
track_policy_version: true
119+
_partial_: true
112120

113121
# Replay buffer configuration
114122
replay_buffer:
115123
storage:
116-
max_size: 1_000_000
124+
max_size: 100_000
117125
device: cpu
118126
ndim: 1
119127
sampler:
120128
writer:
121129
compilable: false
122-
batch_size: 256
130+
batch_size: 64
131+
shared: true
123132

124133
logger:
125134
exp_name: sac_halfcheetah_v4
@@ -134,7 +143,7 @@ trainer:
134143
target_net_updater: ${target_net_updater}
135144
loss_module: ${loss}
136145
logger: ${logger}
137-
total_frames: 1_000_000
146+
total_frames: ${collector.total_frames}
138147
frame_skip: 1
139148
clip_grad_norm: false # SAC typically doesn't use gradient clipping
140149
clip_norm: null
@@ -144,3 +153,4 @@ trainer:
144153
log_interval: 25000
145154
save_trainer_file: null
146155
optim_steps_per_batch: 64 # Match SOTA utd_ratio
156+
async_collection: false
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# SAC Trainer Configuration for HalfCheetah-v4
2+
# Run with `python sota-implementations/sac_trainer/train.py --config-name=config_async`
3+
# This configuration uses the new configurable trainer system and matches SOTA SAC implementation
4+
5+
defaults:
6+
7+
- transform@transform0: step_counter
8+
- transform@transform1: double_to_float
9+
- transform@transform2: reward_sum
10+
- transform@transform3: flatten_tensordict
11+
12+
- env@training_env: batched_env
13+
- env@training_env.create_env_fn: transformed_env
14+
- env@training_env.create_env_fn.base_env: gym
15+
- transform@training_env.create_env_fn.transform: compose
16+
17+
- model@models.policy_model: tanh_normal
18+
- model@models.value_model: value
19+
- model@models.qvalue_model: value
20+
21+
- network@networks.policy_network: mlp
22+
- network@networks.value_network: mlp
23+
- network@networks.qvalue_network: mlp
24+
25+
- collector@collector: multi_async
26+
27+
- replay_buffer@replay_buffer: base
28+
- storage@replay_buffer.storage: lazy_tensor
29+
- writer@replay_buffer.writer: round_robin
30+
- sampler@replay_buffer.sampler: random
31+
- trainer@trainer: sac
32+
- optimizer@optimizer: adam
33+
- loss@loss: sac
34+
- target_net_updater@target_net_updater: soft
35+
- logger@logger: wandb
36+
- _self_
37+
38+
# Network configurations
39+
networks:
40+
policy_network:
41+
out_features: 12 # HalfCheetah action space is 6-dimensional (loc + scale) = 2 * 6
42+
in_features: 17 # HalfCheetah observation space is 17-dimensional
43+
num_cells: [256, 256]
44+
45+
value_network:
46+
out_features: 1 # Value output
47+
in_features: 17 # HalfCheetah observation space
48+
num_cells: [256, 256]
49+
50+
qvalue_network:
51+
out_features: 1 # Q-value output
52+
in_features: 23 # HalfCheetah observation space (17) + action space (6)
53+
num_cells: [256, 256]
54+
55+
# Model configurations
56+
models:
57+
policy_model:
58+
return_log_prob: true
59+
in_keys: ["observation"]
60+
param_keys: ["loc", "scale"]
61+
out_keys: ["action"]
62+
network: ${networks.policy_network}
63+
# Configure NormalParamExtractor for higher exploration
64+
scale_mapping: "biased_softplus_2.0" # Higher bias for more exploration (default: 1.0)
65+
scale_lb: 1e-2 # Minimum scale value (default: 1e-4)
66+
67+
qvalue_model:
68+
in_keys: ["observation", "action"]
69+
out_keys: ["state_action_value"]
70+
network: ${networks.qvalue_network}
71+
72+
transform0:
73+
max_steps: 1000
74+
step_count_key: "step_count"
75+
76+
transform1:
77+
# DoubleToFloatTransform - converts double precision to float to fix dtype mismatch
78+
in_keys: null
79+
out_keys: null
80+
81+
transform2:
82+
# RewardSumTransform - sums up the rewards
83+
in_keys: ["reward"]
84+
out_keys: ["reward_sum"]
85+
86+
training_env:
87+
num_workers: 4
88+
create_env_fn:
89+
base_env:
90+
env_name: HalfCheetah-v4
91+
transform:
92+
transforms:
93+
- ${transform0}
94+
- ${transform1}
95+
- ${transform2}
96+
_partial_: true
97+
98+
# Loss configuration
99+
loss:
100+
actor_network: ${models.policy_model}
101+
qvalue_network: ${models.qvalue_model}
102+
target_entropy: "auto"
103+
loss_function: l2
104+
alpha_init: 1.0
105+
delay_qvalue: true
106+
num_qvalue_nets: 2
107+
108+
target_net_updater:
109+
tau: 0.001
110+
111+
# Optimizer configuration
112+
optimizer:
113+
lr: 3.0e-4
114+
115+
# Collector configuration
116+
collector:
117+
create_env_fn: ${training_env}
118+
policy: ${models.policy_model}
119+
total_frames: 5_000_000
120+
frames_per_batch: 1000
121+
num_workers: 8
122+
# Incompatible with async collection
123+
init_random_frames: 0
124+
track_policy_version: true
125+
extend_buffer: true
126+
_partial_: true
127+
128+
# Replay buffer configuration
129+
replay_buffer:
130+
storage:
131+
max_size: 10_000
132+
device: cpu
133+
ndim: 1
134+
sampler:
135+
writer:
136+
compilable: false
137+
batch_size: 256
138+
shared: true
139+
transform: ${transform3}
140+
141+
logger:
142+
exp_name: sac_halfcheetah_v4
143+
offline: false
144+
project: torchrl-sota-implementations
145+
146+
# Trainer configuration
147+
trainer:
148+
collector: ${collector}
149+
optimizer: ${optimizer}
150+
replay_buffer: ${replay_buffer}
151+
target_net_updater: ${target_net_updater}
152+
loss_module: ${loss}
153+
logger: ${logger}
154+
total_frames: ${collector.total_frames}
155+
frame_skip: 1
156+
clip_grad_norm: false # SAC typically doesn't use gradient clipping
157+
clip_norm: null
158+
progress_bar: true
159+
seed: 42
160+
save_trainer_interval: 25000 # Match SOTA eval_iter
161+
log_interval: 25000
162+
save_trainer_file: null
163+
optim_steps_per_batch: 16 # Match SOTA utd_ratio
164+
async_collection: true

sota-implementations/sac_trainer/train.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,12 @@
33
# LICENSE file in the root directory of this source tree.
44

55
import hydra
6-
import torchrl
76
from torchrl.trainers.algorithms.configs import * # noqa: F401, F403
87

98

109
@hydra.main(config_path="config", config_name="config", version_base="1.1")
1110
def main(cfg):
12-
def print_reward(td):
13-
torchrl.logger.info(f"reward: {td['next', 'reward'].mean(): 4.4f}")
14-
1511
trainer = hydra.utils.instantiate(cfg.trainer)
16-
trainer.register_op(dest="batch_process", op=print_reward)
1712
trainer.train()
1813

1914

torchrl/collectors/collectors.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
6161
from torchrl.envs.common import _do_nothing, EnvBase
6262
from torchrl.envs.env_creator import EnvCreator
63+
64+
from torchrl.envs.llm.transforms.policy_version import PolicyVersion
6365
from torchrl.envs.transforms import StepCounter, TransformedEnv
6466
from torchrl.envs.utils import (
6567
_aggregate_end_of_traj,
@@ -69,8 +71,6 @@
6971
set_exploration_type,
7072
)
7173

72-
from torchrl.envs.llm.transforms.policy_version import PolicyVersion
73-
7474
try:
7575
from torch.compiler import cudagraph_mark_step_begin
7676
except ImportError:
@@ -1818,13 +1818,20 @@ def get_policy_version(self) -> str | int | None:
18181818
return self.policy_version
18191819

18201820
def getattr_policy(self, attr):
1821+
"""Get an attribute from the policy."""
18211822
# send command to policy to return the attr
18221823
return getattr(self.policy, attr)
18231824

18241825
def getattr_env(self, attr):
1826+
"""Get an attribute from the environment."""
18251827
# send command to env to return the attr
18261828
return getattr(self.env, attr)
18271829

1830+
def getattr_rb(self, attr):
1831+
"""Get an attribute from the replay buffer."""
1832+
# send command to rb to return the attr
1833+
return getattr(self.replay_buffer, attr)
1834+
18281835

18291836
class _MultiDataCollector(DataCollectorBase):
18301837
"""Runs a given number of DataCollectors on separate processes.
@@ -2153,6 +2160,7 @@ def __init__(
21532160
and hasattr(replay_buffer, "shared")
21542161
and not replay_buffer.shared
21552162
):
2163+
torchrl_logger.warning("Replay buffer is not shared. Sharing it.")
21562164
replay_buffer.share()
21572165

21582166
self._policy_weights_dict = {}
@@ -2306,8 +2314,8 @@ def _check_replay_buffer_init(self):
23062314
fake_td["collector", "traj_ids"] = torch.zeros(
23072315
fake_td.shape, dtype=torch.long
23082316
)
2309-
2310-
self.replay_buffer.add(fake_td)
2317+
# Use extend to avoid time-related transforms to fail
2318+
self.replay_buffer.extend(fake_td.unsqueeze(-1))
23112319
self.replay_buffer.empty()
23122320

23132321
@classmethod
@@ -2841,6 +2849,10 @@ def getattr_env(self, attr):
28412849

28422850
return result
28432851

2852+
def getattr_rb(self, attr):
2853+
"""Get an attribute from the replay buffer."""
2854+
return getattr(self.replay_buffer, attr)
2855+
28442856

28452857
@accept_remote_rref_udf_invocation
28462858
class MultiSyncDataCollector(_MultiDataCollector):

torchrl/data/replay_buffers/storages.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,10 @@ def max_size_along_dim0(data_shape):
11721172

11731173
self._storage = out
11741174
self.initialized = True
1175+
if hasattr(self._storage, "shape"):
1176+
torchrl_logger.info(
1177+
f"Initialized LazyTensorStorage with {self._storage.shape} shape"
1178+
)
11751179

11761180

11771181
class LazyMemmapStorage(LazyTensorStorage):
@@ -1391,6 +1395,10 @@ def max_size_along_dim0(data_shape):
13911395
else:
13921396
out = _init_pytree(self.scratch_dir, max_size_along_dim0, data)
13931397
self._storage = out
1398+
if hasattr(self._storage, "shape"):
1399+
torchrl_logger.info(
1400+
f"Initialized LazyMemmapStorage with {self._storage.shape} shape"
1401+
)
13941402
self.initialized = True
13951403

13961404
def get(self, index: int | Sequence[int] | slice) -> Any:

0 commit comments

Comments
 (0)