55from __future__ import annotations
66
77import argparse
8+ import asyncio
89import contextlib
910import functools
1011import gc
11-
1212import importlib
1313import os
1414import subprocess
5252 MultiSyncDataCollector ,
5353)
5454
55- from torchrl .collectors .llm_collector import LLMCollector
55+ from torchrl .collectors .llm import LLMCollector
5656from torchrl .collectors .utils import split_trajectories
5757from torchrl .data import (
5858 Composite ,
@@ -3391,11 +3391,11 @@ def test_collector_rb_sync(self):
33913391 assert assert_allclose_td (rbdata0 , rbdata1 )
33923392
33933393 @pytest .mark .skipif (not _has_gym , reason = "requires gym." )
3394- @pytest .mark .parametrize ("replay_buffer_chunk " , [False , True ])
3394+ @pytest .mark .parametrize ("extend_buffer " , [False , True ])
33953395 @pytest .mark .parametrize ("env_creator" , [False , True ])
33963396 @pytest .mark .parametrize ("storagetype" , [LazyTensorStorage , LazyMemmapStorage ])
33973397 def test_collector_rb_multisync (
3398- self , replay_buffer_chunk , env_creator , storagetype , tmpdir
3398+ self , extend_buffer , env_creator , storagetype , tmpdir
33993399 ):
34003400 if not env_creator :
34013401 env = GymEnv (CARTPOLE_VERSIONED ()).append_transform (StepCounter ())
@@ -3420,7 +3420,7 @@ def test_collector_rb_multisync(
34203420 replay_buffer = rb ,
34213421 total_frames = 256 ,
34223422 frames_per_batch = 32 ,
3423- replay_buffer_chunk = replay_buffer_chunk ,
3423+ extend_buffer = extend_buffer ,
34243424 )
34253425 torch .manual_seed (0 )
34263426 pred_len = 0
@@ -3430,7 +3430,7 @@ def test_collector_rb_multisync(
34303430 assert len (rb ) == pred_len
34313431 collector .shutdown ()
34323432 assert len (rb ) == 256
3433- if not replay_buffer_chunk :
3433+ if not extend_buffer :
34343434 steps_counts = rb ["step_count" ].squeeze ().split (16 )
34353435 collector_ids = rb ["collector" , "traj_ids" ].squeeze ().split (16 )
34363436 for step_count , ids in zip (steps_counts , collector_ids ):
@@ -3442,11 +3442,11 @@ def test_collector_rb_multisync(
34423442 assert (idsdiff >= 0 ).all ()
34433443
34443444 @pytest .mark .skipif (not _has_gym , reason = "requires gym." )
3445- @pytest .mark .parametrize ("replay_buffer_chunk " , [False , True ])
3445+ @pytest .mark .parametrize ("extend_buffer " , [False , True ])
34463446 @pytest .mark .parametrize ("env_creator" , [False , True ])
34473447 @pytest .mark .parametrize ("storagetype" , [LazyTensorStorage , LazyMemmapStorage ])
34483448 def test_collector_rb_multiasync (
3449- self , replay_buffer_chunk , env_creator , storagetype , tmpdir
3449+ self , extend_buffer , env_creator , storagetype , tmpdir
34503450 ):
34513451 if not env_creator :
34523452 env = GymEnv (CARTPOLE_VERSIONED ()).append_transform (StepCounter ())
@@ -3471,7 +3471,7 @@ def test_collector_rb_multiasync(
34713471 replay_buffer = rb ,
34723472 total_frames = 256 ,
34733473 frames_per_batch = 16 ,
3474- replay_buffer_chunk = replay_buffer_chunk ,
3474+ extend_buffer = extend_buffer ,
34753475 )
34763476 torch .manual_seed (0 )
34773477 pred_len = 0
@@ -3481,7 +3481,7 @@ def test_collector_rb_multiasync(
34813481 assert len (rb ) >= pred_len
34823482 collector .shutdown ()
34833483 assert len (rb ) == 256
3484- if not replay_buffer_chunk :
3484+ if not extend_buffer :
34853485 steps_counts = rb ["step_count" ].squeeze ().split (16 )
34863486 collector_ids = rb ["collector" , "traj_ids" ].squeeze ().split (16 )
34873487 for step_count , ids in zip (steps_counts , collector_ids ):
@@ -3575,6 +3575,18 @@ def vllm_instance(self):
35753575 tokenizer .pad_token = tokenizer .eos_token
35763576 return llm_model
35773577
3578+ @pytest .fixture (scope = "module" )
3579+ def vllm_instance_opt (self ):
3580+ try :
3581+ import vllm
3582+ except ImportError :
3583+ pytest .skip (reason = "missing vllm" )
3584+
3585+ llm_model = vllm .LLM ("facebook/opt-125m" )
3586+ tokenizer = llm_model .get_tokenizer ()
3587+ tokenizer .pad_token = tokenizer .eos_token
3588+ return llm_model
3589+
35783590 @pytest .fixture (scope = "module" )
35793591 def transformers_instance (self ):
35803592 from transformers import AutoTokenizer , GPT2Config , GPT2LMHeadModel
@@ -3618,12 +3630,11 @@ def test_llm_collector_with_transformers(
36183630 self ._run_collector_test (total_steps , rb , policy , tokenizer )
36193631
36203632 def _run_collector_test (self , total_steps , rb , policy , tokenizer ):
3621- bsz = 1
3633+ bsz = 4
36223634 dataloader = DummyStrDataLoader (bsz )
36233635
36243636 env = LLMEnv .from_dataloader (
36253637 dataloader = dataloader ,
3626- tokenizer = tokenizer ,
36273638 str2str = True ,
36283639 batch_size = bsz ,
36293640 group_repeats = True ,
@@ -3650,15 +3661,142 @@ def _run_collector_test(self, total_steps, rb, policy, tokenizer):
36503661
36513662 if rb is not None :
36523663 # Now check the buffer
3653- assert len (rb ) == total_steps
3654- sample = rb .sample (1 )
3664+ assert len (rb ) >= total_steps
3665+ sample = rb .sample (4 )
3666+ assert sample .shape == (4 ,)
3667+ assert not sample ._has_exclusive_keys
36553668 # Should match length
3656- assert len (sample ["text" ]) == 1
3669+ assert len (sample ["text" ]) == 4
3670+ # assert len(sample["text"][0]) == 10, sample["text"][0]
36573671 # Should be non-empty
36583672 assert sample ["text_response" ] is not None
3673+ for i in range (4 ):
3674+ # Check that there are more chars in the next step
3675+ assert len (sample ["text" ][i ]) < len (sample ["next" , "text" ][i ])
36593676 else :
36603677 stack = torch .cat (stack )
3661- assert stack .numel () == total_steps
3678+ assert not stack ._has_exclusive_keys
3679+ assert stack .numel () == max (- (total_steps // - 4 ) * 4 , 4 )
3680+ stack = stack .view (- 1 )
3681+ for i in range (stack .numel ()):
3682+ # Check that there are more chars in the next step
3683+ assert len (stack ["text" ][i ]) < len (stack ["next" , "text" ][i ])
3684+ assert collector ._frames >= total_steps
3685+
3686+ def test_llm_collector_start (self , vllm_instance ):
3687+ asyncio .run (self ._async_run_collector_test (vllm_instance ))
3688+
3689+ async def _async_run_collector_test (self , vllm_instance ):
3690+ total_steps = 20
3691+ policy = vLLMWrapper (vllm_instance )
3692+ vllm_instance .get_tokenizer ()
3693+ bsz = 4
3694+ dataloader = DummyStrDataLoader (bsz )
3695+
3696+ env = LLMEnv .from_dataloader (
3697+ dataloader = dataloader ,
3698+ str2str = True ,
3699+ batch_size = bsz ,
3700+ group_repeats = True ,
3701+ )
3702+
3703+ rb = ReplayBuffer (storage = LazyStackStorage (max_size = total_steps * 2 ))
3704+ collector = LLMCollector (
3705+ env = env ,
3706+ policy_factory = lambda : policy ,
3707+ steps_per_batch = env .batch_size [0 ],
3708+ replay_buffer = rb ,
3709+ total_steps = total_steps ,
3710+ )
3711+ collector .start ()
3712+
3713+ i = 0
3714+ wait = 0
3715+ while True :
3716+ while not len (rb ):
3717+ await asyncio .sleep (1 ) # Use asyncio.sleep instead of time.sleep
3718+ wait += 1
3719+ if wait > 20 :
3720+ raise RuntimeError
3721+ sample = rb .sample (10 )
3722+ for i in range (sample .numel ()):
3723+ # Check that there are more chars in the next step
3724+ assert len (sample ["text" ][i ]) < len (sample ["next" , "text" ][i ])
3725+ assert not sample ._has_exclusive_keys , sample
3726+ await asyncio .sleep (0.1 ) # Use asyncio.sleep instead of time.sleep
3727+ i += 1
3728+ if i == 5 :
3729+ break
3730+ assert collector ._frames >= total_steps
3731+
3732+ await collector .async_shutdown ()
3733+
3734+ @pytest .mark .slow
3735+ @pytest .mark .parametrize ("rb" , [False , True ])
3736+ @pytest .mark .parametrize ("yield_only_last_steps" , [False , True ])
3737+ def test_llm_collector_completed (
3738+ self , vllm_instance_opt , rb , yield_only_last_steps
3739+ ):
3740+ policy = vLLMWrapper (vllm_instance_opt )
3741+ tokenizer = vllm_instance_opt .get_tokenizer ()
3742+ bsz = 4
3743+ total_steps = 20
3744+ dataloader = DummyStrDataLoader (bsz )
3745+
3746+ env = LLMEnv .from_dataloader (
3747+ dataloader = dataloader ,
3748+ str2str = True ,
3749+ batch_size = bsz ,
3750+ group_repeats = True ,
3751+ eos_token_id = tokenizer .eos_token_id ,
3752+ )
3753+ # To make sure the env breaks at some point
3754+ env = env .append_transform (StepCounter (max_steps = 100 ))
3755+
3756+ if rb :
3757+ rb = ReplayBuffer (storage = LazyStackStorage (max_size = total_steps * 2 ))
3758+ else :
3759+ rb = None
3760+ collector = LLMCollector (
3761+ env = env ,
3762+ policy_factory = lambda : policy ,
3763+ steps_per_batch = env .batch_size [0 ],
3764+ replay_buffer = rb ,
3765+ total_steps = total_steps ,
3766+ yield_completed_trajectories = True ,
3767+ yield_only_last_steps = yield_only_last_steps ,
3768+ )
3769+ assert collector .yield_completed_trajectories
3770+ assert collector .yield_only_last_steps is yield_only_last_steps
3771+
3772+ cur_total_steps = 0
3773+ has_found_one_with_more_steps = False
3774+ for data in collector :
3775+ if rb is None :
3776+ assert data .ndim == 1
3777+ assert (data ["next" , "step_count" ] < 99 ).all ()
3778+ cur_total_steps += data .numel ()
3779+ for i in range (data .numel ()):
3780+ # Check that there are more chars in the next step
3781+ assert len (data ["text" ][i ]) < len (data ["next" , "text" ][i ])
3782+ if yield_only_last_steps :
3783+ assert data .shape == (1 ,)
3784+ else :
3785+ has_found_one_with_more_steps |= data .numel () > 1
3786+ else :
3787+ assert data is None
3788+ sample = rb .sample (5 )
3789+ for i in range (sample .numel ()):
3790+ # Check that there are more chars in the next step
3791+ assert len (sample ["text" ][i ]) < len (sample ["next" , "text" ][i ])
3792+ assert sample .ndim == 1
3793+ assert sample .shape == (5 ,)
3794+ assert (sample ["next" , "step_count" ] < 99 ).all ()
3795+ cur_total_steps += 1
3796+ assert collector ._frames >= cur_total_steps
3797+ if rb is None and not yield_only_last_steps :
3798+ assert has_found_one_with_more_steps
3799+ assert collector ._frames >= total_steps
36623800
36633801
36643802if __name__ == "__main__" :
0 commit comments