992. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
1010"""
1111
12- import asyncio
1312import json
1413import logging
1514import os
1615import random
1716from concurrent .futures import ThreadPoolExecutor
18- from threading import Thread
19- from typing import Any , AsyncIterator , Callable , Dict , List , Mapping , Optional , Type , TypeVar , Union , cast
20- from uuid import uuid4
17+ from typing import Any , AsyncIterator , Callable , Generator , Mapping , Optional , Type , TypeVar , Union , cast
2118
2219from opentelemetry import trace
2320from pydantic import BaseModel
2421
2522from ..event_loop .event_loop import event_loop_cycle
26- from ..handlers .callback_handler import CompositeCallbackHandler , PrintingCallbackHandler , null_callback_handler
23+ from ..handlers .callback_handler import PrintingCallbackHandler , null_callback_handler
2724from ..handlers .tool_handler import AgentToolHandler
2825from ..models .bedrock import BedrockModel
2926from ..telemetry .metrics import EventLoopMetrics
@@ -210,7 +207,7 @@ def __init__(
210207 self ,
211208 model : Union [Model , str , None ] = None ,
212209 messages : Optional [Messages ] = None ,
213- tools : Optional [List [Union [str , Dict [str , str ], Any ]]] = None ,
210+ tools : Optional [list [Union [str , dict [str , str ], Any ]]] = None ,
214211 system_prompt : Optional [str ] = None ,
215212 callback_handler : Optional [
216213 Union [Callable [..., Any ], _DefaultCallbackHandlerSentinel ]
@@ -282,7 +279,7 @@ def __init__(
282279 self .conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager ()
283280
284281 # Process trace attributes to ensure they're of compatible types
285- self .trace_attributes : Dict [str , AttributeValue ] = {}
282+ self .trace_attributes : dict [str , AttributeValue ] = {}
286283 if trace_attributes :
287284 for k , v in trace_attributes .items ():
288285 if isinstance (v , (str , int , float , bool )) or (
@@ -339,7 +336,7 @@ def tool(self) -> ToolCaller:
339336 return self .tool_caller
340337
341338 @property
342- def tool_names (self ) -> List [str ]:
339+ def tool_names (self ) -> list [str ]:
343340 """Get a list of all registered tool names.
344341
345342 Returns:
@@ -384,19 +381,25 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
384381 - metrics: Performance metrics from the event loop
385382 - state: The final state of the event loop
386383 """
384+ callback_handler = kwargs .get ("callback_handler" , self .callback_handler )
385+
387386 self ._start_agent_trace_span (prompt )
388387
389388 try :
390- # Run the event loop and get the result
391- result = self ._run_loop (prompt , kwargs )
389+ events = self ._run_loop (callback_handler , prompt , kwargs )
390+ for event in events :
391+ if "callback" in event :
392+ callback_handler (** event ["callback" ])
393+
394+ stop_reason , message , metrics , state = event ["stop" ]
395+ result = AgentResult (stop_reason , message , metrics , state )
392396
393397 self ._end_agent_trace_span (response = result )
394398
395399 return result
400+
396401 except Exception as e :
397402 self ._end_agent_trace_span (error = e )
398-
399- # Re-raise the exception to preserve original behavior
400403 raise
401404
402405 def structured_output (self , output_model : Type [T ], prompt : Optional [str ] = None ) -> T :
@@ -460,83 +463,56 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
460463 yield event["data"]
461464 ```
462465 """
463- self . _start_agent_trace_span ( prompt )
466+ callback_handler = kwargs . get ( "callback_handler" , self . callback_handler )
464467
465- _stop_event = uuid4 ()
466-
467- queue = asyncio .Queue [Any ]()
468- loop = asyncio .get_event_loop ()
469-
470- def enqueue (an_item : Any ) -> None :
471- nonlocal queue
472- nonlocal loop
473- loop .call_soon_threadsafe (queue .put_nowait , an_item )
474-
475- def queuing_callback_handler (** handler_kwargs : Any ) -> None :
476- enqueue (handler_kwargs .copy ())
468+ self ._start_agent_trace_span (prompt )
477469
478- def target_callback () -> None :
479- nonlocal kwargs
470+ try :
471+ events = self ._run_loop (callback_handler , prompt , kwargs )
472+ for event in events :
473+ if "callback" in event :
474+ callback_handler (** event ["callback" ])
475+ yield event ["callback" ]
480476
481- try :
482- result = self ._run_loop (prompt , kwargs , supplementary_callback_handler = queuing_callback_handler )
483- self ._end_agent_trace_span (response = result )
484- except Exception as e :
485- self ._end_agent_trace_span (error = e )
486- enqueue (e )
487- finally :
488- enqueue (_stop_event )
477+ stop_reason , message , metrics , state = event ["stop" ]
478+ result = AgentResult (stop_reason , message , metrics , state )
489479
490- thread = Thread (target = target_callback , daemon = True )
491- thread .start ()
480+ self ._end_agent_trace_span (response = result )
492481
493- try :
494- while True :
495- item = await queue .get ()
496- if item == _stop_event :
497- break
498- if isinstance (item , Exception ):
499- raise item
500- yield item
501- finally :
502- thread .join ()
482+ except Exception as e :
483+ self ._end_agent_trace_span (error = e )
484+ raise
503485
504486 def _run_loop (
505- self , prompt : str , kwargs : Dict [ str , Any ], supplementary_callback_handler : Optional [ Callable [... , Any ]] = None
506- ) -> AgentResult :
487+ self , callback_handler : Callable [... , Any ], prompt : str , kwargs : dict [ str , Any ]
488+ ) -> Generator [ dict [ str , Any ], None , None ] :
507489 """Execute the agent's event loop with the given prompt and parameters."""
508490 try :
509- # If the call had a callback_handler passed in, then for this event_loop
510- # cycle we call both handlers as the callback_handler
511- invocation_callback_handler = (
512- CompositeCallbackHandler (self .callback_handler , supplementary_callback_handler )
513- if supplementary_callback_handler is not None
514- else self .callback_handler
515- )
516-
517491 # Extract key parameters
518- invocation_callback_handler ( init_event_loop = True , ** kwargs )
492+ yield { "callback" : { " init_event_loop" : True , ** kwargs }}
519493
520494 # Set up the user message with optional knowledge base retrieval
521- message_content : List [ContentBlock ] = [{"text" : prompt }]
495+ message_content : list [ContentBlock ] = [{"text" : prompt }]
522496 new_message : Message = {"role" : "user" , "content" : message_content }
523497 self .messages .append (new_message )
524498
525499 # Execute the event loop cycle with retry logic for context limits
526- return self ._execute_event_loop_cycle (invocation_callback_handler , kwargs )
500+ yield from self ._execute_event_loop_cycle (callback_handler , kwargs )
527501
528502 finally :
529503 self .conversation_manager .apply_management (self )
530504
531- def _execute_event_loop_cycle (self , callback_handler : Callable [..., Any ], kwargs : Dict [str , Any ]) -> AgentResult :
505+ def _execute_event_loop_cycle (
506+ self , callback_handler : Callable [..., Any ], kwargs : dict [str , Any ]
507+ ) -> Generator [dict [str , Any ], None , None ]:
532508 """Execute the event loop cycle with retry logic for context window limits.
533509
534510 This internal method handles the execution of the event loop cycle and implements
535511 retry logic for handling context window overflow exceptions by reducing the
536512 conversation context and retrying.
537513
538- Returns :
539- The result of the event loop cycle.
514+ Yields :
515+ Events of the loop cycle.
540516 """
541517 # Extract parameters with fallbacks to instance values
542518 system_prompt = kwargs .pop ("system_prompt" , self .system_prompt )
@@ -551,7 +527,7 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
551527
552528 try :
553529 # Execute the main event loop cycle
554- events = event_loop_cycle (
530+ yield from event_loop_cycle (
555531 model = model ,
556532 system_prompt = system_prompt ,
557533 messages = messages , # will be modified by event_loop_cycle
@@ -564,26 +540,18 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
564540 event_loop_parent_span = self .trace_span ,
565541 ** kwargs ,
566542 )
567- for event in events :
568- if "callback" in event :
569- callback_handler (** event ["callback" ])
570-
571- stop_reason , message , metrics , state = event ["stop" ]
572-
573- return AgentResult (stop_reason , message , metrics , state )
574543
575544 except ContextWindowOverflowException as e :
576545 # Try reducing the context size and retrying
577-
578546 self .conversation_manager .reduce_context (self , e = e )
579- return self ._execute_event_loop_cycle (callback_handler_override , kwargs )
547+ yield from self ._execute_event_loop_cycle (callback_handler_override , kwargs )
580548
581549 def _record_tool_execution (
582550 self ,
583- tool : Dict [str , Any ],
584- tool_result : Dict [str , Any ],
551+ tool : dict [str , Any ],
552+ tool_result : dict [str , Any ],
585553 user_message_override : Optional [str ],
586- messages : List [ Dict [str , Any ]],
554+ messages : list [ dict [str , Any ]],
587555 ) -> None :
588556 """Record a tool execution in the message history.
589557
@@ -662,7 +630,7 @@ def _end_agent_trace_span(
662630 error: Error to record as a trace attribute.
663631 """
664632 if self .trace_span :
665- trace_attributes : Dict [str , Any ] = {
633+ trace_attributes : dict [str , Any ] = {
666634 "span" : self .trace_span ,
667635 }
668636
0 commit comments