Skip to content

Commit 74642ff

Browse files
authored
Merge branch 'main' into raflFaisal/generaliseContainerEnvForAgent
2 parents cdf7511 + 180c2a9 commit 74642ff

36 files changed

+1196
-160
lines changed

.github/workflows/python-unit-tests.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,5 @@ jobs:
3636
source .venv/bin/activate
3737
pytest tests/unittests \
3838
--ignore=tests/unittests/artifacts/test_artifact_service.py \
39-
--ignore=tests/unittests/tools/application_integration_tool/clients/test_connections_client.py \
4039
--ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py
4140

src/google/adk/agents/base_agent.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any
17+
import inspect
18+
from typing import Any, Awaitable, Union
1819
from typing import AsyncGenerator
1920
from typing import Callable
2021
from typing import final
@@ -37,10 +38,15 @@
3738

3839
tracer = trace.get_tracer('gcp.vertex.agent')
3940

40-
BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
41+
BeforeAgentCallback = Callable[
42+
[CallbackContext],
43+
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
44+
]
4145

42-
43-
AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
46+
AfterAgentCallback = Callable[
47+
[CallbackContext],
48+
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
49+
]
4450

4551

4652
class BaseAgent(BaseModel):
@@ -119,7 +125,7 @@ async def run_async(
119125
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
120126
ctx = self._create_invocation_context(parent_context)
121127

122-
if event := self.__handle_before_agent_callback(ctx):
128+
if event := await self.__handle_before_agent_callback(ctx):
123129
yield event
124130
if ctx.end_invocation:
125131
return
@@ -130,7 +136,7 @@ async def run_async(
130136
if ctx.end_invocation:
131137
return
132138

133-
if event := self.__handle_after_agent_callback(ctx):
139+
if event := await self.__handle_after_agent_callback(ctx):
134140
yield event
135141

136142
@final
@@ -230,7 +236,7 @@ def _create_invocation_context(
230236
invocation_context.branch = f'{parent_context.branch}.{self.name}'
231237
return invocation_context
232238

233-
def __handle_before_agent_callback(
239+
async def __handle_before_agent_callback(
234240
self, ctx: InvocationContext
235241
) -> Optional[Event]:
236242
"""Runs the before_agent_callback if it exists.
@@ -248,6 +254,9 @@ def __handle_before_agent_callback(
248254
callback_context=callback_context
249255
)
250256

257+
if inspect.isawaitable(before_agent_callback_content):
258+
before_agent_callback_content = await before_agent_callback_content
259+
251260
if before_agent_callback_content:
252261
ret_event = Event(
253262
invocation_id=ctx.invocation_id,
@@ -269,7 +278,7 @@ def __handle_before_agent_callback(
269278

270279
return ret_event
271280

272-
def __handle_after_agent_callback(
281+
async def __handle_after_agent_callback(
273282
self, invocation_context: InvocationContext
274283
) -> Optional[Event]:
275284
"""Runs the after_agent_callback if it exists.
@@ -287,6 +296,9 @@ def __handle_after_agent_callback(
287296
callback_context=callback_context
288297
)
289298

299+
if inspect.isawaitable(after_agent_callback_content):
300+
after_agent_callback_content = await after_agent_callback_content
301+
290302
if after_agent_callback_content or callback_context.state.has_delta():
291303
ret_event = Event(
292304
invocation_id=invocation_context.invocation_id,

src/google/adk/agents/llm_agent.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,26 @@
4747

4848
logger = logging.getLogger(__name__)
4949

50+
_SingleBeforeModelCallback: TypeAlias = Callable[
51+
[CallbackContext, LlmRequest],
52+
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
53+
]
5054

51-
BeforeModelCallback: TypeAlias = Callable[
52-
[CallbackContext, LlmRequest], Optional[LlmResponse]
55+
BeforeModelCallback: TypeAlias = Union[
56+
_SingleBeforeModelCallback,
57+
list[_SingleBeforeModelCallback],
5358
]
54-
AfterModelCallback: TypeAlias = Callable[
59+
60+
_SingleAfterModelCallback: TypeAlias = Callable[
5561
[CallbackContext, LlmResponse],
56-
Optional[LlmResponse],
62+
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
5763
]
64+
65+
AfterModelCallback: TypeAlias = Union[
66+
_SingleAfterModelCallback,
67+
list[_SingleAfterModelCallback],
68+
]
69+
5870
BeforeToolCallback: TypeAlias = Callable[
5971
[BaseTool, dict[str, Any], ToolContext],
6072
Union[Awaitable[Optional[dict]], Optional[dict]],
@@ -173,7 +185,11 @@ class LlmAgent(BaseAgent):
173185

174186
# Callbacks - Start
175187
before_model_callback: Optional[BeforeModelCallback] = None
176-
"""Called before calling the LLM.
188+
"""Callback or list of callbacks to be called before calling the LLM.
189+
190+
When a list of callbacks is provided, the callbacks will be called in the
191+
order they are listed until a callback does not return None.
192+
177193
Args:
178194
callback_context: CallbackContext,
179195
llm_request: LlmRequest, The raw model request. Callback can mutate the
@@ -184,7 +200,10 @@ class LlmAgent(BaseAgent):
184200
skipped and the provided content will be returned to user.
185201
"""
186202
after_model_callback: Optional[AfterModelCallback] = None
187-
"""Called after calling LLM.
203+
"""Callback or list of callbacks to be called after calling the LLM.
204+
205+
When a list of callbacks is provided, the callbacks will be called in the
206+
order they are listed until a callback does not return None.
188207
189208
Args:
190209
callback_context: CallbackContext,
@@ -284,6 +303,32 @@ def canonical_tools(self) -> list[BaseTool]:
284303
"""
285304
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
286305

306+
@property
307+
def canonical_before_model_callbacks(
308+
self,
309+
) -> list[_SingleBeforeModelCallback]:
310+
"""The resolved self.before_model_callback field as a list of _SingleBeforeModelCallback.
311+
312+
This method is only for use by Agent Development Kit.
313+
"""
314+
if not self.before_model_callback:
315+
return []
316+
if isinstance(self.before_model_callback, list):
317+
return self.before_model_callback
318+
return [self.before_model_callback]
319+
320+
@property
321+
def canonical_after_model_callbacks(self) -> list[_SingleAfterModelCallback]:
322+
"""The resolved self.after_model_callback field as a list of _SingleAfterModelCallback.
323+
324+
This method is only for use by Agent Development Kit.
325+
"""
326+
if not self.after_model_callback:
327+
return []
328+
if isinstance(self.after_model_callback, list):
329+
return self.after_model_callback
330+
return [self.after_model_callback]
331+
287332
@property
288333
def _llm_flow(self) -> BaseLlmFlow:
289334
if (

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from abc import ABC
1818
import asyncio
19+
import inspect
1920
import logging
2021
from typing import AsyncGenerator
2122
from typing import cast
@@ -192,14 +193,15 @@ async def _receive_from_model(
192193
"""Receive data from model and process events using BaseLlmConnection."""
193194
def get_author(llm_response):
194195
"""Get the author of the event.
195-
196-
When the model returns transcription, the author is "user". Otherwise, the author is the agent.
196+
197+
When the model returns transcription, the author is "user". Otherwise, the
198+
author is the agent.
197199
"""
198200
if llm_response and llm_response.content and llm_response.content.role == "user":
199201
return "user"
200202
else:
201203
return invocation_context.agent.name
202-
204+
203205
assert invocation_context.live_request_queue
204206
try:
205207
while True:
@@ -447,7 +449,7 @@ async def _call_llm_async(
447449
model_response_event: Event,
448450
) -> AsyncGenerator[LlmResponse, None]:
449451
# Runs before_model_callback if it exists.
450-
if response := self._handle_before_model_callback(
452+
if response := await self._handle_before_model_callback(
451453
invocation_context, llm_request, model_response_event
452454
):
453455
yield response
@@ -460,7 +462,7 @@ async def _call_llm_async(
460462
invocation_context.live_request_queue = LiveRequestQueue()
461463
async for llm_response in self.run_live(invocation_context):
462464
# Runs after_model_callback if it exists.
463-
if altered_llm_response := self._handle_after_model_callback(
465+
if altered_llm_response := await self._handle_after_model_callback(
464466
invocation_context, llm_response, model_response_event
465467
):
466468
llm_response = altered_llm_response
@@ -489,14 +491,14 @@ async def _call_llm_async(
489491
llm_response,
490492
)
491493
# Runs after_model_callback if it exists.
492-
if altered_llm_response := self._handle_after_model_callback(
494+
if altered_llm_response := await self._handle_after_model_callback(
493495
invocation_context, llm_response, model_response_event
494496
):
495497
llm_response = altered_llm_response
496498

497499
yield llm_response
498500

499-
def _handle_before_model_callback(
501+
async def _handle_before_model_callback(
500502
self,
501503
invocation_context: InvocationContext,
502504
llm_request: LlmRequest,
@@ -508,17 +510,23 @@ def _handle_before_model_callback(
508510
if not isinstance(agent, LlmAgent):
509511
return
510512

511-
if not agent.before_model_callback:
513+
if not agent.canonical_before_model_callbacks:
512514
return
513515

514516
callback_context = CallbackContext(
515517
invocation_context, event_actions=model_response_event.actions
516518
)
517-
return agent.before_model_callback(
518-
callback_context=callback_context, llm_request=llm_request
519-
)
520519

521-
def _handle_after_model_callback(
520+
for callback in agent.canonical_before_model_callbacks:
521+
before_model_callback_content = callback(
522+
callback_context=callback_context, llm_request=llm_request
523+
)
524+
if inspect.isawaitable(before_model_callback_content):
525+
before_model_callback_content = await before_model_callback_content
526+
if before_model_callback_content:
527+
return before_model_callback_content
528+
529+
async def _handle_after_model_callback(
522530
self,
523531
invocation_context: InvocationContext,
524532
llm_response: LlmResponse,
@@ -530,15 +538,21 @@ def _handle_after_model_callback(
530538
if not isinstance(agent, LlmAgent):
531539
return
532540

533-
if not agent.after_model_callback:
541+
if not agent.canonical_after_model_callbacks:
534542
return
535543

536544
callback_context = CallbackContext(
537545
invocation_context, event_actions=model_response_event.actions
538546
)
539-
return agent.after_model_callback(
540-
callback_context=callback_context, llm_response=llm_response
541-
)
547+
548+
for callback in agent.canonical_after_model_callbacks:
549+
after_model_callback_content = callback(
550+
callback_context=callback_context, llm_response=llm_response
551+
)
552+
if inspect.isawaitable(after_model_callback_content):
553+
after_model_callback_content = await after_model_callback_content
554+
if after_model_callback_content:
555+
return after_model_callback_content
542556

543557
def _finalize_model_response_event(
544558
self,

src/google/adk/models/lite_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ async def generate_content_async(
611611
"""
612612

613613
self._maybe_append_user_content(llm_request)
614-
logger.info(_build_request_log(llm_request))
614+
logger.debug(_build_request_log(llm_request))
615615

616616
messages, tools = _get_completion_inputs(llm_request)
617617

src/google/adk/tools/agent_tool.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,12 @@ async def run_async(
146146

147147
if runner.artifact_service:
148148
# Forward all artifacts to parent session.
149-
async for artifact_name in runner.artifact_service.list_artifact_keys(
149+
artifact_names = await runner.artifact_service.list_artifact_keys(
150150
app_name=session.app_name,
151151
user_id=session.user_id,
152152
session_id=session.id,
153-
):
153+
)
154+
for artifact_name in artifact_names:
154155
if artifact := await runner.artifact_service.load_artifact(
155156
app_name=session.app_name,
156157
user_id=session.user_id,

src/google/adk/tools/application_integration_tool/application_integration_toolset.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(
7676
project: str,
7777
location: str,
7878
integration: Optional[str] = None,
79-
trigger: Optional[str] = None,
79+
triggers: Optional[List[str]] = None,
8080
connection: Optional[str] = None,
8181
entity_operations: Optional[str] = None,
8282
actions: Optional[str] = None,
@@ -98,7 +98,7 @@ def __init__(
9898
project="test-project",
9999
location="us-central1"
100100
integration="test-integration",
101-
trigger="api_trigger/test_trigger",
101+
triggers=["api_trigger/test_trigger"],
102102
service_account_credentials={...},
103103
)
104104
@@ -130,7 +130,7 @@ def __init__(
130130
project: The GCP project ID.
131131
location: The GCP location.
132132
integration: The integration name.
133-
trigger: The trigger name.
133+
triggers: The list of trigger names in the integration.
134134
connection: The connection name.
135135
entity_operations: The entity operations supported by the connection.
136136
actions: The actions supported by the connection.
@@ -149,7 +149,7 @@ def __init__(
149149
self.project = project
150150
self.location = location
151151
self.integration = integration
152-
self.trigger = trigger
152+
self.triggers = triggers
153153
self.connection = connection
154154
self.entity_operations = entity_operations
155155
self.actions = actions
@@ -162,14 +162,14 @@ def __init__(
162162
project,
163163
location,
164164
integration,
165-
trigger,
165+
triggers,
166166
connection,
167167
entity_operations,
168168
actions,
169169
service_account_json,
170170
)
171171
connection_details = {}
172-
if integration and trigger:
172+
if integration:
173173
spec = integration_client.get_openapi_spec_for_integration()
174174
elif connection and (entity_operations or actions):
175175
connections_client = ConnectionsClient(
@@ -210,7 +210,7 @@ def _parse_spec_to_tools(self, spec_dict, connection_details):
210210
)
211211
auth_scheme = HTTPBearer(bearerFormat="JWT")
212212

213-
if self.integration and self.trigger:
213+
if self.integration:
214214
tools = OpenAPIToolset(
215215
spec_dict=spec_dict,
216216
auth_credential=auth_credential,

0 commit comments

Comments
 (0)