@@ -1130,6 +1130,128 @@ def test_async_mcp_tools(self):
11301130 env_pool .close ()
11311131
11321132
1133+ class TestThinkingPrompt :
1134+ @pytest .fixture (autouse = True , scope = "class" )
1135+ def base_env (self ):
1136+ from transformers import AutoTokenizer
1137+
1138+ tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen2.5-3B" )
1139+ env = GSM8KEnv (shuffle = False , tokenizer = tokenizer , max_steps = 10 )
1140+ return env
1141+
1142+ @pytest .mark .skipif (not _has_transformers , reason = "requires transformers" )
1143+ @pytest .mark .skipif (not _has_datasets , reason = "requires gsm8k" )
1144+ @pytest .mark .parametrize (
1145+ "role,edit_last_turn" ,
1146+ [("assistant" , True ), ("assistant" , False ), ("user" , False )],
1147+ )
1148+ @pytest .mark .parametrize ("zero_reward" , [True , False ])
1149+ @pytest .mark .parametrize ("undo_done" , [True , False ])
1150+ @pytest .mark .parametrize ("random_prompt" , [True , False ])
1151+ def test_thinking_prompt_wrong_answer (
1152+ self ,
1153+ role ,
1154+ edit_last_turn ,
1155+ zero_reward ,
1156+ undo_done ,
1157+ random_prompt ,
1158+ tmp_path ,
1159+ base_env ,
1160+ ):
1161+ from torchrl .envs .llm .transforms import AddThinkingPrompt
1162+
1163+ if isinstance (base_env .transform [- 1 ], AddThinkingPrompt ):
1164+ base_env .transform .pop ()
1165+ env = base_env .reset_dataloader ()
1166+ env = base_env .append_transform (
1167+ AddThinkingPrompt (
1168+ cond = lambda td : td ["reward" ] < 50 ,
1169+ role = role ,
1170+ edit_last_turn = edit_last_turn ,
1171+ zero_reward = zero_reward ,
1172+ undo_done = undo_done ,
1173+ random_prompt = random_prompt ,
1174+ )
1175+ )
1176+ reset = env .reset ()
1177+ assert reset [0 ]["history" ][- 1 ].content .startswith (
1178+ "Natalia sold clips to 48 of her friends in April"
1179+ )
1180+ policy_anser = (
1181+ "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
1182+ "To find the total, I need to add April and May: 48 + 24 = 72. Therefore, Natalia sold 72 clips altogether in April and May.</think>\n <answer>322 clips</answer><|im_end|>"
1183+ )
1184+ reset ["text_response" ] = [policy_anser ]
1185+ s = env .step (reset )
1186+ if zero_reward :
1187+ assert (s ["next" , "reward" ] == 0 ).all ()
1188+ else :
1189+ assert (s ["next" , "reward" ] != 0 ).all ()
1190+ if undo_done :
1191+ assert (s ["next" , "done" ] == 0 ).all ()
1192+ else :
1193+ assert (s ["next" , "done" ] != 0 ).all ()
1194+ if edit_last_turn :
1195+ assert s ["next" , "history" ].shape == (1 , 3 )
1196+ else :
1197+ assert s ["next" , "history" ].shape == (1 , 4 )
1198+ if role == "assistant" :
1199+ assert s [0 ]["next" , "history" , "role" ][- 1 ] == "assistant"
1200+ else :
1201+ assert s [0 ]["next" , "history" , "role" ][- 1 ] == "user"
1202+
1203+ @pytest .mark .skipif (not _has_transformers , reason = "requires transformers" )
1204+ @pytest .mark .skipif (not _has_datasets , reason = "requires gsm8k" )
1205+ @pytest .mark .parametrize (
1206+ "role,edit_last_turn" ,
1207+ [("assistant" , True ), ("assistant" , False ), ("user" , False )],
1208+ )
1209+ @pytest .mark .parametrize ("zero_reward" , [True , False ])
1210+ @pytest .mark .parametrize ("undo_done" , [True , False ])
1211+ @pytest .mark .parametrize ("random_prompt" , [True , False ])
1212+ def test_thinking_prompt_correct_answer (
1213+ self ,
1214+ role ,
1215+ edit_last_turn ,
1216+ zero_reward ,
1217+ undo_done ,
1218+ random_prompt ,
1219+ tmp_path ,
1220+ base_env ,
1221+ ):
1222+ # checks that if cond returns False, nothing is changed
1223+ from torchrl .envs .llm .transforms import AddThinkingPrompt
1224+
1225+ if isinstance (base_env .transform [- 1 ], AddThinkingPrompt ):
1226+ base_env .transform .pop ()
1227+ env = base_env
1228+ env = env .reset_dataloader ()
1229+ env = env .append_transform (
1230+ AddThinkingPrompt (
1231+ cond = lambda td : td ["reward" ] < 50 ,
1232+ role = role ,
1233+ edit_last_turn = edit_last_turn ,
1234+ zero_reward = zero_reward ,
1235+ undo_done = undo_done ,
1236+ random_prompt = random_prompt ,
1237+ )
1238+ )
1239+ reset = env .reset ()
1240+ assert reset [0 ]["history" ][- 1 ].content .startswith (
1241+ "Natalia sold clips to 48 of her friends in April"
1242+ )
1243+ policy_anser = (
1244+ "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
1245+ "To find the total, I need to add April and May: 48 + 24 = 72. Therefore, Natalia sold 72 clips altogether in April and May.</think>\n <answer>72</answer><|im_end|>"
1246+ )
1247+ reset ["text_response" ] = [policy_anser ]
1248+ s = env .step (reset )
1249+ assert (s ["next" , "reward" ] != 0 ).all (), s ["next" , "reward" ]
1250+ assert s [0 ]["next" , "history" , "role" ][- 1 ] == "assistant"
1251+ assert s ["next" , "done" ].all ()
1252+ assert len (s [0 ]["next" , "history" , "content" ]) == 3
1253+
1254+
11331255if __name__ == "__main__" :
11341256 args , unknown = argparse .ArgumentParser ().parse_known_args ()
11351257 pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
0 commit comments