@@ -4644,11 +4644,13 @@ def __next__(self):
46444644 @pytest .mark .parametrize ("batch_size" , [0 , 4 ])
46454645 @pytest .mark .parametrize ("device" , [None , "cpu" ])
46464646 def test_llm_env (self , str2str , batched , stack_method , device , batch_size ):
4647- env = LLMEnv (str2str = str2str , device = device )
4647+ env = LLMEnv (
4648+ str2str = str2str , device = device , has_attention = False , no_stack = False
4649+ )
46484650 if str2str :
46494651 primer = DataLoadingPrimer (
46504652 dataloader = self .DummyDataLoader (batch_size = batch_size ),
4651- data_keys = ["observation" ],
4653+ data_keys = [LLMEnv . _DEFAULT_STR_KEY ],
46524654 example_data = "a string!" ,
46534655 )
46544656 else :
@@ -4658,7 +4660,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46584660 dataloader = self .DummyTensorDataLoader (
46594661 batch_size = batch_size , padding = True
46604662 ),
4661- data_keys = ["observation" ],
4663+ data_keys = [LLMEnv . _DEFAULT_TOKEN_KEY ],
46624664 data_specs = [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
46634665 stack_method = stack_method ,
46644666 )
@@ -4668,7 +4670,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46684670 if batched :
46694671 td = env .reset (TensorDict (batch_size = [3 ]))
46704672 env .check_env_specs (break_when_any_done = "both" , tensordict = td )
4671- r = env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
4673+ env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
46724674 else :
46734675 env .check_env_specs (break_when_any_done = "both" )
46744676
@@ -4691,7 +4693,7 @@ def test_llm_from_dataloader(
46914693 if str2str :
46924694 kwargs = {
46934695 "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4694- "data_keys" : ["observation" ],
4696+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
46954697 "example_data" : "a string!" ,
46964698 }
46974699 else :
@@ -4701,11 +4703,18 @@ def test_llm_from_dataloader(
47014703 "dataloader" : self .DummyTensorDataLoader (
47024704 padding = True , batch_size = batch_size
47034705 ),
4704- "data_keys" : ["observation" ],
4706+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
47054707 "data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
47064708 "stack_method" : stack_method ,
47074709 }
4708- kwargs .update ({"str2str" : str2str , "device" : device })
4710+ kwargs .update (
4711+ {
4712+ "str2str" : str2str ,
4713+ "device" : device ,
4714+ "has_attention" : False ,
4715+ "no_stack" : False ,
4716+ }
4717+ )
47094718 env = LLMEnv .from_dataloader (** kwargs )
47104719 assert not env .batch_locked
47114720 if batched :
@@ -4718,46 +4727,64 @@ def test_llm_from_dataloader(
47184727 def policy (td ):
47194728 if str2str :
47204729 if not td .shape :
4721- td ["action" ] = "<nothing>"
4730+ td [LLMEnv . _DEFAULT_ACTION_STR_KEY ] = "<nothing>"
47224731 else :
4723- td ["action" ] = NonTensorStack (
4732+ td [LLMEnv . _DEFAULT_ACTION_STR_KEY ] = NonTensorStack (
47244733 * ["<nothing>" for _ in range (td .shape [0 ])]
47254734 )
47264735 else :
4727- td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4736+ td [LLMEnv ._DEFAULT_ACTION_TOKENS_KEY ] = torch .ones (
4737+ td .shape + (1 ,), dtype = torch .int64
4738+ )
47284739 return td
47294740
47304741 if batched :
47314742 # Tell the env that we want 3 sub-envs
47324743 r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = [3 ]))
47334744 assert r .ndim == 2
47344745 if str2str :
4735- assert isinstance (r [0 , 0 ]["observation" ], str )
4736- assert isinstance (r [0 , 1 ]["observation" ], str )
4746+ assert isinstance (r [0 , 0 ][LLMEnv . _DEFAULT_STR_KEY ], str )
4747+ assert isinstance (r [0 , 1 ][LLMEnv . _DEFAULT_STR_KEY ], str )
47374748 assert (
4738- r [0 , 0 ]["observation" ]
4739- == r [0 , 1 ]["observation" ][: - len (r [0 , 0 ]["action" ])]
4749+ r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4750+ == r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4751+ : - len (r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4752+ ]
47404753 )
47414754 assert (
4742- r [0 , 1 ]["observation" ]
4743- == r [0 , 2 ]["observation" ][: - len (r [0 , 1 ]["action" ])]
4755+ r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4756+ == r [0 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4757+ : - len (r [0 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4758+ ]
47444759 )
47454760 assert (
4746- r [- 1 , 0 ]["observation" ]
4747- == r [- 1 , 1 ]["observation" ][: - len (r [- 1 , 0 ]["action" ])]
4761+ r [- 1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4762+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4763+ : - len (r [- 1 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4764+ ]
47484765 )
47494766 assert (
4750- r [- 1 , 1 ]["observation" ]
4751- == r [- 1 , 2 ]["observation" ][: - len (r [- 1 , 1 ]["action" ])]
4767+ r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4768+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4769+ : - len (r [- 1 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4770+ ]
47524771 )
47534772 else :
4754- assert (r [0 , 0 ]["observation" ] == r [0 , 1 ]["observation" ][:- 1 ]).all ()
4755- assert (r [0 , 1 ]["observation" ] == r [0 , 2 ]["observation" ][:- 1 ]).all ()
47564773 assert (
4757- r [- 1 , 0 ]["observation" ] == r [- 1 , 1 ]["observation" ][:- 1 ]
4774+ r [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4775+ == r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4776+ ).all ()
4777+ assert (
4778+ r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4779+ == r [0 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
47584780 ).all ()
47594781 assert (
4760- r [- 1 , 1 ]["observation" ] == r [- 1 , 2 ]["observation" ][:- 1 ]
4782+ r [- 1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4783+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4784+ ).all ()
4785+ assert (
4786+ r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4787+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
47614788 ).all ()
47624789 else :
47634790 r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = []))
@@ -4783,7 +4810,7 @@ def test_llm_from_dataloader_repeats(
47834810 if str2str :
47844811 kwargs = {
47854812 "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4786- "data_keys" : ["observation" ],
4813+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
47874814 "example_data" : "a string!" ,
47884815 "repeats" : repeats ,
47894816 }
@@ -4794,12 +4821,19 @@ def test_llm_from_dataloader_repeats(
47944821 "dataloader" : self .DummyTensorDataLoader (
47954822 padding = True , batch_size = batch_size
47964823 ),
4797- "data_keys" : ["observation" ],
4824+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
47984825 "data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
47994826 "stack_method" : stack_method ,
48004827 "repeats" : repeats ,
48014828 }
4802- kwargs .update ({"str2str" : str2str , "device" : device })
4829+ kwargs .update (
4830+ {
4831+ "str2str" : str2str ,
4832+ "device" : device ,
4833+ "has_attention" : False ,
4834+ "no_stack" : False ,
4835+ }
4836+ )
48034837 env = LLMEnv .from_dataloader (** kwargs )
48044838 assert env .transform .repeats == repeats
48054839
@@ -4809,13 +4843,15 @@ def test_llm_from_dataloader_repeats(
48094843 def policy (td ):
48104844 if str2str :
48114845 if not td .shape :
4812- td ["action" ] = "<nothing>"
4846+ td [LLMEnv . _DEFAULT_ACTION_STR_KEY ] = "<nothing>"
48134847 else :
4814- td ["action" ] = NonTensorStack (
4848+ td [LLMEnv . _DEFAULT_ACTION_STR_KEY ] = NonTensorStack (
48154849 * ["<nothing>" for _ in range (td .shape [0 ])]
48164850 )
48174851 else :
4818- td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4852+ td [LLMEnv ._DEFAULT_ACTION_TOKENS_KEY ] = torch .ones (
4853+ td .shape + (1 ,), dtype = torch .int64
4854+ )
48194855 return td
48204856
48214857 if batched :
@@ -4831,34 +4867,58 @@ def policy(td):
48314867 r_reset = r [..., ::max_steps ]
48324868 if not batched :
48334869 if str2str :
4834- assert r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4835- assert r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4836- assert r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4870+ assert (
4871+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4872+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4873+ )
4874+ assert (
4875+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4876+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_STR_KEY ]
4877+ )
4878+ assert (
4879+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4880+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_STR_KEY ]
4881+ )
48374882 else :
48384883 assert (
4839- r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4884+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4885+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48404886 ).all ()
48414887 assert (
4842- r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4888+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4889+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48434890 ).all ()
48444891 assert (
4845- r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4892+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4893+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48464894 ).any ()
48474895 else :
48484896 # When batched, each block contains the 3 reset packs
48494897 if str2str :
4850- assert r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4851- assert r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4852- assert r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4898+ assert (
4899+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4900+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4901+ )
4902+ assert (
4903+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4904+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4905+ )
4906+ assert (
4907+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4908+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4909+ )
48534910 else :
48544911 assert (
4855- r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4912+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4913+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48564914 ).all ()
48574915 assert (
4858- r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4916+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4917+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48594918 ).all ()
48604919 assert (
4861- r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4920+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4921+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48624922 ).any ()
48634923
48644924 @pytest .mark .parametrize (
@@ -4892,7 +4952,7 @@ def test_done_and_reward(
48924952 if str2str :
48934953 kwargs = {
48944954 "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4895- "data_keys" : ["observation" ],
4955+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
48964956 "example_data" : "a string!" ,
48974957 "repeats" : repeats ,
48984958 "assign_reward" : assign_reward ,
@@ -4905,20 +4965,27 @@ def test_done_and_reward(
49054965 "dataloader" : self .DummyTensorDataLoader (
49064966 padding = True , batch_size = batch_size
49074967 ),
4908- "data_keys" : ["observation" ],
4968+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
49094969 "data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
49104970 "stack_method" : stack_method ,
49114971 "repeats" : repeats ,
49124972 "assign_reward" : assign_reward ,
49134973 "assign_done" : assign_done ,
49144974 }
4915- kwargs .update ({"str2str" : str2str , "device" : device })
4975+ kwargs .update (
4976+ {
4977+ "str2str" : str2str ,
4978+ "device" : device ,
4979+ "has_attention" : False ,
4980+ "no_stack" : False ,
4981+ }
4982+ )
49164983 env = LLMEnv .from_dataloader (** kwargs )
49174984 # We want to make sure that transforms that rely on the done state work appropriately
49184985 env .append_transform (StepCounter (max_steps = 10 ))
49194986
49204987 def policy (td ):
4921- td ["action" ] = torch .ones (
4988+ td [LLMEnv . _DEFAULT_ACTION_TOKENS_KEY ] = torch .ones (
49224989 td .shape + (torch .randint (10 , (1 ,)).item (),), dtype = torch .int64
49234990 )
49244991 return td
0 commit comments