Skip to content

Commit 3c2624e

Browse files
authored
Fix Anthropic streaming usage counting (#2771)
1 parent 143b735 commit 3c2624e

File tree

2 files changed

+8
-16
lines changed

2 files changed

+8
-16
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -536,20 +536,15 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
536536
}
537537

538538

539-
def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage:
539+
def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage:
540540
if isinstance(message, BetaMessage):
541541
response_usage = message.usage
542542
elif isinstance(message, BetaRawMessageStartEvent):
543543
response_usage = message.message.usage
544544
elif isinstance(message, BetaRawMessageDeltaEvent):
545545
response_usage = message.usage
546546
else:
547-
# No usage information provided in:
548-
# - RawMessageStopEvent
549-
# - RawContentBlockStartEvent
550-
# - RawContentBlockDeltaEvent
551-
# - RawContentBlockStopEvent
552-
return usage.RequestUsage()
547+
assert_never(message)
553548

554549
# Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
555550
# `response_tokens`
@@ -586,10 +581,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
586581
current_block: BetaContentBlock | None = None
587582

588583
async for event in self._response:
589-
self._usage += _map_usage(event)
590-
591584
if isinstance(event, BetaRawMessageStartEvent):
592-
pass
585+
self._usage = _map_usage(event)
593586

594587
elif isinstance(event, BetaRawContentBlockStartEvent):
595588
current_block = event.content_block
@@ -652,7 +645,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
652645
pass
653646

654647
elif isinstance(event, BetaRawMessageDeltaEvent):
655-
pass
648+
self._usage = _map_usage(event)
656649

657650
elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
658651
current_block = None

tests/models/test_anthropic.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ async def test_stream_structured(allow_model_requests: None):
608608
BetaRawMessageDeltaEvent(
609609
type='message_delta',
610610
delta=Delta(stop_reason='end_turn'),
611-
usage=BetaMessageDeltaUsage(output_tokens=5),
611+
usage=BetaMessageDeltaUsage(input_tokens=20, output_tokens=5),
612612
),
613613
# Mark message as complete
614614
BetaRawMessageStopEvent(type='message_stop'),
@@ -1291,12 +1291,11 @@ def anth_msg(usage: BetaUsage) -> BetaMessage:
12911291
snapshot(RequestUsage(output_tokens=5, details={'output_tokens': 5})),
12921292
id='RawMessageDeltaEvent',
12931293
),
1294-
pytest.param(
1295-
lambda: BetaRawMessageStopEvent(type='message_stop'), snapshot(RequestUsage()), id='RawMessageStopEvent'
1296-
),
12971294
],
12981295
)
1299-
def test_usage(message_callback: Callable[[], BetaMessage | BetaRawMessageStreamEvent], usage: RunUsage):
1296+
def test_usage(
1297+
message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage
1298+
):
13001299
assert _map_usage(message_callback()) == usage
13011300

13021301

0 commit comments

Comments
 (0)