Skip to content

Commit 6403dff

Browse files
committed
[feat] Handle annotation deltas in agents.stream_agent_response
1 parent 5c528d0 commit 6403dff

File tree

4 files changed

+172
-25
lines changed

4 files changed

+172
-25
lines changed

chatkit/agents.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from collections import defaultdict
23
import json
34
from collections.abc import AsyncIterator
45
from datetime import datetime
@@ -36,6 +37,12 @@
3637
from openai.types.responses.response_output_text import (
3738
Annotation as ResponsesAnnotation,
3839
)
40+
from openai.types.responses.response_output_text import (
41+
AnnotationContainerFileCitation,
42+
AnnotationFileCitation,
43+
AnnotationFilePath,
44+
AnnotationURLCitation,
45+
)
3946
from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
4047
from typing_extensions import assert_never
4148

@@ -45,6 +52,7 @@
4552
Annotation,
4653
AssistantMessageContent,
4754
AssistantMessageContentPartAdded,
55+
AssistantMessageContentPartAnnotationAdded,
4856
AssistantMessageContentPartDone,
4957
AssistantMessageContentPartTextDelta,
5058
AssistantMessageItem,
@@ -207,9 +215,8 @@ def _complete(self):
207215

208216
def _convert_content(content: Content) -> AssistantMessageContent:
209217
if content.type == "output_text":
210-
annotations = []
211-
for annotation in content.annotations:
212-
annotations.extend(_convert_annotation(annotation))
218+
annotations = [_convert_annotation(annotation) for annotation in content.annotations]
219+
annotations = [a for a in annotations if a is not None]
213220
return AssistantMessageContent(
214221
text=content.text,
215222
annotations=annotations,
@@ -222,36 +229,47 @@ def _convert_content(content: Content) -> AssistantMessageContent:
222229

223230

224231
def _convert_annotation(
225-
annotation: ResponsesAnnotation,
226-
) -> list[Annotation]:
232+
raw_annotation: object
233+
) -> Annotation | None:
227234
# There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
228-
# resulting into annotation being a dict instead of a ResponsesAnnotation
229-
if isinstance(annotation, dict):
230-
annotation = TypeAdapter(ResponsesAnnotation).validate_python(annotation)
235+
# resulting into annotation being a dict.
236+
match raw_annotation:
237+
case AnnotationFileCitation() | AnnotationURLCitation() | AnnotationContainerFileCitation() | AnnotationFilePath():
238+
annotation = raw_annotation
239+
case _:
240+
annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(raw_annotation)
241+
231242

232-
result: list[Annotation] = []
233243
if annotation.type == "file_citation":
234244
filename = annotation.filename
235245
if not filename:
236-
return []
237-
result.append(
238-
Annotation(
246+
return None
247+
248+
return Annotation(
239249
source=FileSource(filename=filename, title=filename),
240250
index=annotation.index,
241251
)
252+
253+
if annotation.type == "url_citation":
254+
return Annotation(
255+
source=URLSource(
256+
url=annotation.url,
257+
title=annotation.title,
258+
),
259+
index=annotation.end_index,
242260
)
243-
elif annotation.type == "url_citation":
244-
result.append(
245-
Annotation(
246-
source=URLSource(
247-
url=annotation.url,
248-
title=annotation.title,
249-
),
261+
262+
if annotation.type == "container_file_citation":
263+
filename = annotation.filename
264+
if not filename:
265+
return None
266+
267+
return Annotation(
268+
source=FileSource(filename=filename, title=filename),
250269
index=annotation.end_index,
251270
)
252-
)
253271

254-
return result
272+
return None
255273

256274

257275
T1 = TypeVar("T1")
@@ -349,6 +367,8 @@ async def stream_agent_response(
349367
queue_iterator = _AsyncQueueIterator(context._events)
350368
produced_items = set()
351369
streaming_thought: None | StreamingThoughtTracker = None
370+
# item_id -> content_index -> annotation count
371+
item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(lambda: defaultdict(int))
352372

353373
# check if the last item in the thread was a workflow or a client tool call
354374
# if it was a client tool call, check if the second last item was a workflow
@@ -462,7 +482,20 @@ def end_workflow(item: WorkflowItem):
462482
),
463483
)
464484
elif event.type == "response.output_text.annotation.added":
465-
# Ignore annotation-added events; annotations are reflected in the final item content.
485+
annotation = _convert_annotation(event.annotation)
486+
if annotation:
487+
# Manually track annotation indices per content part in case we drop an annotation that
488+
# we can't convert to our internal representation (e.g. missing filename).
489+
annotation_index = item_annotation_count[event.item_id][event.content_index]
490+
item_annotation_count[event.item_id][event.content_index] = annotation_index + 1
491+
yield ThreadItemUpdated(
492+
item_id=event.item_id,
493+
update=AssistantMessageContentPartAnnotationAdded(
494+
content_index=event.content_index,
495+
annotation_index=annotation_index,
496+
annotation=annotation,
497+
)
498+
)
466499
continue
467500
elif event.type == "response.output_item.added":
468501
item = event.item

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "openai-chatkit"
3-
version = "1.1.2"
3+
version = "1.2.2"
44
description = "A ChatKit backend SDK."
55
readme = "README.md"
66
requires-python = ">=3.10"

tests/test_agents.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
ResponseContentPartAddedEvent,
3939
)
4040
from openai.types.responses.response_file_search_tool_call import Result
41+
from openai.types.responses.response_output_text import (
42+
AnnotationContainerFileCitation as ResponsesAnnotationContainerFileCitation,
43+
)
4144
from openai.types.responses.response_output_text import (
4245
AnnotationFileCitation as ResponsesAnnotationFileCitation,
4346
)
@@ -64,6 +67,7 @@
6467
Annotation,
6568
AssistantMessageContent,
6669
AssistantMessageContentPartAdded,
70+
AssistantMessageContentPartAnnotationAdded,
6771
AssistantMessageContentPartDone,
6872
AssistantMessageContentPartTextDelta,
6973
AssistantMessageItem,
@@ -790,7 +794,17 @@ async def test_stream_agent_response_maps_events():
790794
sequence_number=3,
791795
),
792796
),
793-
None,
797+
ThreadItemUpdated(
798+
item_id="123",
799+
update=AssistantMessageContentPartAnnotationAdded(
800+
content_index=0,
801+
annotation_index=0,
802+
annotation=Annotation(
803+
source=FileSource(filename="file.txt", title="file.txt"),
804+
index=5,
805+
),
806+
),
807+
),
794808
),
795809
],
796810
)
@@ -810,6 +824,91 @@ async def test_event_mapping(raw_event, expected_event):
810824
assert events == []
811825

812826

827+
async def test_stream_agent_response_emits_annotation_added_events():
828+
context = AgentContext(
829+
previous_response_id=None, thread=thread, store=mock_store, request_context=None
830+
)
831+
result = make_result()
832+
item_id = "item_123"
833+
834+
def add_annotation_event(annotation, sequence_number):
835+
result.add_event(
836+
RawResponsesStreamEvent(
837+
type="raw_response_event",
838+
data=Mock(
839+
type="response.output_text.annotation.added",
840+
annotation=annotation,
841+
content_index=0,
842+
item_id=item_id,
843+
annotation_index=sequence_number,
844+
output_index=0,
845+
sequence_number=sequence_number,
846+
),
847+
)
848+
)
849+
850+
add_annotation_event(
851+
ResponsesAnnotationFileCitation(
852+
type="file_citation",
853+
file_id="file_invalid",
854+
filename="",
855+
index=0,
856+
),
857+
sequence_number=0,
858+
)
859+
add_annotation_event(
860+
ResponsesAnnotationContainerFileCitation(
861+
type="container_file_citation",
862+
container_id="container_1",
863+
file_id="file_123",
864+
filename="container.txt",
865+
start_index=0,
866+
end_index=3,
867+
),
868+
sequence_number=1,
869+
)
870+
add_annotation_event(
871+
ResponsesAnnotationURLCitation(
872+
type="url_citation",
873+
url="https://example.com",
874+
title="Example",
875+
start_index=1,
876+
end_index=5,
877+
),
878+
sequence_number=2,
879+
)
880+
result.done()
881+
882+
events = await all_events(stream_agent_response(context, result))
883+
assert events == [
884+
ThreadItemUpdated(
885+
item_id=item_id,
886+
update=AssistantMessageContentPartAnnotationAdded(
887+
content_index=0,
888+
annotation_index=0,
889+
annotation=Annotation(
890+
source=FileSource(filename="container.txt", title="container.txt"),
891+
index=3,
892+
),
893+
),
894+
),
895+
ThreadItemUpdated(
896+
item_id=item_id,
897+
update=AssistantMessageContentPartAnnotationAdded(
898+
content_index=0,
899+
annotation_index=1,
900+
annotation=Annotation(
901+
source=URLSource(
902+
url="https://example.com",
903+
title="Example",
904+
),
905+
index=5,
906+
),
907+
),
908+
),
909+
]
910+
911+
813912
@pytest.mark.parametrize("throw_guardrail", ["input", "output"])
814913
async def test_stream_agent_response_yields_item_removed_event(throw_guardrail):
815914
context = AgentContext(
@@ -942,6 +1041,14 @@ async def test_stream_agent_response_assistant_message_content_types():
9421041
index=0,
9431042
filename="test.txt",
9441043
),
1044+
ResponsesAnnotationContainerFileCitation(
1045+
type="container_file_citation",
1046+
container_id="container_1",
1047+
file_id="f_456",
1048+
filename="container.txt",
1049+
start_index=0,
1050+
end_index=3,
1051+
),
9451052
ResponsesAnnotationURLCitation(
9461053
type="url_citation",
9471054
url="https://www.google.com",
@@ -994,6 +1101,13 @@ async def test_stream_agent_response_assistant_message_content_types():
9941101
),
9951102
index=0,
9961103
),
1104+
Annotation(
1105+
source=FileSource(
1106+
filename="container.txt",
1107+
title="container.txt",
1108+
),
1109+
index=3,
1110+
),
9971111
Annotation(
9981112
source=URLSource(
9991113
url="https://www.google.com",

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)