Skip to content

Commit 8eeba1d

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent 48e6fdc commit 8eeba1d

File tree

3 files changed

+761
-1
lines changed

3 files changed

+761
-1
lines changed

docs/source/reference/collectors.rst

Lines changed: 291 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,299 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
117117

118118
Policy copy decision tree in Collectors.
119119

120-
Weight Synchronization in Distributed Environments
120+
Weight Synchronization using Weight Update Schemes
121121
--------------------------------------------------
122122

123+
RL pipelines are typically split in two big computational buckets: training, and inference.
124+
While the inference pipeline sends data to the training one, the training pipeline needs to occasionally
125+
synchronize its weights with the inference one.
126+
In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are
127+
used in both instances. From there, anything can happen:
128+
129+
- In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named
130+
`DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights
131+
for his instance of the policy.
132+
- In some cases, the environment or the postprocessing hooks can rely on the usage of a model which itself needs
133+
synchronization. This means that there can be multiple ends in the data transfer API and one needs to think beyond
134+
policy-to-policy weight synchronization strategies.
135+
- In the LLM world, the inference engine and the training one are very different: they will use different libraries,
136+
kernels and calling APIs (e.g., `generate` vs. `forward`). The weight format can also be drastically different (quantized
137+
vs non-quantized).
138+
This makes the weight synchronization much more complex, as one cannot simply dump and load a state dict on both ends.
139+
- One typically also has to choose who instantiates a transfer: should this come from the inference engine who actively
140+
asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach
141+
is to store the weights on some intermediary server and let the workers fetch them when necessary.
142+
143+
TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight
144+
transfer:
145+
146+
- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer;
147+
- A `Receiver` class that casts the weights to the destination module (policy or other utility module);
148+
- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else).
149+
- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them.
150+
151+
Each of these classes is detailed below.
152+
153+
Usage Examples
154+
~~~~~~~~~~~~~~
155+
156+
.. note::
157+
**Runnable versions** of these examples are available in the repository:
158+
159+
- `examples/collectors/weight_sync_standalone.py <https://github.com/pytorch/rl/blob/main/examples/collectors/weight_sync_standalone.py>`_: Standalone weight synchronization
160+
- `examples/collectors/weight_sync_collectors.py <https://github.com/pytorch/rl/blob/main/examples/collectors/weight_sync_collectors.py>`_: Collector integration
161+
162+
Using Weight Update Schemes Independently
163+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
164+
165+
Weight update schemes can be used outside of collectors for custom synchronization scenarios. Here's a basic example:
166+
167+
.. code-block:: python
168+
169+
import torch
170+
import torch.nn as nn
171+
from torch import multiprocessing as mp
172+
from tensordict import TensorDict
173+
from torchrl.weight_update import (
174+
MultiProcessWeightSyncScheme,
175+
SharedMemWeightSyncScheme,
176+
)
177+
178+
# Create a simple policy
179+
policy = nn.Linear(4, 2)
180+
181+
# Example 1: Multiprocess weight synchronization with state_dict
182+
# --------------------------------------------------------------
183+
# On the main process side (trainer):
184+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
185+
sender = scheme.create_sender()
186+
187+
# Register worker pipes
188+
parent_pipe, child_pipe = mp.Pipe()
189+
sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe)
190+
191+
# Send weights to workers
192+
weights = policy.state_dict()
193+
sender.update_weights(weights)
194+
195+
# On the worker process side:
196+
# receiver = scheme.create_receiver()
197+
# receiver.register_model(policy)
198+
# receiver.register_worker_transport(child_pipe)
199+
# # Receive and apply weights
200+
# result = receiver._transport.receive_weights(timeout=5.0)
201+
# if result is not None:
202+
# model_id, weights = result
203+
# receiver.apply_weights(weights)
204+
205+
# Example 2: Shared memory weight synchronization
206+
# ------------------------------------------------
207+
# Create shared memory scheme with auto-registration
208+
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
209+
shared_sender = shared_scheme.create_sender()
210+
211+
# Register worker pipes for lazy registration
212+
parent_pipe2, child_pipe2 = mp.Pipe()
213+
shared_sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe2)
214+
215+
# Send weights (automatically creates shared buffer on first send)
216+
weights_td = TensorDict.from_module(policy)
217+
shared_sender.update_weights(weights_td)
218+
219+
# Workers automatically see updates via shared memory!
220+
221+
Using Weight Update Schemes with Collectors
222+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
223+
224+
Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization
225+
across multiple inference workers:
226+
227+
.. code-block:: python
228+
229+
import torch.nn as nn
230+
from tensordict.nn import TensorDictModule
231+
from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector
232+
from torchrl.envs import GymEnv
233+
from torchrl.weight_update import (
234+
MultiProcessWeightSyncScheme,
235+
SharedMemWeightSyncScheme,
236+
NoWeightSyncScheme,
237+
)
238+
239+
# Create environment and policy
240+
env = GymEnv("CartPole-v1")
241+
policy = TensorDictModule(
242+
nn.Linear(env.observation_spec["observation"].shape[-1],
243+
env.action_spec.shape[-1]),
244+
in_keys=["observation"],
245+
out_keys=["action"],
246+
)
247+
248+
# Example 1: Single collector with multiprocess scheme
249+
# -----------------------------------------------------
250+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
251+
252+
collector = SyncDataCollector(
253+
create_env_fn=lambda: GymEnv("CartPole-v1"),
254+
policy=policy,
255+
frames_per_batch=64,
256+
total_frames=1000,
257+
weight_sync_schemes={"policy": scheme},
258+
)
259+
260+
# Collect data and update weights periodically
261+
for i, data in enumerate(collector):
262+
# ... training step with data ...
263+
264+
# Update policy weights every N iterations
265+
if i % 10 == 0:
266+
new_weights = policy.state_dict()
267+
collector.update_policy_weights_(new_weights)
268+
269+
collector.shutdown()
270+
271+
# Example 2: Multiple collectors with shared memory
272+
# --------------------------------------------------
273+
# Shared memory is more efficient for frequent updates
274+
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
275+
276+
collector = MultiSyncDataCollector(
277+
create_env_fn=[
278+
lambda: GymEnv("CartPole-v1"),
279+
lambda: GymEnv("CartPole-v1"),
280+
lambda: GymEnv("CartPole-v1"),
281+
],
282+
policy=policy,
283+
frames_per_batch=192,
284+
total_frames=10000,
285+
weight_sync_schemes={"policy": shared_scheme},
286+
)
287+
288+
# Workers automatically see weight updates via shared memory
289+
for data in collector:
290+
# ... training ...
291+
collector.update_policy_weights_(TensorDict.from_module(policy))
292+
293+
collector.shutdown()
294+
295+
# Example 3: Multiple models (policy + value network)
296+
# ---------------------------------------------------
297+
value_net = TensorDictModule(
298+
nn.Linear(env.observation_spec["observation"].shape[-1], 1),
299+
in_keys=["observation"],
300+
out_keys=["value"],
301+
)
302+
303+
weight_sync_schemes = {
304+
"policy": MultiProcessWeightSyncScheme(strategy="state_dict"),
305+
"value": MultiProcessWeightSyncScheme(strategy="state_dict"),
306+
}
307+
308+
collector = SyncDataCollector(
309+
create_env_fn=lambda: GymEnv("CartPole-v1"),
310+
policy=policy,
311+
frames_per_batch=64,
312+
total_frames=1000,
313+
weight_sync_schemes=weight_sync_schemes,
314+
)
315+
316+
# Update multiple models independently
317+
collector.update_policy_weights_(
318+
{"policy": policy.state_dict(), "value": value_net.state_dict()}
319+
)
320+
321+
collector.shutdown()
322+
323+
# Example 4: Disable weight synchronization
324+
# ------------------------------------------
325+
# Useful for debugging or when using a shared policy reference
326+
no_sync_scheme = NoWeightSyncScheme()
327+
328+
collector = SyncDataCollector(
329+
create_env_fn=lambda: GymEnv("CartPole-v1"),
330+
policy=policy,
331+
frames_per_batch=64,
332+
total_frames=1000,
333+
weight_sync_schemes={"policy": no_sync_scheme},
334+
)
335+
336+
.. note::
337+
When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all
338+
processes share the same memory buffers. This is ideal for frequent weight updates but requires all
339+
processes to be on the same machine.
340+
341+
.. note::
342+
The ``strategy`` parameter determines the weight format: ``"state_dict"`` uses PyTorch's native state
343+
dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured
344+
models and supports advanced features like lazy initialization.
345+
346+
Weight Senders
347+
~~~~~~~~~~~~~~
348+
349+
.. currentmodule:: torchrl.weight_update
350+
351+
.. autosummary::
352+
:toctree: generated/
353+
:template: rl_template.rst
354+
355+
WeightSender
356+
RayModuleTransformSender
357+
358+
Weight Receivers
359+
~~~~~~~~~~~~~~~~
360+
361+
.. currentmodule:: torchrl.weight_update
362+
363+
.. autosummary::
364+
:toctree: generated/
365+
:template: rl_template.rst
366+
367+
WeightReceiver
368+
RayModuleTransformReceiver
369+
370+
Transports
371+
~~~~~~~~~~
372+
373+
.. currentmodule:: torchrl.weight_update
374+
375+
.. autosummary::
376+
:toctree: generated/
377+
:template: rl_template.rst
378+
379+
TransportBackend
380+
MPTransport
381+
SharedMemTransport
382+
RayTransport
383+
RayActorTransport
384+
RPCTransport
385+
DistributedTransport
386+
387+
Schemes
388+
~~~~~~~
389+
390+
.. currentmodule:: torchrl.weight_update
391+
392+
.. autosummary::
393+
:toctree: generated/
394+
:template: rl_template.rst
395+
396+
WeightSyncScheme
397+
MultiProcessWeightSyncScheme
398+
SharedMemWeightSyncScheme
399+
NoWeightSyncScheme
400+
RayWeightSyncScheme
401+
RayModuleTransformScheme
402+
RPCWeightSyncScheme
403+
DistributedWeightSyncScheme
404+
405+
Legacy: Weight Synchronization in Distributed Environments
406+
----------------------------------------------------------
407+
408+
.. warning::
409+
The `WeightUpdater` is considered legacy as per the 0.11 release and will be deprecated soon.
410+
The Weight update schemes, which provides more flexibility and a better compatibility with heavy
411+
weight transfers (e.g., LLMs) is to be preferred.
412+
123413
In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
124414
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
125415
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.

0 commit comments

Comments
 (0)