Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
01b421e
fix: change "client/test_client.py" to "client/test_client_factory.py…
sokoliva Oct 23, 2025
697438f
feat: Add client-side extension support
sokoliva Oct 28, 2025
17d30a4
Merge branch 'main' into Extension-support-for-Client
sokoliva Oct 28, 2025
860f2d5
refactor: remove redundant tests for send_message without extensions …
sokoliva Oct 29, 2025
511de38
refactor: reorder parameters in JsonRpcTransport and RestTransport co…
sokoliva Oct 29, 2025
6e80123
refactor: reorder parameters in JsonRpcTransport and RestTransport co…
sokoliva Oct 29, 2025
fd5986a
Merge branch 'Extension-support-for-Client' of https://github.com/sok…
sokoliva Oct 29, 2025
31a4581
Fix Parsing Bug in _update_extension_header method
sokoliva Oct 29, 2025
5fc530e
Fix Parsing Bug in _update_extension_header method
sokoliva Oct 29, 2025
3144f43
Merge branch 'Extension-support-for-Client' of https://github.com/sok…
sokoliva Oct 29, 2025
97eec52
refactor: streamline extension header handling in JsonRpcTransport an…
sokoliva Oct 29, 2025
caba0a2
refactor: rename client_extensions to extensions in JsonRpcTransport …
sokoliva Oct 30, 2025
28b1d53
feat: move common functions for managing HTTP extension headers to ut…
sokoliva Nov 3, 2025
270d6e7
Remove extensions from grpc methog get_card
sokoliva Nov 3, 2025
a9aa9ee
feat: add support for extensions in Client and BaseClient, update tra…
sokoliva Nov 3, 2025
948d3f3
fix: correct order of extension header updates in update_extension_he…
sokoliva Nov 3, 2025
4073c0b
refactor: streamline extension handling in BaseClient and GrpcTranspo…
sokoliva Nov 4, 2025
c5cea2c
Move transport tests from tests/client to tests/client/transport. Add…
sokoliva Nov 5, 2025
6e856d5
feat: enhance GrpcTransport to manage extensions in metadata and upda…
sokoliva Nov 6, 2025
edd7982
refactor: remove unused __merge_extensions function from utils.py
sokoliva Nov 6, 2025
48ea2ae
feat: update extension handling in transports and tests, migrate util…
sokoliva Nov 12, 2025
ffc0279
Merge remote-tracking branch 'origin/main' into Extension-support-for…
sokoliva Nov 12, 2025
5b47562
fix(client): clarify the purpose of the extensions parameter in Clien…
sokoliva Nov 12, 2025
a2eeb7b
feat: enhance extension handling across client and transport layers
sokoliva Nov 13, 2025
0746541
feat: add extensions parameter documentation in ClientFactory and upd…
sokoliva Nov 13, 2025
f5443d6
Merge branch 'main' into Extension-support-for-Client
sokoliva Nov 13, 2025
1337dcf
refactor: streamline extension handling in transport classes and upda…
sokoliva Nov 14, 2025
7f4ba58
Merge remote-tracking branch 'refs/remotes/upstream/Extension-support…
sokoliva Nov 14, 2025
a97c5b3
Merge branch 'main' into Extension-support-for-Client
sokoliva Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def send_message(
*,
context: ClientCallContext | None = None,
request_metadata: dict[str, Any] | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent | Message]:
"""Sends a message to the agent.

Expand All @@ -60,6 +61,7 @@ async def send_message(
request: The message to send to the agent.
context: The client call context.
request_metadata: Extensions Metadata attached to the request.
extensions: List of extensions to be activated.

Yields:
An async iterator of `ClientEvent` or a final `Message` response.
Expand All @@ -79,7 +81,7 @@ async def send_message(

if not self._config.streaming or not self._card.capabilities.streaming:
response = await self._transport.send_message(
params, context=context
params, context=context, extensions=extensions
)
result = (
(response, None) if isinstance(response, Task) else response
Expand All @@ -89,7 +91,9 @@ async def send_message(
return

tracker = ClientTaskManager()
stream = self._transport.send_message_streaming(params, context=context)
stream = self._transport.send_message_streaming(
params, context=context, extensions=extensions
)

first_event = await anext(stream)
# The response from a server may be either exactly one Message or a
Expand Down Expand Up @@ -126,74 +130,91 @@ async def get_task(
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task.

Args:
request: The `TaskQueryParams` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.

Returns:
A `Task` object representing the current state of the task.
"""
return await self._transport.get_task(request, context=context)
return await self._transport.get_task(
request, context=context, extensions=extensions
)

async def cancel_task(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task.

Args:
request: The `TaskIdParams` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.

Returns:
A `Task` object containing the updated task status.
"""
return await self._transport.cancel_task(request, context=context)
return await self._transport.cancel_task(
request, context=context, extensions=extensions
)

async def set_task_callback(
self,
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task.

Args:
request: The `TaskPushNotificationConfig` object with the new configuration.
context: The client call context.
extensions: List of extensions to be activated.

Returns:
The created or updated `TaskPushNotificationConfig` object.
"""
return await self._transport.set_task_callback(request, context=context)
return await self._transport.set_task_callback(
request, context=context, extensions=extensions
)

async def get_task_callback(
self,
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task.

Args:
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
context: The client call context.
extensions: List of extensions to be activated.

Returns:
A `TaskPushNotificationConfig` object containing the configuration.
"""
return await self._transport.get_task_callback(request, context=context)
return await self._transport.get_task_callback(
request, context=context, extensions=extensions
)

async def resubscribe(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream.

Expand All @@ -202,6 +223,7 @@ async def resubscribe(
Args:
request: Parameters to identify the task to resubscribe to.
context: The client call context.
extensions: List of extensions to be activated.

Yields:
An async iterator of `ClientEvent` objects.
Expand All @@ -219,12 +241,15 @@ async def resubscribe(
# we should never see Message updates, despite the typing of the service
# definition indicating it may be possible.
async for event in self._transport.resubscribe(
request, context=context
request, context=context, extensions=extensions
):
yield await self._process_response(tracker, event)

async def get_card(
self, *, context: ClientCallContext | None = None
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AgentCard:
"""Retrieves the agent's card.

Expand All @@ -233,11 +258,14 @@ async def get_card(

Args:
context: The client call context.
extensions: List of extensions to be activated.

Returns:
The `AgentCard` for the agent.
"""
card = await self._transport.get_card(context=context)
card = await self._transport.get_card(
context=context, extensions=extensions
)
self._card = card
return card

Expand Down
14 changes: 13 additions & 1 deletion src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class ClientConfig:
)
"""Push notification callbacks to use for every request."""

extensions: list[str] = dataclasses.field(default_factory=list)
"""A list of extension URIs the client supports."""


UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None
# Alias for emitted events from client
Expand Down Expand Up @@ -111,6 +114,7 @@ async def send_message(
*,
context: ClientCallContext | None = None,
request_metadata: dict[str, Any] | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent | Message]:
"""Sends a message to the server.

Expand All @@ -129,6 +133,7 @@ async def get_task(
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""

Expand All @@ -138,6 +143,7 @@ async def cancel_task(
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""

Expand All @@ -147,6 +153,7 @@ async def set_task_callback(
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""

Expand All @@ -156,6 +163,7 @@ async def get_task_callback(
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""

Expand All @@ -165,14 +173,18 @@ async def resubscribe(
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream."""
return
yield

@abstractmethod
async def get_card(
self, *, context: ClientCallContext | None = None
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""

Expand Down
19 changes: 17 additions & 2 deletions src/a2a/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _register_defaults(
card,
url,
interceptors,
config.extensions or None,
),
)
if TransportProtocol.http_json in supported:
Expand All @@ -90,6 +91,7 @@ def _register_defaults(
card,
url,
interceptors,
config.extensions or None,
),
)
if TransportProtocol.grpc in supported:
Expand All @@ -113,6 +115,7 @@ async def connect( # noqa: PLR0913
relative_card_path: str | None = None,
resolver_http_kwargs: dict[str, Any] | None = None,
extra_transports: dict[str, TransportProducer] | None = None,
extensions: list[str] | None = None,
) -> Client:
"""Convenience method for constructing a client.

Expand Down Expand Up @@ -142,6 +145,7 @@ async def connect( # noqa: PLR0913
A2AAgentCardResolver.get_agent_card as the http_kwargs parameter.
extra_transports: Additional transport protocols to enable when
constructing the client.
extensions: List of extensions to be activated.

Returns:
A `Client` object.
Expand All @@ -166,7 +170,7 @@ async def connect( # noqa: PLR0913
factory = cls(client_config)
for label, generator in (extra_transports or {}).items():
factory.register(label, generator)
return factory.create(card, consumers, interceptors)
return factory.create(card, consumers, interceptors, extensions)

def register(self, label: str, generator: TransportProducer) -> None:
"""Register a new transport producer for a given transport label."""
Expand All @@ -177,6 +181,7 @@ def create(
card: AgentCard,
consumers: list[Consumer] | None = None,
interceptors: list[ClientCallInterceptor] | None = None,
extensions: list[str] | None = None,
) -> Client:
"""Create a new `Client` for the provided `AgentCard`.

Expand All @@ -186,6 +191,7 @@ def create(
interceptors: A list of interceptors to use for each request. These
are used for things like attaching credentials or http headers
to all outbound requests.
extensions: List of extensions to be activated.

Returns:
A `Client` object.
Expand Down Expand Up @@ -226,12 +232,21 @@ def create(
if consumers:
all_consumers.extend(consumers)

all_extensions = self._config.extensions.copy()
if extensions:
all_extensions.extend(extensions)
self._config.extensions = all_extensions

transport = self._registry[transport_protocol](
card, transport_url, self._config, interceptors or []
)

return BaseClient(
card, self._config, transport, all_consumers, interceptors or []
card,
self._config,
transport,
all_consumers,
interceptors or [],
)


Expand Down
8 changes: 8 additions & 0 deletions src/a2a/client/transports/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def send_message(
request: MessageSendParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task | Message:
"""Sends a non-streaming message request to the agent."""

Expand All @@ -34,6 +35,7 @@ async def send_message_streaming(
request: MessageSendParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncGenerator[
Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
]:
Expand All @@ -47,6 +49,7 @@ async def get_task(
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""

Expand All @@ -56,6 +59,7 @@ async def cancel_task(
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""

Expand All @@ -65,6 +69,7 @@ async def set_task_callback(
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""

Expand All @@ -74,6 +79,7 @@ async def get_task_callback(
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""

Expand All @@ -83,6 +89,7 @@ async def resubscribe(
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncGenerator[
Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
]:
Expand All @@ -95,6 +102,7 @@ async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AgentCard:
"""Retrieves the AgentCard."""

Expand Down
Loading
Loading