77import argparse
88import contextlib
99import importlib .util
10+ import random
1011import re
12+ import time
1113
1214import pytest
1315import torch
@@ -434,7 +436,7 @@ def test_chat_env(slef, tokenizer):
434436 )
435437 )
436438 # Check history after reset
437- torchrl_logger .info (' td_reset["history"].content' , td_reset [ "history" ]. content )
439+ torchrl_logger .info (f' { td_reset ["history" ].content = } ' )
438440 assert len (td_reset ["history" ][0 ].content ) == 2
439441 assert td_reset ["history" ][0 , 0 ].content == "I'm system, do what I want."
440442 assert td_reset ["history" ][0 , 1 ].content .startswith ("I'm the user." )
@@ -593,9 +595,10 @@ def test_ifeval(self):
593595 env = IFEvalEnv (apply_template = True , tokenizer = tokenizer )
594596 torchrl_logger .info (env .reset ())
595597 r = env .reset ()
596- r [0 ][
597- "text_response"
598- ] = """<think>
598+ r .set (
599+ "text_response" ,
600+ [
601+ """<think>
599602The task requires crafting a riddle about a 'house' that's not traditionally considered one. The answer must be included, and the response should be at least 400 words with a title wrapped in double angular brackets. Let's start by brainstorming what could be considered a 'house' in a non-traditional sense. Ideas include natural shelters, abstract concepts, or objects that serve a similar purpose to a house.
600603One potential concept is a "womb," as it provides shelter and housing for a developing being. However, we need to ensure our riddle is engaging, meets the word count requirement, and includes the necessary elements like a title.
601604Let's construct a narrative around the chosen concept, ensuring it's detailed and follows the required structure.
@@ -637,6 +640,8 @@ def test_ifeval(self):
637640By embracing such metaphors, we're encouraged to look beyond the obvious and appreciate the myriad ways 'shelter' manifests in our lives. And so, the riddle serves not just as a puzzle to be solved but as a reflection on the profound connections that bind us to the very essence of existence.
638641</answer><|im_end|>
639642"""
643+ ],
644+ )
640645 td = env .step (r )
641646 assert td ["next" , "ifeval_score" ].all ()
642647 assert td .get (("next" , "reward" )) is not None
@@ -881,7 +886,7 @@ def test_python_interpreter_persistent_reset(self):
881886 r ["text_response" ] = [
882887 """Here is a python code to execute:
883888```python
884- # check if a is still defined
889+ # check if a is still defined
885890if "a" in globals():
886891 raise RuntimeError("a is still defined")
887892else:
@@ -899,7 +904,7 @@ def test_python_interpreter_persistent_reset(self):
899904 "<|im_start|>assistant\n "
900905 "Here is a python code to execute:\n "
901906 "```python\n "
902- "#\xa0 check if a is still defined\n "
907+ "# check if a is still defined\n "
903908 'if "a" in globals():\n '
904909 ' raise RuntimeError("a is still defined")\n '
905910 "else:\n "
@@ -914,6 +919,216 @@ def test_python_interpreter_persistent_reset(self):
914919 "<|im_start|>assistant\n " ,
915920 )
916921
922+ @pytest .mark .skipif (not _has_transformers , reason = "requires transformers" )
923+ def test_mcp_tool_transform (self ):
924+ """Test the MCPToolTransform with a simple calculator tool."""
925+ from torchrl .envs .llm import ChatEnv
926+ from torchrl .envs .llm .transforms .tools import MCPToolTransform
927+ from transformers import AutoTokenizer
928+
929+ # Define a simple calculator tool
930+ def calculator (operation : str , a : float , b : float ) -> dict :
931+ if operation == "add" :
932+ return {"result" : a + b }
933+ elif operation == "multiply" :
934+ return {"result" : a * b }
935+ else :
936+ raise ValueError (f"Unknown operation: { operation } " )
937+
938+ # Define the tool schema
939+ calculator_schema = {
940+ "name" : "calculator" ,
941+ "description" : "A simple calculator that can add or multiply two numbers" ,
942+ "parameters" : {
943+ "type" : "object" ,
944+ "properties" : {
945+ "operation" : {"type" : "string" , "enum" : ["add" , "multiply" ]},
946+ "a" : {"type" : "number" },
947+ "b" : {"type" : "number" },
948+ },
949+ "required" : ["operation" , "a" , "b" ],
950+ },
951+ }
952+
953+ # Create tools dictionary
954+ tools = {"calculator" : calculator }
955+ schemas = {"calculator" : calculator_schema }
956+
957+ # Create environment and transform
958+ tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen2.5-3B" )
959+ env = ChatEnv (
960+ batch_size = (1 ,),
961+ system_prompt = "You are a helpful assistant that uses a calculator." ,
962+ apply_template = True ,
963+ tokenizer = tokenizer ,
964+ )
965+ transform = MCPToolTransform (tools , schemas )
966+ env = env .append_transform (transform )
967+
968+ # Test single tool call
969+ td = TensorDict ({"text" : ["Let me calculate 2 + 3" ]}, batch_size = (1 ,))
970+ td = env .reset (td )
971+ td ["text_response" ] = [
972+ 'I will help you calculate 2 + 3:\n <tool>calculator\n {"operation": "add", "a": 2, "b": 3}</tool><|im_end|>'
973+ ]
974+ result = env .step (td )
975+
976+ # Check that the tool was executed and returned correct result
977+ history = result ["next" , "history" ]
978+ assert len (history [0 ]) == 4 # system, user, assistant, tool response
979+ assert history [0 , - 1 ].role == "tool"
980+ assert "result': 5" in history [0 , - 1 ].content
981+
982+ # Test multiple tool calls in one response
983+ td = TensorDict ({"text" : ["Calculate 2 + 3 and 4 * 5" ]}, batch_size = (1 ,))
984+ td = env .reset (td )
985+ td ["text_response" ] = [
986+ "I will help you calculate both:\n "
987+ '<tool>calculator\n {"operation": "add", "a": 2, "b": 3}</tool>\n '
988+ '<tool>calculator\n {"operation": "multiply", "a": 4, "b": 5}</tool><|im_end|>'
989+ ]
990+ result = env .step (td )
991+
992+ # Check that both tools were executed and returned correct results
993+ history = result ["next" , "history" ]
994+ assert (
995+ len (history [0 ]) == 5
996+ ) # system, user, assistant, tool response 1, tool response 2
997+ assert history [0 , - 2 ].role == "tool"
998+ assert history [0 , - 1 ].role == "tool"
999+ assert "result': 5" in history [0 , - 2 ].content # 2 + 3 = 5
1000+ assert "result': 20" in history [0 , - 1 ].content # 4 * 5 = 20
1001+
1002+ # Test error handling
1003+ td = TensorDict ({"text" : ["Calculate 2 ? 3" ]}, batch_size = (1 ,))
1004+ td = env .reset (td )
1005+ td ["text_response" ] = [
1006+ 'I will try to calculate:\n <tool>calculator\n {"operation": "invalid", "a": 2, "b": 3}</tool><|im_end|>'
1007+ ]
1008+ result = env .step (td )
1009+
1010+ # Check that error was handled gracefully
1011+ history = result ["next" , "history" ]
1012+ assert len (history [0 ]) == 4
1013+ assert history [0 , - 1 ].role == "tool"
1014+ assert "failed" in history [0 , - 1 ].content
1015+ assert "Unknown operation: invalid" in history [0 , - 1 ].content
1016+
1017+ # Test invalid JSON
1018+ td = TensorDict ({"text" : ["Calculate something" ]}, batch_size = (1 ,))
1019+ td = env .reset (td )
1020+ td ["text_response" ] = [
1021+ "Let me calculate:\n <tool>calculator\n invalid json</tool><|im_end|>"
1022+ ]
1023+ result = env .step (td )
1024+
1025+ # Check that JSON error was handled gracefully
1026+ history = result ["next" , "history" ]
1027+ assert len (history [0 ]) == 4
1028+ assert history [0 , - 1 ].role == "tool"
1029+ assert "failed" in history [0 , - 1 ].content
1030+ assert "Failed to parse tool arguments" in history [0 , - 1 ].content
1031+
1032+ # Define a tool that waits for a random amount of time
1033+ @classmethod
1034+ def delayed_calculator (cls , operation : str , a : float , b : float ) -> dict :
1035+ # Random delay between 100ms and 300ms
1036+ delay = random .uniform (0.1 , 0.3 )
1037+ time .sleep (delay )
1038+ if operation == "add" :
1039+ return {"result" : a + b , "delay" : delay }
1040+ elif operation == "multiply" :
1041+ return {"result" : a * b , "delay" : delay }
1042+ else :
1043+ raise ValueError (f"Unknown operation: { operation } " )
1044+
1045+ # Define the tool schema
1046+ calculator_schema = {
1047+ "name" : "delayed_calculator" ,
1048+ "description" : "A calculator that introduces random delays" ,
1049+ "parameters" : {
1050+ "type" : "object" ,
1051+ "properties" : {
1052+ "operation" : {"type" : "string" , "enum" : ["add" , "multiply" ]},
1053+ "a" : {"type" : "number" },
1054+ "b" : {"type" : "number" },
1055+ },
1056+ "required" : ["operation" , "a" , "b" ],
1057+ },
1058+ }
1059+
1060+ # Create environment factory
1061+ @classmethod
1062+ def make_env (cls ):
1063+ from torchrl .envs .llm .transforms .tools import MCPToolTransform
1064+
1065+ tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen2.5-3B" )
1066+ env = ChatEnv (
1067+ batch_size = (1 ,),
1068+ system_prompt = "I'm a calculator assistant" ,
1069+ apply_template = True ,
1070+ tokenizer = tokenizer ,
1071+ )
1072+ tools = {"calculator" : cls .delayed_calculator }
1073+ schemas = {"calculator" : cls .calculator_schema }
1074+ return env .append_transform (MCPToolTransform (tools , schemas ))
1075+
1076+ @pytest .mark .skipif (not _has_transformers , reason = "requires transformers" )
1077+ def test_async_mcp_tools (self ):
1078+ """Test async execution of MCP tools in an AsyncEnvPool."""
1079+ from tensordict import TensorDict
1080+ from torchrl .envs import AsyncEnvPool
1081+
1082+ # Create async env pool with 2 environments
1083+ env_pool = AsyncEnvPool (
1084+ [self .make_env , self .make_env ], backend = "multiprocessing"
1085+ )
1086+ try :
1087+ # Reset both environments
1088+ tdreset = TensorDict (
1089+ text = [["Let me calculate 2 + 3" ], ["Let me calculate 4 * 5" ]],
1090+ batch_size = (2 , 1 ),
1091+ )
1092+ td = env_pool .reset (tdreset )
1093+
1094+ # Send async steps to both environments
1095+ td ["text_response" ] = [
1096+ [
1097+ 'Let me calculate 2 + 3:\n <tool>calculator\n {"operation": "add", "a": 2, "b": 3}</tool><|im_end|>'
1098+ ],
1099+ [
1100+ 'Let me calculate 4 * 5:\n <tool>calculator\n {"operation": "multiply", "a": 4, "b": 5}</tool><|im_end|>'
1101+ ],
1102+ ]
1103+ env_pool .async_step_send (td )
1104+
1105+ # Get results as they complete
1106+ results = env_pool .async_step_recv (min_get = 1 ) # Get at least one result
1107+ assert len (results ) >= 1 # We should get at least one result
1108+
1109+ # Get remaining results
1110+ if len (results ) < 2 :
1111+ remaining = env_pool .async_step_recv ()
1112+ else :
1113+ remaining = []
1114+
1115+ # Combine results
1116+ all_results = torch .stack (list (results ) + list (remaining ))
1117+
1118+ # Verify results
1119+ history = all_results ["next" , "history" ]
1120+ assert len (history [0 , 0 ]) == 4 # system, user, assistant, tool response
1121+ assert history [0 , 0 , - 1 ].role == "tool"
1122+ assert any (
1123+ "result': 5" in c for c in history [:, 0 , - 1 ].content
1124+ ) # 2 + 3 = 5
1125+ assert any (
1126+ "result': 20" in c for c in history [:, 0 , - 1 ].content
1127+ ) # 4 * 5 = 20
1128+
1129+ finally :
1130+ env_pool .close ()
1131+
9171132
9181133if __name__ == "__main__" :
9191134 args , unknown = argparse .ArgumentParser ().parse_known_args ()
0 commit comments