1414single collectors, multiple collectors, multiple models, and no synchronization.
1515"""
1616
17- import torch
1817import torch .nn as nn
1918from tensordict import TensorDict
2019from tensordict .nn import TensorDictModule
2322from torchrl .weight_update import (
2423 MultiProcessWeightSyncScheme ,
2524 SharedMemWeightSyncScheme ,
26- NoWeightSyncScheme ,
2725)
2826
2927
@@ -66,7 +64,7 @@ def example_single_collector_multiprocess():
6664 if i % 2 == 0 :
6765 new_weights = policy .state_dict ()
6866 collector .update_policy_weights_ (new_weights )
69- print (f " → Updated policy weights" )
67+ print (" → Updated policy weights" )
7068
7169 if i >= 2 : # Just run a few iterations for demo
7270 break
@@ -116,7 +114,7 @@ def example_multi_collector_shared_memory():
116114
117115 # Update weights frequently (shared memory makes this very fast)
118116 collector .update_policy_weights_ (TensorDict .from_module (policy ))
119- print (f " → Updated policy weights via shared memory" )
117+ print (" → Updated policy weights via shared memory" )
120118
121119 if i >= 1 : # Just run a couple iterations for demo
122120 break
@@ -125,115 +123,6 @@ def example_multi_collector_shared_memory():
125123 print ("✓ Multi-collector with shared memory example completed!\n " )
126124
127125
128- def example_multiple_models ():
129- """Example 3: Multiple models (policy + value network)."""
130- print ("\n " + "=" * 70 )
131- print ("Example 3: Multiple Models (Policy + Value Network)" )
132- print ("=" * 70 )
133-
134- # Create environment
135- env = GymEnv ("CartPole-v1" )
136-
137- # Create policy and value network
138- policy = TensorDictModule (
139- nn .Linear (
140- env .observation_spec ["observation" ].shape [- 1 ],
141- env .action_spec .shape [- 1 ]
142- ),
143- in_keys = ["observation" ],
144- out_keys = ["action" ],
145- )
146-
147- value_net = TensorDictModule (
148- nn .Linear (
149- env .observation_spec ["observation" ].shape [- 1 ],
150- 1
151- ),
152- in_keys = ["observation" ],
153- out_keys = ["value" ],
154- )
155- env .close ()
156-
157- # Create separate schemes for each model
158- weight_sync_schemes = {
159- "policy" : MultiProcessWeightSyncScheme (strategy = "state_dict" ),
160- "value" : MultiProcessWeightSyncScheme (strategy = "state_dict" ),
161- }
162-
163- print ("Creating collector with multiple models..." )
164- collector = SyncDataCollector (
165- create_env_fn = lambda : GymEnv ("CartPole-v1" ),
166- policy = policy ,
167- frames_per_batch = 64 ,
168- total_frames = 200 ,
169- weight_sync_schemes = weight_sync_schemes ,
170- )
171-
172- print ("Collecting data..." )
173- for i , data in enumerate (collector ):
174- print (f"Iteration { i } : Collected { data .numel ()} transitions" )
175-
176- # Update both models independently
177- collector .update_policy_weights_ (
178- {
179- "policy" : policy .state_dict (),
180- "value" : value_net .state_dict ()
181- }
182- )
183- print (f" → Updated both policy and value network weights" )
184-
185- if i >= 1 :
186- break
187-
188- collector .shutdown ()
189- print ("✓ Multiple models example completed!\n " )
190-
191-
192- def example_no_weight_sync ():
193- """Example 4: Disable weight synchronization."""
194- print ("\n " + "=" * 70 )
195- print ("Example 4: Disable Weight Synchronization" )
196- print ("=" * 70 )
197-
198- # Create environment and policy
199- env = GymEnv ("CartPole-v1" )
200- policy = TensorDictModule (
201- nn .Linear (
202- env .observation_spec ["observation" ].shape [- 1 ],
203- env .action_spec .shape [- 1 ]
204- ),
205- in_keys = ["observation" ],
206- out_keys = ["action" ],
207- )
208- env .close ()
209-
210- # Useful for debugging or when using a shared policy reference
211- scheme = NoWeightSyncScheme ()
212-
213- print ("Creating collector with no weight synchronization..." )
214- collector = SyncDataCollector (
215- create_env_fn = lambda : GymEnv ("CartPole-v1" ),
216- policy = policy ,
217- frames_per_batch = 64 ,
218- total_frames = 200 ,
219- weight_sync_schemes = {"policy" : scheme },
220- )
221-
222- print ("Collecting data (no weight updates)..." )
223- for i , data in enumerate (collector ):
224- print (f"Iteration { i } : Collected { data .numel ()} transitions" )
225-
226- # Weight updates are no-ops with NoWeightSyncScheme
227- collector .update_policy_weights_ (policy .state_dict ())
228- print (f" → Weight update call was a no-op" )
229-
230- if i >= 1 :
231- break
232-
233- collector .shutdown ()
234- print ("✓ No weight sync example completed!\n " )
235-
236-
237126def main ():
238127 """Run all examples."""
239128 print ("\n " + "=" * 70 )
@@ -250,17 +139,13 @@ def main():
250139 # Run examples
251140 example_single_collector_multiprocess ()
252141 example_multi_collector_shared_memory ()
253- example_multiple_models ()
254- example_no_weight_sync ()
255142
256143 print ("\n " + "=" * 70 )
257144 print ("All examples completed successfully!" )
258145 print ("=" * 70 )
259146 print ("\n Key takeaways:" )
260147 print (" • MultiProcessWeightSyncScheme: Good for general multiprocess scenarios" )
261148 print (" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers" )
262- print (" • Multiple models: Each model can have its own sync scheme" )
263- print (" • NoWeightSyncScheme: Useful for debugging or shared policy references" )
264149 print ("=" * 70 + "\n " )
265150
266151
0 commit comments