Skip to content

Commit e2eb32b

Browse files
committed
Update
[ghstack-poisoned]
2 parents acc0c84 + 7b704d3 commit e2eb32b

File tree

3 files changed

+53
-159
lines changed

3 files changed

+53
-159
lines changed

docs/source/reference/collectors.rst

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ across multiple inference workers:
233233
from torchrl.weight_update import (
234234
MultiProcessWeightSyncScheme,
235235
SharedMemWeightSyncScheme,
236-
NoWeightSyncScheme,
237236
)
238237
239238
# Create environment and policy
@@ -292,47 +291,6 @@ across multiple inference workers:
292291
293292
collector.shutdown()
294293
295-
# Example 3: Multiple models (policy + value network)
296-
# ---------------------------------------------------
297-
value_net = TensorDictModule(
298-
nn.Linear(env.observation_spec["observation"].shape[-1], 1),
299-
in_keys=["observation"],
300-
out_keys=["value"],
301-
)
302-
303-
weight_sync_schemes = {
304-
"policy": MultiProcessWeightSyncScheme(strategy="state_dict"),
305-
"value": MultiProcessWeightSyncScheme(strategy="state_dict"),
306-
}
307-
308-
collector = SyncDataCollector(
309-
create_env_fn=lambda: GymEnv("CartPole-v1"),
310-
policy=policy,
311-
frames_per_batch=64,
312-
total_frames=1000,
313-
weight_sync_schemes=weight_sync_schemes,
314-
)
315-
316-
# Update multiple models independently
317-
collector.update_policy_weights_(
318-
{"policy": policy.state_dict(), "value": value_net.state_dict()}
319-
)
320-
321-
collector.shutdown()
322-
323-
# Example 4: Disable weight synchronization
324-
# ------------------------------------------
325-
# Useful for debugging or when using a shared policy reference
326-
no_sync_scheme = NoWeightSyncScheme()
327-
328-
collector = SyncDataCollector(
329-
create_env_fn=lambda: GymEnv("CartPole-v1"),
330-
policy=policy,
331-
frames_per_batch=64,
332-
total_frames=1000,
333-
weight_sync_schemes={"policy": no_sync_scheme},
334-
)
335-
336294
.. note::
337295
When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all
338296
processes share the same memory buffers. This is ideal for frequent weight updates but requires all

examples/collectors/weight_sync_collectors.py

Lines changed: 2 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
single collectors, multiple collectors, multiple models, and no synchronization.
1515
"""
1616

17-
import torch
1817
import torch.nn as nn
1918
from tensordict import TensorDict
2019
from tensordict.nn import TensorDictModule
@@ -23,7 +22,6 @@
2322
from torchrl.weight_update import (
2423
MultiProcessWeightSyncScheme,
2524
SharedMemWeightSyncScheme,
26-
NoWeightSyncScheme,
2725
)
2826

2927

@@ -66,7 +64,7 @@ def example_single_collector_multiprocess():
6664
if i % 2 == 0:
6765
new_weights = policy.state_dict()
6866
collector.update_policy_weights_(new_weights)
69-
print(f" → Updated policy weights")
67+
print(" → Updated policy weights")
7068

7169
if i >= 2: # Just run a few iterations for demo
7270
break
@@ -116,7 +114,7 @@ def example_multi_collector_shared_memory():
116114

117115
# Update weights frequently (shared memory makes this very fast)
118116
collector.update_policy_weights_(TensorDict.from_module(policy))
119-
print(f" → Updated policy weights via shared memory")
117+
print(" → Updated policy weights via shared memory")
120118

121119
if i >= 1: # Just run a couple iterations for demo
122120
break
@@ -125,115 +123,6 @@ def example_multi_collector_shared_memory():
125123
print("✓ Multi-collector with shared memory example completed!\n")
126124

127125

128-
def example_multiple_models():
129-
"""Example 3: Multiple models (policy + value network)."""
130-
print("\n" + "="*70)
131-
print("Example 3: Multiple Models (Policy + Value Network)")
132-
print("="*70)
133-
134-
# Create environment
135-
env = GymEnv("CartPole-v1")
136-
137-
# Create policy and value network
138-
policy = TensorDictModule(
139-
nn.Linear(
140-
env.observation_spec["observation"].shape[-1],
141-
env.action_spec.shape[-1]
142-
),
143-
in_keys=["observation"],
144-
out_keys=["action"],
145-
)
146-
147-
value_net = TensorDictModule(
148-
nn.Linear(
149-
env.observation_spec["observation"].shape[-1],
150-
1
151-
),
152-
in_keys=["observation"],
153-
out_keys=["value"],
154-
)
155-
env.close()
156-
157-
# Create separate schemes for each model
158-
weight_sync_schemes = {
159-
"policy": MultiProcessWeightSyncScheme(strategy="state_dict"),
160-
"value": MultiProcessWeightSyncScheme(strategy="state_dict"),
161-
}
162-
163-
print("Creating collector with multiple models...")
164-
collector = SyncDataCollector(
165-
create_env_fn=lambda: GymEnv("CartPole-v1"),
166-
policy=policy,
167-
frames_per_batch=64,
168-
total_frames=200,
169-
weight_sync_schemes=weight_sync_schemes,
170-
)
171-
172-
print("Collecting data...")
173-
for i, data in enumerate(collector):
174-
print(f"Iteration {i}: Collected {data.numel()} transitions")
175-
176-
# Update both models independently
177-
collector.update_policy_weights_(
178-
{
179-
"policy": policy.state_dict(),
180-
"value": value_net.state_dict()
181-
}
182-
)
183-
print(f" → Updated both policy and value network weights")
184-
185-
if i >= 1:
186-
break
187-
188-
collector.shutdown()
189-
print("✓ Multiple models example completed!\n")
190-
191-
192-
def example_no_weight_sync():
193-
"""Example 4: Disable weight synchronization."""
194-
print("\n" + "="*70)
195-
print("Example 4: Disable Weight Synchronization")
196-
print("="*70)
197-
198-
# Create environment and policy
199-
env = GymEnv("CartPole-v1")
200-
policy = TensorDictModule(
201-
nn.Linear(
202-
env.observation_spec["observation"].shape[-1],
203-
env.action_spec.shape[-1]
204-
),
205-
in_keys=["observation"],
206-
out_keys=["action"],
207-
)
208-
env.close()
209-
210-
# Useful for debugging or when using a shared policy reference
211-
scheme = NoWeightSyncScheme()
212-
213-
print("Creating collector with no weight synchronization...")
214-
collector = SyncDataCollector(
215-
create_env_fn=lambda: GymEnv("CartPole-v1"),
216-
policy=policy,
217-
frames_per_batch=64,
218-
total_frames=200,
219-
weight_sync_schemes={"policy": scheme},
220-
)
221-
222-
print("Collecting data (no weight updates)...")
223-
for i, data in enumerate(collector):
224-
print(f"Iteration {i}: Collected {data.numel()} transitions")
225-
226-
# Weight updates are no-ops with NoWeightSyncScheme
227-
collector.update_policy_weights_(policy.state_dict())
228-
print(f" → Weight update call was a no-op")
229-
230-
if i >= 1:
231-
break
232-
233-
collector.shutdown()
234-
print("✓ No weight sync example completed!\n")
235-
236-
237126
def main():
238127
"""Run all examples."""
239128
print("\n" + "="*70)
@@ -250,17 +139,13 @@ def main():
250139
# Run examples
251140
example_single_collector_multiprocess()
252141
example_multi_collector_shared_memory()
253-
example_multiple_models()
254-
example_no_weight_sync()
255142

256143
print("\n" + "="*70)
257144
print("All examples completed successfully!")
258145
print("="*70)
259146
print("\nKey takeaways:")
260147
print(" • MultiProcessWeightSyncScheme: Good for general multiprocess scenarios")
261148
print(" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers")
262-
print(" • Multiple models: Each model can have its own sync scheme")
263-
print(" • NoWeightSyncScheme: Useful for debugging or shared policy references")
264149
print("="*70 + "\n")
265150

266151

versions.html

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
<html>
2+
<head>
3+
<meta charset="utf-8">
4+
5+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
6+
<link rel="stylesheet" href="main/_static/css/theme.css" type="text/css" />
7+
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Lato" type="text/css" />
8+
<link rel="stylesheet" href="main/_static/css/pytorch_theme.css" type="text/css" />
9+
<script src="main/_static/js/modernizr.min.js"></script>
10+
11+
12+
</head>
13+
<body>
14+
<div class="wy-nav-content">
15+
<div class="rst-content">
16+
<h1> PyTorch Documentation </h1>
17+
<div class="toctree-wrapper compound">
18+
<p class="caption"><span class="caption-text">Pick a version</span></p>
19+
<ul>
20+
<li class="toctree-l1">
21+
<a class="reference internal" href="main/">main (unstable)</a>
22+
</li>
23+
<li class="toctree-l1">
24+
<a class="reference internal" href="0.10/">v0.10 (stable release)</a>
25+
</li>
26+
<li class="toctree-l1">
27+
<a class="reference internal" href="0.9/">v0.9</a>
28+
</li>
29+
<li class="toctree-l1">
30+
<a class="reference internal" href="0.8/">v0.8</a>
31+
</li>
32+
<li class="toctree-l1">
33+
<a class="reference internal" href="0.7/">v0.7</a>
34+
</li>
35+
<li class="toctree-l1">
36+
<a class="reference internal" href="0.6/">v0.6</a>
37+
</li>
38+
<li class="toctree-l1">
39+
<a class="reference internal" href="0.5/">v0.5</a>
40+
</li>
41+
<li class="toctree-l1">
42+
<a class="reference internal" href="0.4/">v0.4</a>
43+
</li>
44+
</ul>
45+
<p>You can view previous versions of the torchrl documentation
46+
<a href="https://pytorch.org/rl/versions.html">here</a>.
47+
48+
</div>
49+
</div></div>
50+
</body>
51+
</html>

0 commit comments

Comments
 (0)