Skip to content

Commit 37fdd3b

Browse files
committed
[Feature] Collectors - Weight Sync Scheme Integration
ghstack-source-id: b1907cd Pull-Request: #3187
1 parent cd7dcaa commit 37fdd3b

File tree

9 files changed

+1319
-314
lines changed

9 files changed

+1319
-314
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Example of updating weights of several models at once in a multiprocessed data collector.
2+
3+
This example demonstrates:
4+
1. Using different weight sync schemes for different models
5+
2. Updating the policy (via pipes with MultiProcessWeightSyncScheme)
6+
3. Updating Ray-based transforms in env and replay buffer (via RayModuleTransformScheme)
7+
4. Atomic multi-model weight updates using weights_dict
8+
9+
Note:
10+
- Ray actors are shared across all workers, so RayModuleTransformScheme uses a
11+
single transport rather than per-worker pipes.
12+
- When using transform_factory with a replay buffer, delayed_init automatically defaults
13+
to True for proper serialization in multiprocessing contexts.
14+
- extend_buffer defaults to True in all collectors, extending the buffer with entire
15+
rollouts rather than individual frames for better compatibility with postprocessing.
16+
"""
17+
18+
from functools import partial
19+
20+
import torch.nn as nn
21+
from tensordict import TensorDict
22+
from tensordict.nn import TensorDictModule
23+
24+
from torchrl.collectors import MultiSyncDataCollector
25+
from torchrl.data import LazyTensorStorage, ReplayBuffer
26+
from torchrl.envs.libs.gym import GymEnv
27+
from torchrl.envs.transforms.module import ModuleTransform
28+
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme
29+
30+
31+
def make_module():
32+
# A module that transforms the observations
33+
return TensorDictModule(
34+
nn.Linear(3, 3), in_keys=["observation"], out_keys=["observation"]
35+
)
36+
37+
38+
def policy_factory():
39+
# A module that produces the actions
40+
return TensorDictModule(
41+
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
42+
)
43+
44+
45+
def make_env():
46+
env_module = ModuleTransform(
47+
module_factory=make_module, inverse=False, no_grad=True
48+
)
49+
return GymEnv("Pendulum-v1").append_transform(env_module)
50+
51+
52+
def main():
53+
rb = ReplayBuffer(
54+
storage=LazyTensorStorage(10000, shared_init=True),
55+
transform_factory=partial(
56+
ModuleTransform,
57+
module_factory=make_module,
58+
inverse=True,
59+
no_grad=True,
60+
),
61+
# delayed_init automatically defaults to True when transform_factory is provided
62+
)
63+
64+
policy = policy_factory()
65+
66+
weight_sync_schemes = {
67+
"policy": MultiProcessWeightSyncScheme(strategy="state_dict"),
68+
"replay_buffer.transform[0].module": MultiProcessWeightSyncScheme(
69+
strategy="tensordict"
70+
),
71+
"env.transform[0].module": MultiProcessWeightSyncScheme(strategy="tensordict"),
72+
}
73+
74+
collector = MultiSyncDataCollector(
75+
create_env_fn=[make_env, make_env],
76+
policy_factory=policy_factory,
77+
total_frames=2000,
78+
max_frames_per_traj=50,
79+
frames_per_batch=200,
80+
init_random_frames=-1,
81+
device="cpu",
82+
storing_device="cpu",
83+
weight_sync_schemes=weight_sync_schemes,
84+
replay_buffer=rb,
85+
local_init_rb=True,
86+
# extend_buffer=True is the default for MultiSyncDataCollector
87+
)
88+
89+
policy_weights = TensorDict.from_module(policy).data
90+
env_module_weights = TensorDict.from_module(make_module()).data
91+
rb_module_weights = TensorDict.from_module(make_module()).data
92+
93+
for i, _data in enumerate(collector):
94+
env_module_weights.zero_()
95+
rb_module_weights.zero_()
96+
policy_weights.zero_()
97+
98+
collector.update_policy_weights_(
99+
weights_dict={
100+
"policy": policy_weights,
101+
"env.transform[0].module": env_module_weights,
102+
"replay_buffer.transform[0].module": rb_module_weights,
103+
}
104+
)
105+
106+
assert len(rb) == i * 200 + 200
107+
108+
if i >= 10:
109+
break
110+
111+
collector.shutdown()
112+
113+
114+
if __name__ == "__main__":
115+
main()

examples/collectors/weight_sync_standalone.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def example_multiprocess_sync():
129129
print(
130130
f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}"
131131
)
132-
print(f"✓ Weight synchronization successful!")
132+
print("✓ Weight synchronization successful!")
133133

134134

135135
def example_shared_memory_sync():
@@ -163,7 +163,7 @@ def example_shared_memory_sync():
163163
weights_td["weight"].fill_(2.0)
164164
weights_td["bias"].fill_(1.0)
165165

166-
print(f"Main: Sending weights via shared memory...")
166+
print("Main: Sending weights via shared memory...")
167167
sender.update_weights(weights_td)
168168

169169
# Workers automatically see updates via shared memory!
@@ -179,7 +179,7 @@ def example_shared_memory_sync():
179179
print(
180180
f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}"
181181
)
182-
print(f"✓ Shared memory synchronization successful!")
182+
print("✓ Shared memory synchronization successful!")
183183

184184

185185
def main():

0 commit comments

Comments
 (0)