Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
01b421e
fix: change "client/test_client.py" to "client/test_client_factory.py…
sokoliva Oct 23, 2025
697438f
feat: Add client-side extension support
sokoliva Oct 28, 2025
17d30a4
Merge branch 'main' into Extension-support-for-Client
sokoliva Oct 28, 2025
860f2d5
refactor: remove redundant tests for send_message without extensions …
sokoliva Oct 29, 2025
511de38
refactor: reorder parameters in JsonRpcTransport and RestTransport co…
sokoliva Oct 29, 2025
6e80123
refactor: reorder parameters in JsonRpcTransport and RestTransport co…
sokoliva Oct 29, 2025
fd5986a
Merge branch 'Extension-support-for-Client' of https://github.com/sok…
sokoliva Oct 29, 2025
31a4581
Fix Parsing Bug in _update_extension_header method
sokoliva Oct 29, 2025
5fc530e
Fix Parsing Bug in _update_extension_header method
sokoliva Oct 29, 2025
3144f43
Merge branch 'Extension-support-for-Client' of https://github.com/sok…
sokoliva Oct 29, 2025
97eec52
refactor: streamline extension header handling in JsonRpcTransport an…
sokoliva Oct 29, 2025
caba0a2
refactor: rename client_extensions to extensions in JsonRpcTransport …
sokoliva Oct 30, 2025
28b1d53
feat: move common functions for managing HTTP extension headers to ut…
sokoliva Nov 3, 2025
270d6e7
Remove extensions from grpc methog get_card
sokoliva Nov 3, 2025
a9aa9ee
feat: add support for extensions in Client and BaseClient, update tra…
sokoliva Nov 3, 2025
948d3f3
fix: correct order of extension header updates in update_extension_he…
sokoliva Nov 3, 2025
4073c0b
refactor: streamline extension handling in BaseClient and GrpcTranspo…
sokoliva Nov 4, 2025
c5cea2c
Move transport tests from tests/client to tests/client/transport. Add…
sokoliva Nov 5, 2025
6e856d5
feat: enhance GrpcTransport to manage extensions in metadata and upda…
sokoliva Nov 6, 2025
edd7982
refactor: remove unused __merge_extensions function from utils.py
sokoliva Nov 6, 2025
48ea2ae
feat: update extension handling in transports and tests, migrate util…
sokoliva Nov 12, 2025
ffc0279
Merge remote-tracking branch 'origin/main' into Extension-support-for…
sokoliva Nov 12, 2025
5b47562
fix(client): clarify the purpose of the extensions parameter in Clien…
sokoliva Nov 12, 2025
a2eeb7b
feat: enhance extension handling across client and transport layers
sokoliva Nov 13, 2025
0746541
feat: add extensions parameter documentation in ClientFactory and upd…
sokoliva Nov 13, 2025
f5443d6
Merge branch 'main' into Extension-support-for-Client
sokoliva Nov 13, 2025
1337dcf
refactor: streamline extension handling in transport classes and upda…
sokoliva Nov 14, 2025
7f4ba58
Merge remote-tracking branch 'refs/remotes/upstream/Extension-support…
sokoliva Nov 14, 2025
a97c5b3
Merge branch 'main' into Extension-support-for-Client
sokoliva Nov 14, 2025
16ee453
add integration test for extensions. Add a test case to test_common.p…
sokoliva Nov 17, 2025
674e840
Merge branch 'main' into Extension-support-for-Client
sokoliva Nov 17, 2025
7fb55d0
Merge remote-tracking branch 'refs/remotes/upstream/Extension-support…
sokoliva Nov 17, 2025
80be4bf
change test case name in tests/extensions/test_common.py
sokoliva Nov 17, 2025
4a423ef
Change the order of update_extension_header and _apply_interceptors f…
sokoliva Nov 18, 2025
9a5b1d6
Merge branch 'main' into Extension-support-for-Client
sokoliva Nov 18, 2025
125406d
Change assertion in test_client_server_integration
sokoliva Nov 18, 2025
f581c27
Merge remote-tracking branch 'refs/remotes/upstream/Extension-support…
sokoliva Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ def __init__(
transport: ClientTransport,
consumers: list[Consumer],
middleware: list[ClientCallInterceptor],
extensions: list[str],
):
super().__init__(consumers, middleware)
super().__init__(consumers, middleware, extensions)
self._card = card
config.extensions = extensions
self._config = config
self._transport = transport

Expand Down
7 changes: 7 additions & 0 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class ClientConfig:
)
"""Push notification callbacks to use for every request."""

extensions: list[str] = dataclasses.field(default_factory=list)
"""A list of extension URIs the client supports."""


UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None
# Alias for emitted events from client
Expand All @@ -90,6 +93,7 @@ def __init__(
self,
consumers: list[Consumer] | None = None,
middleware: list[ClientCallInterceptor] | None = None,
extensions: list[str] | None = None,
):
"""Initializes the client with consumers and middleware.

Expand All @@ -101,8 +105,11 @@ def __init__(
middleware = []
if consumers is None:
consumers = []
if extensions is None:
extensions = []
self._consumers = consumers
self._middleware = middleware
self._extensions = extensions

@abstractmethod
async def send_message(
Expand Down
18 changes: 16 additions & 2 deletions src/a2a/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _register_defaults(
card,
url,
interceptors,
config.extensions or None,
),
)
if TransportProtocol.http_json in supported:
Expand All @@ -90,6 +91,7 @@ def _register_defaults(
card,
url,
interceptors,
config.extensions or None,
),
)
if TransportProtocol.grpc in supported:
Expand All @@ -113,6 +115,7 @@ async def connect( # noqa: PLR0913
relative_card_path: str | None = None,
resolver_http_kwargs: dict[str, Any] | None = None,
extra_transports: dict[str, TransportProducer] | None = None,
extensions: list[str] | None = None,
) -> Client:
"""Convenience method for constructing a client.

Expand Down Expand Up @@ -166,7 +169,7 @@ async def connect( # noqa: PLR0913
factory = cls(client_config)
for label, generator in (extra_transports or {}).items():
factory.register(label, generator)
return factory.create(card, consumers, interceptors)
return factory.create(card, consumers, interceptors, extensions)

def register(self, label: str, generator: TransportProducer) -> None:
"""Register a new transport producer for a given transport label."""
Expand All @@ -177,6 +180,7 @@ def create(
card: AgentCard,
consumers: list[Consumer] | None = None,
interceptors: list[ClientCallInterceptor] | None = None,
extensions: list[str] | None = None,
) -> Client:
"""Create a new `Client` for the provided `AgentCard`.

Expand Down Expand Up @@ -226,12 +230,22 @@ def create(
if consumers:
all_consumers.extend(consumers)

all_extensions = self._config.extensions.copy()
if extensions:
all_extensions.extend(extensions)
self._config.extensions = all_extensions

transport = self._registry[transport_protocol](
card, transport_url, self._config, interceptors or []
)

return BaseClient(
card, self._config, transport, all_consumers, interceptors or []
card,
self._config,
transport,
all_consumers,
interceptors or [],
all_extensions,
)


Expand Down
43 changes: 30 additions & 13 deletions src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
"'pip install a2a-sdk[grpc]'"
) from e


from a2a.client.client import ClientConfig
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.optionals import Channel
from a2a.client.transports.base import ClientTransport
from a2a.client.transports.utils import update_extension_metadata
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
from a2a.types import (
AgentCard,
Expand Down Expand Up @@ -44,6 +46,7 @@ def __init__(
self,
channel: Channel,
agent_card: AgentCard | None,
extensions: list[str] | None = None,
):
"""Initializes the GrpcTransport."""
self.agent_card = agent_card
Expand All @@ -54,6 +57,7 @@ def __init__(
if agent_card
else True
)
self.extensions = extensions

@classmethod
def create(
Expand All @@ -66,10 +70,7 @@ def create(
"""Creates a gRPC transport for the A2A client."""
if config.grpc_channel_factory is None:
raise ValueError('grpc_channel_factory is required when using gRPC')
return cls(
config.grpc_channel_factory(url),
card,
)
return cls(config.grpc_channel_factory(url), card, config.extensions)

async def send_message(
self,
Expand All @@ -84,8 +85,10 @@ async def send_message(
configuration=proto_utils.ToProto.message_send_configuration(
request.configuration
),
metadata=proto_utils.ToProto.metadata(request.metadata),
)
metadata=update_extension_metadata(
request.metadata, self.extensions
),
),
)
if response.HasField('task'):
return proto_utils.FromProto.task(response.task)
Expand All @@ -106,8 +109,10 @@ async def send_message_streaming(
configuration=proto_utils.ToProto.message_send_configuration(
request.configuration
),
metadata=proto_utils.ToProto.metadata(request.metadata),
)
metadata=update_extension_metadata(
request.metadata, self.extensions
),
),
)
while True:
response = await stream.read()
Expand All @@ -122,7 +127,10 @@ async def resubscribe(
]:
"""Reconnects to get task updates."""
stream = self.stub.TaskSubscription(
a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}')
a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'),
metadata=update_extension_metadata(
request.metadata, self.extensions
),
)
while True:
response = await stream.read()
Expand All @@ -141,7 +149,10 @@ async def get_task(
a2a_pb2.GetTaskRequest(
name=f'tasks/{request.id}',
history_length=request.history_length,
)
),
metadata=update_extension_metadata(
request.metadata, self.extensions
),
)
return proto_utils.FromProto.task(task)

Expand All @@ -153,7 +164,7 @@ async def cancel_task(
) -> Task:
"""Requests the agent to cancel a specific task."""
task = await self.stub.CancelTask(
a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}')
a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'),
)
return proto_utils.FromProto.task(task)

Expand All @@ -171,7 +182,10 @@ async def set_task_callback(
config=proto_utils.ToProto.task_push_notification_config(
request
),
)
),
metadata=update_extension_metadata(
request.metadata, self.extensions
),
)
return proto_utils.FromProto.task_push_notification_config(config)

Expand All @@ -185,7 +199,10 @@ async def get_task_callback(
config = await self.stub.GetTaskPushNotificationConfig(
a2a_pb2.GetTaskPushNotificationConfigRequest(
name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
)
),
metadata=update_extension_metadata(
request.metadata, self.extensions
),
)
return proto_utils.FromProto.task_push_notification_config(config)

Expand Down
51 changes: 36 additions & 15 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.client.transports.utils import get_http_args, update_extension_header
from a2a.types import (
AgentCard,
CancelTaskRequest,
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
agent_card: AgentCard | None = None,
url: str | None = None,
interceptors: list[ClientCallInterceptor] | None = None,
extensions: list[str] | None = None,
):
"""Initializes the JsonRpcTransport."""
if url:
Expand All @@ -79,6 +81,7 @@ def __init__(
if agent_card
else True
)
self.extensions = extensions

async def _apply_interceptors(
self,
Expand All @@ -103,11 +106,6 @@ async def _apply_interceptors(
)
return final_request_payload, final_http_kwargs

def _get_http_args(
self, context: ClientCallContext | None
) -> dict[str, Any] | None:
return context.state.get('http_kwargs') if context else None

async def send_message(
self,
request: MessageSendParams,
Expand All @@ -119,9 +117,12 @@ async def send_message(
payload, modified_kwargs = await self._apply_interceptors(
'message/send',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
get_http_args(context),
context,
)
modified_kwargs = update_extension_header(
modified_kwargs, self.extensions
)
response_data = await self._send_request(payload, modified_kwargs)
response = SendMessageResponse.model_validate(response_data)
if isinstance(response.root, JSONRPCErrorResponse):
Expand All @@ -143,10 +144,13 @@ async def send_message_streaming(
payload, modified_kwargs = await self._apply_interceptors(
'message/stream',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
get_http_args(context),
context,
)

modified_kwargs = update_extension_header(
modified_kwargs, self.extensions
)
modified_kwargs.setdefault(
'timeout', self.httpx_client.timeout.as_dict().get('read', None)
)
Expand Down Expand Up @@ -213,9 +217,12 @@ async def get_task(
payload, modified_kwargs = await self._apply_interceptors(
'tasks/get',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
get_http_args(context),
context,
)
modified_kwargs = update_extension_header(
modified_kwargs, self.extensions
)
response_data = await self._send_request(payload, modified_kwargs)
response = GetTaskResponse.model_validate(response_data)
if isinstance(response.root, JSONRPCErrorResponse):
Expand All @@ -233,9 +240,12 @@ async def cancel_task(
payload, modified_kwargs = await self._apply_interceptors(
'tasks/cancel',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
get_http_args(context),
context,
)
modified_kwargs = update_extension_header(
modified_kwargs, self.extensions
)
response_data = await self._send_request(payload, modified_kwargs)
response = CancelTaskResponse.model_validate(response_data)
if isinstance(response.root, JSONRPCErrorResponse):
Expand All @@ -255,9 +265,12 @@ async def set_task_callback(
payload, modified_kwargs = await self._apply_interceptors(
'tasks/pushNotificationConfig/set',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
get_http_args(context),
context,
)
modified_kwargs = update_extension_header(
modified_kwargs, self.extensions
)
response_data = await self._send_request(payload, modified_kwargs)
response = SetTaskPushNotificationConfigResponse.model_validate(
response_data
Expand All @@ -279,9 +292,12 @@ async def get_task_callback(
payload, modified_kwargs = await self._apply_interceptors(
'tasks/pushNotificationConfig/get',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
get_http_args(context),
context,
)
modified_kwargs = update_extension_header(
modified_kwargs, self.extensions
)
response_data = await self._send_request(payload, modified_kwargs)
response = GetTaskPushNotificationConfigResponse.model_validate(
response_data
Expand All @@ -303,10 +319,12 @@ async def resubscribe(
payload, modified_kwargs = await self._apply_interceptors(
'tasks/resubscribe',
rpc_request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
get_http_args(context),
context,
)

modified_kwargs = update_extension_header(
modified_kwargs, self.extensions
)
modified_kwargs.setdefault('timeout', None)

async with aconnect_sse(
Expand Down Expand Up @@ -345,7 +363,7 @@ async def get_card(
if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(
http_kwargs=self._get_http_args(context)
http_kwargs=get_http_args(context)
)
self._needs_extended_card = (
card.supports_authenticated_extended_card
Expand All @@ -359,9 +377,12 @@ async def get_card(
payload, modified_kwargs = await self._apply_interceptors(
request.method,
request.model_dump(mode='json', exclude_none=True),
self._get_http_args(context),
get_http_args(context),
context,
)
modified_kwargs = update_extension_header(
modified_kwargs, self.extensions
)

response_data = await self._send_request(
payload,
Expand Down
Loading