diff --git a/docs/source/reference/llms.rst b/docs/source/reference/llms.rst index 98edf335704..327303fe338 100644 --- a/docs/source/reference/llms.rst +++ b/docs/source/reference/llms.rst @@ -633,7 +633,7 @@ Collectors .. _Collectors: TorchRL offers specialized collector classes (:class:`~torchrl.collectors.llm.LLMCollector` and :class:`~torchrl.collectors.llm.RayLLMCollector`) -that are tailored for LLM use cases. We also provide dedicated updaters for some inference engines. +that are tailored for LLM use cases. We also provide weight synchronization schemes for vLLM inference engines. See :ref:`ref_collectors` for more details on the collector API. In brief, the idea of a collector is to isolate the inference part of the pipeline in a dedicated class. @@ -649,8 +649,126 @@ Collectors are defined by the following parameters and features: In other cases, the collector can be iterated over to collect data. - **Steps**: A collector is built with a certain number of steps budget, as well as a number of steps to be included in each batch yield during collection. -- **Weight Updater**: Weight updaters are the classes that update the policy weights. Isolating the weight update - in a dedicated class allows to easily implement different weight update strategies depending on the policy specification. +- **Weight Synchronization Schemes**: Weight sync schemes handle the synchronization of weights between the training model + and the inference engine. The new scheme-based approach provides flexible, high-performance weight updates for vLLM and + other inference backends. + +vLLM Weight Synchronization Schemes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +TorchRL provides two weight synchronization schemes for vLLM engines, offering different trade-offs between +performance and simplicity: + +**1. NCCL-Based Synchronization** (:class:`~torchrl.weight_update.llm.VLLMWeightSyncScheme`) + +Uses NCCL collectives for high-bandwidth GPU-to-GPU weight transfers. Best for: + +- High-frequency weight updates +- Large models where transfer speed is critical +- Setups with GPU interconnect (NVLink, InfiniBand) + +**2. Double-Buffer Synchronization** (:class:`~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme`) + +Uses memory-mapped file storage for asynchronous weight transfers. Best for: + +- Simpler setup without NCCL coordination +- Distributed setups with shared filesystems (NFS) +- Cases where update frequency is lower + +**Usage Example with NCCL:** + +.. code-block:: python + + from torchrl.collectors.llm import RayLLMCollector + from torchrl.weight_update.llm import VLLMWeightSyncScheme + from torchrl.modules.llm import AsyncVLLM, vLLMWrapper + + # Create vLLM engine + vllm_engine = AsyncVLLM.from_pretrained( + "Qwen/Qwen2.5-7B", + num_devices=2, + num_replicas=2, + ) + policy = vLLMWrapper(vllm_engine, input_mode="history") + + # Create NCCL weight sync scheme + weight_sync_scheme = VLLMWeightSyncScheme( + master_address="localhost", + master_port=29500, + gpus_per_replica=2, # tp_size × dp_size × pp_size + num_replicas=2, + strategy="state_dict" + ) + + # Create collector with weight sync scheme + collector = RayLLMCollector( + env=make_env, + policy=policy, + dialog_turns_per_batch=256, + total_dialog_turns=10000, + weight_sync_schemes={"policy": weight_sync_scheme}, + track_policy_version=True, + ) + + # During training, get the sender and update weights + sender = collector._weight_senders["policy"] + sender.register_model(training_model) + + # Initialize collective group (must be called before first update) + metadata = get_model_metadata(training_model) + sender.init_all_workers_group(metadata, vllm_engine=vllm_engine) + + # Update weights during training + for i, data in enumerate(collector): + # ... training step ... + if i % 10 == 0: + sender.update_weights() # Broadcasts via NCCL + +**Usage Example with Double-Buffer:** + +.. code-block:: python + + from torchrl.collectors.llm import RayLLMCollector + from torchrl.weight_update.llm import VLLMDoubleBufferSyncScheme + from torchrl.modules.llm import AsyncVLLM, vLLMWrapper + + # Create vLLM engine + vllm_engine = AsyncVLLM.from_pretrained( + "Qwen/Qwen2.5-7B", + num_devices=2, + num_replicas=1, + ) + policy = vLLMWrapper(vllm_engine, input_mode="history") + + # Create double-buffer weight sync scheme + weight_sync_scheme = VLLMDoubleBufferSyncScheme( + remote_addr="/tmp/weights", # Or "/mnt/shared/weights" for NFS + num_threads=128, + strategy="state_dict" + ) + + # Create collector with weight sync scheme + collector = RayLLMCollector( + env=make_env, + policy=policy, + dialog_turns_per_batch=256, + total_dialog_turns=10000, + weight_sync_schemes={"policy": weight_sync_scheme}, + track_policy_version=True, + ) + + # During training, get the sender and receiver + sender = collector._weight_senders["policy"] + sender.register_model(training_model) + + # No initialization needed for double-buffer scheme! + + # Update weights during training + for i, data in enumerate(collector): + # ... training step ... + if i % 10 == 0: + sender.update_weights() # Writes to shared storage + # vLLM workers can poll and apply: receiver.poll_and_apply() Policy Version Tracking ~~~~~~~~~~~~~~~~~~~~~~~ @@ -662,19 +780,52 @@ transform, or a boolean to the collector constructor. >>> from torchrl.envs.llm.transforms import PolicyVersion >>> from torchrl.collectors.llm import LLMCollector - >>> from torchrl.collectors.llm.weight_update import vLLMUpdater + >>> from torchrl.weight_update.llm import VLLMWeightSyncScheme, get_model_metadata >>> env = make_env() # place your code here >>> policy = make_policy() # place your code here - >>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True) - >>> # init the updater - >>> collector.weight_updater.init(...) - >>> # the version is incremented after each weight update - >>> collector.update_policy_weights_(state_dict=...) + >>> scheme = VLLMWeightSyncScheme(master_port=29500, gpus_per_replica=1, num_replicas=1) + >>> collector = LLMCollector(env, policy=policy, weight_sync_schemes={"policy": scheme}, track_policy_version=True) + >>> # Get the sender and register model + >>> sender = collector._weight_senders["policy"] + >>> sender.register_model(training_model) + >>> # Initialize the collective group + >>> metadata = get_model_metadata(training_model) + >>> sender.init_all_workers_group(metadata, vllm_engine=policy.model) + >>> # Update weights + >>> sender.update_weights() >>> print(collector.policy_version_tracker.version) >>> # the policy version is written in the data >>> for data in collector: ... print(data["policy_version"]) +.. currentmodule:: torchrl.weight_update.llm + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + VLLMWeightSyncScheme + VLLMWeightSender + VLLMWeightReceiver + VLLMCollectiveTransport + VLLMDoubleBufferSyncScheme + VLLMDoubleBufferWeightSender + VLLMDoubleBufferWeightReceiver + VLLMDoubleBufferTransport + get_model_metadata + +Legacy Weight Updaters (Deprecated) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. deprecated:: 0.11 + The `vLLMUpdater` and `vLLMUpdaterV2` classes are deprecated in favor of the new weight synchronization schemes + (:class:`~torchrl.weight_update.llm.VLLMWeightSyncScheme` and :class:`~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme`). + These schemes provide better performance, more flexibility, and cleaner integration with collectors. + The legacy updaters will be removed in a future release. + + The legacy weight updaters (`vLLMUpdater` and `vLLMUpdaterV2`) are still available but are no longer recommended. + Please migrate to the new weight synchronization schemes shown above. + .. currentmodule:: torchrl.collectors.llm .. autosummary:: diff --git a/examples/collectors/multi_weight_updates.py b/examples/collectors/multi_weight_updates.py new file mode 100644 index 00000000000..7011e7f4879 --- /dev/null +++ b/examples/collectors/multi_weight_updates.py @@ -0,0 +1,115 @@ +"""Example of updating weights of several models at once in a multiprocessed data collector. + +This example demonstrates: +1. Using different weight sync schemes for different models +2. Updating the policy (via pipes with MultiProcessWeightSyncScheme) +3. Updating Ray-based transforms in env and replay buffer (via RayModuleTransformScheme) +4. Atomic multi-model weight updates using weights_dict + +Note: +- Ray actors are shared across all workers, so RayModuleTransformScheme uses a + single transport rather than per-worker pipes. +- When using transform_factory with a replay buffer, delayed_init automatically defaults + to True for proper serialization in multiprocessing contexts. +- extend_buffer defaults to True in all collectors, extending the buffer with entire + rollouts rather than individual frames for better compatibility with postprocessing. +""" + +from functools import partial + +import torch.nn as nn +from tensordict import TensorDict +from tensordict.nn import TensorDictModule + +from torchrl.collectors import MultiSyncDataCollector +from torchrl.data import LazyTensorStorage, ReplayBuffer +from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.transforms.module import ModuleTransform +from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme + + +def make_module(): + # A module that transforms the observations + return TensorDictModule( + nn.Linear(3, 3), in_keys=["observation"], out_keys=["observation"] + ) + + +def policy_factory(): + # A module that produces the actions + return TensorDictModule( + nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"] + ) + + +def make_env(): + env_module = ModuleTransform( + module_factory=make_module, inverse=False, no_grad=True + ) + return GymEnv("Pendulum-v1").append_transform(env_module) + + +def main(): + rb = ReplayBuffer( + storage=LazyTensorStorage(10000, shared_init=True), + transform_factory=partial( + ModuleTransform, + module_factory=make_module, + inverse=True, + no_grad=True, + ), + # delayed_init automatically defaults to True when transform_factory is provided + ) + + policy = policy_factory() + + weight_sync_schemes = { + "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), + "replay_buffer.transform[0].module": MultiProcessWeightSyncScheme( + strategy="tensordict" + ), + "env.transform[0].module": MultiProcessWeightSyncScheme(strategy="tensordict"), + } + + collector = MultiSyncDataCollector( + create_env_fn=[make_env, make_env], + policy_factory=policy_factory, + total_frames=2000, + max_frames_per_traj=50, + frames_per_batch=200, + init_random_frames=-1, + device="cpu", + storing_device="cpu", + weight_sync_schemes=weight_sync_schemes, + replay_buffer=rb, + local_init_rb=True, + # extend_buffer=True is the default for MultiSyncDataCollector + ) + + policy_weights = TensorDict.from_module(policy).data + env_module_weights = TensorDict.from_module(make_module()).data + rb_module_weights = TensorDict.from_module(make_module()).data + + for i, _data in enumerate(collector): + env_module_weights.zero_() + rb_module_weights.zero_() + policy_weights.zero_() + + collector.update_policy_weights_( + weights_dict={ + "policy": policy_weights, + "env.transform[0].module": env_module_weights, + "replay_buffer.transform[0].module": rb_module_weights, + } + ) + + assert len(rb) == i * 200 + 200 + + if i >= 10: + break + + collector.shutdown() + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/expert-iteration/ei_utils.py b/sota-implementations/expert-iteration/ei_utils.py index 552776fa6b7..c6732c1763e 100644 --- a/sota-implementations/expert-iteration/ei_utils.py +++ b/sota-implementations/expert-iteration/ei_utils.py @@ -15,10 +15,10 @@ from torch import device as torch_device, dtype as torch_dtype from torchrl._utils import logger as torchrl_logger -from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater from torchrl.envs.llm import RetrieveLogProb from torchrl.envs.llm.datasets.ifeval import IFEvalEnv from torchrl.modules.llm import TransformersWrapper, vLLMWrapper +from torchrl.weight_update.llm import VLLMWeightSyncScheme from transformers.models.auto.modeling_auto import AutoModelForCausalLM from transformers.tokenization_utils import PreTrainedTokenizer @@ -479,42 +479,40 @@ def get_hf_model( torch.set_default_dtype(original_dtype) -def make_weight_updater( - policy_training=None, +def make_weight_sync_scheme( master_address=None, master_port=None, - model_metadata=None, - vllm_tp_size=None, -) -> vLLMUpdater: - """Creates a vLLM weight updater for the policy. + vllm_tp_size=1, +) -> VLLMWeightSyncScheme: + """Creates a vLLM weight synchronization scheme using NCCL collectives. - This function can be used in two ways: - 1. Synchronous mode (expert-iteration-sync.py): Pass policy_training to get an initialized updater with metadata - 2. Async mode (expert-iteration-async.py): Pass master_address, master_port, model_metadata, and remote_actor + This function creates a weight sync scheme that uses NCCL for high-performance + GPU-to-GPU weight transfers from the training model to vLLM inference workers. Args: - policy_training (Optional[TransformersWrapper]): The training policy model. Required for sync mode. - master_address (Optional[str]): Ray master address for async mode. - master_port (Optional[int]): Ray master port for async mode. - model_metadata (Optional[dict]): Model metadata for async mode. If not provided but policy_training is, - it will be extracted from the policy. - vllm_tp_size (Optional[int]): vLLM tensor parallel size. If not provided, will be set to 1. + master_address (Optional[str]): Address of the master node for distributed init. + Defaults to "localhost". + master_port (Optional[int]): Port of the master node for distributed init. + If None, will auto-assign. + vllm_tp_size (int): vLLM tensor parallel size (gpus_per_replica). Defaults to 1. Returns: - vLLMUpdater: An instance of the weight updater configured to update - the vLLM worker's weights. + VLLMWeightSyncScheme: A weight sync scheme configured for the vLLM engine. """ - if model_metadata is None and policy_training is not None: - # Extract metadata from training policy - model_metadata = { - k: (v.dtype, v.shape) for k, v in policy_training.model.state_dict().items() - } + if master_address is None: + master_address = "localhost" + + torchrl_logger.info( + f"Creating VLLMWeightSyncScheme with tp_size={vllm_tp_size}, " + f"master_address={master_address}, master_port={master_port}" + ) - return vLLMUpdater( + return VLLMWeightSyncScheme( master_address=master_address, master_port=master_port, - model_metadata=model_metadata, - vllm_tp_size=vllm_tp_size, + gpus_per_replica=vllm_tp_size, + num_replicas=1, # For expert iteration, typically 1 replica + strategy="state_dict", ) diff --git a/sota-implementations/expert-iteration/expert-iteration-async.py b/sota-implementations/expert-iteration/expert-iteration-async.py index 89e93223b75..75f8d39462d 100644 --- a/sota-implementations/expert-iteration/expert-iteration-async.py +++ b/sota-implementations/expert-iteration/expert-iteration-async.py @@ -13,9 +13,9 @@ import hydra from torchrl import torchrl_logger -from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater from torchrl.data.llm.history import History from torchrl.record.loggers.wandb import WandbLogger +from torchrl.weight_update.llm import get_model_metadata try: import ray @@ -33,7 +33,7 @@ get_train_model, log_training_metrics, make_env, - make_weight_updater, + make_weight_sync_scheme, RemoteDataLogger, ) from omegaconf import DictConfig @@ -115,26 +115,39 @@ def train( if cfg.model.compile: loss_fn = torch.compile(loss_fn) - # Get metadata - model_metadata = vLLMUpdater.get_model_metadata(policy_training) + # Get vLLM engine from the inference policy + # Note: In expert iteration, the inference policy is typically created in get_inference_model + # We need to get the vLLM engine from the collector's policy or create it + # For now, we'll use the approach similar to GRPO with explicit scheme creation - # Create weight updater with remote LLM - weight_updater: vLLMUpdater = make_weight_updater( + # Create weight sync scheme + weight_sync_scheme = make_weight_sync_scheme( master_address="localhost", # Since we're running locally master_port=None, # Will auto-assign an open port - model_metadata=model_metadata, vllm_tp_size=cfg.inference_model.num_devices if cfg.inference_model.num_devices is not None else len(cfg.inference_model.get("devices", [1])), ) - collector.weight_updater = weight_updater - # Initialize the weight updater - weight_updater.init(model_metadata=model_metadata) + # Set up weight sender + torchrl_logger.info("Setting up weight synchronization scheme...") + sender = weight_sync_scheme.create_sender() + sender.register_model(policy_training) - # First update the weights + # Get vLLM engine reference from collector's policy + # The collector has the policy which wraps the vLLM engine + vllm_engine = collector.policy.model if hasattr(collector, "policy") else None + if vllm_engine is None: + raise RuntimeError("Could not get vLLM engine from collector policy") + + # Initialize collective group + torchrl_logger.info("Initializing collective group...") + metadata = get_model_metadata(policy_training) + sender.init_all_workers_group(metadata, vllm_engine=vllm_engine) + + # First weight update with timeit("update_policy_weights"): - weight_updater.push_weights(policy_training) + sender.update_weights() timeit.print(prefix="First update_policy_weights_ time") timeit.reset() @@ -329,7 +342,7 @@ def train( if step % cfg.train.weight_update_frequency == 0: with timeit("update_policy_weights"): torchrl_logger.info("Updating policy weights...") - weight_updater.push_weights(policy_training) + sender.update_weights() # TODO: do we need this? Does it interfere with other processes? # torch.cuda.empty_cache() gc.collect() diff --git a/sota-implementations/expert-iteration/expert-iteration-sync.py b/sota-implementations/expert-iteration/expert-iteration-sync.py index 6670adc1201..126c188b6e9 100644 --- a/sota-implementations/expert-iteration/expert-iteration-sync.py +++ b/sota-implementations/expert-iteration/expert-iteration-sync.py @@ -13,9 +13,9 @@ import hydra from torchrl import torchrl_logger -from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater from torchrl.data.llm.history import History from torchrl.record.loggers.wandb import WandbLogger +from torchrl.weight_update.llm import get_model_metadata try: import ray @@ -35,7 +35,7 @@ get_train_model, log_training_metrics, make_env, - make_weight_updater, + make_weight_sync_scheme, RemoteDataLogger, ) from omegaconf import DictConfig @@ -114,26 +114,33 @@ def train( if cfg.model.compile: loss_fn = torch.compile(loss_fn) - # Get metadata - model_metadata = vLLMUpdater.get_model_metadata(policy_training) - - # Create weight updater with remote LLM - weight_updater: vLLMUpdater = make_weight_updater( + # Create weight sync scheme + weight_sync_scheme = make_weight_sync_scheme( master_address="localhost", # Since we're running locally master_port=None, # Will auto-assign an open port - model_metadata=model_metadata, vllm_tp_size=cfg.inference_model.num_devices if cfg.inference_model.num_devices is not None else len(cfg.inference_model.get("devices", [1])), ) - collector.weight_updater = weight_updater - # Initialize the weight updater - weight_updater.init(model_metadata=model_metadata) + # Set up weight sender + torchrl_logger.info("Setting up weight synchronization scheme...") + sender = weight_sync_scheme.create_sender() + sender.register_model(policy_training) + + # Get vLLM engine reference from collector's policy + vllm_engine = collector.policy.model if hasattr(collector, "policy") else None + if vllm_engine is None: + raise RuntimeError("Could not get vLLM engine from collector policy") + + # Initialize collective group + torchrl_logger.info("Initializing collective group...") + metadata = get_model_metadata(policy_training) + sender.init_all_workers_group(metadata, vllm_engine=vllm_engine) - # First update the weights + # First weight update with timeit("update_policy_weights"): - weight_updater.push_weights(policy_training) + sender.update_weights() timeit.print(prefix="First update_policy_weights_ time") timeit.reset() @@ -333,7 +340,7 @@ def train( ): with timeit("update_policy_weights"): torchrl_logger.info("Updating policy weights...") - weight_updater.push_weights(policy_training) + sender.update_weights() torch.cuda.empty_cache() gc.collect() # Checkpointing disabled to prevent disk space issues @@ -356,7 +363,7 @@ def train( # If weight_update_frequency is not set, we update the weights after each batch with timeit("update_policy_weights"): torchrl_logger.info("Updating policy weights...") - weight_updater.push_weights(policy_training) + sender.update_weights() torch.cuda.empty_cache() gc.collect() diff --git a/sota-implementations/grpo/grpo-async.py b/sota-implementations/grpo/grpo-async.py index a1553dda742..e94d25c56fc 100644 --- a/sota-implementations/grpo/grpo-async.py +++ b/sota-implementations/grpo/grpo-async.py @@ -14,9 +14,9 @@ import hydra from torchrl import torchrl_logger -from torchrl.collectors.llm.weight_update.vllm_v2 import vLLMUpdaterV2 from torchrl.data.llm.history import History from torchrl.record.loggers.wandb import WandbLogger +from torchrl.weight_update.llm import get_model_metadata try: import ray @@ -35,7 +35,7 @@ get_train_model, log_training_metrics, make_env, - make_weight_updater, + make_weight_sync_scheme, ) from omegaconf import DictConfig @@ -110,14 +110,28 @@ def train( loss_fn = torch.compile(loss_fn) vllm_engine = inference_policy.model - weight_updater: vLLMUpdaterV2 = make_weight_updater(vllm_engine=vllm_engine) - for collector in tqdm.tqdm(collectors, desc="Setting weight updater"): - collector.weight_updater = weight_updater - torchrl_logger.info("Initializing weight updater...") - weight_updater.init() + # Create weight sync scheme for the collectors + weight_sync_scheme = make_weight_sync_scheme(vllm_engine=vllm_engine) + + # Set up weight sync scheme for collectors + # Note: We need to get the sender after the collectors are created + # For now, we'll update the collectors to use the scheme + torchrl_logger.info("Setting up weight synchronization scheme...") + + # We'll need to manually set up the sender since collectors were already created + # without the scheme. In production, collectors should be created with weight_sync_schemes parameter. + sender = weight_sync_scheme.create_sender() + sender.register_model(policy_training) + + # Initialize collective group + torchrl_logger.info("Initializing collective group...") + metadata = get_model_metadata(policy_training) + sender.init_all_workers_group(metadata, vllm_engine=vllm_engine) + + # First weight update with timeit("update_policy_weights"): - weight_updater.push_weights_from_transformers(policy_training) + sender.update_weights() torchrl_logger.info("Completed first update_policy_weights. Starting collectors...") timeit.print(prefix="First update_policy_weights_ time") timeit.reset() @@ -248,7 +262,7 @@ def train( if step % cfg.train.weight_update_frequency == 0: with timeit("update_policy_weights"): torchrl_logger.info("Updating policy weights...") - weight_updater.push_weights_from_transformers(policy_training) + sender.update_weights() # TODO: do we need this? Does it interfere with other processes? # torch.cuda.empty_cache() gc.collect() diff --git a/sota-implementations/grpo/grpo-sync.py b/sota-implementations/grpo/grpo-sync.py index 3e3df9d63cf..309581d6c75 100644 --- a/sota-implementations/grpo/grpo-sync.py +++ b/sota-implementations/grpo/grpo-sync.py @@ -13,9 +13,9 @@ import hydra from torchrl import torchrl_logger -from torchrl.collectors.llm.weight_update.vllm_v2 import vLLMUpdaterV2 from torchrl.data.llm.history import History from torchrl.record.loggers.wandb import WandbLogger +from torchrl.weight_update.llm import get_model_metadata try: import ray @@ -36,7 +36,7 @@ get_train_model, log_training_metrics, make_env, - make_weight_updater, + make_weight_sync_scheme, ) from omegaconf import DictConfig @@ -111,12 +111,23 @@ def train( loss_fn = torch.compile(loss_fn) vllm_engine = inference_policy.model - weight_updater: vLLMUpdaterV2 = make_weight_updater(vllm_engine=vllm_engine) - collector.weight_updater = weight_updater - weight_updater.init() + # Create weight sync scheme + weight_sync_scheme = make_weight_sync_scheme(vllm_engine=vllm_engine) + + # Set up weight sender + torchrl_logger.info("Setting up weight synchronization scheme...") + sender = weight_sync_scheme.create_sender() + sender.register_model(policy_training) + + # Initialize collective group + torchrl_logger.info("Initializing collective group...") + metadata = get_model_metadata(policy_training) + sender.init_all_workers_group(metadata, vllm_engine=vllm_engine) + + # First weight update with timeit("update_policy_weights"): - weight_updater.push_weights_from_transformers(policy_training) + sender.update_weights() timeit.print(prefix="First update_policy_weights_ time") timeit.reset() @@ -267,7 +278,7 @@ def train( with timeit("update_policy_weights"): torchrl_logger.info("Updating policy weights...") - weight_updater.push_weights_from_transformers(policy_training) + sender.update_weights() # TODO: do we need this? Does it interfere with other processes? # torch.cuda.empty_cache() gc.collect() diff --git a/sota-implementations/grpo/grpo_utils.py b/sota-implementations/grpo/grpo_utils.py index d217710bc81..5b05136fc0b 100644 --- a/sota-implementations/grpo/grpo_utils.py +++ b/sota-implementations/grpo/grpo_utils.py @@ -15,10 +15,10 @@ from torch import device as torch_device, dtype as torch_dtype from torchrl._utils import logger as torchrl_logger, timeit -from torchrl.collectors.llm.weight_update.vllm_v2 import vLLMUpdaterV2 from torchrl.envs.llm import AddThinkingPrompt, GSM8KEnv, KLRewardTransform, RetrieveKL from torchrl.envs.llm.datasets.ifeval import IFEvalEnv from torchrl.modules.llm import TransformersWrapper, vLLMWrapper +from torchrl.weight_update.llm import VLLMWeightSyncScheme from transformers.models.auto.modeling_auto import AutoModelForCausalLM from transformers.tokenization_utils import PreTrainedTokenizer @@ -544,23 +544,41 @@ def get_hf_model( torch.set_default_dtype(original_dtype) -def make_weight_updater( +def make_weight_sync_scheme( vllm_engine, -) -> vLLMUpdaterV2: - """Creates a vLLM weight updater for the policy using the new V2 API. +) -> VLLMWeightSyncScheme: + """Creates a vLLM weight synchronization scheme using NCCL collectives. - The V2 updater is much simpler - it just needs a vLLM engine that implements - the RLvLLMEngine interface (like RayLLMWorker, LocalLLMWrapper, or AsyncVLLM). + This function creates a weight sync scheme that uses NCCL for high-performance + GPU-to-GPU weight transfers from the training model to vLLM inference workers. Args: - vllm_engine: A vLLM engine implementing the RLvLLMEngine interface. + vllm_engine: A vLLM engine implementing the RLvLLMEngine interface + (like RayLLMWorker, LocalLLMWrapper, or AsyncVLLM). This is typically obtained from the inference policy's model attribute. Returns: - vLLMUpdaterV2: An instance of the weight updater configured to update - the vLLM worker's weights through the engine's own methods. + VLLMWeightSyncScheme: A weight sync scheme configured for the vLLM engine. """ - return vLLMUpdaterV2(vllm_engine=vllm_engine) + # Get configuration from the vLLM engine + tp_size = vllm_engine.get_tp_size() + num_replicas = getattr(vllm_engine, "num_replicas", 1) + master_address = vllm_engine.get_master_address() + master_port = vllm_engine.get_master_port() + + torchrl_logger.info( + f"Creating VLLMWeightSyncScheme with tp_size={tp_size}, " + f"num_replicas={num_replicas}, master_address={master_address}, " + f"master_port={master_port}" + ) + + return VLLMWeightSyncScheme( + master_address=master_address, + master_port=master_port, + gpus_per_replica=tp_size, + num_replicas=num_replicas, + strategy="state_dict", + ) def compute_device_allocation(cfg): diff --git a/test/test_collector.py b/test/test_collector.py index a47d0a8aba0..bb0c0330bf7 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1486,6 +1486,10 @@ def env_fn(seed): @pytest.mark.parametrize("cudagraph", [False, True]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") def test_update_weights(self, use_async, cudagraph): + from torchrl.weight_update.weight_sync_schemes import ( + MultiProcessWeightSyncScheme, + ) + def create_env(): return ContinuousActionVecMockEnv() @@ -1506,6 +1510,7 @@ def create_env(): frames_per_batch=20, cat_results="stack", cudagraph_policy=cudagraph, + weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, ) try: # collect state_dict @@ -1553,6 +1558,69 @@ def create_env(): collector.shutdown() del collector + @pytest.mark.parametrize( + "use_async", [True] + ) # MultiSync has known indexing issues with SharedMem + def test_update_weights_shared_mem(self, use_async): + """Test shared memory weight synchronization scheme.""" + from tensordict import TensorDict + from torchrl.weight_update.weight_sync_schemes import SharedMemWeightSyncScheme + + def create_env(): + return ContinuousActionVecMockEnv() + + n_actions = ContinuousActionVecMockEnv().action_spec.shape[-1] + policy = SafeModule( + torch.nn.LazyLinear(n_actions), in_keys=["observation"], out_keys=["action"] + ) + policy(create_env().reset()) + + # Get policy weights and put them in shared memory + policy_weights = TensorDict.from_module(policy) + policy_weights.share_memory_() + + # Create shared memory weight sync scheme + weight_sync_scheme = SharedMemWeightSyncScheme() + weight_sync_scheme.register_shared_weights("policy", policy_weights) + + collector_class = ( + MultiSyncDataCollector if not use_async else MultiaSyncDataCollector + ) + collector = collector_class( + [create_env] * 3, + policy=policy, + frames_per_batch=20, + cat_results="stack", + weight_sync_schemes={"policy": weight_sync_scheme}, + ) + try: + # Collect first batch + for _ in collector: + break + + # Change policy weights + old_weight = policy.module.weight.data.clone() + for p in policy.parameters(): + p.data += torch.randn_like(p) + new_weight = policy.module.weight.data.clone() + + # Verify weights changed + assert not torch.allclose(old_weight, new_weight) + + # Update weights using shared memory + collector.update_policy_weights_() + + # Collect another batch - should use new weights + for _ in collector: + break + + # Verify shared memory was updated + assert torch.allclose(policy_weights["module", "weight"], new_weight) + + finally: + collector.shutdown() + del collector + @pytest.mark.parametrize("num_env", [1, 2]) @pytest.mark.parametrize("env_name", ["vec"]) @pytest.mark.parametrize("frames_per_batch_worker", [[10, 10], [15, 5]]) @@ -2209,23 +2277,23 @@ def env_fn(seed): @pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") @pytest.mark.parametrize( - "collector_class", + "collector_class,num_envs", [ - SyncDataCollector, - MultiaSyncDataCollector, - functools.partial(MultiSyncDataCollector, cat_results="stack"), + (SyncDataCollector, 1), + (MultiaSyncDataCollector, 1), + (functools.partial(MultiSyncDataCollector, cat_results="stack"), 1), + (MultiaSyncDataCollector, 2), + (functools.partial(MultiSyncDataCollector, cat_results="stack"), 2), ], ) class TestAutoWrap: - num_envs = 1 - @pytest.fixture def env_maker(self): from torchrl.envs.libs.gym import GymEnv return lambda: GymEnv(PENDULUM_VERSIONED()) - def _create_collector_kwargs(self, env_maker, collector_class, policy): + def _create_collector_kwargs(self, env_maker, collector_class, policy, num_envs): collector_kwargs = { "create_env_fn": env_maker, "policy": policy, @@ -2235,7 +2303,7 @@ def _create_collector_kwargs(self, env_maker, collector_class, policy): if collector_class is not SyncDataCollector: collector_kwargs["create_env_fn"] = [ - collector_kwargs["create_env_fn"] for _ in range(self.num_envs) + collector_kwargs["create_env_fn"] for _ in range(num_envs) ] return collector_kwargs @@ -2243,7 +2311,7 @@ def _create_collector_kwargs(self, env_maker, collector_class, policy): @pytest.mark.parametrize("multiple_outputs", [True, False]) @pytest.mark.parametrize("device", get_default_devices()) def test_auto_wrap_modules( - self, collector_class, multiple_outputs, env_maker, device + self, collector_class, multiple_outputs, env_maker, device, num_envs ): policy = WrappablePolicy( out_features=env_maker().action_spec.shape[-1], @@ -2253,33 +2321,40 @@ def test_auto_wrap_modules( policy(env_maker().reset().get("observation")) collector = collector_class( - **self._create_collector_kwargs(env_maker, collector_class, policy), + **self._create_collector_kwargs( + env_maker, collector_class, policy, num_envs + ), device=device, ) - out_keys = ["action"] - if multiple_outputs: - out_keys.extend(f"output{i}" for i in range(1, 4)) - - if collector_class is SyncDataCollector: - assert isinstance(collector.policy, TensorDictModule) - assert collector.policy.out_keys == out_keys - # this does not work now that we force the device of the policy - # assert collector.policy.module is policy + try: + out_keys = ["action"] + if multiple_outputs: + out_keys.extend(f"output{i}" for i in range(1, 4)) - for i, data in enumerate(collector): - if i == 0: - assert (data["action"] != 0).any() - for p in policy.parameters(): - p.data.zero_() - assert p.device == torch.device("cpu") - collector.update_policy_weights_() - elif i == 4: - assert (data["action"] == 0).all() - break + if collector_class is SyncDataCollector: + assert isinstance(collector._wrapped_policy, TensorDictModule) + assert collector._wrapped_policy.out_keys == out_keys + # this does not work now that we force the device of the policy + # assert collector.policy.module is policy - collector.shutdown() - del collector + for i, data in enumerate(collector): + # Debug: iteration {i} + if i == 0: + assert (data["action"] != 0).any() + for p in policy.parameters(): + p.data.zero_() + assert p.device == torch.device("cpu") + # Debug: updating policy weights + collector.update_policy_weights_() + # Debug: updated policy weights + elif i == 4: + assert (data["action"] == 0).all() + break + finally: + # Debug: shutting down collector + collector.shutdown() + del collector # Deprecated as from v0.3 # def test_no_wrap_compatible_module(self, collector_class, env_maker): @@ -2314,14 +2389,16 @@ def test_auto_wrap_modules( # collector.shutdown() # del collector - def test_auto_wrap_error(self, collector_class, env_maker): + def test_auto_wrap_error(self, collector_class, env_maker, num_envs): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) with pytest.raises( TypeError, match=("Arguments to policy.forward are incompatible with entries in"), ): collector_class( - **self._create_collector_kwargs(env_maker, collector_class, policy) + **self._create_collector_kwargs( + env_maker, collector_class, policy, num_envs + ) ) @@ -2779,13 +2856,22 @@ def forward(self, td): ], ) def test_param_sync(self, give_weights, collector, policy_device, env_device): + from torchrl.weight_update.weight_sync_schemes import ( + MultiProcessWeightSyncScheme, + ) + policy = TestUpdateParams.Policy().to(policy_device) env = EnvCreator(lambda: TestUpdateParams.DummyEnv(device=env_device)) device = env().device env = [env] col = collector( - env, policy, device=device, total_frames=200, frames_per_batch=10 + env, + policy, + device=device, + total_frames=200, + frames_per_batch=10, + weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, ) try: for i, data in enumerate(col): @@ -2833,6 +2919,10 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device): def test_param_sync_mixed_device( self, give_weights, collector, policy_device, env_device ): + from torchrl.weight_update.weight_sync_schemes import ( + MultiProcessWeightSyncScheme, + ) + with torch.device("cpu"): policy = TestUpdateParams.Policy() policy.param = nn.Parameter(policy.param.data.to(policy_device)) @@ -2842,7 +2932,12 @@ def test_param_sync_mixed_device( device = env().device env = [env] col = collector( - env, policy, device=device, total_frames=200, frames_per_batch=10 + env, + policy, + device=device, + total_frames=200, + frames_per_batch=10, + weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, ) try: for i, data in enumerate(col): @@ -3865,6 +3960,10 @@ def test_start_multi(self, total_frames, cls): "cls", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] ) def test_start_update_policy(self, total_frames, cls): + from torchrl.weight_update.weight_sync_schemes import ( + MultiProcessWeightSyncScheme, + ) + rb = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000)) env = CountingEnv() m = nn.Linear(env.observation_spec["observation"].shape[-1], 1) @@ -3882,12 +3981,19 @@ def test_start_update_policy(self, total_frames, cls): td = TensorDict.from_module(policy).data.clone() if cls != SyncDataCollector: env = [CountingEnv] * 2 + + # Add weight sync schemes for multi-process collectors + kwargs = {} + if cls != SyncDataCollector: + kwargs["weight_sync_schemes"] = {"policy": MultiProcessWeightSyncScheme()} + collector = cls( env, policy, replay_buffer=rb, total_frames=total_frames, frames_per_batch=16, + **kwargs, ) try: collector.start() @@ -3913,4 +4019,6 @@ def test_start_update_policy(self, total_frames, cls): if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) + pytest.main( + [__file__, "--capture", "no", "--exitfirst", "--timeout", "180"] + unknown + ) diff --git a/test/test_distributed.py b/test/test_distributed.py index 1f03d385607..6183132394e 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -102,26 +102,29 @@ def _start_worker(cls): @classmethod def _test_distributed_collector_basic(cls, queue, frames_per_batch): - cls._start_worker() - env = ContinuousActionVecMockEnv - policy = RandomPolicy(env().action_spec) - torchrl_logger.info("creating collector") - collector = cls.distributed_class()( - [env] * 2, - policy, - total_frames=1000, - frames_per_batch=frames_per_batch, - **cls.distributed_kwargs(), - ) - total = 0 - torchrl_logger.info("getting data...") - for data in collector: - total += data.numel() - assert data.numel() == frames_per_batch - assert data.names[-1] == "time" - collector.shutdown() - assert total == 1000 - queue.put("passed") + try: + cls._start_worker() + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) + torchrl_logger.info("creating collector") + collector = cls.distributed_class()( + [env] * 2, + policy, + total_frames=1000, + frames_per_batch=frames_per_batch, + **cls.distributed_kwargs(), + ) + total = 0 + torchrl_logger.info("getting data...") + for data in collector: + total += data.numel() + assert data.numel() == frames_per_batch + assert data.names[-1] == "time" + collector.shutdown() + assert total == 1000 + queue.put("passed") + except Exception as e: + queue.put(f"not passed: {str(e)}") @pytest.mark.parametrize("frames_per_batch", [50, 100]) def test_distributed_collector_basic(self, frames_per_batch): @@ -143,23 +146,26 @@ def test_distributed_collector_basic(self, frames_per_batch): @classmethod def _test_distributed_collector_mult(cls, queue, frames_per_batch): - cls._start_worker() - env = ContinuousActionVecMockEnv - policy = RandomPolicy(env().action_spec) - collector = cls.distributed_class()( - [env] * 2, - policy, - total_frames=1000, - frames_per_batch=frames_per_batch, - **cls.distributed_kwargs(), - ) - total = 0 - for data in collector: - total += data.numel() - assert data.numel() == frames_per_batch - collector.shutdown() - assert total == -frames_per_batch * (1000 // -frames_per_batch) - queue.put("passed") + try: + cls._start_worker() + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) + collector = cls.distributed_class()( + [env] * 2, + policy, + total_frames=1000, + frames_per_batch=frames_per_batch, + **cls.distributed_kwargs(), + ) + total = 0 + for data in collector: + total += data.numel() + assert data.numel() == frames_per_batch + collector.shutdown() + assert total == -frames_per_batch * (1000 // -frames_per_batch) + queue.put("passed") + except Exception as e: + queue.put(f"not passed: {e}") def test_distributed_collector_mult(self, frames_per_batch=200): """Testing multiple nodes.""" @@ -181,24 +187,27 @@ def test_distributed_collector_mult(self, frames_per_batch=200): @classmethod def _test_distributed_collector_sync(cls, queue, sync): - frames_per_batch = 50 - env = ContinuousActionVecMockEnv - policy = RandomPolicy(env().action_spec) - collector = cls.distributed_class()( - [env] * 2, - policy, - total_frames=200, - frames_per_batch=frames_per_batch, - sync=sync, - **cls.distributed_kwargs(), - ) - total = 0 - for data in collector: - total += data.numel() - assert data.numel() == frames_per_batch - collector.shutdown() - assert total == 200 - queue.put("passed") + try: + frames_per_batch = 50 + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) + collector = cls.distributed_class()( + [env] * 2, + policy, + total_frames=200, + frames_per_batch=frames_per_batch, + sync=sync, + **cls.distributed_kwargs(), + ) + total = 0 + for data in collector: + total += data.numel() + assert data.numel() == frames_per_batch + collector.shutdown() + assert total == 200 + queue.put("passed") + except Exception as e: + queue.put(f"not passed: {str(e)}") @pytest.mark.parametrize("sync", [False, True]) def test_distributed_collector_sync(self, sync): @@ -220,24 +229,27 @@ def test_distributed_collector_sync(self, sync): @classmethod def _test_distributed_collector_class(cls, queue, collector_class): - frames_per_batch = 50 - env = ContinuousActionVecMockEnv - policy = RandomPolicy(env().action_spec) - collector = cls.distributed_class()( - [env] * 2, - policy, - collector_class=collector_class, - total_frames=200, - frames_per_batch=frames_per_batch, - **cls.distributed_kwargs(), - ) - total = 0 - for data in collector: - total += data.numel() - assert data.numel() == frames_per_batch - collector.shutdown() - assert total == 200 - queue.put("passed") + try: + frames_per_batch = 50 + env = ContinuousActionVecMockEnv + policy = RandomPolicy(env().action_spec) + collector = cls.distributed_class()( + [env] * 2, + policy, + collector_class=collector_class, + total_frames=200, + frames_per_batch=frames_per_batch, + **cls.distributed_kwargs(), + ) + total = 0 + for data in collector: + total += data.numel() + assert data.numel() == frames_per_batch + collector.shutdown() + assert total == 200 + queue.put("passed") + except Exception as e: + queue.put(f"not passed: {str(e)}") @pytest.mark.parametrize( "collector_class", @@ -266,42 +278,45 @@ def test_distributed_collector_class(self, collector_class): @classmethod def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): - frames_per_batch = 50 - total_frames = 300 - env = CountingEnv - policy = CountingPolicy() - if collector_class is MultiaSyncDataCollector: - # otherwise we may collect data from a collector that has not yet been - # updated - n_collectors = 1 - else: - n_collectors = 2 - collector = cls.distributed_class()( - [env] * n_collectors, - policy, - collector_class=collector_class, - total_frames=total_frames, - frames_per_batch=frames_per_batch, - sync=sync, - **cls.distributed_kwargs(), - ) - total = 0 - first_batch = None - last_batch = None - for i, data in enumerate(collector): - total += data.numel() - assert data.numel() == frames_per_batch - if i == 0: - first_batch = data - policy.weight.data += 1 - collector.update_policy_weights_() - elif total == total_frames - frames_per_batch: - last_batch = data - assert (first_batch["action"] == 1).all(), first_batch["action"] - assert (last_batch["action"] == 2).all(), last_batch["action"] - collector.shutdown() - assert total == total_frames - queue.put("passed") + try: + frames_per_batch = 50 + total_frames = 300 + env = CountingEnv + policy = CountingPolicy() + if collector_class is MultiaSyncDataCollector: + # otherwise we may collect data from a collector that has not yet been + # updated + n_collectors = 1 + else: + n_collectors = 2 + collector = cls.distributed_class()( + [env] * n_collectors, + policy, + collector_class=collector_class, + total_frames=total_frames, + frames_per_batch=frames_per_batch, + sync=sync, + **cls.distributed_kwargs(), + ) + total = 0 + first_batch = None + last_batch = None + for i, data in enumerate(collector): + total += data.numel() + assert data.numel() == frames_per_batch + if i == 0: + first_batch = data + policy.weight.data += 1 + collector.update_policy_weights_() + elif total == total_frames - frames_per_batch: + last_batch = data + assert (first_batch["action"] == 1).all(), first_batch["action"] + assert (last_batch["action"] == 2).all(), last_batch["action"] + collector.shutdown() + assert total == total_frames + queue.put("passed") + except Exception as e: + queue.put(f"not passed: {str(e)}") @pytest.mark.parametrize( "collector_class", @@ -470,7 +485,6 @@ def distributed_kwargs(cls) -> dict: ray_init_config["runtime_env"] = { "working_dir": os.path.dirname(__file__), "env_vars": {"PYTHONPATH": os.path.dirname(__file__)}, - "pip": ["ray"], } # for ray workers remote_configs = { "num_cpus": 1, diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 9c2d2025087..5da43accf7c 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -10,8 +10,10 @@ import torch import torch.nn as nn from tensordict import TensorDict +from tensordict.nn import TensorDictModule from torch import multiprocessing as mp - +from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector +from torchrl.envs import GymEnv from torchrl.weight_update.weight_sync_schemes import ( _resolve_model, MPTransport, @@ -269,6 +271,226 @@ def test_no_weight_sync_scheme(self): transport.send_weights("policy", weights) +class TestCollectorIntegration: + @pytest.fixture + def simple_env(self): + return GymEnv("CartPole-v1") + + @pytest.fixture + def simple_policy(self, simple_env): + return TensorDictModule( + nn.Linear( + simple_env.observation_spec["observation"].shape[-1], + simple_env.action_spec.shape[-1], + ), + in_keys=["observation"], + out_keys=["action"], + ) + + def test_syncdatacollector_multiprocess_scheme(self, simple_policy): + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=simple_policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes={"policy": scheme}, + ) + + new_weights = simple_policy.state_dict() + with torch.no_grad(): + for key in new_weights: + new_weights[key].fill_(1.0) + + collector.update_policy_weights_(new_weights) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + + def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy): + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + + collector = MultiSyncDataCollector( + create_env_fn=[ + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + ], + policy=simple_policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes={"policy": scheme}, + ) + + new_weights = simple_policy.state_dict() + with torch.no_grad(): + for key in new_weights: + new_weights[key].fill_(1.0) + + collector.update_policy_weights_(new_weights) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + + def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy): + scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + + collector = MultiSyncDataCollector( + create_env_fn=[ + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + ], + policy=simple_policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes={"policy": scheme}, + ) + + new_weights = TensorDict.from_module(simple_policy) + with torch.no_grad(): + new_weights["module"]["weight"].fill_(1.0) + new_weights["module"]["bias"].fill_(1.0) + + collector.update_policy_weights_(new_weights) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + + def test_collector_no_weight_sync(self, simple_policy): + scheme = NoWeightSyncScheme() + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=simple_policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes={"policy": scheme}, + ) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + + +class TestMultiModelUpdates: + def test_multi_model_state_dict_updates(self): + env = GymEnv("CartPole-v1") + + policy = TensorDictModule( + nn.Linear( + env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] + ), + in_keys=["observation"], + out_keys=["action"], + ) + + value = TensorDictModule( + nn.Linear(env.observation_spec["observation"].shape[-1], 1), + in_keys=["observation"], + out_keys=["value"], + ) + + weight_sync_schemes = { + "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), + "value": MultiProcessWeightSyncScheme(strategy="state_dict"), + } + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes=weight_sync_schemes, + ) + + policy_weights = policy.state_dict() + value_weights = value.state_dict() + + with torch.no_grad(): + for key in policy_weights: + policy_weights[key].fill_(1.0) + for key in value_weights: + value_weights[key].fill_(2.0) + + collector.update_policy_weights_( + weights_dict={ + "policy": policy_weights, + "value": value_weights, + } + ) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + env.close() + + def test_multi_model_tensordict_updates(self): + env = GymEnv("CartPole-v1") + + policy = TensorDictModule( + nn.Linear( + env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] + ), + in_keys=["observation"], + out_keys=["action"], + ) + + value = TensorDictModule( + nn.Linear(env.observation_spec["observation"].shape[-1], 1), + in_keys=["observation"], + out_keys=["value"], + ) + + weight_sync_schemes = { + "policy": MultiProcessWeightSyncScheme(strategy="tensordict"), + "value": MultiProcessWeightSyncScheme(strategy="tensordict"), + } + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes=weight_sync_schemes, + ) + + policy_weights = TensorDict.from_module(policy) + value_weights = TensorDict.from_module(value) + + with torch.no_grad(): + policy_weights["module"]["weight"].fill_(1.0) + policy_weights["module"]["bias"].fill_(1.0) + value_weights["module"]["weight"].fill_(2.0) + value_weights["module"]["bias"].fill_(2.0) + + collector.update_policy_weights_( + weights_dict={ + "policy": policy_weights, + "value": value_weights, + } + ) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + env.close() + + class TestHelpers: def test_resolve_model_simple(self): class Context: diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 34937c22d47..355e6e98db0 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -51,11 +51,7 @@ VERBOSE, ) from torchrl.collectors.utils import split_trajectories -from torchrl.collectors.weight_update import ( - MultiProcessedWeightUpdater, - VanillaWeightUpdater, - WeightUpdaterBase, -) +from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data import ReplayBuffer from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import _do_nothing, EnvBase @@ -70,6 +66,13 @@ RandomPolicy, set_exploration_type, ) +from torchrl.weight_update.weight_sync_schemes import ( + _resolve_model, + MultiProcessWeightSyncScheme, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) try: from torch.compiler import cudagraph_mark_step_begin @@ -157,6 +160,9 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): compiled_policy: bool cudagraphed_policy: bool _weight_updater: WeightUpdaterBase | None = None + _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None + _weight_senders: dict[str, WeightSender] | None = None + _weight_receivers: dict[str, WeightReceiver] | None = None verbose: bool = False @property @@ -197,15 +203,6 @@ def _get_policy_and_device( if policy_device is NO_DEFAULT: policy_device = self.policy_device - if not self.trust_policy: - env = getattr(self, "env", None) - policy = _make_compatible_policy( - policy, - getattr(env, "observation_spec", None), - env=env, - env_maker=env_maker, - env_maker_kwargs=env_maker_kwargs, - ) if not policy_device: return policy, None @@ -294,11 +291,59 @@ def async_shutdown( """ return self.shutdown(timeout=timeout, close_env=close_env) + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + Args: + weights: Either already-extracted weights or a model to extract from. + model_id: The model identifier for resolving string paths. + + Returns: + Extracted weights in the appropriate format. + """ + scheme = ( + self._weight_sync_schemes.get(model_id) + if self._weight_sync_schemes + else None + ) + + if weights is None: + if model_id == "policy" and hasattr(self, "policy_weights"): + return self.policy_weights + elif model_id == "policy" and hasattr(self, "_policy_weights_dict"): + policy_device = ( + self.policy_device + if not isinstance(self.policy_device, (list, tuple)) + else self.policy_device[0] + ) + return self._policy_weights_dict.get(policy_device) + return None + + if scheme is None: + return weights + + from torchrl.weight_update.weight_sync_schemes import ( + _resolve_model, + WeightStrategy, + ) + + strategy = WeightStrategy(extract_as=scheme.strategy) + + if isinstance(weights, nn.Module): + return strategy.extract_weights(weights) + elif isinstance(weights, str): + model = _resolve_model(self, weights) + return strategy.extract_weights(model) + else: + return weights + def update_policy_weights_( self, policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, **kwargs, ) -> None: """Updates the policy weights for the data collector, accommodating both local and remote execution contexts. @@ -317,9 +362,15 @@ def update_policy_weights_( worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the workers that need to be updated. This is relevant when the collector has more than one worker associated with it. + model_id (str | None, optional): The model identifier to update. If provided, only updates this specific + model. Cannot be used together with weights_dict. + weights_dict (dict[str, Any] | None, optional): Dictionary mapping model_id to weights for updating + multiple models atomically. Keys should match the model_ids registered in weight_sync_schemes. + Cannot be used together with model_id or policy_or_weights. Raises: TypeError: If `worker_ids` is provided but no `weight_updater` is configured. + ValueError: If conflicting parameters are provided (e.g., both model_id and weights_dict). .. note:: Users should extend the `WeightUpdaterBase` classes to customize the weight update logic for specific use cases. This method should not be overwritten. @@ -335,9 +386,83 @@ def update_policy_weights_( ) policy_or_weights = kwargs.pop("policy_weights") - self.weight_updater( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) + if weights_dict is not None and model_id is not None: + raise ValueError("Cannot specify both 'weights_dict' and 'model_id'") + + if weights_dict is not None and policy_or_weights is not None: + raise ValueError( + "Cannot specify both 'weights_dict' and 'policy_or_weights'" + ) + + # Priority: new weight sync schemes > old weight updater system + if self._weight_senders: + if weights_dict is not None: + for target_model_id, weights in weights_dict.items(): + if target_model_id not in self._weight_senders: + raise KeyError( + f"Model '{target_model_id}' not found in registered weight senders. " + f"Available models: {list(self._weight_senders.keys())}" + ) + processed_weights = self._extract_weights_if_needed( + weights, target_model_id + ) + self._weight_senders[target_model_id].update_weights( + processed_weights + ) + elif model_id is not None: + if model_id not in self._weight_senders: + raise KeyError( + f"Model '{model_id}' not found in registered weight senders. " + f"Available models: {list(self._weight_senders.keys())}" + ) + processed_weights = self._extract_weights_if_needed( + policy_or_weights, model_id + ) + self._weight_senders[model_id].update_weights(processed_weights) + else: + if "policy" in self._weight_senders: + processed_weights = self._extract_weights_if_needed( + policy_or_weights, "policy" + ) + self._weight_senders["policy"].update_weights(processed_weights) + elif len(self._weight_senders) == 1: + single_model_id = next(iter(self._weight_senders.keys())) + single_sender = self._weight_senders[single_model_id] + processed_weights = self._extract_weights_if_needed( + policy_or_weights, single_model_id + ) + single_sender.update_weights(processed_weights) + else: + for target_model_id, sender in self._weight_senders.items(): + processed_weights = self._extract_weights_if_needed( + policy_or_weights, target_model_id + ) + sender.update_weights(processed_weights) + + elif self._weight_updater is not None: + # Fall back to old weight updater system + self.weight_updater( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + else: + # No weight updater configured + # For single-process collectors, apply weights locally if explicitly provided + if policy_or_weights is not None: + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + # Use WeightStrategy to apply weights properly + strategy = WeightStrategy(extract_as="tensordict") + + # Extract weights if needed + if isinstance(policy_or_weights, nn.Module): + weights = strategy.extract_weights(policy_or_weights) + else: + weights = policy_or_weights + + # Apply to local policy + if hasattr(self, "policy") and isinstance(self.policy, nn.Module): + strategy.apply_weights(self.policy, weights) + # Otherwise, no action needed - policy is local and changes are immediately visible def __iter__(self) -> Iterator[TensorDictBase]: try: @@ -547,14 +672,19 @@ class SyncDataCollector(DataCollectorBase): but populate the buffer instead. Defaults to ``None``. - .. seealso:: By default, the buffer is populated every time a (batch of) frames is collected. - If the buffer needs to be extended with entire rollouts, set `extend_buffer` to `True`. + .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts. + If the buffer needs to be populated with individual frames as they are collected, + set ``extend_buffer=False`` (deprecated). - .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` is prohibited unless + .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms. extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not - with single steps. Defaults to `False`. + with single steps. Defaults to `True`. + + .. note:: Setting this to `False` is deprecated and will be removed in a future version. + Extending the buffer with entire rollouts is the recommended approach for better + compatibility with postprocessing and trajectory splitting. trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules and ``False`` otherwise. @@ -636,6 +766,8 @@ class SyncDataCollector(DataCollectorBase): """ + _ignore_rb: bool = False + def __init__( self, create_env_fn: ( @@ -664,7 +796,8 @@ def __init__( set_truncated: bool = False, use_buffers: bool | None = None, replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = False, + extend_buffer: bool = True, + local_init_rb: bool | None = None, trust_policy: bool | None = None, compile_policy: bool | dict[str, Any] | None = None, cudagraph_policy: bool | dict[str, Any] | None = None, @@ -672,6 +805,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, track_policy_version: bool = False, **kwargs, ): @@ -794,33 +928,38 @@ def __init__( # Policy version tracking setup self.policy_version_tracker = track_policy_version - if PolicyVersion is not None: - if isinstance(track_policy_version, bool) and track_policy_version: - from torchrl.envs.batched_envs import BatchedEnvBase + if isinstance(track_policy_version, bool) and track_policy_version: + from torchrl.envs.batched_envs import BatchedEnvBase - if isinstance(self.env, BatchedEnvBase): - raise RuntimeError( - "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " - "and pass that transform to the collector." - ) - self.policy_version_tracker = PolicyVersion() - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - elif hasattr( - track_policy_version, "increment_version" - ): # Check if it's a PolicyVersion instance - self.policy_version_tracker = track_policy_version - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - else: - self.policy_version_tracker = None - else: - if track_policy_version: - raise ImportError( - "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." + if isinstance(self.env, BatchedEnvBase): + raise RuntimeError( + "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " + "and pass that transform to the collector." ) + self.policy_version_tracker = PolicyVersion() + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + elif hasattr( + track_policy_version, "increment_version" + ): # Check if it's a PolicyVersion instance + self.policy_version_tracker = track_policy_version + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + else: self.policy_version_tracker = None self.replay_buffer = replay_buffer self.extend_buffer = extend_buffer - if self.replay_buffer is not None: + + # Handle local_init_rb deprecation for SyncDataCollector + if local_init_rb is None: + local_init_rb = False # Default for SyncDataCollector + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + self.local_init_rb = local_init_rb + + if self.replay_buffer is not None and not self._ignore_rb: if postproc is not None and not self.extend_buffer: raise TypeError( "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True." @@ -852,23 +991,37 @@ def __init__( if hasattr(self.env, "register_collector"): self.env.register_collector(self) - (self.policy, self.get_weights_fn,) = self._get_policy_and_device( + self._original_policy = policy + (policy, self.get_weights_fn,) = self._get_policy_and_device( policy=policy, ) - if isinstance(self.policy, nn.Module): + + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + if isinstance(self._wrapped_policy, nn.Module): self.policy_weights = TensorDict.from_module( - self.policy, as_module=True + self._wrapped_policy, as_module=True ).data else: self.policy_weights = TensorDict() if self.compiled_policy: - self.policy = compile_with_warmup( - self.policy, **self.compiled_policy_kwargs + self._wrapped_policy = compile_with_warmup( + self._wrapped_policy, **self.compiled_policy_kwargs ) if self.cudagraphed_policy: - self.policy = CudaGraphModule( - self.policy, + self._wrapped_policy = CudaGraphModule( + self._wrapped_policy, in_keys=[], out_keys=[], device=self.policy_device, @@ -975,16 +1128,44 @@ def __init__( self._frames = 0 self._iter = -1 - if weight_updater is None: - weight_updater = VanillaWeightUpdater( - weight_getter=self.get_weights_fn, policy_weights=self.policy_weights - ) - elif not isinstance(weight_updater, WeightUpdaterBase): - raise TypeError( - f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." - ) + # Set up weight synchronization - prefer new schemes over legacy updater + # For single-process SyncDataCollector, no weight sync is needed (policy is local) + # Weight sync schemes are only needed for multi-process/distributed collectors + if weight_sync_schemes is not None: + # Use new simplified weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + + # For single-process collectors, we don't need senders/receivers + # The policy is local and changes are immediately visible + # Senders will be set up in multiprocess collectors during _run_processes + + self.weight_updater = None # Don't use legacy system + elif weight_updater is not None: + # Use legacy weight updater system if explicitly provided + if not isinstance(weight_updater, WeightUpdaterBase): + if callable(weight_updater): + weight_updater = weight_updater() + else: + raise TypeError( + f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." + ) - self.weight_updater = weight_updater + warnings.warn( + "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " + "This will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} + else: + # No weight sync needed for single-process collectors + # The policy is local and changes are immediately visible + self.weight_updater = None + self._weight_sync_schemes = None + self._weight_senders = {} @property def _traj_pool(self): @@ -1034,18 +1215,18 @@ def _maybe_make_final_rollout(self, make_rollout: bool): self._policy_output_keys = set() if ( make_rollout - and hasattr(self.policy, "spec") - and self.policy.spec is not None - and all(v is not None for v in self.policy.spec.values(True, True)) + and hasattr(self._wrapped_policy, "spec") + and self._wrapped_policy.spec is not None + and all(v is not None for v in self._wrapped_policy.spec.values(True, True)) ): if any( key not in self._final_rollout.keys(isinstance(key, tuple)) - for key in self.policy.spec.keys(True, True) + for key in self._wrapped_policy.spec.keys(True, True) ): # if policy spec is non-empty, all the values are not None and the keys # match the out_keys we assume the user has given all relevant information # the policy could have more keys than the env: - policy_spec = self.policy.spec + policy_spec = self._wrapped_policy.spec if policy_spec.ndim < self._final_rollout.ndim: policy_spec = policy_spec.expand(self._final_rollout.shape) for key, spec in policy_spec.items(True, True): @@ -1055,10 +1236,10 @@ def _maybe_make_final_rollout(self, make_rollout: bool): self._final_rollout.set(key, spec.zero()) elif ( not make_rollout - and hasattr(self.policy, "out_keys") - and self.policy.out_keys + and hasattr(self._wrapped_policy, "out_keys") + and self._wrapped_policy.out_keys ): - self._policy_output_keys = list(self.policy.out_keys) + self._policy_output_keys = list(self._wrapped_policy.out_keys) else: if make_rollout: # otherwise, we perform a small number of steps with the policy to @@ -1078,7 +1259,7 @@ def _maybe_make_final_rollout(self, make_rollout: bool): ) # to test if values have changed in-place if self.compiled_policy: cudagraph_mark_step_begin() - policy_output = self.policy(policy_input) + policy_output = self._wrapped_policy(policy_input) # check that we don't have exclusive keys, because they don't appear in keys def check_exclusive(val): @@ -1319,7 +1500,7 @@ def cuda_check(tensor: torch.Tensor): event.record() event.synchronize() yield tensordict_out - elif self.replay_buffer is not None: + elif self.replay_buffer is not None and not self._ignore_rb: self.replay_buffer.extend(tensordict_out) if self.verbose: torchrl_logger.info( @@ -1539,7 +1720,7 @@ def rollout(self) -> TensorDictBase: # we still do the assignment for security if self.compiled_policy: cudagraph_mark_step_begin() - policy_output = self.policy(policy_input) + policy_output = self._wrapped_policy(policy_input) if self.compiled_policy: policy_output = policy_output.clone() if self._shuttle is not policy_output: @@ -1575,7 +1756,11 @@ def rollout(self) -> TensorDictBase: next_data.clear_device_() self._shuttle.set("next", next_data) - if self.replay_buffer is not None and not self.extend_buffer: + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): self.replay_buffer.add(self._shuttle) if self._increment_frames(self._shuttle.numel()): return @@ -1606,7 +1791,11 @@ def rollout(self) -> TensorDictBase: self.interruptor is not None and self.interruptor.collection_stopped() ): - if self.replay_buffer is not None and not self.extend_buffer: + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): return result = self._final_rollout if self._use_buffers: @@ -1643,7 +1832,11 @@ def rollout(self) -> TensorDictBase: self._final_rollout.ndim - 1, out=self._final_rollout, ) - elif self.replay_buffer is not None and not self.extend_buffer: + elif ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): return else: result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) @@ -1775,7 +1968,7 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: def __repr__(self) -> str: try: env_str = indent(f"env={self.env}", 4 * " ") - policy_str = indent(f"policy={self.policy}", 4 * " ") + policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") td_out_str = indent( f"td_out={getattr(self, '_final_rollout', None)}", 4 * " " ) @@ -1820,7 +2013,7 @@ def get_policy_version(self) -> str | int | None: def getattr_policy(self, attr): """Get an attribute from the policy.""" # send command to policy to return the attr - return getattr(self.policy, attr) + return getattr(self._wrapped_policy, attr) def getattr_env(self, attr): """Get an attribute from the environment.""" @@ -2002,6 +2195,11 @@ class _MultiDataCollector(DataCollectorBase): but populate the buffer instead. Defaults to ``None``. extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not with single steps. Defaults to `True` for multiprocessed data collectors. + local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize + the replay buffer in the main process (legacy behavior). If ``True``, the storage-level + coordination will handle initialization with real data from worker processes. + Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning. + This parameter is deprecated and will be removed in v0.12. trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules and ``False`` otherwise. @@ -2021,6 +2219,8 @@ class _MultiDataCollector(DataCollectorBase): If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, which handles weight synchronization across multiple processes. Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): A dictionary of weight sync schemes for the different models. + If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track @@ -2064,6 +2264,7 @@ def __init__( replay_buffer: ReplayBuffer | None = None, extend_buffer: bool = True, replay_buffer_chunk: bool | None = None, + local_init_rb: bool | None = None, trust_policy: bool | None = None, compile_policy: bool | dict[str, Any] | None = None, cudagraph_policy: bool | dict[str, Any] | None = None, @@ -2071,6 +2272,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, track_policy_version: bool = False, ): self.closed = True @@ -2142,6 +2344,21 @@ def __init__( self._use_buffers = use_buffers self.replay_buffer = replay_buffer + + # Handle local_init_rb deprecation + if local_init_rb is None: + # v0.11: Default to False (current behavior), show deprecation warning + # v0.12: Default to True (new behavior) + local_init_rb = False # Will become True in 0.12 + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + + self.local_init_rb = local_init_rb + self._check_replay_buffer_init() if replay_buffer_chunk is not None: if extend_buffer is None: @@ -2186,17 +2403,20 @@ def __init__( if type(policy_new_device) is not type(policy): policy = policy_new_device weights = ( - TensorDict.from_module(policy_new_device).data + TensorDict.from_module(policy_new_device) if isinstance(policy_new_device, nn.Module) else TensorDict() ) self._policy_weights_dict[policy_device] = weights self._get_weights_fn = get_weights_fn if weight_updater is None: - weight_updater = MultiProcessedWeightUpdater( - get_server_weights=self._get_weights_fn, - policy_weights=self._policy_weights_dict, - ) + # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default + if weight_sync_schemes is None: + weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} + # Don't create legacy weight updater if we have schemes + else: + # Legacy weight updater was explicitly provided + pass elif weight_updater is None: warnings.warn( "weight_updater is None, but policy_factory is provided. This means that the server will " @@ -2207,7 +2427,18 @@ def __init__( "This will work whenever your inference and training policies are nn.Module instances with similar structures." ) - self.weight_updater = weight_updater + # Set up weight synchronization - prefer new schemes over legacy updater + if weight_sync_schemes is not None: + # Use new simplified weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + # Senders will be created in _run_processes when pipes are available + self.weight_updater = None # Don't use legacy system + else: + # Fall back to legacy weight updater system + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} # Policy version tracking setup self.policy_version_tracker = track_policy_version @@ -2301,8 +2532,17 @@ def __init__( def _check_replay_buffer_init(self): if self.replay_buffer is None: return - is_init = getattr(self.replay_buffer._storage, "initialized", True) + is_init = hasattr(self.replay_buffer, "_storage") and getattr( + self.replay_buffer._storage, "initialized", True + ) if not is_init: + if self.local_init_rb: + # New behavior: storage handles all coordination itself + # Nothing to do here - the storage will coordinate during first write + self.replay_buffer.share() + return + + # Legacy behavior: fake tensordict initialization if isinstance(self.create_env_fn[0], EnvCreator): fake_td = self.create_env_fn[0].meta_data.tensordict elif isinstance(self.create_env_fn[0], EnvBase): @@ -2394,6 +2634,15 @@ def _run_processes(self) -> None: 1, torch.get_num_threads() - total_workers ) # 1 more thread for this proc + # Initialize weight senders for multiprocess collectors + if self._weight_sync_schemes: + # Create one sender per model using scheme's factory method + for model_id, scheme in self._weight_sync_schemes.items(): + sender = scheme.create_sender() + sender._model_id = model_id + if hasattr(sender, "set_context"): + sender.set_context(self, model_id) + self._weight_senders[model_id] = sender torch.set_num_threads(self.num_threads) queue_out = mp.Queue(self._queue_len) # sends data from proc to main self.procs = [] @@ -2465,6 +2714,7 @@ def _run_processes(self) -> None: "postproc": self.postprocs if self.replay_buffer is not None else None, + "weight_sync_schemes": self._weight_sync_schemes, } proc = _ProcessNoWarn( target=_main_async_collector, @@ -2474,6 +2724,21 @@ def _run_processes(self) -> None: # proc.daemon can't be set as daemonic processes may be launched by the process itself try: proc.start() + except TypeError as err: + if "cannot pickle" in str(err): + raise RuntimeError( + "A non-serializable object was passed to the collector workers." + ) from err + except RuntimeError as err: + if "Cowardly refusing to serialize non-leaf tensor" in str(err): + raise RuntimeError( + "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. " + "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n" + "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n" + "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead." + ) from err + else: + raise err except _pickle.PicklingError as err: if "" in str(err): raise RuntimeError( @@ -2488,11 +2753,62 @@ def _run_processes(self) -> None: pipe_child.close() self.procs.append(proc) self.pipes.append(pipe_parent) - for pipe_parent in self.pipes: + + # Register worker with senders + if self._weight_senders: + for _, sender in self._weight_senders.items(): + sender.register_worker(i, pipe_parent) + + for i, pipe_parent in enumerate(self.pipes): pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) - msg = pipe_parent.recv() + try: + msg = pipe_parent.recv() + except EOFError as e: + raise RuntimeError( + f"Worker {i} failed to initialize and closed the connection before sending status. " + f"This typically indicates that the worker process crashed during initialization. " + f"Check the worker process logs for the actual error." + ) from e if msg != "instantiated": - raise RuntimeError(msg) + # Check if it's an error dict from worker + if isinstance(msg, dict) and msg.get("error"): + # Reconstruct the exception from the worker + exc_type_name = msg["exception_type"] + exc_msg = msg["exception_msg"] + traceback_str = msg["traceback"] + + # Try to get the actual exception class + exc_class = None + exc_module = msg["exception_module"] + + if exc_module == "builtins": + # Get from builtins + import builtins + + exc_class = getattr(builtins, exc_type_name, None) + else: + # Try to import from the module + try: + import importlib + + mod = importlib.import_module(exc_module) + exc_class = getattr(mod, exc_type_name, None) + except Exception: + pass + + # Re-raise with original exception type if possible + if exc_class is not None: + raise exc_class( + f"{exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Fall back to RuntimeError if we can't get the original type + raise RuntimeError( + f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Legacy string error message + raise RuntimeError(msg) self.queue_out = queue_out self.closed = False @@ -3057,6 +3373,7 @@ def iterator(self) -> Iterator[TensorDictBase]: msg = "continue_random" else: msg = "continue" + # Debug: sending 'continue' self.pipes[idx].send((None, msg)) self._iter += 1 @@ -3120,13 +3437,13 @@ def iterator(self) -> Iterator[TensorDictBase]: # mask buffers if cat, and create a mask if stack if cat_results != "stack": buffers = {} - for idx, buffer in self.buffers.items(): + for worker_idx, buffer in self.buffers.items(): valid = buffer.get(("collector", "traj_ids")) != -1 if valid.ndim > 2: valid = valid.flatten(0, -2) if valid.ndim == 2: valid = valid.any(0) - buffers[idx] = buffer[..., valid] + buffers[worker_idx] = buffer[..., valid] else: for buffer in self.buffers.values(): with buffer.unlock_(): @@ -3138,6 +3455,11 @@ def iterator(self) -> Iterator[TensorDictBase]: else: buffers = self.buffers + # Skip frame counting if this worker didn't send data this iteration + # (happens when reusing buffers or on first iteration with some workers) + if idx not in buffers: + continue + workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() if workers_frames[idx] >= self.total_frames: @@ -3156,7 +3478,7 @@ def iterator(self) -> Iterator[TensorDictBase]: # we have to correct the traj_ids to make sure that they don't overlap # We can count the number of frames collected for free in this loop n_collected = 0 - for idx in range(self.num_workers): + for idx in buffers.keys(): buffer = buffers[idx] traj_ids = buffer.get(("collector", "traj_ids")) if preempt: @@ -3775,6 +4097,7 @@ def _main_async_collector( policy_factory: Callable | None = None, collector_class: type | Callable[[], DataCollectorBase] | None = None, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ) -> None: if collector_class is None: collector_class = SyncDataCollector @@ -3782,40 +4105,77 @@ def _main_async_collector( # init variables that will be cleared when closing collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None - inner_collector = collector_class( - create_env_fn, - create_env_kwargs=create_env_kwargs, - policy=policy, - policy_factory=policy_factory, - total_frames=-1, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - reset_at_each_iter=reset_at_each_iter, - postproc=postproc, - split_trajs=False, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - exploration_type=exploration_type, - reset_when_done=reset_when_done, - return_same_td=replay_buffer is None, - interruptor=interruptor, - set_truncated=set_truncated, - use_buffers=use_buffers, - replay_buffer=replay_buffer if not extend_buffer else None, - extend_buffer=False, - traj_pool=traj_pool, - trust_policy=trust_policy, - compile_policy=compile_policy, - cudagraph_policy=cudagraph_policy, - no_cuda_sync=no_cuda_sync, - ) - use_buffers = inner_collector._use_buffers - if verbose: - torchrl_logger.info("Sync data collector created") - dc_iter = iter(inner_collector) - j = 0 - pipe_child.send("instantiated") + try: + collector_class._ignore_rb = extend_buffer + inner_collector = collector_class( + create_env_fn, + create_env_kwargs=create_env_kwargs, + policy=policy, + policy_factory=policy_factory, + total_frames=-1, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + postproc=postproc, + split_trajs=False, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + exploration_type=exploration_type, + reset_when_done=reset_when_done, + return_same_td=replay_buffer is None, + interruptor=interruptor, + set_truncated=set_truncated, + use_buffers=use_buffers, + replay_buffer=replay_buffer, + extend_buffer=False, + traj_pool=traj_pool, + trust_policy=trust_policy, + compile_policy=compile_policy, + cudagraph_policy=cudagraph_policy, + no_cuda_sync=no_cuda_sync, + weight_sync_schemes=weight_sync_schemes, + ) + + # Set up weight receivers for worker process + if weight_sync_schemes: + inner_collector._weight_receivers = {} + for model_id, scheme in weight_sync_schemes.items(): + receiver = scheme.create_receiver() + receiver.set_context(inner_collector) + receiver.register_worker_transport(pipe_child) + + model = _resolve_model(inner_collector, model_id) + receiver.register_model(model) + + inner_collector._weight_receivers[model_id] = receiver + else: + inner_collector._weight_receivers = {} + + use_buffers = inner_collector._use_buffers + if verbose: + torchrl_logger.info("Sync data collector created") + dc_iter = iter(inner_collector) + j = 0 + pipe_child.send("instantiated") + except Exception as e: + # Send error information to main process + # We send a dict with the exception info so we can recreate it in the main process + import traceback + + error_info = { + "error": True, + "exception_type": type(e).__name__, + "exception_module": type(e).__module__, + "exception_msg": str(e), + "traceback": traceback.format_exc(), + } + try: + pipe_child.send(error_info) + except Exception: + # If pipe is broken, nothing we can do + pass + return has_timed_out = False counter = 0 @@ -3888,7 +4248,81 @@ def _main_async_collector( data_in = None # TODO: this does not work with random frames msg = "continue" + # Note: The "continue" message handling has been moved below after update_weights handling + # to allow falling through from update_weights to continue + + if msg == "update": + torchrl_logger.info(f"worker {idx} updating the params...") + inner_collector.update_policy_weights_(policy_weights=data_in) + pipe_child.send((j, "updated")) + has_timed_out = False + continue + + if msg == "register_shared_weights": + # Shared memory lazy registration: main process sends buffer reference + if verbose: + torchrl_logger.info( + f"worker {idx} received shared memory buffer registration" + ) + model_id, shared_buffer = data_in + + # Store the shared buffer reference for this model + # The receiver will use this buffer for all future weight accesses + if ( + inner_collector._weight_receivers + and model_id in inner_collector._weight_receivers + ): + # Update receiver's buffer reference + receiver = inner_collector._weight_receivers[model_id] + # Store the shared buffer - the model's parameters should point to this + if hasattr(receiver, "_shared_weights"): + receiver._shared_weights[model_id] = shared_buffer + + # Apply the buffer to the model immediately + receiver.apply_weights(shared_buffer) + + if verbose: + torchrl_logger.info( + f"worker {idx} registered shared buffer for model '{model_id}'" + ) + else: + torchrl_logger.warning( + f"worker {idx} received shared buffer for unknown model '{model_id}'" + ) + + # Send acknowledgment back to main process + pipe_child.send((None, "registered")) + has_timed_out = False + continue + + if msg == "update_weights": + # New weight update protocol for simplified weight sync system + if verbose: + torchrl_logger.info( + f"worker {idx} received weight update via new protocol" + ) + model_id, weights = data_in + + # Apply weights using the appropriate receiver for this model + if ( + inner_collector._weight_receivers + and model_id in inner_collector._weight_receivers + ): + inner_collector._weight_receivers[model_id].apply_weights(weights) + else: + torchrl_logger.warning( + f"worker {idx} received weights for unknown model '{model_id}'" + ) + + # After applying weights, we continue collecting immediately as if we received + # a "continue" message. This ensures the worker keeps collecting data without + # waiting for an explicit continue from the main process. + has_timed_out = False + msg = "continue" + # Now check if we should continue collecting + if msg in ("continue", "continue_random"): + # This block handles both explicit continue messages and implicit ones after weight updates if msg == "continue_random": inner_collector.init_random_frames = float("inf") else: @@ -3980,14 +4414,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): has_timed_out = True continue - elif msg == "update": - torchrl_logger.info(f"worker {idx} updating the params...") - inner_collector.update_policy_weights_(policy_weights=data_in) - pipe_child.send((j, "updated")) - has_timed_out = False - continue - - elif msg == "seed": + if msg == "seed": data_in, static_seed = data_in new_seed = inner_collector.set_seed(data_in, static_seed=static_seed) torch.manual_seed(data_in) diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index df35a782d75..494af927f4e 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -37,6 +37,7 @@ from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator +from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme SUBMITIT_ERR = None try: @@ -461,6 +462,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ): if collector_class == "async": @@ -566,14 +568,61 @@ def __init__( self._init_workers() self._make_container() - if weight_updater is None: - weight_updater = DistributedWeightUpdater( - store=self._store, - policy_weights=self.policy_weights, - num_workers=self.num_workers, - sync=self._sync, + + # Set up weight synchronization - prefer new schemes over legacy updater + if weight_updater is None and weight_sync_schemes is None: + # Default to Distributed weight sync scheme for distributed collectors + from torchrl.weight_update.weight_sync_schemes import ( + DistributedWeightSyncScheme, ) - self.weight_updater = weight_updater + + weight_sync_schemes = { + "policy": DistributedWeightSyncScheme(backend=backend, sync=self._sync) + } + + if weight_sync_schemes is not None: + # Use new weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + + # Set up weight senders now that remote collectors exist + for model_id, scheme in self._weight_sync_schemes.items(): + sender = scheme.create_sender() + sender._model_id = model_id + + # Create transports for each remote collector + for i in range(self.num_workers): + rank = i + 1 # Workers are 1-indexed in distributed + transport = scheme.create_transport((self._store, rank)) + sender._transports[i] = transport + + # Set context and register model + if hasattr(sender, "set_context"): + sender.set_context(self, model_id) + + # Store reference to source model for automatic extraction + if ( + model_id == "policy" + and hasattr(self, "policy") + and self.policy is not None + ): + sender._source_model = self.policy + + self._weight_senders[model_id] = sender + + self.weight_updater = None + else: + # Fall back to legacy weight updater system + if weight_updater is None: + weight_updater = DistributedWeightUpdater( + store=self._store, + policy_weights=self.policy_weights, + num_workers=self.num_workers, + sync=self._sync, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} @property def device(self) -> list[torch.device]: @@ -928,6 +977,34 @@ def _next_async(self, total_frames, trackers): break return data, total_frames + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + For distributed collectors, when weights is None and we have a weight sync scheme, + extract fresh weights from the tracked policy model. + """ + scheme = ( + self._weight_sync_schemes.get(model_id) + if self._weight_sync_schemes + else None + ) + + if weights is None and scheme is not None: + # Extract fresh weights from the source model + sender = self._weight_senders.get(model_id) + if ( + sender + and hasattr(sender, "_source_model") + and sender._source_model is not None + ): + # For distributed collectors, we need TensorDict format for isend/irecv + from tensordict import TensorDict + + return TensorDict.from_module(sender._source_model).data.lock_() + + # Fall back to base class implementation + return super()._extract_weights_if_needed(weights, model_id) + def set_seed(self, seed: int, static_seed: bool = False) -> int: for i in range(self.num_workers): rank = i + 1 diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 97955c7c8b5..f2e00828dd3 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -28,6 +28,7 @@ from torchrl.data import ReplayBuffer from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator +from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme RAY_ERR = None try: @@ -258,11 +259,15 @@ class RayCollector(DataCollectorBase): .. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a :class:`~torchrl.data.RayReplayBuffer` instance should be used here. - weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` + weight_updater (WeightUpdaterBase or constructor, optional): (Deprecated) An instance of :class:`~torchrl.collectors.WeightUpdaterBase` or its subclass, responsible for updating the policy weights on remote inference workers managed by Ray. If not provided, a :class:`~torchrl.collectors.RayWeightUpdater` will be used by default, leveraging Ray's distributed capabilities. Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary mapping model identifiers to + :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` instances. + This is the recommended way to configure weight synchronization. If not provided, + defaults to ``{"policy": RayWeightSyncScheme()}``. Examples: >>> from torch import nn @@ -326,6 +331,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ): self.frames_per_batch = frames_per_batch if remote_configs is None: @@ -436,9 +442,9 @@ def check_list_length_consistency(*lists): if not isinstance(policy_factory, Sequence): policy_factory = [policy_factory] * len(create_env_fn) self.policy_factory = policy_factory - self._local_policy = policy - if isinstance(self._local_policy, nn.Module): - policy_weights = TensorDict.from_module(self._local_policy) + self.policy = policy # Store policy for weight extraction + if isinstance(policy, nn.Module): + policy_weights = TensorDict.from_module(policy) policy_weights = policy_weights.data.lock_() else: policy_weights = TensorDict(lock=True) @@ -476,7 +482,10 @@ def check_list_length_consistency(*lists): # update collector kwargs for i, collector_kwarg in enumerate(self.collector_kwargs): - collector_kwarg["policy_factory"] = policy_factory[i] + # Don't pass policy_factory if we have a policy - remote collectors need the policy object + # to be able to apply weight updates + if policy is None: + collector_kwarg["policy_factory"] = policy_factory[i] collector_kwarg["max_frames_per_traj"] = max_frames_per_traj collector_kwarg["init_random_frames"] = ( init_random_frames // self.num_collectors @@ -510,19 +519,84 @@ def check_list_length_consistency(*lists): collector_kwargs, remote_configs, ) - if weight_updater is None: - weight_updater = RayWeightUpdater( - policy_weights=policy_weights, - remote_collectors=self.remote_collectors, - max_interval=self.max_weight_update_interval, - ) - self.weight_updater = weight_updater + # Set up weight synchronization - prefer new schemes over legacy updater + if weight_updater is None and weight_sync_schemes is None: + # Default to Ray weight sync scheme for Ray collectors + from torchrl.weight_update.weight_sync_schemes import RayWeightSyncScheme - # Print info of all remote workers - pending_samples = [ - e.print_remote_collector_info.remote() for e in self.remote_collectors - ] - ray.wait(pending_samples) + weight_sync_schemes = {"policy": RayWeightSyncScheme()} + + if weight_sync_schemes is not None: + # Use new weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + + # Set up weight senders now that remote collectors exist + for model_id, scheme in self._weight_sync_schemes.items(): + sender = scheme.create_sender() + sender._model_id = model_id + + # Register each remote collector as a separate worker + # This follows the same pattern as multiprocess collectors + for worker_idx, remote_collector in enumerate(self.remote_collectors): + # Create a transport for this specific collector + # Pass the collector as context so the transport knows which one to talk to + sender.register_worker(worker_idx, remote_collector) + + # Set context and register model + if hasattr(sender, "set_context"): + sender.set_context(self, model_id) + + # Store reference to source model for automatic extraction + if model_id == "policy": + sender._source_model = self.policy + + self._weight_senders[model_id] = sender + + self.weight_updater = None # Don't use legacy system + else: + # Fall back to legacy weight updater system + if weight_updater is None: + weight_updater = RayWeightUpdater( + policy_weights=policy_weights, + remote_collectors=self.remote_collectors, + max_interval=self.max_weight_update_interval, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} + + # Print info of all remote workers (fire and forget - no need to wait) + for e in self.remote_collectors: + e.print_remote_collector_info.remote() + + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + For Ray collectors, when weights is None and we have a weight sync scheme, + extract fresh weights from the tracked policy model. + """ + scheme = ( + self._weight_sync_schemes.get(model_id) + if self._weight_sync_schemes + else None + ) + + if weights is None and scheme is not None: + # Extract fresh weights from the source model + sender = self._weight_senders.get(model_id) + if ( + sender + and hasattr(sender, "_source_model") + and sender._source_model is not None + ): + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + strategy = WeightStrategy(extract_as=scheme.strategy) + return strategy.extract_weights(sender._source_model) + + # Fall back to base class behavior + return super()._extract_weights_if_needed(weights, model_id) @property def num_workers(self): diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 5336b588232..9d2cf36c0cf 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -42,6 +42,7 @@ from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator +from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme SUBMITIT_ERR = None try: @@ -308,6 +309,7 @@ def __init__( weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ): if collector_class == "async": collector_class = MultiaSyncDataCollector @@ -413,15 +415,63 @@ def __init__( tensorpipe_options ) self._init() - if weight_updater is None: - weight_updater = RPCWeightUpdater( - collector_infos=self.collector_infos, - collector_class=self.collector_class, - collector_rrefs=self.collector_rrefs, - policy_weights=self.policy_weights, - num_workers=self.num_workers, - ) - self.weight_updater = weight_updater + + # Set up weight synchronization - prefer new schemes over legacy updater + if weight_updater is None and weight_sync_schemes is None: + # Default to RPC weight sync scheme for RPC collectors + from torchrl.weight_update.weight_sync_schemes import RPCWeightSyncScheme + + weight_sync_schemes = {"policy": RPCWeightSyncScheme()} + + if weight_sync_schemes is not None: + # Use new weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + + # Set up weight senders now that remote collectors exist + for model_id, scheme in self._weight_sync_schemes.items(): + sender = scheme.create_sender() + sender._model_id = model_id + + # Create transports for each remote collector + for i in range(self.num_workers): + transport = scheme.create_transport( + ( + self.collector_infos[i], + self.collector_rrefs[i], + self.collector_class, + ) + ) + sender._transports[i] = transport + + # Set context and register model + if hasattr(sender, "set_context"): + sender.set_context(self, model_id) + + # Store reference to source model for automatic extraction + if ( + model_id == "policy" + and hasattr(self, "policy") + and self.policy is not None + ): + sender._source_model = self.policy + + self._weight_senders[model_id] = sender + + self.weight_updater = None + else: + # Fall back to legacy weight updater system + if weight_updater is None: + weight_updater = RPCWeightUpdater( + collector_infos=self.collector_infos, + collector_class=self.collector_class, + collector_rrefs=self.collector_rrefs, + policy_weights=self.policy_weights, + num_workers=self.num_workers, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} @property def device(self) -> list[torch.device]: @@ -764,6 +814,34 @@ def _next_sync_rpc(self): self._collected_frames += data.numel() return data + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + For RPC collectors, when weights is None and we have a weight sync scheme, + extract fresh weights from the tracked policy model. + """ + scheme = ( + self._weight_sync_schemes.get(model_id) + if self._weight_sync_schemes + else None + ) + + if weights is None and scheme is not None: + # Extract fresh weights from the source model + sender = self._weight_senders.get(model_id) + if ( + sender + and hasattr(sender, "_source_model") + and sender._source_model is not None + ): + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + strategy = WeightStrategy(extract_as=scheme.strategy) + return strategy.extract_weights(sender._source_model) + + # Fall back to base class implementation + return super()._extract_weights_if_needed(weights, model_id) + def set_seed(self, seed: int, static_seed: bool = False) -> int: for worker in self.collector_infos: seed = rpc.rpc_sync(worker, self.collector_class.set_seed, args=(seed,))