|
1 | 1 | """Tests for the ClientFactory.""" |
2 | 2 |
|
| 3 | +from unittest.mock import AsyncMock, MagicMock, patch |
| 4 | + |
3 | 5 | import httpx |
4 | 6 | import pytest |
5 | 7 |
|
@@ -103,3 +105,158 @@ def test_client_factory_no_compatible_transport(base_agent_card: AgentCard): |
103 | 105 | factory = ClientFactory(config) |
104 | 106 | with pytest.raises(ValueError, match='no compatible transports found'): |
105 | 107 | factory.create(base_agent_card) |
| 108 | + |
| 109 | + |
| 110 | +@pytest.mark.asyncio |
| 111 | +async def test_client_factory_connect_with_agent_card( |
| 112 | + base_agent_card: AgentCard, |
| 113 | +): |
| 114 | + """Verify that connect works correctly when provided with an AgentCard.""" |
| 115 | + client = await ClientFactory.connect(base_agent_card) |
| 116 | + assert isinstance(client._transport, JsonRpcTransport) |
| 117 | + assert client._transport.url == 'http://primary-url.com' |
| 118 | + |
| 119 | + |
| 120 | +@pytest.mark.asyncio |
| 121 | +async def test_client_factory_connect_with_url(base_agent_card: AgentCard): |
| 122 | + """Verify that connect works correctly when provided with a URL.""" |
| 123 | + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: |
| 124 | + mock_resolver.return_value.get_agent_card = AsyncMock( |
| 125 | + return_value=base_agent_card |
| 126 | + ) |
| 127 | + |
| 128 | + agent_url = 'http://example.com' |
| 129 | + client = await ClientFactory.connect(agent_url) |
| 130 | + |
| 131 | + mock_resolver.assert_called_once() |
| 132 | + assert mock_resolver.call_args[0][1] == agent_url |
| 133 | + mock_resolver.return_value.get_agent_card.assert_awaited_once() |
| 134 | + |
| 135 | + assert isinstance(client._transport, JsonRpcTransport) |
| 136 | + assert client._transport.url == 'http://primary-url.com' |
| 137 | + |
| 138 | + |
| 139 | +@pytest.mark.asyncio |
| 140 | +async def test_client_factory_connect_with_url_and_client_config( |
| 141 | + base_agent_card: AgentCard, |
| 142 | +): |
| 143 | + """Verify connect with a URL and a pre-configured httpx client.""" |
| 144 | + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: |
| 145 | + mock_resolver.return_value.get_agent_card = AsyncMock( |
| 146 | + return_value=base_agent_card |
| 147 | + ) |
| 148 | + |
| 149 | + agent_url = 'http://example.com' |
| 150 | + mock_httpx_client = httpx.AsyncClient() |
| 151 | + config = ClientConfig(httpx_client=mock_httpx_client) |
| 152 | + |
| 153 | + client = await ClientFactory.connect(agent_url, client_config=config) |
| 154 | + |
| 155 | + mock_resolver.assert_called_once_with(mock_httpx_client, agent_url) |
| 156 | + mock_resolver.return_value.get_agent_card.assert_awaited_once() |
| 157 | + |
| 158 | + assert isinstance(client._transport, JsonRpcTransport) |
| 159 | + assert client._transport.url == 'http://primary-url.com' |
| 160 | + |
| 161 | + |
| 162 | +@pytest.mark.asyncio |
| 163 | +async def test_client_factory_connect_with_resolver_args( |
| 164 | + base_agent_card: AgentCard, |
| 165 | +): |
| 166 | + """Verify connect passes resolver arguments correctly.""" |
| 167 | + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: |
| 168 | + mock_resolver.return_value.get_agent_card = AsyncMock( |
| 169 | + return_value=base_agent_card |
| 170 | + ) |
| 171 | + |
| 172 | + agent_url = 'http://example.com' |
| 173 | + relative_path = '/card' |
| 174 | + http_kwargs = {'headers': {'X-Test': 'true'}} |
| 175 | + |
| 176 | + # The resolver args are only passed if an httpx_client is provided in config |
| 177 | + config = ClientConfig(httpx_client=httpx.AsyncClient()) |
| 178 | + |
| 179 | + await ClientFactory.connect( |
| 180 | + agent_url, |
| 181 | + client_config=config, |
| 182 | + relative_card_path=relative_path, |
| 183 | + resolver_http_kwargs=http_kwargs, |
| 184 | + ) |
| 185 | + |
| 186 | + mock_resolver.return_value.get_agent_card.assert_awaited_once_with( |
| 187 | + relative_card_path=relative_path, |
| 188 | + http_kwargs=http_kwargs, |
| 189 | + ) |
| 190 | + |
| 191 | + |
| 192 | +@pytest.mark.asyncio |
| 193 | +async def test_client_factory_connect_resolver_args_without_client( |
| 194 | + base_agent_card: AgentCard, |
| 195 | +): |
| 196 | + """Verify resolver args are ignored if no httpx_client is provided.""" |
| 197 | + with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver: |
| 198 | + mock_resolver.return_value.get_agent_card = AsyncMock( |
| 199 | + return_value=base_agent_card |
| 200 | + ) |
| 201 | + |
| 202 | + agent_url = 'http://example.com' |
| 203 | + relative_path = '/card' |
| 204 | + http_kwargs = {'headers': {'X-Test': 'true'}} |
| 205 | + |
| 206 | + await ClientFactory.connect( |
| 207 | + agent_url, |
| 208 | + relative_card_path=relative_path, |
| 209 | + resolver_http_kwargs=http_kwargs, |
| 210 | + ) |
| 211 | + |
| 212 | + mock_resolver.return_value.get_agent_card.assert_awaited_once_with( |
| 213 | + relative_card_path=relative_path, |
| 214 | + http_kwargs=http_kwargs, |
| 215 | + ) |
| 216 | + |
| 217 | + |
| 218 | +@pytest.mark.asyncio |
| 219 | +async def test_client_factory_connect_with_extra_transports( |
| 220 | + base_agent_card: AgentCard, |
| 221 | +): |
| 222 | + """Verify that connect can register and use extra transports.""" |
| 223 | + |
| 224 | + class CustomTransport: |
| 225 | + pass |
| 226 | + |
| 227 | + def custom_transport_producer(*args, **kwargs): |
| 228 | + return CustomTransport() |
| 229 | + |
| 230 | + base_agent_card.preferred_transport = 'custom' |
| 231 | + base_agent_card.url = 'custom://foo' |
| 232 | + |
| 233 | + config = ClientConfig(supported_transports=['custom']) |
| 234 | + |
| 235 | + client = await ClientFactory.connect( |
| 236 | + base_agent_card, |
| 237 | + client_config=config, |
| 238 | + extra_transports={'custom': custom_transport_producer}, |
| 239 | + ) |
| 240 | + |
| 241 | + assert isinstance(client._transport, CustomTransport) |
| 242 | + |
| 243 | + |
| 244 | +@pytest.mark.asyncio |
| 245 | +async def test_client_factory_connect_with_consumers_and_interceptors( |
| 246 | + base_agent_card: AgentCard, |
| 247 | +): |
| 248 | + """Verify consumers and interceptors are passed through correctly.""" |
| 249 | + consumer1 = MagicMock() |
| 250 | + interceptor1 = MagicMock() |
| 251 | + |
| 252 | + with patch('a2a.client.client_factory.BaseClient') as mock_base_client: |
| 253 | + await ClientFactory.connect( |
| 254 | + base_agent_card, |
| 255 | + consumers=[consumer1], |
| 256 | + interceptors=[interceptor1], |
| 257 | + ) |
| 258 | + |
| 259 | + mock_base_client.assert_called_once() |
| 260 | + call_args = mock_base_client.call_args[0] |
| 261 | + assert call_args[3] == [consumer1] |
| 262 | + assert call_args[4] == [interceptor1] |
0 commit comments