55
66from __future__ import annotations
77
8+ import asyncio
89import warnings
910from typing import Callable , Iterator , OrderedDict
1011
2122 SyncDataCollector ,
2223)
2324from torchrl .collectors .utils import _NON_NN_POLICY_WEIGHTS , split_trajectories
25+ from torchrl .data import ReplayBuffer
2426from torchrl .envs .common import EnvBase
2527from torchrl .envs .env_creator import EnvCreator
2628
@@ -256,6 +258,11 @@ class RayCollector(DataCollectorBase):
256258 parameters being updated for a certain time even if ``update_after_each_batch``
257259 is turned on.
258260 Defaults to -1 (no forced update).
261+ replay_buffer (RayReplayBuffer, optional): if provided, the collector will not yield tensordicts
262+ but populate the buffer instead. Defaults to ``None``.
263+
264+ .. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a
265+ :class:`~torchrl.data.RayReplayBuffer` instance should be used here.
259266
260267 Examples:
261268 >>> from torch import nn
@@ -312,7 +319,9 @@ def __init__(
312319 num_collectors : int = None ,
313320 update_after_each_batch = False ,
314321 max_weight_update_interval = - 1 ,
322+ replay_buffer : ReplayBuffer = None ,
315323 ):
324+ self .frames_per_batch = frames_per_batch
316325 if remote_configs is None :
317326 remote_configs = DEFAULT_REMOTE_CLASS_CONFIG
318327
@@ -321,6 +330,14 @@ def __init__(
321330
322331 if collector_kwargs is None :
323332 collector_kwargs = {}
333+ if replay_buffer is not None :
334+ if isinstance (collector_kwargs , dict ):
335+ collector_kwargs .setdefault ("replay_buffer" , replay_buffer )
336+ else :
337+ collector_kwargs = [
338+ ck .setdefault ("replay_buffer" , replay_buffer )
339+ for ck in collector_kwargs
340+ ]
324341
325342 # Make sure input parameters are consistent
326343 def check_consistency_with_num_collectors (param , param_name , num_collectors ):
@@ -386,7 +403,8 @@ def check_list_length_consistency(*lists):
386403 raise RuntimeError (
387404 "ray library not found, unable to create a DistributedCollector. "
388405 ) from RAY_ERR
389- ray .init (** ray_init_config )
406+ if not ray .is_initialized ():
407+ ray .init (** ray_init_config )
390408 if not ray .is_initialized ():
391409 raise RuntimeError ("Ray could not be initialized." )
392410
@@ -400,6 +418,7 @@ def check_list_length_consistency(*lists):
400418 collector_class .as_remote = as_remote
401419 collector_class .print_remote_collector_info = print_remote_collector_info
402420
421+ self .replay_buffer = replay_buffer
403422 self ._local_policy = policy
404423 if isinstance (self ._local_policy , nn .Module ):
405424 policy_weights = TensorDict .from_module (self ._local_policy )
@@ -557,7 +576,7 @@ def add_collectors(
557576 policy ,
558577 other_params ,
559578 )
560- self ._remote_collectors .extend ([ collector ] )
579+ self ._remote_collectors .append ( collector )
561580
562581 def local_policy (self ):
563582 """Returns local collector."""
@@ -577,17 +596,33 @@ def stop_remote_collectors(self):
577596 ) # This will interrupt any running tasks on the actor, causing them to fail immediately
578597
579598 def iterator (self ):
599+ def proc (data ):
600+ if self .split_trajs :
601+ data = split_trajectories (data )
602+ if self .postproc is not None :
603+ data = self .postproc (data )
604+ return data
605+
580606 if self ._sync :
581- data = self ._sync_iterator ()
607+ meth = self ._sync_iterator
582608 else :
583- data = self ._async_iterator ()
609+ meth = self ._async_iterator
610+ yield from (proc (data ) for data in meth ())
584611
585- if self .split_trajs :
586- data = split_trajectories (data )
587- if self .postproc is not None :
588- data = self .postproc (data )
612+ async def _asyncio_iterator (self ):
613+ def proc (data ):
614+ if self .split_trajs :
615+ data = split_trajectories (data )
616+ if self .postproc is not None :
617+ data = self .postproc (data )
618+ return data
589619
590- return data
620+ if self ._sync :
621+ for d in self ._sync_iterator ():
622+ yield proc (d )
623+ else :
624+ for d in self ._async_iterator ():
625+ yield proc (d )
591626
592627 def _sync_iterator (self ) -> Iterator [TensorDictBase ]:
593628 """Collects one data batch per remote collector in each iteration."""
@@ -634,7 +669,30 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]:
634669 ):
635670 self .update_policy_weights_ (rank )
636671
637- self .shutdown ()
672+ if self ._task is None :
673+ self .shutdown ()
674+
675+ _task = None
676+
677+ def start (self ):
678+ """Starts the RayCollector."""
679+ if self .replay_buffer is None :
680+ raise RuntimeError ("Replay buffer must be defined for asyncio execution." )
681+ if self ._task is None or self ._task .done ():
682+ loop = asyncio .get_event_loop ()
683+ self ._task = loop .create_task (self ._run_iterator_silently ())
684+
685+ async def _run_iterator_silently (self ):
686+ async for _ in self ._asyncio_iterator ():
687+ # Process each item silently
688+ continue
689+
690+ async def async_shutdown (self ):
691+ """Finishes processes started by ray.init() during async execution."""
692+ if self ._task is not None :
693+ await self ._task
694+ self .stop_remote_collectors ()
695+ ray .shutdown ()
638696
639697 def _async_iterator (self ) -> Iterator [TensorDictBase ]:
640698 """Collects a data batch from a single remote collector in each iteration."""
@@ -658,7 +716,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
658716 ray .internal .free (
659717 [future ]
660718 ) # should not be necessary, deleted automatically when ref count is down to 0
661- self .collected_frames += out_td . numel ()
719+ self .collected_frames += self . frames_per_batch
662720
663721 yield out_td
664722
@@ -689,8 +747,8 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
689747 # object_ref=ref,
690748 # force=False,
691749 # )
692-
693- self .shutdown ()
750+ if self . _task is None :
751+ self .shutdown ()
694752
695753 def update_policy_weights_ (self , worker_rank = None ) -> None :
696754 """Updates the weights of the worker nodes.
0 commit comments