1414from utcp .data .auth_implementations .basic_auth import BasicAuth
1515from utcp .data .auth_implementations .oauth2_auth import OAuth2Auth
1616
17- from utcp_gql .gql_call_template import GraphQLProvider
17+ from utcp_gql .gql_call_template import GraphQLCallTemplate
1818
1919if TYPE_CHECKING :
2020 from utcp .utcp_client import UtcpClient
@@ -70,25 +70,25 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str:
7070 return token_response ["access_token" ]
7171
7272 async def _prepare_headers (
73- self , provider : GraphQLProvider , tool_args : Optional [Dict [str , Any ]] = None
73+ self , call_template : GraphQLCallTemplate , tool_args : Optional [Dict [str , Any ]] = None
7474 ) -> Dict [str , str ]:
75- headers : Dict [str , str ] = provider .headers .copy () if provider .headers else {}
76- if provider .auth :
77- if isinstance (provider .auth , ApiKeyAuth ):
78- if provider .auth .api_key and provider .auth .location == "header" :
79- headers [provider .auth .var_name ] = provider .auth .api_key
80- elif isinstance (provider .auth , BasicAuth ):
75+ headers : Dict [str , str ] = call_template .headers .copy () if call_template .headers else {}
76+ if call_template .auth :
77+ if isinstance (call_template .auth , ApiKeyAuth ):
78+ if call_template .auth .api_key and call_template .auth .location == "header" :
79+ headers [call_template .auth .var_name ] = call_template .auth .api_key
80+ elif isinstance (call_template .auth , BasicAuth ):
8181 import base64
8282
83- userpass = f"{ provider .auth .username } :{ provider .auth .password } "
83+ userpass = f"{ call_template .auth .username } :{ call_template .auth .password } "
8484 headers ["Authorization" ] = "Basic " + base64 .b64encode (userpass .encode ()).decode ()
85- elif isinstance (provider .auth , OAuth2Auth ):
86- token = await self ._handle_oauth2 (provider .auth )
85+ elif isinstance (call_template .auth , OAuth2Auth ):
86+ token = await self ._handle_oauth2 (call_template .auth )
8787 headers ["Authorization" ] = f"Bearer { token } "
8888
8989 # Map selected tool_args into headers if requested
90- if tool_args and provider .header_fields :
91- for field in provider .header_fields :
90+ if tool_args and call_template .header_fields :
91+ for field in call_template .header_fields :
9292 if field in tool_args and isinstance (tool_args [field ], str ):
9393 headers [field ] = tool_args [field ]
9494
@@ -97,8 +97,8 @@ async def _prepare_headers(
9797 async def register_manual (
9898 self , caller : "UtcpClient" , manual_call_template : CallTemplate
9999 ) -> RegisterManualResult :
100- if not isinstance (manual_call_template , GraphQLProvider ):
101- raise ValueError ("GraphQLCommunicationProtocol requires a GraphQLProvider call template" )
100+ if not isinstance (manual_call_template , GraphQLCallTemplate ):
101+ raise ValueError ("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template" )
102102 self ._enforce_https_or_localhost (manual_call_template .url )
103103
104104 try :
@@ -176,26 +176,40 @@ async def call_tool(
176176 tool_args : Dict [str , Any ],
177177 tool_call_template : CallTemplate ,
178178 ) -> Any :
179- if not isinstance (tool_call_template , GraphQLProvider ):
180- raise ValueError ("GraphQLCommunicationProtocol requires a GraphQLProvider call template" )
179+ if not isinstance (tool_call_template , GraphQLCallTemplate ):
180+ raise ValueError ("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template" )
181181 self ._enforce_https_or_localhost (tool_call_template .url )
182182
183183 headers = await self ._prepare_headers (tool_call_template , tool_args )
184184 transport = AIOHTTPTransport (url = tool_call_template .url , headers = headers )
185185 async with GqlClient (transport = transport , fetch_schema_from_transport = True ) as session :
186- op_type = getattr (tool_call_template , "operation_type" , "query" )
187- # Strip manual prefix if present (client prefixes at save time)
188- base_tool_name = tool_name .split ("." , 1 )[- 1 ] if "." in tool_name else tool_name
189186 # Filter out header fields from GraphQL variables; these are sent via HTTP headers
190187 header_fields = tool_call_template .header_fields or []
191188 filtered_args = {k : v for k , v in tool_args .items () if k not in header_fields }
192189
193- arg_str = ", " .join (f"${ k } : String" for k in filtered_args .keys ())
194- var_defs = f"({ arg_str } )" if arg_str else ""
195- arg_pass = ", " .join (f"{ k } : ${ k } " for k in filtered_args .keys ())
196- arg_pass = f"({ arg_pass } )" if arg_pass else ""
190+ # Use custom query if provided (highest flexibility for agents)
191+ if tool_call_template .query :
192+ gql_str = tool_call_template .query
193+ else :
194+ # Auto-generate query - use variable_types for proper typing
195+ op_type = getattr (tool_call_template , "operation_type" , "query" )
196+ base_tool_name = tool_name .split ("." , 1 )[- 1 ] if "." in tool_name else tool_name
197+ variable_types = tool_call_template .variable_types or {}
198+
199+ # Build variable definitions with proper types (default to String)
200+ arg_str = ", " .join (
201+ f"${ k } : { variable_types .get (k , 'String' )} "
202+ for k in filtered_args .keys ()
203+ )
204+ var_defs = f"({ arg_str } )" if arg_str else ""
205+ arg_pass = ", " .join (f"{ k } : ${ k } " for k in filtered_args .keys ())
206+ arg_pass = f"({ arg_pass } )" if arg_pass else ""
207+
208+ # Note: Auto-generated queries for object-returning fields will still fail
209+ # without a selection set. Use the `query` field for full control.
210+ gql_str = f"{ op_type } { var_defs } {{ { base_tool_name } { arg_pass } }}"
211+ logger .debug (f"Auto-generated GraphQL: { gql_str } " )
197212
198- gql_str = f"{ op_type } { var_defs } {{ { base_tool_name } { arg_pass } }}"
199213 document = gql_query (gql_str )
200214 result = await session .execute (document , variable_values = filtered_args )
201215 return result
0 commit comments