@@ -35,11 +35,11 @@ class LLMCollector(SyncDataCollector):
3535
3636 .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
3737
38- steps_per_batch (int): A keyword-only argument representing the total
39- number of elements in a batch; -1 is never ending (until shutdown) .
40- total_steps (int): A keyword-only argument representing the total
41- number of steps returned by the collector
42- during its lifespan .
38+ dialog_turns_per_batch (int, optional ): A keyword-only argument representing the total
39+ number of elements in a batch. It is always required except when `yield_completed_trajectories=True` .
40+ total_dialog_turns (int): A keyword-only argument representing the total
41+ number of steps returned by the collector during its lifespan. -1 is never ending (until shutdown).
42+ Defaults to -1 .
4343 yield_completed_trajectories (bool, optional): whether to yield batches of rollouts with a given number of steps
4444 (`yield_completed_trajectories=False`, default) or single, completed trajectories
4545 (`yield_completed_trajectories=True`).
@@ -149,7 +149,7 @@ def __init__(
149149 policy : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
150150 policy_factory : Callable [[], Callable [[TensorDictBase ], TensorDictBase ]]
151151 | None = None ,
152- dialog_turns_per_batch : int ,
152+ dialog_turns_per_batch : int | None = None ,
153153 yield_only_last_steps : bool | None = None ,
154154 yield_completed_trajectories : bool | None = None ,
155155 postproc : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
@@ -172,6 +172,8 @@ def __init__(
172172 elif queue is not None :
173173 # disguise the queue as a replay buffer
174174 replay_buffer = _QueueAsRB (queue )
175+ if dialog_turns_per_batch is None and yield_completed_trajectories :
176+ dialog_turns_per_batch = 0
175177 super ().__init__ (
176178 create_env_fn = env ,
177179 policy = policy ,
0 commit comments