diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9394330..bf677f8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,10 +31,11 @@ jobs: pip install -e plugins/communication_protocols/http[dev] pip install -e plugins/communication_protocols/mcp[dev] pip install -e plugins/communication_protocols/text[dev] + pip install -e plugins/communication_protocols/socket[dev] - name: Run tests with pytest run: | - pytest core/tests/ plugins/communication_protocols/cli/tests/ plugins/communication_protocols/http/tests/ plugins/communication_protocols/mcp/tests/ plugins/communication_protocols/text/tests/ --doctest-modules --junitxml=junit/test-results.xml --cov=core/src/utcp --cov-report=xml --cov-report=html + pytest core/tests/ plugins/communication_protocols/cli/tests/ plugins/communication_protocols/http/tests/ plugins/communication_protocols/mcp/tests/ plugins/communication_protocols/text/tests/ plugins/communication_protocols/socket/tests/ --doctest-modules --junitxml=junit/test-results.xml --cov=core/src/utcp --cov-report=xml --cov-report=html - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 diff --git a/plugins/communication_protocols/gql/README.md b/plugins/communication_protocols/gql/README.md index 8febb5a..e329da2 100644 --- a/plugins/communication_protocols/gql/README.md +++ b/plugins/communication_protocols/gql/README.md @@ -1 +1,47 @@ -Find the UTCP readme at https://github.com/universal-tool-calling-protocol/python-utcp. \ No newline at end of file + +UTCP GraphQL Communication Protocol Plugin + +This plugin integrates GraphQL as a UTCP 1.0 communication protocol and call template. It supports discovery via schema introspection, authenticated calls, and header handling. + +Getting Started + +Installation + +```bash +pip install gql +``` + +Registration + +```python +import utcp_gql +utcp_gql.register() +``` + +How To Use + +- Ensure the plugin is imported and registered: `import utcp_gql; utcp_gql.register()`. +- Add a manual in your client config: + ```json + { + "name": "my_graph", + "call_template_type": "graphql", + "url": "https://your.graphql/endpoint", + "operation_type": "query", + "headers": { "x-client": "utcp" }, + "header_fields": ["x-session-id"] + } + ``` +- Call a tool: + ```python + await client.call_tool("my_graph.someQuery", {"id": "123", "x-session-id": "abc"}) + ``` + +Notes + +- Tool names are prefixed by the manual name (e.g., `my_graph.someQuery`). +- Headers merge static `headers` plus whitelisted dynamic fields from `header_fields`. +- Supported auth: API key, Basic auth, OAuth2 (client-credentials). +- Security: only `https://` or `http://localhost`/`http://127.0.0.1` endpoints. + +For UTCP core docs, see https://github.com/universal-tool-calling-protocol/python-utcp. \ No newline at end of file diff --git a/plugins/communication_protocols/gql/src/utcp_gql/__init__.py b/plugins/communication_protocols/gql/src/utcp_gql/__init__.py index e69de29..7362502 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/__init__.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/__init__.py @@ -0,0 +1,9 @@ +from utcp.plugins.discovery import register_communication_protocol, register_call_template + +from .gql_communication_protocol import GraphQLCommunicationProtocol +from .gql_call_template import GraphQLProvider, GraphQLProviderSerializer + + +def register(): + register_communication_protocol("graphql", GraphQLCommunicationProtocol()) + register_call_template("graphql", GraphQLProviderSerializer()) \ No newline at end of file diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py index dfe5b07..b86c2ad 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py @@ -1,7 +1,10 @@ -from utcp.data.call_template import CallTemplate -from utcp.data.auth import Auth +from utcp.data.call_template import CallTemplate, CallTemplateSerializer +from utcp.data.auth import Auth, AuthSerializer +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback from typing import Dict, List, Optional, Literal -from pydantic import Field +from pydantic import Field, field_serializer, field_validator class GraphQLProvider(CallTemplate): """Provider configuration for GraphQL-based tools. @@ -27,3 +30,31 @@ class GraphQLProvider(CallTemplate): auth: Optional[Auth] = None headers: Optional[Dict[str, str]] = None header_fields: Optional[List[str]] = Field(default=None, description="List of input fields to be sent as request headers for the initial connection.") + + @field_serializer("auth") + def serialize_auth(self, auth: Optional[Auth]): + if auth is None: + return None + return AuthSerializer().to_dict(auth) + + @field_validator("auth", mode="before") + @classmethod + def validate_auth(cls, v: Optional[Auth | dict]): + if v is None: + return None + if isinstance(v, Auth): + return v + return AuthSerializer().validate_dict(v) + + +class GraphQLProviderSerializer(Serializer[GraphQLProvider]): + def to_dict(self, obj: GraphQLProvider) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> GraphQLProvider: + try: + return GraphQLProvider.model_validate(data) + except Exception as e: + raise UtcpSerializerValidationError( + f"Invalid GraphQLProvider: {e}\n{traceback.format_exc()}" + ) diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py index f27f803..d1ab5fd 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py @@ -1,36 +1,54 @@ -import sys -from typing import Dict, Any, List, Optional, Callable +import logging +from typing import Dict, Any, List, Optional, AsyncGenerator, TYPE_CHECKING + import aiohttp -import asyncio -import ssl from gql import Client as GqlClient, gql as gql_query from gql.transport.aiohttp import AIOHTTPTransport -from utcp.client.client_transport_interface import ClientTransportInterface -from utcp.shared.provider import Provider, GraphQLProvider -from utcp.shared.tool import Tool, ToolInputOutputSchema -from utcp.shared.auth import ApiKeyAuth, BasicAuth, OAuth2Auth -import logging + +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp.data.call_template import CallTemplate +from utcp.data.tool import Tool, JsonSchema +from utcp.data.utcp_manual import UtcpManual +from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth +from utcp.data.auth_implementations.basic_auth import BasicAuth +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth + +from utcp_gql.gql_call_template import GraphQLProvider + +if TYPE_CHECKING: + from utcp.utcp_client import UtcpClient + logging.basicConfig( level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s" + format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s", ) logger = logging.getLogger(__name__) -class GraphQLClientTransport(ClientTransportInterface): - """ - Simple, robust, production-ready GraphQL transport using gql. - Stateless, per-operation. Supports all GraphQL features. + +class GraphQLCommunicationProtocol(CommunicationProtocol): + """GraphQL protocol implementation for UTCP 1.0. + + - Discovers tools via GraphQL schema introspection. + - Executes per-call sessions using `gql` over HTTP(S). + - Supports `ApiKeyAuth`, `BasicAuth`, and `OAuth2Auth`. + - Enforces HTTPS or localhost for security. """ - def __init__(self): + + def __init__(self) -> None: self._oauth_tokens: Dict[str, Dict[str, Any]] = {} - def _enforce_https_or_localhost(self, url: str): - if not (url.startswith("https://") or url.startswith("http://localhost") or url.startswith("http://127.0.0.1")): + def _enforce_https_or_localhost(self, url: str) -> None: + if not ( + url.startswith("https://") + or url.startswith("http://localhost") + or url.startswith("http://127.0.0.1") + ): raise ValueError( - f"Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. Got: {url}. " - "Non-secure URLs are vulnerable to man-in-the-middle attacks." + "Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. " + f"Got: {url}." ) async def _handle_oauth2(self, auth: OAuth2Auth) -> str: @@ -39,10 +57,10 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: return self._oauth_tokens[client_id]["access_token"] async with aiohttp.ClientSession() as session: data = { - 'grant_type': 'client_credentials', - 'client_id': client_id, - 'client_secret': auth.client_secret, - 'scope': auth.scope + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": auth.client_secret, + "scope": auth.scope, } async with session.post(auth.token_url, data=data) as resp: resp.raise_for_status() @@ -50,87 +68,144 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: self._oauth_tokens[client_id] = token_response return token_response["access_token"] - async def _prepare_headers(self, provider: GraphQLProvider) -> Dict[str, str]: - headers = provider.headers.copy() if provider.headers else {} + async def _prepare_headers( + self, provider: GraphQLProvider, tool_args: Optional[Dict[str, Any]] = None + ) -> Dict[str, str]: + headers: Dict[str, str] = provider.headers.copy() if provider.headers else {} if provider.auth: if isinstance(provider.auth, ApiKeyAuth): - if provider.auth.api_key: - if provider.auth.location == "header": - headers[provider.auth.var_name] = provider.auth.api_key - # (query/cookie not supported for GraphQL by default) + if provider.auth.api_key and provider.auth.location == "header": + headers[provider.auth.var_name] = provider.auth.api_key elif isinstance(provider.auth, BasicAuth): import base64 + userpass = f"{provider.auth.username}:{provider.auth.password}" headers["Authorization"] = "Basic " + base64.b64encode(userpass.encode()).decode() elif isinstance(provider.auth, OAuth2Auth): token = await self._handle_oauth2(provider.auth) headers["Authorization"] = f"Bearer {token}" + + # Map selected tool_args into headers if requested + if tool_args and provider.header_fields: + for field in provider.header_fields: + if field in tool_args and isinstance(tool_args[field], str): + headers[field] = tool_args[field] + return headers - async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: - if not isinstance(manual_provider, GraphQLProvider): - raise ValueError("GraphQLClientTransport can only be used with GraphQLProvider") - self._enforce_https_or_localhost(manual_provider.url) - headers = await self._prepare_headers(manual_provider) - transport = AIOHTTPTransport(url=manual_provider.url, headers=headers) - async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: - schema = session.client.schema - tools = [] - # Queries - if hasattr(schema, 'query_type') and schema.query_type: - for name, field in schema.query_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - # Mutations - if hasattr(schema, 'mutation_type') and schema.mutation_type: - for name, field in schema.mutation_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - # Subscriptions (listed, but not called here) - if hasattr(schema, 'subscription_type') and schema.subscription_type: - for name, field in schema.subscription_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - return tools - - async def deregister_tool_provider(self, manual_provider: Provider) -> None: - # Stateless: nothing to do - pass - - async def call_tool(self, tool_name: str, tool_args: Dict[str, Any], tool_provider: Provider, query: Optional[str] = None) -> Any: - if not isinstance(tool_provider, GraphQLProvider): - raise ValueError("GraphQLClientTransport can only be used with GraphQLProvider") - self._enforce_https_or_localhost(tool_provider.url) - headers = await self._prepare_headers(tool_provider) - transport = AIOHTTPTransport(url=tool_provider.url, headers=headers) + async def register_manual( + self, caller: "UtcpClient", manual_call_template: CallTemplate + ) -> RegisterManualResult: + if not isinstance(manual_call_template, GraphQLProvider): + raise ValueError("GraphQLCommunicationProtocol requires a GraphQLProvider call template") + self._enforce_https_or_localhost(manual_call_template.url) + + try: + headers = await self._prepare_headers(manual_call_template) + transport = AIOHTTPTransport(url=manual_call_template.url, headers=headers) + async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: + schema = session.client.schema + tools: List[Tool] = [] + + # Queries + if hasattr(schema, "query_type") and schema.query_type: + for name, field in schema.query_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + # Mutations + if hasattr(schema, "mutation_type") and schema.mutation_type: + for name, field in schema.mutation_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + # Subscriptions (listed for completeness) + if hasattr(schema, "subscription_type") and schema.subscription_type: + for name, field in schema.subscription_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + manual = UtcpManual(tools=tools) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=True, + errors=[], + ) + except Exception as e: + logger.error(f"GraphQL manual registration failed for '{manual_call_template.name}': {e}") + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=UtcpManual(manual_version="0.0.0", tools=[]), + success=False, + errors=[str(e)], + ) + + async def deregister_manual( + self, caller: "UtcpClient", manual_call_template: CallTemplate + ) -> None: + # Stateless: nothing to clean up + return None + + async def call_tool( + self, + caller: "UtcpClient", + tool_name: str, + tool_args: Dict[str, Any], + tool_call_template: CallTemplate, + ) -> Any: + if not isinstance(tool_call_template, GraphQLProvider): + raise ValueError("GraphQLCommunicationProtocol requires a GraphQLProvider call template") + self._enforce_https_or_localhost(tool_call_template.url) + + headers = await self._prepare_headers(tool_call_template, tool_args) + transport = AIOHTTPTransport(url=tool_call_template.url, headers=headers) async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: - if query is not None: - document = gql_query(query) - result = await session.execute(document, variable_values=tool_args) - return result - # If no query provided, build a simple query - # Default to query operation - op_type = getattr(tool_provider, 'operation_type', 'query') - arg_str = ', '.join(f"${k}: String" for k in tool_args.keys()) + op_type = getattr(tool_call_template, "operation_type", "query") + # Strip manual prefix if present (client prefixes at save time) + base_tool_name = tool_name.split(".", 1)[-1] if "." in tool_name else tool_name + + arg_str = ", ".join(f"${k}: String" for k in tool_args.keys()) var_defs = f"({arg_str})" if arg_str else "" - arg_pass = ', '.join(f"{k}: ${k}" for k in tool_args.keys()) + arg_pass = ", ".join(f"{k}: ${k}" for k in tool_args.keys()) arg_pass = f"({arg_pass})" if arg_pass else "" - gql_str = f"{op_type} {var_defs} {{ {tool_name}{arg_pass} }}" + + gql_str = f"{op_type} {var_defs} {{ {base_tool_name}{arg_pass} }}" document = gql_query(gql_str) result = await session.execute(document, variable_values=tool_args) return result + async def call_tool_streaming( + self, + caller: "UtcpClient", + tool_name: str, + tool_args: Dict[str, Any], + tool_call_template: CallTemplate, + ) -> AsyncGenerator[Any, None]: + # Basic implementation: execute non-streaming and yield once + result = await self.call_tool(caller, tool_name, tool_args, tool_call_template) + yield result + async def close(self) -> None: - self._oauth_tokens.clear() + self._oauth_tokens.clear() \ No newline at end of file diff --git a/plugins/communication_protocols/gql/tests/test_graphql_protocol.py b/plugins/communication_protocols/gql/tests/test_graphql_protocol.py new file mode 100644 index 0000000..8080957 --- /dev/null +++ b/plugins/communication_protocols/gql/tests/test_graphql_protocol.py @@ -0,0 +1,111 @@ +import asyncio +import os +import sys +import types +import pytest + + +# Ensure plugin src is importable +PLUGIN_SRC = os.path.join(os.path.dirname(__file__), "..", "src") +PLUGIN_SRC = os.path.abspath(PLUGIN_SRC) +if PLUGIN_SRC not in sys.path: + sys.path.append(PLUGIN_SRC) + +import utcp_gql +from utcp_gql.gql_call_template import GraphQLProvider +from utcp_gql import gql_communication_protocol as gql_module + +from utcp.data.tool import Tool +from utcp.data.utcp_manual import UtcpManual +from utcp.utcp_client import UtcpClient +from utcp.implementations.utcp_client_implementation import UtcpClientImplementation + + +class FakeSchema: + def __init__(self): + # Minimal field objects with descriptions + self.query_type = types.SimpleNamespace( + fields={ + "hello": types.SimpleNamespace(description="Returns greeting"), + } + ) + self.mutation_type = types.SimpleNamespace( + fields={ + "add": types.SimpleNamespace(description="Adds numbers"), + } + ) + self.subscription_type = None + + +class FakeClientObj: + def __init__(self): + self.client = types.SimpleNamespace(schema=FakeSchema()) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def execute(self, document, variable_values=None): + # document is a gql query; we can base behavior on variable_values + variable_values = variable_values or {} + # Determine operation by presence of variables used + if "hello" in str(document): + name = variable_values.get("name", "") + return {"hello": f"Hello {name}"} + if "add" in str(document): + a = int(variable_values.get("a", 0)) + b = int(variable_values.get("b", 0)) + return {"add": a + b} + return {"ok": True} + + +class FakeTransport: + def __init__(self, url: str, headers: dict | None = None): + self.url = url + self.headers = headers or {} + + +@pytest.mark.asyncio +async def test_graphql_register_and_call(monkeypatch): + # Patch gql client/transport used by protocol to avoid needing a real server + monkeypatch.setattr(gql_module, "GqlClient", lambda *args, **kwargs: FakeClientObj()) + monkeypatch.setattr(gql_module, "AIOHTTPTransport", FakeTransport) + # Avoid real GraphQL parsing; pass-through document string to fake execute + monkeypatch.setattr(gql_module, "gql_query", lambda s: s) + + # Register plugin (call_template serializer + protocol) + utcp_gql.register() + + # Create protocol and manual call template + protocol = gql_module.GraphQLCommunicationProtocol() + provider = GraphQLProvider( + name="mock_graph", + call_template_type="graphql", + url="http://localhost/graphql", + operation_type="query", + headers={"x-client": "utcp"}, + header_fields=["x-session-id"], + ) + + # Minimal UTCP client implementation for caller context + client: UtcpClient = await UtcpClientImplementation.create() + client.config.variables = {} + + # Register and discover tools + reg = await protocol.register_manual(client, provider) + assert reg.success is True + assert isinstance(reg.manual, UtcpManual) + tool_names = sorted(t.name for t in reg.manual.tools) + assert "hello" in tool_names + assert "add" in tool_names + + # Call hello + res = await protocol.call_tool(client, "mock_graph.hello", {"name": "UTCP", "x-session-id": "abc"}, provider) + assert res == {"hello": "Hello UTCP"} + + # Call add (mutation) + provider.operation_type = "mutation" + res2 = await protocol.call_tool(client, "mock_graph.add", {"a": 2, "b": 3}, provider) + assert res2 == {"add": 5} \ No newline at end of file diff --git a/plugins/communication_protocols/mcp/pyproject.toml b/plugins/communication_protocols/mcp/pyproject.toml index 36bb48e..9232cd5 100644 --- a/plugins/communication_protocols/mcp/pyproject.toml +++ b/plugins/communication_protocols/mcp/pyproject.toml @@ -15,7 +15,8 @@ dependencies = [ "pydantic>=2.0", "mcp>=1.12", "utcp>=1.0", - "mcp-use>=1.3" + "mcp-use>=1.3", + "langchain==0.3.27", ] classifiers = [ "Development Status :: 4 - Beta", diff --git a/plugins/communication_protocols/socket/README.md b/plugins/communication_protocols/socket/README.md index 8febb5a..3e695c9 100644 --- a/plugins/communication_protocols/socket/README.md +++ b/plugins/communication_protocols/socket/README.md @@ -1 +1,44 @@ -Find the UTCP readme at https://github.com/universal-tool-calling-protocol/python-utcp. \ No newline at end of file +# UTCP Socket Plugin (UDP/TCP) + +This plugin adds UDP and TCP communication protocols to UTCP 1.0. + +## Running Tests + +Prerequisites: +- Python 3.10+ +- `pip` +- (Optional) a virtual environment + +1) Install core and the socket plugin in editable mode with dev extras: + +```bash +pip install -e "core[dev]" +pip install -e plugins/communication_protocols/socket[dev] +``` + +2) Run the socket plugin tests: + +```bash +python -m pytest plugins/communication_protocols/socket/tests -v +``` + +3) Run a single test or filter by keyword: + +```bash +# One file +python -m pytest plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py -v + +# Filter by keyword (e.g., delimiter framing) +python -m pytest plugins/communication_protocols/socket/tests -k delimiter -q +``` + +4) Optional end-to-end sanity check (mock UDP/TCP servers): + +```bash +python scripts/socket_sanity.py +``` + +Notes: +- On Windows, your firewall may prompt the first time tests open UDP/TCP sockets; allow access or run as admin if needed. +- Tests use `pytest-asyncio`. The dev extras installed above provide required dependencies. +- Streaming is single-chunk by design, consistent with HTTP/Text transports. Multi-chunk streaming can be added later behind provider configuration. \ No newline at end of file diff --git a/plugins/communication_protocols/socket/pyproject.toml b/plugins/communication_protocols/socket/pyproject.toml index 06f845e..2f232ad 100644 --- a/plugins/communication_protocols/socket/pyproject.toml +++ b/plugins/communication_protocols/socket/pyproject.toml @@ -36,4 +36,7 @@ dev = [ [project.urls] Homepage = "https://utcp.io" Source = "https://github.com/universal-tool-calling-protocol/python-utcp" -Issues = "https://github.com/universal-tool-calling-protocol/python-utcp/issues" \ No newline at end of file +Issues = "https://github.com/universal-tool-calling-protocol/python-utcp/issues" + +[project.entry-points."utcp.plugins"] +socket = "utcp_socket:register" \ No newline at end of file diff --git a/plugins/communication_protocols/socket/src/utcp_socket/__init__.py b/plugins/communication_protocols/socket/src/utcp_socket/__init__.py index e69de29..a0b7f3b 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/__init__.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/__init__.py @@ -0,0 +1,18 @@ +from utcp.plugins.discovery import register_communication_protocol, register_call_template +from utcp_socket.tcp_communication_protocol import TCPTransport +from utcp_socket.udp_communication_protocol import UDPTransport +from utcp_socket.tcp_call_template import TCPProviderSerializer +from utcp_socket.udp_call_template import UDPProviderSerializer + + +def register() -> None: + # Register communication protocols + register_communication_protocol("tcp", TCPTransport()) + register_communication_protocol("udp", UDPTransport()) + + # Register call templates and their serializers + register_call_template("tcp", TCPProviderSerializer()) + register_call_template("udp", UDPProviderSerializer()) + + +__all__ = ["register"] \ No newline at end of file diff --git a/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py b/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py index 157e43c..10fc1d6 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py @@ -1,6 +1,9 @@ from utcp.data.call_template import CallTemplate from typing import Optional, Literal from pydantic import Field +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback class TCPProvider(CallTemplate): """Provider configuration for raw TCP socket tools. @@ -63,7 +66,7 @@ class TCPProvider(CallTemplate): # Delimiter-based framing options message_delimiter: str = Field( default='\x00', - description="Delimiter to detect end of TCP response (e.g., '\\n', '\\r\\n', '\\x00'). Used with 'delimiter' framing." + description="Delimiter to detect end of TCP response (e.g., '\n', '\r\n', '\x00'). Used with 'delimiter' framing." ) # Fixed-length framing options fixed_message_length: Optional[int] = Field( @@ -77,3 +80,16 @@ class TCPProvider(CallTemplate): ) timeout: int = 30000 auth: None = None + + +class TCPProviderSerializer(Serializer[TCPProvider]): + def to_dict(self, obj: TCPProvider) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> TCPProvider: + try: + return TCPProvider.model_validate(data) + except Exception as e: + raise UtcpSerializerValidationError( + f"Invalid TCPProvider: {e}\n{traceback.format_exc()}" + ) diff --git a/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py b/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py index 1b360a8..d5d64ac 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py @@ -10,9 +10,12 @@ import sys from typing import Dict, Any, List, Optional, Callable, Union -from utcp.client.client_transport_interface import ClientTransportInterface -from utcp.shared.provider import Provider, TCPProvider -from utcp.shared.tool import Tool +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp_socket.tcp_call_template import TCPProvider, TCPProviderSerializer +from utcp.data.tool import Tool +from utcp.data.call_template import CallTemplate, CallTemplateSerializer +from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.utcp_manual import UtcpManual import logging logging.basicConfig( @@ -22,7 +25,7 @@ logger = logging.getLogger(__name__) -class TCPTransport(ClientTransportInterface): +class TCPTransport(CommunicationProtocol): """Transport implementation for TCP-based tool providers. This transport communicates with tools over TCP sockets. It supports: @@ -85,6 +88,35 @@ def _format_tool_call_message( else: # Default to JSON format return json.dumps(tool_args) + + def _ensure_tool_call_template(self, tool_data: Dict[str, Any], manual_call_template: TCPProvider) -> Dict[str, Any]: + """Normalize tool definition to include a valid 'tool_call_template'. + + - If 'tool_call_template' exists, validate it. + - Else if legacy 'tool_provider' exists, convert using TCPProviderSerializer. + - Else default to the provided manual_call_template. + """ + normalized = dict(tool_data) + try: + if "tool_call_template" in normalized and normalized["tool_call_template"] is not None: + try: + ctpl = CallTemplateSerializer().validate_dict(normalized["tool_call_template"]) # type: ignore + normalized["tool_call_template"] = ctpl + except Exception: + normalized["tool_call_template"] = manual_call_template + elif "tool_provider" in normalized and normalized["tool_provider"] is not None: + try: + ctpl = TCPProviderSerializer().validate_dict(normalized["tool_provider"]) # type: ignore + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = ctpl + except Exception: + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = manual_call_template + else: + normalized["tool_call_template"] = manual_call_template + except Exception: + normalized["tool_call_template"] = manual_call_template + return normalized def _encode_message_with_framing(self, message: str, provider: TCPProvider) -> bytes: """Encode message with appropriate TCP framing. @@ -115,7 +147,7 @@ def _encode_message_with_framing(self, message: str, provider: TCPProvider) -> b elif provider.framing_strategy == "delimiter": # Add delimiter after the message - delimiter = provider.message_delimiter or "\\x00" + delimiter = provider.message_delimiter or "\x00" # Handle escape sequences delimiter = delimiter.encode('utf-8').decode('unicode_escape') return message_bytes + delimiter.encode('utf-8') @@ -170,7 +202,7 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid elif provider.framing_strategy == "delimiter": # Read until delimiter is found - delimiter = provider.message_delimiter or "\\x00" + delimiter = provider.message_delimiter or "\x00" delimiter = delimiter.encode('utf-8').decode('unicode_escape').encode('utf-8') response_data = b"" @@ -215,9 +247,6 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid return response_data - else: - raise ValueError(f"Unknown framing strategy: {provider.framing_strategy}") - async def _send_tcp_message( self, host: str, @@ -289,122 +318,91 @@ def _send_and_receive(): self._log_error(f"Error in TCP communication: {e}") raise - async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: - """Register a TCP provider and discover its tools. - - Sends a discovery message to the TCP provider and parses the response. - - Args: - manual_provider: The TCPProvider to register - - Returns: - List of tools discovered from the TCP provider - - Raises: - ValueError: If provider is not a TCPProvider - """ - if not isinstance(manual_provider, TCPProvider): + async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult: + """Register a TCP manual and discover its tools.""" + if not isinstance(manual_call_template, TCPProvider): raise ValueError("TCPTransport can only be used with TCPProvider") - self._log_info(f"Registering TCP provider '{manual_provider.name}'") + self._log_info(f"Registering TCP provider '{manual_call_template.name}'") try: - # Send discovery message - discovery_message = json.dumps({ - "type": "utcp" - }) - + discovery_message = json.dumps({"type": "utcp"}) response = await self._send_tcp_message( - manual_provider.host, - manual_provider.port, + manual_call_template.host, + manual_call_template.port, discovery_message, - manual_provider, - manual_provider.timeout / 1000.0, # Convert ms to seconds - manual_provider.response_byte_format + manual_call_template, + manual_call_template.timeout / 1000.0, + manual_call_template.response_byte_format ) - - # Parse response try: - # Handle bytes response by trying to decode as UTF-8 for JSON parsing - if isinstance(response, bytes): - response_str = response.decode('utf-8') - else: - response_str = response - + response_str = response.decode('utf-8') if isinstance(response, bytes) else response response_data = json.loads(response_str) - - # Check if response contains tools + tools: List[Tool] = [] if isinstance(response_data, dict) and 'tools' in response_data: tools_data = response_data['tools'] - - # Parse tools - tools = [] for tool_data in tools_data: try: - tool = Tool(**tool_data) - tools.append(tool) + normalized = self._ensure_tool_call_template(tool_data, manual_call_template) + tools.append(Tool(**normalized)) except Exception as e: - self._log_error(f"Invalid tool definition in TCP provider '{manual_provider.name}': {e}") + self._log_error(f"Invalid tool definition in TCP provider '{manual_call_template.name}': {e}") continue - - self._log_info(f"Discovered {len(tools)} tools from TCP provider '{manual_provider.name}'") - return tools + self._log_info(f"Discovered {len(tools)} tools from TCP provider '{manual_call_template.name}'") else: - self._log_info(f"No tools found in TCP provider '{manual_provider.name}' response") - return [] - + self._log_info(f"No tools found in TCP provider '{manual_call_template.name}' response") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=tools) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=True, + errors=[] + ) except json.JSONDecodeError as e: - self._log_error(f"Invalid JSON response from TCP provider '{manual_provider.name}': {e}") - return [] - + self._log_error(f"Invalid JSON response from TCP provider '{manual_call_template.name}': {e}") + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]), + success=False, + errors=[str(e)] + ) except Exception as e: - self._log_error(f"Error registering TCP provider '{manual_provider.name}': {e}") - return [] + self._log_error(f"Error registering TCP provider '{manual_call_template.name}': {e}") + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]), + success=False, + errors=[str(e)] + ) - async def deregister_tool_provider(self, manual_provider: Provider) -> None: - """Deregister a TCP provider. - - This is a no-op for TCP providers since connections are created per request. - - Args: - manual_provider: The provider to deregister - """ - if not isinstance(manual_provider, TCPProvider): + async def deregister_manual(self, caller, manual_call_template: CallTemplate) -> None: + """Deregister a TCP provider (no-op).""" + if not isinstance(manual_call_template, TCPProvider): raise ValueError("TCPTransport can only be used with TCPProvider") - - self._log_info(f"Deregistering TCP provider '{manual_provider.name}' (no-op)") + self._log_info(f"Deregistering TCP provider '{manual_call_template.name}' (no-op)") - async def call_tool(self, tool_name: str, tool_args: Dict[str, Any], tool_provider: Provider) -> Any: - """Call a TCP tool. - - Sends a tool call message to the TCP provider and returns the response. - - Args: - tool_name: Name of the tool to call - tool_args: Arguments for the tool call - tool_provider: The TCPProvider containing the tool - - Returns: - The response from the TCP tool - - Raises: - ValueError: If provider is not a TCPProvider - """ - if not isinstance(tool_provider, TCPProvider): + async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate): + async def _generator(): + yield await self.call_tool(caller, tool_name, tool_args, tool_call_template) + return _generator() + + async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate) -> Any: + """Call a TCP tool.""" + if not isinstance(tool_call_template, TCPProvider): raise ValueError("TCPTransport can only be used with TCPProvider") - self._log_info(f"Calling TCP tool '{tool_name}' on provider '{tool_provider.name}'") + self._log_info(f"Calling TCP tool '{tool_name}' on provider '{tool_call_template.name}'") try: - tool_call_message = self._format_tool_call_message(tool_args, tool_provider) + tool_call_message = self._format_tool_call_message(tool_args, tool_call_template) response = await self._send_tcp_message( - tool_provider.host, - tool_provider.port, + tool_call_template.host, + tool_call_template.port, tool_call_message, - tool_provider, - tool_provider.timeout / 1000.0, # Convert ms to seconds - tool_provider.response_byte_format + tool_call_template, + tool_call_template.timeout / 1000.0, + tool_call_template.response_byte_format ) return response diff --git a/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py b/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py index 4c704da..8c30c86 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py @@ -1,6 +1,9 @@ from utcp.data.call_template import CallTemplate from typing import Optional, Literal from pydantic import Field +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback class UDPProvider(CallTemplate): """Provider configuration for UDP (User Datagram Protocol) socket tools. @@ -38,3 +41,16 @@ class UDPProvider(CallTemplate): response_byte_format: Optional[str] = Field(default="utf-8", description="Encoding to decode response bytes. If None, returns raw bytes.") timeout: int = 30000 auth: None = None + + +class UDPProviderSerializer(Serializer[UDPProvider]): + def to_dict(self, obj: UDPProvider) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> UDPProvider: + try: + return UDPProvider.model_validate(data) + except Exception as e: + raise UtcpSerializerValidationError( + f"Invalid UDPProvider: {e}\n{traceback.format_exc()}" + ) diff --git a/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py b/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py index 8d4d404..b59ef37 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py @@ -9,14 +9,17 @@ import traceback from typing import Dict, Any, List, Optional, Callable, Union -from utcp.client.client_transport_interface import ClientTransportInterface -from utcp.shared.provider import Provider, UDPProvider -from utcp.shared.tool import Tool +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp_socket.udp_call_template import UDPProvider, UDPProviderSerializer +from utcp.data.tool import Tool +from utcp.data.call_template import CallTemplate, CallTemplateSerializer +from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.utcp_manual import UtcpManual import logging logger = logging.getLogger(__name__) -class UDPTransport(ClientTransportInterface): +class UDPTransport(CommunicationProtocol): """Transport implementation for UDP-based tool providers. This transport communicates with tools over UDP sockets. It supports: @@ -80,6 +83,38 @@ def _format_tool_call_message( else: # Default to JSON format return json.dumps(tool_args) + + def _ensure_tool_call_template(self, tool_data: Dict[str, Any], manual_call_template: UDPProvider) -> Dict[str, Any]: + """Normalize tool definition to include a valid 'tool_call_template'. + + - If 'tool_call_template' exists, validate it. + - Else if legacy 'tool_provider' exists, convert using UDPProviderSerializer. + - Else default to the provided manual_call_template. + """ + normalized = dict(tool_data) + try: + if "tool_call_template" in normalized and normalized["tool_call_template"] is not None: + # Validate via generic CallTemplate serializer (type-dispatched) + try: + ctpl = CallTemplateSerializer().validate_dict(normalized["tool_call_template"]) # type: ignore + normalized["tool_call_template"] = ctpl + except Exception: + # Fallback to manual template if validation fails + normalized["tool_call_template"] = manual_call_template + elif "tool_provider" in normalized and normalized["tool_provider"] is not None: + # Convert legacy provider -> call template + try: + ctpl = UDPProviderSerializer().validate_dict(normalized["tool_provider"]) # type: ignore + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = ctpl + except Exception: + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = manual_call_template + else: + normalized["tool_call_template"] = manual_call_template + except Exception: + normalized["tool_call_template"] = manual_call_template + return normalized async def _send_udp_message( self, @@ -202,125 +237,89 @@ def _send_only(): self._log_error(f"Error sending UDP message (no response): {traceback.format_exc()}") raise - async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: - """Register a UDP provider and discover its tools. - - Sends a discovery message to the UDP provider and parses the response. - - Args: - manual_provider: The UDPProvider to register - - Returns: - List of tools discovered from the UDP provider - - Raises: - ValueError: If provider is not a UDPProvider - """ - if not isinstance(manual_provider, UDPProvider): + async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult: + """Register a UDP manual and discover its tools.""" + if not isinstance(manual_call_template, UDPProvider): raise ValueError("UDPTransport can only be used with UDPProvider") - self._log_info(f"Registering UDP provider '{manual_provider.name}' at {manual_provider.host}:{manual_provider.port}") + self._log_info(f"Registering UDP provider '{manual_call_template.name}' at {manual_call_template.host}:{manual_call_template.port}") try: - # Send discovery message - discovery_message = json.dumps({ - "type": "utcp" - }) - + discovery_message = json.dumps({"type": "utcp"}) response = await self._send_udp_message( - manual_provider.host, - manual_provider.port, + manual_call_template.host, + manual_call_template.port, discovery_message, - manual_provider.timeout / 1000.0, # Convert ms to seconds - manual_provider.number_of_response_datagrams, - manual_provider.response_byte_format + manual_call_template.timeout / 1000.0, + manual_call_template.number_of_response_datagrams, + manual_call_template.response_byte_format ) - - # Parse response try: - # Handle bytes response by trying to decode as UTF-8 for JSON parsing - if isinstance(response, bytes): - response_str = response.decode('utf-8') - else: - response_str = response - + response_str = response.decode('utf-8') if isinstance(response, bytes) else response response_data = json.loads(response_str) - - # Check if response contains tools + tools: List[Tool] = [] if isinstance(response_data, dict) and 'tools' in response_data: tools_data = response_data['tools'] - - # Parse tools - tools = [] for tool_data in tools_data: try: - tool = Tool(**tool_data) + normalized = self._ensure_tool_call_template(tool_data, manual_call_template) + tool = Tool(**normalized) tools.append(tool) - except Exception as e: - self._log_error(f"Invalid tool definition in UDP provider '{manual_provider.name}': {traceback.format_exc()}") + except Exception: + self._log_error(f"Invalid tool definition in UDP provider '{manual_call_template.name}': {traceback.format_exc()}") continue - - self._log_info(f"Discovered {len(tools)} tools from UDP provider '{manual_provider.name}'") - return tools + self._log_info(f"Discovered {len(tools)} tools from UDP provider '{manual_call_template.name}'") else: - self._log_info(f"No tools found in UDP provider '{manual_provider.name}' response") - return [] - + self._log_info(f"No tools found in UDP provider '{manual_call_template.name}' response") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=tools) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=True, + errors=[] + ) except json.JSONDecodeError as e: - self._log_error(f"Invalid JSON response from UDP provider '{manual_provider.name}': {traceback.format_exc()}") - return [] - + self._log_error(f"Invalid JSON response from UDP provider '{manual_call_template.name}': {traceback.format_exc()}") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=False, + errors=[str(e)] + ) except Exception as e: - self._log_error(f"Error registering UDP provider '{manual_provider.name}': {traceback.format_exc()}") - return [] + self._log_error(f"Error registering UDP provider '{manual_call_template.name}': {traceback.format_exc()}") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=False, + errors=[str(e)] + ) - async def deregister_tool_provider(self, manual_provider: Provider) -> None: - """Deregister a UDP provider. - - This is a no-op for UDP providers since they are stateless. - - Args: - manual_provider: The provider to deregister - """ - if not isinstance(manual_provider, UDPProvider): + async def deregister_manual(self, caller, manual_call_template: CallTemplate) -> None: + if not isinstance(manual_call_template, UDPProvider): raise ValueError("UDPTransport can only be used with UDPProvider") - - self._log_info(f"Deregistering UDP provider '{manual_provider.name}' (no-op)") + self._log_info(f"Deregistering UDP provider '{manual_call_template.name}' (no-op)") - async def call_tool(self, tool_name: str, tool_args: Dict[str, Any], tool_provider: Provider) -> Any: - """Call a UDP tool. - - Sends a tool call message to the UDP provider and returns the response. - - Args: - tool_name: Name of the tool to call - arguments: Arguments for the tool call - tool_provider: The UDPProvider containing the tool - - Returns: - The response from the UDP tool - - Raises: - ValueError: If provider is not a UDPProvider - """ - if not isinstance(tool_provider, UDPProvider): + async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate) -> Any: + if not isinstance(tool_call_template, UDPProvider): raise ValueError("UDPTransport can only be used with UDPProvider") - - self._log_info(f"Calling UDP tool '{tool_name}' on provider '{tool_provider.name}'") - + self._log_info(f"Calling UDP tool '{tool_name}' on provider '{tool_call_template.name}'") try: - tool_call_message = self._format_tool_call_message(tool_args, tool_provider) - + tool_call_message = self._format_tool_call_message(tool_args, tool_call_template) response = await self._send_udp_message( - tool_provider.host, - tool_provider.port, + tool_call_template.host, + tool_call_template.port, tool_call_message, - tool_provider.timeout / 1000.0, # Convert ms to seconds - tool_provider.number_of_response_datagrams, - tool_provider.response_byte_format + tool_call_template.timeout / 1000.0, + tool_call_template.number_of_response_datagrams, + tool_call_template.response_byte_format ) return response - except Exception as e: self._log_error(f"Error calling UDP tool '{tool_name}': {traceback.format_exc()}") raise + + async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate): + yield await self.call_tool(caller, tool_name, tool_args, tool_call_template) diff --git a/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py b/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py new file mode 100644 index 0000000..1b6ffb2 --- /dev/null +++ b/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py @@ -0,0 +1,178 @@ +import asyncio +import json +import pytest + +from utcp_socket.tcp_communication_protocol import TCPTransport +from utcp_socket.tcp_call_template import TCPProvider + + +async def start_tcp_server(): + """Start a simple TCP server that sends a mutable JSON object then closes.""" + response_container = {"bytes": b""} + + async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + try: + # Read any incoming data to simulate request handling + await reader.read(1024) + except Exception: + pass + # Send response and close connection + writer.write(response_container["bytes"]) + await writer.drain() + try: + writer.close() + await writer.wait_closed() + except Exception: + pass + + server = await asyncio.start_server(handle, host="127.0.0.1", port=0) + port = server.sockets[0].getsockname()[1] + + def set_response(obj): + response_container["bytes"] = json.dumps(obj).encode("utf-8") + + return server, port, set_response + + +@pytest.mark.asyncio +async def test_register_manual_converts_legacy_tool_provider_tcp(): + """When manual returns legacy tool_provider, it is converted to tool_call_template.""" + # Start server and configure response after obtaining port + server, port, set_response = await start_tcp_server() + set_response({ + "tools": [ + { + "name": "tcp_tool", + "description": "Echo over TCP", + "inputs": {}, + "outputs": {}, + "tool_provider": { + "call_template_type": "tcp", + "name": "tcp-executor", + "host": "127.0.0.1", + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "framing_strategy": "stream", + "timeout": 2000 + } + } + ] + }) + + try: + provider = TCPProvider( + name="tcp-provider", + host="127.0.0.1", + port=port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=2000 + ) + transport_client = TCPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert result.manual is not None + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "tcp" + assert isinstance(tool.tool_call_template, TCPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_register_manual_validates_provided_tool_call_template_tcp(): + """When manual provides tool_call_template, it is validated and preserved.""" + server, port, set_response = await start_tcp_server() + set_response({ + "tools": [ + { + "name": "tcp_tool", + "description": "Echo over TCP", + "inputs": {}, + "outputs": {}, + "tool_call_template": { + "call_template_type": "tcp", + "name": "tcp-executor", + "host": "127.0.0.1", + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "framing_strategy": "stream", + "timeout": 2000 + } + } + ] + }) + + try: + provider = TCPProvider( + name="tcp-provider", + host="127.0.0.1", + port=port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=2000 + ) + transport_client = TCPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "tcp" + assert isinstance(tool.tool_call_template, TCPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_register_manual_fallbacks_to_manual_template_tcp(): + """When neither tool_provider nor tool_call_template is provided, fall back to manual template.""" + server, port, set_response = await start_tcp_server() + set_response({ + "tools": [ + { + "name": "tcp_tool", + "description": "Echo over TCP", + "inputs": {}, + "outputs": {} + } + ] + }) + + try: + provider = TCPProvider( + name="tcp-provider", + host="127.0.0.1", + port=port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=2000 + ) + transport_client = TCPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "tcp" + assert isinstance(tool.tool_call_template, TCPProvider) + # Should match manual (discovery) provider values + assert tool.tool_call_template.host == provider.host + assert tool.tool_call_template.port == provider.port + assert tool.tool_call_template.name == provider.name + finally: + server.close() + await server.wait_closed() \ No newline at end of file diff --git a/plugins/communication_protocols/socket/tests/test_udp_communication_protocol.py b/plugins/communication_protocols/socket/tests/test_udp_communication_protocol.py new file mode 100644 index 0000000..d6a770c --- /dev/null +++ b/plugins/communication_protocols/socket/tests/test_udp_communication_protocol.py @@ -0,0 +1,176 @@ +import asyncio +import json +import pytest + +from utcp_socket.udp_communication_protocol import UDPTransport +from utcp_socket.udp_call_template import UDPProvider + + +async def start_udp_server(): + """Start a simple UDP server that replies with a mutable JSON payload.""" + loop = asyncio.get_running_loop() + response_container = {"bytes": b""} + + class _Protocol(asyncio.DatagramProtocol): + def __init__(self, container): + self.container = container + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data, addr): + # Always respond with the prepared payload + if self.transport: + self.transport.sendto(self.container["bytes"], addr) + + transport, protocol = await loop.create_datagram_endpoint( + lambda: _Protocol(response_container), local_addr=("127.0.0.1", 0) + ) + port = transport.get_extra_info("socket").getsockname()[1] + + def set_response(obj): + response_container["bytes"] = json.dumps(obj).encode("utf-8") + + return transport, port, set_response + + +@pytest.mark.asyncio +async def test_register_manual_converts_legacy_tool_provider_udp(): + """When manual returns legacy tool_provider, it is converted to tool_call_template.""" + # Start server and configure response after obtaining port + transport, port, set_response = await start_udp_server() + set_response({ + "tools": [ + { + "name": "udp_tool", + "description": "Echo over UDP", + "inputs": {}, + "outputs": {}, + "tool_provider": { + "call_template_type": "udp", + "name": "udp-executor", + "host": "127.0.0.1", + "port": port, + "number_of_response_datagrams": 1, + "request_data_format": "json", + "response_byte_format": "utf-8", + "timeout": 2000 + } + } + ] + }) + + try: + provider = UDPProvider( + name="udp-provider", + host="127.0.0.1", + port=port, + number_of_response_datagrams=1, + request_data_format="json", + response_byte_format="utf-8", + timeout=2000 + ) + transport_client = UDPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert result.manual is not None + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "udp" + assert isinstance(tool.tool_call_template, UDPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + transport.close() + + +@pytest.mark.asyncio +async def test_register_manual_validates_provided_tool_call_template_udp(): + """When manual provides tool_call_template, it is validated and preserved.""" + transport, port, set_response = await start_udp_server() + set_response({ + "tools": [ + { + "name": "udp_tool", + "description": "Echo over UDP", + "inputs": {}, + "outputs": {}, + "tool_call_template": { + "call_template_type": "udp", + "name": "udp-executor", + "host": "127.0.0.1", + "port": port, + "number_of_response_datagrams": 1, + "request_data_format": "json", + "response_byte_format": "utf-8", + "timeout": 2000 + } + } + ] + }) + + try: + provider = UDPProvider( + name="udp-provider", + host="127.0.0.1", + port=port, + number_of_response_datagrams=1, + request_data_format="json", + response_byte_format="utf-8", + timeout=2000 + ) + transport_client = UDPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "udp" + assert isinstance(tool.tool_call_template, UDPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + transport.close() + + +@pytest.mark.asyncio +async def test_register_manual_fallbacks_to_manual_template_udp(): + """When neither tool_provider nor tool_call_template is provided, fall back to manual template.""" + transport, port, set_response = await start_udp_server() + set_response({ + "tools": [ + { + "name": "udp_tool", + "description": "Echo over UDP", + "inputs": {}, + "outputs": {} + } + ] + }) + + try: + provider = UDPProvider( + name="udp-provider", + host="127.0.0.1", + port=port, + number_of_response_datagrams=1, + request_data_format="json", + response_byte_format="utf-8", + timeout=2000 + ) + transport_client = UDPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "udp" + assert isinstance(tool.tool_call_template, UDPProvider) + # Should match manual (discovery) provider values + assert tool.tool_call_template.host == provider.host + assert tool.tool_call_template.port == provider.port + assert tool.tool_call_template.name == provider.name + finally: + transport.close() \ No newline at end of file diff --git a/scripts/socket_sanity.py b/scripts/socket_sanity.py new file mode 100644 index 0000000..40b2c16 --- /dev/null +++ b/scripts/socket_sanity.py @@ -0,0 +1,265 @@ +import sys +import os +import json +import time +import socket +import threading +import asyncio +from pathlib import Path + +# Ensure core and socket plugin sources are on sys.path +ROOT = Path(__file__).resolve().parent.parent +CORE_SRC = ROOT / "core" / "src" +SOCKET_SRC = ROOT / "plugins" / "communication_protocols" / "socket" / "src" +for p in [str(CORE_SRC), str(SOCKET_SRC)]: + if p not in sys.path: + sys.path.insert(0, p) + +from utcp_socket.udp_communication_protocol import UDPTransport +from utcp_socket.tcp_communication_protocol import TCPTransport +from utcp_socket.udp_call_template import UDPProvider +from utcp_socket.tcp_call_template import TCPProvider + +# ------------------------------- +# Mock UDP Server +# ------------------------------- + +def start_udp_server(host: str, port: int): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind((host, port)) + + def run(): + while True: + data, addr = sock.recvfrom(65535) + try: + msg = data.decode("utf-8") + except Exception: + msg = "" + # Handle discovery + try: + parsed = json.loads(msg) + except Exception: + parsed = None + if isinstance(parsed, dict) and parsed.get("type") == "utcp": + manual = { + "utcp_version": "1.0", + "manual_version": "1.0", + "tools": [ + { + "name": "udp.echo", + "description": "Echo UDP args as JSON", + "inputs": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "extra": {"type": "number"} + }, + "required": ["text"] + }, + "outputs": { + "type": "object", + "properties": { + "ok": {"type": "boolean"}, + "echo": {"type": "string"}, + "args": {"type": "object"} + } + }, + "tags": ["socket", "udp"], + "average_response_size": 64, + # Return legacy provider to exercise conversion path + "tool_provider": { + "call_template_type": "udp", + "name": "udp", + "host": host, + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "number_of_response_datagrams": 1, + "timeout": 3000 + } + } + ] + } + payload = json.dumps(manual).encode("utf-8") + sock.sendto(payload, addr) + else: + # Tool call: echo JSON payload + try: + args = json.loads(msg) + except Exception: + args = {"raw": msg} + resp = { + "ok": True, + "echo": args.get("text", ""), + "args": args + } + sock.sendto(json.dumps(resp).encode("utf-8"), addr) + t = threading.Thread(target=run, daemon=True) + t.start() + return t + +# ------------------------------- +# Mock TCP Server (delimiter-based) +# ------------------------------- + +def start_tcp_server(host: str, port: int, delimiter: str = "\n"): + delim_bytes = delimiter.encode("utf-8") + + def handle_client(conn: socket.socket, addr): + try: + # Read until delimiter + buf = b"" + while True: + chunk = conn.recv(1) + if not chunk: + break + buf += chunk + if buf.endswith(delim_bytes): + break + msg = buf[:-len(delim_bytes)].decode("utf-8") if buf.endswith(delim_bytes) else buf.decode("utf-8") + # Discovery + parsed = None + try: + parsed = json.loads(msg) + except Exception: + pass + if isinstance(parsed, dict) and parsed.get("type") == "utcp": + manual = { + "utcp_version": "1.0", + "manual_version": "1.0", + "tools": [ + { + "name": "tcp.echo", + "description": "Echo TCP args as JSON", + "inputs": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "extra": {"type": "number"} + }, + "required": ["text"] + }, + "outputs": { + "type": "object", + "properties": { + "ok": {"type": "boolean"}, + "echo": {"type": "string"}, + "args": {"type": "object"} + } + }, + "tags": ["socket", "tcp"], + "average_response_size": 64, + # Legacy provider to exercise conversion + "tool_provider": { + "call_template_type": "tcp", + "name": "tcp", + "host": host, + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "framing_strategy": "delimiter", + "message_delimiter": "\\n", + "timeout": 3000 + } + } + ] + } + payload = json.dumps(manual).encode("utf-8") + delim_bytes + conn.sendall(payload) + else: + # Tool call: echo JSON payload + try: + args = json.loads(msg) + except Exception: + args = {"raw": msg} + resp = { + "ok": True, + "echo": args.get("text", ""), + "args": args + } + conn.sendall(json.dumps(resp).encode("utf-8") + delim_bytes) + finally: + try: + conn.shutdown(socket.SHUT_RDWR) + except Exception: + pass + conn.close() + + def run(): + srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + srv.bind((host, port)) + srv.listen(5) + while True: + conn, addr = srv.accept() + threading.Thread(target=handle_client, args=(conn, addr), daemon=True).start() + + t = threading.Thread(target=run, daemon=True) + t.start() + return t + +# ------------------------------- +# Sanity test runner +# ------------------------------- + +async def run_sanity(): + udp_host, udp_port = "127.0.0.1", 23456 + tcp_host, tcp_port = "127.0.0.1", 23457 + + # Start servers + start_udp_server(udp_host, udp_port) + start_tcp_server(tcp_host, tcp_port, delimiter="\n") + await asyncio.sleep(0.2) # small delay to ensure servers are listening + + # Transports + udp_transport = UDPTransport() + tcp_transport = TCPTransport() + + # Register manuals + udp_manual_template = UDPProvider(name="udp", host=udp_host, port=udp_port, request_data_format="json", response_byte_format="utf-8", number_of_response_datagrams=1, timeout=3000) + tcp_manual_template = TCPProvider(name="tcp", host=tcp_host, port=tcp_port, request_data_format="json", response_byte_format="utf-8", framing_strategy="delimiter", message_delimiter="\n", timeout=3000) + + udp_reg = await udp_transport.register_manual(None, udp_manual_template) + tcp_reg = await tcp_transport.register_manual(None, tcp_manual_template) + + print("UDP register success:", udp_reg.success, "tools:", len(udp_reg.manual.tools)) + print("TCP register success:", tcp_reg.success, "tools:", len(tcp_reg.manual.tools)) + + assert udp_reg.success and len(udp_reg.manual.tools) == 1 + assert tcp_reg.success and len(tcp_reg.manual.tools) == 1 + + # Verify tool_call_template present + assert udp_reg.manual.tools[0].tool_call_template.call_template_type == "udp" + assert tcp_reg.manual.tools[0].tool_call_template.call_template_type == "tcp" + + # Call tools + udp_result = await udp_transport.call_tool(None, "udp.echo", {"text": "hello", "extra": 42}, udp_reg.manual.tools[0].tool_call_template) + tcp_result = await tcp_transport.call_tool(None, "tcp.echo", {"text": "world", "extra": 99}, tcp_reg.manual.tools[0].tool_call_template) + + print("UDP call result:", udp_result) + print("TCP call result:", tcp_result) + + # Basic assertions on response shape + def ensure_dict(s): + if isinstance(s, (bytes, bytearray)): + try: + s = s.decode("utf-8") + except Exception: + return {} + if isinstance(s, str): + try: + return json.loads(s) + except Exception: + return {"raw": s} + return s if isinstance(s, dict) else {} + + udp_resp = ensure_dict(udp_result) + tcp_resp = ensure_dict(tcp_result) + + assert udp_resp.get("ok") is True and udp_resp.get("echo") == "hello" + assert tcp_resp.get("ok") is True and tcp_resp.get("echo") == "world" + + print("Sanity passed: UDP/TCP discovery and calls work with tool_call_template normalization.") + +if __name__ == "__main__": + asyncio.run(run_sanity()) \ No newline at end of file diff --git a/socket_plugin_test.py b/socket_plugin_test.py new file mode 100644 index 0000000..c03d2a9 --- /dev/null +++ b/socket_plugin_test.py @@ -0,0 +1,40 @@ +import asyncio +import sys +from pathlib import Path + +# Add core and plugin src paths so imports work without installing packages +core_src = Path(__file__).parent / "core" / "src" +socket_src = Path(__file__).parent / "plugins" / "communication_protocols" / "socket" / "src" +sys.path.insert(0, str(core_src.resolve())) +sys.path.insert(0, str(socket_src.resolve())) + +from utcp.plugins.plugin_loader import ensure_plugins_initialized +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp.data.call_template import CallTemplateSerializer +from utcp_socket import register as register_socket + +async def main(): + # Manually register the socket plugin + register_socket() + + # Load core plugins (auth, repo, search, post-processors) + ensure_plugins_initialized() + + # 1. Check if communication protocols are registered + registered_protocols = CommunicationProtocol.communication_protocols + print(f"Registered communication protocols: {list(registered_protocols.keys())}") + assert "tcp" in registered_protocols, "TCP communication protocol not registered" + assert "udp" in registered_protocols, "UDP communication protocol not registered" + print("āœ… TCP and UDP communication protocols are registered.") + + # 2. Check if call templates are registered + registered_serializers = CallTemplateSerializer.call_template_serializers + print(f"Registered call template serializers: {list(registered_serializers.keys())}") + assert "tcp" in registered_serializers, "TCP call template serializer not registered" + assert "udp" in registered_serializers, "UDP call template serializer not registered" + print("āœ… TCP and UDP call template serializers are registered.") + + print("\nšŸŽ‰ Socket plugin sanity check passed! šŸŽ‰") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file