1414import re
1515import string
1616from collections import defaultdict
17+ from contextlib import nullcontext
1718from functools import partial
1819from sys import platform
1920from typing import Any , Optional
3334 TensorDictBase ,
3435)
3536from tensordict .nn import TensorDictModuleBase
36- from tensordict .tensorclass import NonTensorStack , TensorClass
37+ from tensordict .tensorclass import NonTensorData , NonTensorStack , TensorClass
3738from tensordict .utils import _unravel_key_to_tuple
3839from torch import nn
3940
@@ -4630,6 +4631,7 @@ def __next__(self):
46304631 else :
46314632 return tensors
46324633
4634+ @pytest .mark .skipif (not _has_transformers , reason = "test requires transformers" )
46334635 @pytest .mark .parametrize (
46344636 "str2str,stack_method" ,
46354637 [
@@ -4674,22 +4676,36 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46744676 else :
46754677 env .check_env_specs (break_when_any_done = "both" )
46764678
4679+ @pytest .mark .skipif (not _has_transformers , reason = "test requires transformers" )
4680+ @pytest .mark .parametrize ("tokenizer" , [True , False ])
46774681 @pytest .mark .parametrize (
4678- "str2str,stack_method" ,
4682+ "str2str,no_stack, stack_method" ,
46794683 [
4680- [True , None ],
4681- [False , "as_padded_tensor" ],
4682- # TODO: a bit experimental, fails with check_env_specs
4683- # [False, "as_nested_tensor"],
4684- [False , None ],
4684+ [True , True , None ],
4685+ [True , False , None ],
4686+ [False , False , "as_padded_tensor" ],
4687+ [False , False , None ],
46854688 ],
46864689 )
46874690 @pytest .mark .parametrize ("batched" , [True , False ])
46884691 @pytest .mark .parametrize ("device" , [None , "cpu" ])
46894692 @pytest .mark .parametrize ("batch_size" , [0 , 4 ])
46904693 def test_llm_from_dataloader (
4691- self , str2str , batched , stack_method , device , batch_size
4694+ self ,
4695+ str2str ,
4696+ batched ,
4697+ stack_method ,
4698+ device ,
4699+ batch_size ,
4700+ tokenizer ,
4701+ no_stack ,
46924702 ):
4703+ from transformers import AutoTokenizer
4704+
4705+ if tokenizer :
4706+ tokenizer = AutoTokenizer .from_pretrained ("bert-base-uncased" )
4707+ else :
4708+ tokenizer = None
46934709 if str2str :
46944710 kwargs = {
46954711 "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
@@ -4712,7 +4728,8 @@ def test_llm_from_dataloader(
47124728 "str2str" : str2str ,
47134729 "device" : device ,
47144730 "has_attention" : False ,
4715- "no_stack" : False ,
4731+ "no_stack" : no_stack ,
4732+ "tokenizer" : tokenizer ,
47164733 }
47174734 )
47184735 env = LLMEnv .from_dataloader (** kwargs )
@@ -4725,12 +4742,17 @@ def test_llm_from_dataloader(
47254742 if batch_size > 0 :
47264743
47274744 def policy (td ):
4728- if str2str :
4745+ if str2str and tokenizer is None :
47294746 if not td .shape :
4730- td [LLMEnv ._DEFAULT_ACTION_STR_KEY ] = "<nothing>"
4747+ td [LLMEnv ._DEFAULT_ACTION_STR_KEY ] = NonTensorData (
4748+ "<nothing>" , device = device
4749+ )
47314750 else :
47324751 td [LLMEnv ._DEFAULT_ACTION_STR_KEY ] = NonTensorStack (
4733- * ["<nothing>" for _ in range (td .shape [0 ])]
4752+ * [
4753+ NonTensorData ("<nothing>" , device = device )
4754+ for _ in range (td .shape [0 ])
4755+ ]
47344756 )
47354757 else :
47364758 td [LLMEnv ._DEFAULT_ACTION_TOKENS_KEY ] = torch .ones (
@@ -4742,34 +4764,48 @@ def policy(td):
47424764 # Tell the env that we want 3 sub-envs
47434765 r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = [3 ]))
47444766 assert r .ndim == 2
4745- if str2str :
4767+ if str2str and tokenizer is None :
47464768 assert isinstance (r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ], str )
47474769 assert isinstance (r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ], str )
4748- assert (
4749- r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4750- == r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4751- : - len (r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4752- ]
4753- )
4754- assert (
4755- r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4756- == r [0 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4757- : - len (r [0 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4758- ]
4759- )
4760- assert (
4761- r [- 1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4762- == r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4763- : - len (r [- 1 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4764- ]
4765- )
4766- assert (
4767- r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4768- == r [- 1 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4769- : - len (r [- 1 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4770- ]
4771- )
4772- else :
4770+ should_fail = no_stack
4771+ if should_fail :
4772+ ctx = pytest .raises (AssertionError )
4773+ else :
4774+ ctx = nullcontext ()
4775+ with ctx :
4776+ assert (
4777+ r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4778+ == r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4779+ : - len (r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4780+ ]
4781+ ), (
4782+ r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ],
4783+ r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ],
4784+ r [0 , 0 ]["next" , LLMEnv ._DEFAULT_STR_KEY ],
4785+ r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ],
4786+ )
4787+ with ctx :
4788+ assert (
4789+ r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4790+ == r [0 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4791+ : - len (r [0 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4792+ ]
4793+ )
4794+ with ctx :
4795+ assert (
4796+ r [- 1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4797+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4798+ : - len (r [- 1 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4799+ ]
4800+ )
4801+ with ctx :
4802+ assert (
4803+ r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4804+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4805+ : - len (r [- 1 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4806+ ]
4807+ )
4808+ elif tokenizer is None :
47734809 assert (
47744810 r [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
47754811 == r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
0 commit comments