Skip to content

Commit 675ded1

Browse files
authored
Merge pull request #65 from openai/threads-stop
Handle cancelled stream
2 parents 31f0778 + 7ae5852 commit 675ded1

File tree

5 files changed

+398
-32
lines changed

5 files changed

+398
-32
lines changed

chatkit/agents.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
EndOfTurnItem,
5757
FileSource,
5858
HiddenContextItem,
59+
SDKHiddenContextItem,
5960
Task,
6061
TaskItem,
6162
ThoughtTask,
@@ -712,6 +713,30 @@ async def hidden_context_to_input(
712713
role="user",
713714
)
714715

716+
async def sdk_hidden_context_to_input(
717+
self, item: SDKHiddenContextItem
718+
) -> TResponseInputItem | list[TResponseInputItem] | None:
719+
"""
720+
Convert a SDKHiddenContextItem into input item to send to the model.
721+
This is used by the ChatKit Python SDK for storing additional context
722+
for internal operations.
723+
Override if you want to wrap the content in a different format.
724+
"""
725+
text = (
726+
"Hidden context for the agent (not shown to the user):\n"
727+
f"<HiddenContext>\n{item.content}\n</HiddenContext>"
728+
)
729+
return Message(
730+
type="message",
731+
content=[
732+
ResponseInputTextParam(
733+
type="input_text",
734+
text=text,
735+
)
736+
],
737+
role="user",
738+
)
739+
715740
async def task_to_input(
716741
self, item: TaskItem
717742
) -> TResponseInputItem | list[TResponseInputItem] | None:
@@ -948,6 +973,9 @@ async def _thread_item_to_input_item(
948973
case HiddenContextItem():
949974
out = await self.hidden_context_to_input(item) or []
950975
return out if isinstance(out, list) else [out]
976+
case SDKHiddenContextItem():
977+
out = await self.sdk_hidden_context_to_input(item) or []
978+
return out if isinstance(out, list) else [out]
951979
case _:
952980
assert_never(item)
953981

chatkit/server.py

Lines changed: 159 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
from .store import AttachmentStore, Store, StoreItemType, default_generate_id
2828
from .types import (
2929
Action,
30+
AssistantMessageContent,
31+
AssistantMessageContentPartAdded,
32+
AssistantMessageContentPartAnnotationAdded,
33+
AssistantMessageContentPartDone,
34+
AssistantMessageContentPartTextDelta,
35+
AssistantMessageItem,
3036
AttachmentsCreateReq,
3137
AttachmentsDeleteReq,
3238
ChatKitReq,
@@ -39,7 +45,10 @@
3945
ItemsListReq,
4046
NonStreamingReq,
4147
Page,
48+
SDKHiddenContextItem,
4249
StreamingReq,
50+
StreamOptions,
51+
StreamOptionsEvent,
4352
Thread,
4453
ThreadCreatedEvent,
4554
ThreadItem,
@@ -66,6 +75,9 @@
6675
WidgetItem,
6776
WidgetRootUpdated,
6877
WidgetStreamingTextValueDelta,
78+
WorkflowItem,
79+
WorkflowTaskAdded,
80+
WorkflowTaskUpdated,
6981
is_streaming_req,
7082
)
7183
from .version import __version__
@@ -316,6 +328,56 @@ def action(
316328
"See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
317329
)
318330

331+
def get_stream_options(
332+
self, thread: ThreadMetadata, context: TContext
333+
) -> StreamOptions:
334+
"""
335+
Return stream-level runtime options. Allows the user to cancel the stream by default.
336+
Override this method to customize behavior.
337+
"""
338+
return StreamOptions(allow_cancel=True)
339+
340+
async def handle_stream_cancelled(
341+
self,
342+
thread: ThreadMetadata,
343+
pending_items: list[ThreadItem],
344+
context: TContext,
345+
):
346+
"""Perform custom cleanup / stop inference when a stream is cancelled.
347+
Updates you make here will not be reflected in the UI until a reload.
348+
349+
The default implementation persists any non-empty pending assistant messages
350+
to the thread but does not auto-save pending widget items or workflow items.
351+
352+
Args:
353+
thread: The thread that was being processed.
354+
pending_items: Items that were not done streaming at cancellation time.
355+
context: Arbitrary per-request context provided by the caller.
356+
"""
357+
pending_assistant_message_items: list[AssistantMessageItem] = [
358+
item for item in pending_items if isinstance(item, AssistantMessageItem)
359+
]
360+
for item in pending_assistant_message_items:
361+
is_empty = len(item.content) == 0 or all(
362+
(not content.text.strip()) for content in item.content
363+
)
364+
if not is_empty:
365+
await self.store.add_thread_item(thread.id, item, context=context)
366+
367+
# Add a hidden context item to the thread to indicate that the stream was cancelled.
368+
# Otherwise, depending on the timing of the cancellation, subsequent responses may
369+
# attempt to continue the cancelled response.
370+
await self.store.add_thread_item(
371+
thread.id,
372+
SDKHiddenContextItem(
373+
thread_id=thread.id,
374+
created_at=datetime.now(),
375+
id=self.store.generate_item_id("sdk_hidden_context", thread, context),
376+
content="The user cancelled the stream. Stop responding to the prior request.",
377+
),
378+
context=context,
379+
)
380+
319381
async def process(
320382
self, request: str | bytes | bytearray, context: TContext
321383
) -> StreamingResult | NonStreamingResult:
@@ -387,11 +449,11 @@ async def _process_non_streaming(
387449
after=items_list_params.after,
388450
context=context,
389451
)
390-
# filter out HiddenContextItems
452+
# filter out hidden context items
391453
items.data = [
392454
item
393455
for item in items.data
394-
if not isinstance(item, HiddenContextItem)
456+
if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
395457
]
396458
return self._serialize(items)
397459
case ThreadsUpdateReq():
@@ -416,6 +478,9 @@ async def _process_streaming(
416478
async for event in self._process_streaming_impl(request, context):
417479
b = self._serialize(event)
418480
yield b"data: " + b + b"\n\n"
481+
except asyncio.CancelledError:
482+
# Let cancellation bubble up without logging as an error.
483+
raise
419484
except Exception:
420485
logger.exception("Error while generating streamed response")
421486
raise
@@ -612,29 +677,51 @@ async def _process_events(
612677
) -> AsyncIterator[ThreadStreamEvent]:
613678
await asyncio.sleep(0) # allow the response to start streaming
614679

680+
# Send initial stream options
681+
yield StreamOptionsEvent(
682+
stream_options=self.get_stream_options(thread, context)
683+
)
684+
615685
last_thread = thread.model_copy(deep=True)
616686

687+
# Keep track of items that were streamed but not yet saved
688+
# so that we can persist them when the stream is cancelled.
689+
pending_items: dict[str, ThreadItem] = {}
690+
617691
try:
618692
with agents_sdk_user_agent_override():
619693
async for event in stream():
694+
if isinstance(event, ThreadItemAddedEvent):
695+
pending_items[event.item.id] = event.item
696+
620697
match event:
621698
case ThreadItemDoneEvent():
622699
await self.store.add_thread_item(
623700
thread.id, event.item, context=context
624701
)
702+
pending_items.pop(event.item.id, None)
625703
case ThreadItemRemovedEvent():
626704
await self.store.delete_thread_item(
627705
thread.id, event.item_id, context=context
628706
)
707+
pending_items.pop(event.item_id, None)
629708
case ThreadItemReplacedEvent():
630709
await self.store.save_item(
631710
thread.id, event.item, context=context
632711
)
712+
pending_items.pop(event.item.id, None)
713+
case ThreadItemUpdatedEvent():
714+
# Keep pending assistant message and workflow items up to date
715+
# so that we have a reference to the latest version of these pending items
716+
# when the stream is cancelled.
717+
self._update_pending_items(pending_items, event)
633718

634719
# special case - don't send hidden context items back to the client
635720
should_swallow_event = isinstance(
636721
event, ThreadItemDoneEvent
637-
) and isinstance(event.item, HiddenContextItem)
722+
) and isinstance(
723+
event.item, (HiddenContextItem, SDKHiddenContextItem)
724+
)
638725

639726
if not should_swallow_event:
640727
yield event
@@ -651,6 +738,11 @@ async def _process_events(
651738
last_thread = thread.model_copy(deep=True)
652739
await self.store.save_thread(thread, context=context)
653740
yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
741+
except asyncio.CancelledError:
742+
await self.handle_stream_cancelled(
743+
thread, list(pending_items.values()), context
744+
)
745+
raise
654746
except CustomStreamError as e:
655747
yield ErrorEvent(
656748
code="custom",
@@ -674,6 +766,69 @@ async def _process_events(
674766
await self.store.save_thread(thread, context=context)
675767
yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
676768

769+
def _apply_assistant_message_update(
770+
self,
771+
item: AssistantMessageItem,
772+
update: AssistantMessageContentPartAdded
773+
| AssistantMessageContentPartTextDelta
774+
| AssistantMessageContentPartAnnotationAdded
775+
| AssistantMessageContentPartDone,
776+
) -> AssistantMessageItem:
777+
updated = item.model_copy(deep=True)
778+
779+
# Pad the content list so the requested content_index exists before we write into it.
780+
# (Streaming updates can arrive for an index that hasn’t been created yet)
781+
while len(updated.content) <= update.content_index:
782+
updated.content.append(AssistantMessageContent(text="", annotations=[]))
783+
784+
match update:
785+
case AssistantMessageContentPartAdded():
786+
updated.content[update.content_index] = update.content
787+
case AssistantMessageContentPartTextDelta():
788+
updated.content[update.content_index].text += update.delta
789+
case AssistantMessageContentPartAnnotationAdded():
790+
annotations = updated.content[update.content_index].annotations
791+
if update.annotation_index <= len(annotations):
792+
annotations.insert(update.annotation_index, update.annotation)
793+
else:
794+
annotations.append(update.annotation)
795+
case AssistantMessageContentPartDone():
796+
updated.content[update.content_index] = update.content
797+
return updated
798+
799+
def _update_pending_items(
800+
self,
801+
pending_items: dict[str, ThreadItem],
802+
event: ThreadItemUpdatedEvent,
803+
):
804+
updated_item = pending_items.get(event.item_id)
805+
update = event.update
806+
match updated_item:
807+
case AssistantMessageItem():
808+
if isinstance(
809+
update,
810+
(
811+
AssistantMessageContentPartAdded,
812+
AssistantMessageContentPartTextDelta,
813+
AssistantMessageContentPartAnnotationAdded,
814+
AssistantMessageContentPartDone,
815+
),
816+
):
817+
pending_items[updated_item.id] = (
818+
self._apply_assistant_message_update(updated_item, update)
819+
)
820+
case WorkflowItem():
821+
if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
822+
match update:
823+
case WorkflowTaskUpdated():
824+
updated_item.workflow.tasks[update.task_index] = update.task
825+
case WorkflowTaskAdded():
826+
updated_item.workflow.tasks.append(update.task)
827+
828+
pending_items[updated_item.id] = updated_item
829+
case _:
830+
pass
831+
677832
async def _build_user_message_item(
678833
self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
679834
) -> UserMessageItem:
@@ -722,7 +877,7 @@ def _serialize(self, obj: BaseModel) -> bytes:
722877

723878
def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
724879
def is_hidden(item: ThreadItem) -> bool:
725-
return isinstance(item, HiddenContextItem)
880+
return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
726881

727882
items = thread.items if isinstance(thread, Thread) else Page()
728883
items.data = [item for item in items.data if not is_hidden(item)]

chatkit/store.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
TContext = TypeVar("TContext", default=Any)
1616

1717
StoreItemType = Literal[
18-
"thread", "message", "tool_call", "task", "workflow", "attachment"
18+
"thread",
19+
"message",
20+
"tool_call",
21+
"task",
22+
"workflow",
23+
"attachment",
24+
"sdk_hidden_context",
1925
]
2026

2127

@@ -26,6 +32,7 @@
2632
"workflow": "wf",
2733
"task": "tsk",
2834
"attachment": "atc",
35+
"sdk_hidden_context": "shcx",
2936
}
3037

3138

chatkit/types.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,20 @@ class ThreadItemReplacedEvent(BaseModel):
317317
item: ThreadItem
318318

319319

320+
class StreamOptions(BaseModel):
321+
"""Settings that control runtime stream behavior."""
322+
323+
allow_cancel: bool
324+
"""Allow the client to request cancellation mid-stream."""
325+
326+
327+
class StreamOptionsEvent(BaseModel):
328+
"""Event emitted to set stream options at runtime."""
329+
330+
type: Literal["stream_options"] = "stream_options"
331+
stream_options: StreamOptions
332+
333+
320334
class ProgressUpdateEvent(BaseModel):
321335
"""Event providing incremental progress from the assistant."""
322336

@@ -354,6 +368,7 @@ class NoticeEvent(BaseModel):
354368
| ThreadItemUpdated
355369
| ThreadItemRemovedEvent
356370
| ThreadItemReplacedEvent
371+
| StreamOptionsEvent
357372
| ProgressUpdateEvent
358373
| ErrorEvent
359374
| NoticeEvent,
@@ -576,12 +591,25 @@ class EndOfTurnItem(ThreadItemBase):
576591

577592

578593
class HiddenContextItem(ThreadItemBase):
579-
"""HiddenContext is never sent to the client. It's not officially part of ChatKit. It is only used internally to store additional context in a specific place in the thread."""
594+
"""
595+
HiddenContext is never sent to the client. It's not officially part of ChatKit.js.
596+
It is only used internally to store additional context in a specific place in the thread.
597+
"""
580598

581599
type: Literal["hidden_context_item"] = "hidden_context_item"
582600
content: Any
583601

584602

603+
class SDKHiddenContextItem(ThreadItemBase):
604+
"""
605+
Hidden context that is used by the ChatKit Python SDK for storing additional context
606+
for internal operations.
607+
"""
608+
609+
type: Literal["sdk_hidden_context"] = "sdk_hidden_context"
610+
content: str
611+
612+
585613
ThreadItem = Annotated[
586614
UserMessageItem
587615
| AssistantMessageItem
@@ -590,6 +618,7 @@ class HiddenContextItem(ThreadItemBase):
590618
| WorkflowItem
591619
| TaskItem
592620
| HiddenContextItem
621+
| SDKHiddenContextItem
593622
| EndOfTurnItem,
594623
Field(discriminator="type"),
595624
]

0 commit comments

Comments
 (0)