Skip to content

Commit d2e2639

Browse files
committed
[Feature] Collectors - Weight Sync Scheme Integration
ghstack-source-id: 8caabb1 Pull-Request: #3187
1 parent 7f19046 commit d2e2639

File tree

7 files changed

+1203
-310
lines changed

7 files changed

+1203
-310
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Example of updating weights of several models at once in a multiprocessed data collector.
2+
3+
This example demonstrates:
4+
1. Using different weight sync schemes for different models
5+
2. Updating the policy (via pipes with MultiProcessWeightSyncScheme)
6+
3. Updating Ray-based transforms in env and replay buffer (via RayModuleTransformScheme)
7+
4. Atomic multi-model weight updates using weights_dict
8+
9+
Note:
10+
- Ray actors are shared across all workers, so RayModuleTransformScheme uses a
11+
single transport rather than per-worker pipes.
12+
- When using transform_factory with a replay buffer, delayed_init automatically defaults
13+
to True for proper serialization in multiprocessing contexts.
14+
- extend_buffer defaults to True in all collectors, extending the buffer with entire
15+
rollouts rather than individual frames for better compatibility with postprocessing.
16+
"""
17+
18+
from functools import partial
19+
20+
import torch.nn as nn
21+
from tensordict import TensorDict
22+
from tensordict.nn import TensorDictModule
23+
24+
from torchrl.collectors import MultiSyncDataCollector
25+
from torchrl.data import LazyTensorStorage, ReplayBuffer
26+
from torchrl.envs.libs.gym import GymEnv
27+
from torchrl.envs.transforms.module import ModuleTransform
28+
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme
29+
30+
31+
def make_module():
32+
# A module that transforms the observations
33+
return TensorDictModule(
34+
nn.Linear(3, 3), in_keys=["observation"], out_keys=["observation"]
35+
)
36+
37+
38+
def policy_factory():
39+
# A module that produces the actions
40+
return TensorDictModule(
41+
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
42+
)
43+
44+
45+
def make_env():
46+
env_module = ModuleTransform(
47+
module_factory=make_module, inverse=False, no_grad=True
48+
)
49+
return GymEnv("Pendulum-v1").append_transform(env_module)
50+
51+
52+
def main():
53+
rb = ReplayBuffer(
54+
storage=LazyTensorStorage(10000, shared_init=True),
55+
transform_factory=partial(
56+
ModuleTransform,
57+
module_factory=make_module,
58+
inverse=True,
59+
no_grad=True,
60+
),
61+
# delayed_init automatically defaults to True when transform_factory is provided
62+
)
63+
64+
policy = policy_factory()
65+
66+
weight_sync_schemes = {
67+
"policy": MultiProcessWeightSyncScheme(strategy="state_dict"),
68+
"replay_buffer.transform[0].module": MultiProcessWeightSyncScheme(
69+
strategy="tensordict"
70+
),
71+
"env.transform[0].module": MultiProcessWeightSyncScheme(strategy="tensordict"),
72+
}
73+
74+
collector = MultiSyncDataCollector(
75+
create_env_fn=[make_env, make_env],
76+
policy_factory=policy_factory,
77+
total_frames=2000,
78+
max_frames_per_traj=50,
79+
frames_per_batch=200,
80+
init_random_frames=-1,
81+
device="cpu",
82+
storing_device="cpu",
83+
weight_sync_schemes=weight_sync_schemes,
84+
replay_buffer=rb,
85+
local_init_rb=True,
86+
# extend_buffer=True is the default for MultiSyncDataCollector
87+
)
88+
89+
policy_weights = TensorDict.from_module(policy).data
90+
env_module_weights = TensorDict.from_module(make_module()).data
91+
rb_module_weights = TensorDict.from_module(make_module()).data
92+
93+
for i, _data in enumerate(collector):
94+
env_module_weights.zero_()
95+
rb_module_weights.zero_()
96+
policy_weights.zero_()
97+
98+
collector.update_policy_weights_(
99+
weights_dict={
100+
"policy": policy_weights,
101+
"env.transform[0].module": env_module_weights,
102+
"replay_buffer.transform[0].module": rb_module_weights,
103+
}
104+
)
105+
106+
assert len(rb) == i * 200 + 200
107+
108+
if i >= 10:
109+
break
110+
111+
collector.shutdown()
112+
113+
114+
if __name__ == "__main__":
115+
main()

0 commit comments

Comments
 (0)