1010
1111import pytest
1212import torch
13- from tensordict import LazyStackedTensorDict , NonTensorStack , TensorDict
13+ from tensordict import (
14+ lazy_stack ,
15+ LazyStackedTensorDict ,
16+ NonTensorStack ,
17+ set_list_to_stack ,
18+ TensorDict ,
19+ )
1420from tensordict .nn import CompositeDistribution , TensorDictModule
1521from tensordict .nn .distributions import NormalParamExtractor
1622
@@ -937,6 +943,38 @@ def vllm_instance(self):
937943 tokenizer .pad_token = tokenizer .eos_token
938944 return llm_model
939945
946+ @pytest .fixture (scope = "module" )
947+ def transformers_instance (self ):
948+ from transformers import AutoTokenizer , GPT2Config , GPT2LMHeadModel
949+
950+ tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
951+ model = GPT2LMHeadModel (GPT2Config ()).eval ()
952+ # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
953+ # model = OPTModel(OPTConfig("facebook/opt-125m"))
954+ # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
955+ # model = OPTForCausalLM(OPTConfig())
956+
957+ tokenizer .pad_token = tokenizer .eos_token
958+ tokenizer .padding_side = "left"
959+
960+ return model , tokenizer
961+
962+ @pytest .fixture (scope = "module" )
963+ def transformers_instance_pretrained (self ):
964+ from transformers import AutoTokenizer , OPTForCausalLM
965+
966+ # tokenizer = AutoTokenizer.from_pretrained("gpt2")
967+ # model = GPT2LMHeadModel(GPT2Config())
968+ # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
969+ # model = OPTModel(OPTConfig("facebook/opt-125m"))
970+ tokenizer = AutoTokenizer .from_pretrained ("facebook/opt-125m" )
971+ model = OPTForCausalLM .from_pretrained ("facebook/opt-125m" )
972+
973+ tokenizer .pad_token = tokenizer .eos_token
974+ tokenizer .padding_side = "left"
975+
976+ return model , tokenizer
977+
940978 @pytest .mark .parametrize (
941979 "from_text, generate, return_log_probs, tokens, attention_mask" ,
942980 [
@@ -961,22 +999,18 @@ def vllm_instance(self):
961999 (False , True , False , torch .randint (1024 , (1 , 10 )), None ),
9621000 ],
9631001 )
964- def test_TransformersWrapper (
965- self , from_text , generate , return_log_probs , tokens , attention_mask
1002+ def test_transformers_wrapper (
1003+ self ,
1004+ from_text ,
1005+ generate ,
1006+ return_log_probs ,
1007+ tokens ,
1008+ attention_mask ,
1009+ transformers_instance ,
9661010 ):
9671011 torch .manual_seed (0 )
968- from transformers import AutoTokenizer , GPT2Config , GPT2LMHeadModel
969-
970- # model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny"
971- # Load the model and tokenizer
972- # model = AutoModel.from_pretrained(model_name)
973- # tokenizer = AutoTokenizer.from_pretrained(model_name)
9741012
975- tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
976- model = GPT2LMHeadModel (GPT2Config ())
977-
978- tokenizer .pad_token = tokenizer .eos_token
979- tokenizer .padding_side = "left"
1013+ model , tokenizer = transformers_instance
9801014
9811015 m = TransformersWrapper (
9821016 model ,
@@ -1019,7 +1053,7 @@ def test_TransformersWrapper(
10191053 (False , True , False , torch .randint (1024 , (1 , 10 )), None ),
10201054 ],
10211055 )
1022- def test_from_vllm (
1056+ def test_vllm_wrapper (
10231057 self ,
10241058 from_text ,
10251059 generate ,
@@ -1163,15 +1197,11 @@ def _run_check(
11631197 (True , None , None ),
11641198 ],
11651199 )
1166- def test_from_hf_logprobs (self , from_text , tokens , attention_mask ):
1200+ def test_transformers_logprobs (
1201+ self , from_text , tokens , attention_mask , transformers_instance
1202+ ):
11671203 torch .manual_seed (0 )
1168- from transformers import AutoTokenizer , GPT2Config , GPT2LMHeadModel
1169-
1170- tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
1171- model = GPT2LMHeadModel (GPT2Config ()).eval ()
1172-
1173- tokenizer .pad_token = tokenizer .eos_token
1174- tokenizer .padding_side = "left"
1204+ model , tokenizer = transformers_instance
11751205
11761206 m_generate = TransformersWrapper (
11771207 model ,
@@ -1201,7 +1231,7 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
12011231 (True , False , torch .randint (1024 , (1 , 10 )), None ),
12021232 ],
12031233 )
1204- def test_from_vllm_logprobs (
1234+ def test_vllm_logprobs (
12051235 self , from_text , tokens , attention_mask , pad_output , vllm_instance
12061236 ):
12071237 torch .manual_seed (0 )
@@ -1254,6 +1284,7 @@ def _check_lps(
12541284 )
12551285 td_logprobs = model_logprobs (tdin_logprobs )
12561286 assert td_generate .log_probs .shape == td_generate .tokens_response .shape
1287+ assert td_logprobs .log_probs .shape == td_logprobs .tokens_response .shape
12571288 assert td_logprobs .log_probs .shape == td_generate .tokens_response .shape
12581289 torch .testing .assert_close (
12591290 td_generate .log_probs , td_logprobs .log_probs , rtol = tol , atol = tol
@@ -1374,7 +1405,7 @@ def _run_check_collector(self, policy):
13741405 assert "tokens" in data
13751406 # assert ("next", "tokens") in data
13761407
1377- def test_generate_multiple_trajs_vllm (self , vllm_instance ):
1408+ def test_vllm_generate_multiple_trajs (self , vllm_instance ):
13781409 policy = vLLMWrapper (
13791410 vllm_instance ,
13801411 return_log_probs = True ,
@@ -1386,6 +1417,63 @@ def test_generate_multiple_trajs_vllm(self, vllm_instance):
13861417 )
13871418 data = policy (data )
13881419
1420+ @set_list_to_stack (True )
1421+ @pytest .mark .parametrize ("from_text" , [True , False ])
1422+ @pytest .mark .parametrize ("generate" , [True , False ])
1423+ def test_transformers_long_sequences (
1424+ self , from_text , generate , transformers_instance_pretrained
1425+ ):
1426+ torch .manual_seed (42 )
1427+ model , tokenizer = transformers_instance_pretrained
1428+ prompts = [
1429+ "The quick brown fox jumps over the lazy dog." , # Likely to finish soon
1430+ "Once upon a time in a land far, far away, there was a" , # Likely to continue longer
1431+ "In the beginning, the universe was created. This has made a lot of people very angry and been widely regarded as a bad move." ,
1432+ ]
1433+ data = lazy_stack ([TensorDict () for _ in range (len (prompts ))])
1434+ data ["text" ] = prompts
1435+ eos_token_id = tokenizer .convert_tokens_to_ids ("," )
1436+ if not from_text :
1437+ data ["tokens" ] = tokenizer (data ["text" ])["input_ids" ]
1438+ data ["attention_mask" ] = (
1439+ 0 * data .get ("tokens" , as_nested_tensor = True , layout = torch .strided ) + 1
1440+ )
1441+ if not generate :
1442+ # we need responses
1443+ responses = prompts [1 :] + [" et dolore magna aliqua." ]
1444+ data ["text_response" ] = responses
1445+ if not from_text :
1446+ data ["tokens_response" ] = tokenizer (data ["text_response" ])["input_ids" ]
1447+ # make sure dimensions are ragged for tokens entries
1448+ if "tokens" in data :
1449+ assert data .get_item_shape ("tokens" )[- 1 ] == - 1
1450+ if "tokens_response" in data :
1451+ assert data .get_item_shape ("tokens_response" )[- 1 ] == - 1
1452+ generate_kwargs = {}
1453+ if generate :
1454+ generate_kwargs = {
1455+ "max_new_tokens" : 128 , # Set a reasonable number of new tokens to generate
1456+ "min_length" : 20 , # Ensure a minimum length for the generated sequence
1457+ "pad_token_id" : tokenizer .pad_token_id , # Use the tokenizer's pad token
1458+ "forced_eos_token_id" : eos_token_id , # Use comma as an EOS token
1459+ }
1460+ policy = TransformersWrapper (
1461+ model ,
1462+ tokenizer = tokenizer ,
1463+ from_text = from_text ,
1464+ generate = generate ,
1465+ return_log_probs = True ,
1466+ # TODO: use n trajs
1467+ generate_kwargs = generate_kwargs ,
1468+ )
1469+ data_policy = policy (data )
1470+ if "tokens" in data_policy :
1471+ assert data_policy .get_item_shape ("tokens" )[- 1 ] == - 1
1472+ if "tokens_response" in data_policy :
1473+ assert (
1474+ data_policy .get_item_shape ("tokens_response" )[- 1 ] == - 1
1475+ ) # TODO: this fails
1476+
13891477
13901478if __name__ == "__main__" :
13911479 args , unknown = argparse .ArgumentParser ().parse_known_args ()
0 commit comments