Skip to content

Commit 8374bb4

Browse files
authored
[BugFix] Fix wrong assertion about collector and buffer (#3176)
1 parent 5f1eb2c commit 8374bb4

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchrl/trainers/algorithms/configs/trainers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer:
101101
elif replay_buffer is not None:
102102
collector = collector(replay_buffer=replay_buffer)
103103
elif getattr(collector, "replay_buffer", None) is None:
104-
if collector.replay_buffer is None or replay_buffer is None:
104+
if async_collection and (
105+
collector.replay_buffer is None or replay_buffer is None
106+
):
105107
raise ValueError(
106108
"replay_buffer must be provided when async_collection is True"
107109
)
@@ -230,7 +232,7 @@ def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer:
230232
collector = collector()
231233
else:
232234
collector = collector(replay_buffer=replay_buffer)
233-
elif getattr(collector, "replay_buffer", None) is None:
235+
elif async_collection and getattr(collector, "replay_buffer", None) is None:
234236
raise RuntimeError(
235237
"replay_buffer must be provided when async_collection is True"
236238
)

0 commit comments

Comments
 (0)