@@ -4861,31 +4861,22 @@ def policy(td):
48614861 r_reset = r [..., ::max_steps ]
48624862 if not batched :
48634863 if str2str :
4864+ all_strings = r_reset .view (- 1 )[LLMEnv ._DEFAULT_STR_KEY ]
4865+ assert sum (s == all_strings [0 ] for s in all_strings ) == repeats
4866+ assert sum (s == all_strings [repeats ] for s in all_strings ) == repeats
48644867 assert (
4865- r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4866- == r_reset [..., 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4868+ sum (s == all_strings [repeats * 2 ] for s in all_strings ) == repeats
48674869 )
4870+ else :
4871+ all_tokens = r_reset .view (- 1 )[LLMEnv ._DEFAULT_TOKEN_KEY ]
4872+ assert sum ((s == all_tokens [0 ]).all () for s in all_tokens ) == repeats
48684873 assert (
4869- r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4870- == r_reset [..., 2 ][LLMEnv ._DEFAULT_STR_KEY ]
4874+ sum ((s == all_tokens [repeats ]).all () for s in all_tokens ) == repeats
48714875 )
48724876 assert (
4873- r_reset [..., 0 ][ LLMEnv . _DEFAULT_STR_KEY ]
4874- != r_reset [..., 3 ][ LLMEnv . _DEFAULT_STR_KEY ]
4877+ sum (( s == all_tokens [ repeats * 2 ]). all () for s in all_tokens )
4878+ == repeats
48754879 )
4876- else :
4877- assert (
4878- r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4879- == r_reset [..., 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4880- ).all ()
4881- assert (
4882- r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4883- == r_reset [..., 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4884- ).all ()
4885- assert (
4886- r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4887- != r_reset [..., 3 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4888- ).any ()
48894880 else :
48904881 # When batched, each block contains the 3 reset packs
48914882 if str2str :
0 commit comments