2727from .store import AttachmentStore , Store , StoreItemType , default_generate_id
2828from .types import (
2929 Action ,
30+ AssistantMessageContent ,
31+ AssistantMessageContentPartAdded ,
32+ AssistantMessageContentPartAnnotationAdded ,
33+ AssistantMessageContentPartDone ,
34+ AssistantMessageContentPartTextDelta ,
35+ AssistantMessageItem ,
3036 AttachmentsCreateReq ,
3137 AttachmentsDeleteReq ,
3238 ChatKitReq ,
3945 ItemsListReq ,
4046 NonStreamingReq ,
4147 Page ,
48+ SDKHiddenContextItem ,
4249 StreamingReq ,
50+ StreamOptions ,
51+ StreamOptionsEvent ,
4352 Thread ,
4453 ThreadCreatedEvent ,
4554 ThreadItem ,
6675 WidgetItem ,
6776 WidgetRootUpdated ,
6877 WidgetStreamingTextValueDelta ,
78+ WorkflowItem ,
79+ WorkflowTaskAdded ,
80+ WorkflowTaskUpdated ,
6981 is_streaming_req ,
7082)
7183from .version import __version__
@@ -316,6 +328,56 @@ def action(
316328 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
317329 )
318330
331+ def get_stream_options (
332+ self , thread : ThreadMetadata , context : TContext
333+ ) -> StreamOptions :
334+ """
335+ Return stream-level runtime options. Allows the user to cancel the stream by default.
336+ Override this method to customize behavior.
337+ """
338+ return StreamOptions (allow_cancel = True )
339+
340+ async def handle_stream_cancelled (
341+ self ,
342+ thread : ThreadMetadata ,
343+ pending_items : list [ThreadItem ],
344+ context : TContext ,
345+ ):
346+ """Perform custom cleanup / stop inference when a stream is cancelled.
347+ Updates you make here will not be reflected in the UI until a reload.
348+
349+ The default implementation persists any non-empty pending assistant messages
350+ to the thread but does not auto-save pending widget items or workflow items.
351+
352+ Args:
353+ thread: The thread that was being processed.
354+ pending_items: Items that were not done streaming at cancellation time.
355+ context: Arbitrary per-request context provided by the caller.
356+ """
357+ pending_assistant_message_items : list [AssistantMessageItem ] = [
358+ item for item in pending_items if isinstance (item , AssistantMessageItem )
359+ ]
360+ for item in pending_assistant_message_items :
361+ is_empty = len (item .content ) == 0 or all (
362+ (not content .text .strip ()) for content in item .content
363+ )
364+ if not is_empty :
365+ await self .store .add_thread_item (thread .id , item , context = context )
366+
367+ # Add a hidden context item to the thread to indicate that the stream was cancelled.
368+ # Otherwise, depending on the timing of the cancellation, subsequent responses may
369+ # attempt to continue the cancelled response.
370+ await self .store .add_thread_item (
371+ thread .id ,
372+ SDKHiddenContextItem (
373+ thread_id = thread .id ,
374+ created_at = datetime .now (),
375+ id = self .store .generate_item_id ("sdk_hidden_context" , thread , context ),
376+ content = "The user cancelled the stream. Stop responding to the prior request." ,
377+ ),
378+ context = context ,
379+ )
380+
319381 async def process (
320382 self , request : str | bytes | bytearray , context : TContext
321383 ) -> StreamingResult | NonStreamingResult :
@@ -387,11 +449,11 @@ async def _process_non_streaming(
387449 after = items_list_params .after ,
388450 context = context ,
389451 )
390- # filter out HiddenContextItems
452+ # filter out hidden context items
391453 items .data = [
392454 item
393455 for item in items .data
394- if not isinstance (item , HiddenContextItem )
456+ if not isinstance (item , ( HiddenContextItem , SDKHiddenContextItem ) )
395457 ]
396458 return self ._serialize (items )
397459 case ThreadsUpdateReq ():
@@ -416,6 +478,9 @@ async def _process_streaming(
416478 async for event in self ._process_streaming_impl (request , context ):
417479 b = self ._serialize (event )
418480 yield b"data: " + b + b"\n \n "
481+ except asyncio .CancelledError :
482+ # Let cancellation bubble up without logging as an error.
483+ raise
419484 except Exception :
420485 logger .exception ("Error while generating streamed response" )
421486 raise
@@ -612,29 +677,51 @@ async def _process_events(
612677 ) -> AsyncIterator [ThreadStreamEvent ]:
613678 await asyncio .sleep (0 ) # allow the response to start streaming
614679
680+ # Send initial stream options
681+ yield StreamOptionsEvent (
682+ stream_options = self .get_stream_options (thread , context )
683+ )
684+
615685 last_thread = thread .model_copy (deep = True )
616686
687+ # Keep track of items that were streamed but not yet saved
688+ # so that we can persist them when the stream is cancelled.
689+ pending_items : dict [str , ThreadItem ] = {}
690+
617691 try :
618692 with agents_sdk_user_agent_override ():
619693 async for event in stream ():
694+ if isinstance (event , ThreadItemAddedEvent ):
695+ pending_items [event .item .id ] = event .item
696+
620697 match event :
621698 case ThreadItemDoneEvent ():
622699 await self .store .add_thread_item (
623700 thread .id , event .item , context = context
624701 )
702+ pending_items .pop (event .item .id , None )
625703 case ThreadItemRemovedEvent ():
626704 await self .store .delete_thread_item (
627705 thread .id , event .item_id , context = context
628706 )
707+ pending_items .pop (event .item_id , None )
629708 case ThreadItemReplacedEvent ():
630709 await self .store .save_item (
631710 thread .id , event .item , context = context
632711 )
712+ pending_items .pop (event .item .id , None )
713+ case ThreadItemUpdatedEvent ():
714+ # Keep pending assistant message and workflow items up to date
715+ # so that we have a reference to the latest version of these pending items
716+ # when the stream is cancelled.
717+ self ._update_pending_items (pending_items , event )
633718
634719 # special case - don't send hidden context items back to the client
635720 should_swallow_event = isinstance (
636721 event , ThreadItemDoneEvent
637- ) and isinstance (event .item , HiddenContextItem )
722+ ) and isinstance (
723+ event .item , (HiddenContextItem , SDKHiddenContextItem )
724+ )
638725
639726 if not should_swallow_event :
640727 yield event
@@ -651,6 +738,11 @@ async def _process_events(
651738 last_thread = thread .model_copy (deep = True )
652739 await self .store .save_thread (thread , context = context )
653740 yield ThreadUpdatedEvent (thread = self ._to_thread_response (thread ))
741+ except asyncio .CancelledError :
742+ await self .handle_stream_cancelled (
743+ thread , list (pending_items .values ()), context
744+ )
745+ raise
654746 except CustomStreamError as e :
655747 yield ErrorEvent (
656748 code = "custom" ,
@@ -674,6 +766,69 @@ async def _process_events(
674766 await self .store .save_thread (thread , context = context )
675767 yield ThreadUpdatedEvent (thread = self ._to_thread_response (thread ))
676768
769+ def _apply_assistant_message_update (
770+ self ,
771+ item : AssistantMessageItem ,
772+ update : AssistantMessageContentPartAdded
773+ | AssistantMessageContentPartTextDelta
774+ | AssistantMessageContentPartAnnotationAdded
775+ | AssistantMessageContentPartDone ,
776+ ) -> AssistantMessageItem :
777+ updated = item .model_copy (deep = True )
778+
779+ # Pad the content list so the requested content_index exists before we write into it.
780+ # (Streaming updates can arrive for an index that hasn’t been created yet)
781+ while len (updated .content ) <= update .content_index :
782+ updated .content .append (AssistantMessageContent (text = "" , annotations = []))
783+
784+ match update :
785+ case AssistantMessageContentPartAdded ():
786+ updated .content [update .content_index ] = update .content
787+ case AssistantMessageContentPartTextDelta ():
788+ updated .content [update .content_index ].text += update .delta
789+ case AssistantMessageContentPartAnnotationAdded ():
790+ annotations = updated .content [update .content_index ].annotations
791+ if update .annotation_index <= len (annotations ):
792+ annotations .insert (update .annotation_index , update .annotation )
793+ else :
794+ annotations .append (update .annotation )
795+ case AssistantMessageContentPartDone ():
796+ updated .content [update .content_index ] = update .content
797+ return updated
798+
799+ def _update_pending_items (
800+ self ,
801+ pending_items : dict [str , ThreadItem ],
802+ event : ThreadItemUpdatedEvent ,
803+ ):
804+ updated_item = pending_items .get (event .item_id )
805+ update = event .update
806+ match updated_item :
807+ case AssistantMessageItem ():
808+ if isinstance (
809+ update ,
810+ (
811+ AssistantMessageContentPartAdded ,
812+ AssistantMessageContentPartTextDelta ,
813+ AssistantMessageContentPartAnnotationAdded ,
814+ AssistantMessageContentPartDone ,
815+ ),
816+ ):
817+ pending_items [updated_item .id ] = (
818+ self ._apply_assistant_message_update (updated_item , update )
819+ )
820+ case WorkflowItem ():
821+ if isinstance (update , (WorkflowTaskUpdated , WorkflowTaskAdded )):
822+ match update :
823+ case WorkflowTaskUpdated ():
824+ updated_item .workflow .tasks [update .task_index ] = update .task
825+ case WorkflowTaskAdded ():
826+ updated_item .workflow .tasks .append (update .task )
827+
828+ pending_items [updated_item .id ] = updated_item
829+ case _:
830+ pass
831+
677832 async def _build_user_message_item (
678833 self , input : UserMessageInput , thread : ThreadMetadata , context : TContext
679834 ) -> UserMessageItem :
@@ -722,7 +877,7 @@ def _serialize(self, obj: BaseModel) -> bytes:
722877
723878 def _to_thread_response (self , thread : ThreadMetadata | Thread ) -> Thread :
724879 def is_hidden (item : ThreadItem ) -> bool :
725- return isinstance (item , HiddenContextItem )
880+ return isinstance (item , ( HiddenContextItem , SDKHiddenContextItem ) )
726881
727882 items = thread .items if isinstance (thread , Thread ) else Page ()
728883 items .data = [item for item in items .data if not is_hidden (item )]
0 commit comments