1616
1717from abc import ABC
1818import asyncio
19+ import inspect
1920import logging
2021from typing import AsyncGenerator
2122from typing import cast
@@ -192,14 +193,15 @@ async def _receive_from_model(
192193 """Receive data from model and process events using BaseLlmConnection."""
193194 def get_author (llm_response ):
194195 """Get the author of the event.
195-
196- When the model returns transcription, the author is "user". Otherwise, the author is the agent.
196+
197+ When the model returns transcription, the author is "user". Otherwise, the
198+ author is the agent.
197199 """
198200 if llm_response and llm_response .content and llm_response .content .role == "user" :
199201 return "user"
200202 else :
201203 return invocation_context .agent .name
202-
204+
203205 assert invocation_context .live_request_queue
204206 try :
205207 while True :
@@ -447,7 +449,7 @@ async def _call_llm_async(
447449 model_response_event : Event ,
448450 ) -> AsyncGenerator [LlmResponse , None ]:
449451 # Runs before_model_callback if it exists.
450- if response := self ._handle_before_model_callback (
452+ if response := await self ._handle_before_model_callback (
451453 invocation_context , llm_request , model_response_event
452454 ):
453455 yield response
@@ -460,7 +462,7 @@ async def _call_llm_async(
460462 invocation_context .live_request_queue = LiveRequestQueue ()
461463 async for llm_response in self .run_live (invocation_context ):
462464 # Runs after_model_callback if it exists.
463- if altered_llm_response := self ._handle_after_model_callback (
465+ if altered_llm_response := await self ._handle_after_model_callback (
464466 invocation_context , llm_response , model_response_event
465467 ):
466468 llm_response = altered_llm_response
@@ -489,14 +491,14 @@ async def _call_llm_async(
489491 llm_response ,
490492 )
491493 # Runs after_model_callback if it exists.
492- if altered_llm_response := self ._handle_after_model_callback (
494+ if altered_llm_response := await self ._handle_after_model_callback (
493495 invocation_context , llm_response , model_response_event
494496 ):
495497 llm_response = altered_llm_response
496498
497499 yield llm_response
498500
499- def _handle_before_model_callback (
501+ async def _handle_before_model_callback (
500502 self ,
501503 invocation_context : InvocationContext ,
502504 llm_request : LlmRequest ,
@@ -508,17 +510,23 @@ def _handle_before_model_callback(
508510 if not isinstance (agent , LlmAgent ):
509511 return
510512
511- if not agent .before_model_callback :
513+ if not agent .canonical_before_model_callbacks :
512514 return
513515
514516 callback_context = CallbackContext (
515517 invocation_context , event_actions = model_response_event .actions
516518 )
517- return agent .before_model_callback (
518- callback_context = callback_context , llm_request = llm_request
519- )
520519
521- def _handle_after_model_callback (
520+ for callback in agent .canonical_before_model_callbacks :
521+ before_model_callback_content = callback (
522+ callback_context = callback_context , llm_request = llm_request
523+ )
524+ if inspect .isawaitable (before_model_callback_content ):
525+ before_model_callback_content = await before_model_callback_content
526+ if before_model_callback_content :
527+ return before_model_callback_content
528+
529+ async def _handle_after_model_callback (
522530 self ,
523531 invocation_context : InvocationContext ,
524532 llm_response : LlmResponse ,
@@ -530,15 +538,21 @@ def _handle_after_model_callback(
530538 if not isinstance (agent , LlmAgent ):
531539 return
532540
533- if not agent .after_model_callback :
541+ if not agent .canonical_after_model_callbacks :
534542 return
535543
536544 callback_context = CallbackContext (
537545 invocation_context , event_actions = model_response_event .actions
538546 )
539- return agent .after_model_callback (
540- callback_context = callback_context , llm_response = llm_response
541- )
547+
548+ for callback in agent .canonical_after_model_callbacks :
549+ after_model_callback_content = callback (
550+ callback_context = callback_context , llm_response = llm_response
551+ )
552+ if inspect .isawaitable (after_model_callback_content ):
553+ after_model_callback_content = await after_model_callback_content
554+ if after_model_callback_content :
555+ return after_model_callback_content
542556
543557 def _finalize_model_response_event (
544558 self ,
0 commit comments