Skip to content

Commit c3debdf

Browse files
committed
feat: add usage to streamin response
1 parent f31ac2e commit c3debdf

File tree

3 files changed

+140
-82
lines changed

3 files changed

+140
-82
lines changed

llama_cpp/llama.py

Lines changed: 110 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,50 @@ def decode_batch(seq_sizes: List[int]):
10541054
else:
10551055
return output
10561056

1057+
def _create_chunk(
1058+
self,
1059+
completion_id: str,
1060+
created: int,
1061+
model_name: str,
1062+
text: str,
1063+
logprobs_or_none: Union[Optional[CompletionLogprobs], None],
1064+
index: int,
1065+
finish_reason: Union[str, None],
1066+
usage: Optional[Dict[str, Any]] = None,
1067+
) -> CreateCompletionStreamResponse:
1068+
"""Create chunks for streaming API, depending on whether usage is requested or not."""
1069+
if usage is not None:
1070+
return {
1071+
"id": completion_id,
1072+
"object": "text_completion",
1073+
"created": created,
1074+
"model": model_name,
1075+
"choices": [
1076+
{
1077+
"text": text,
1078+
"index": index,
1079+
"logprobs": logprobs_or_none,
1080+
"finish_reason": finish_reason,
1081+
}
1082+
],
1083+
"usage": usage,
1084+
}
1085+
else:
1086+
return {
1087+
"id": completion_id,
1088+
"object": "text_completion",
1089+
"created": created,
1090+
"model": model_name,
1091+
"choices": [
1092+
{
1093+
"text": text,
1094+
"index": index,
1095+
"logprobs": logprobs_or_none,
1096+
"finish_reason": finish_reason,
1097+
}
1098+
],
1099+
}
1100+
10571101
def _create_completion(
10581102
self,
10591103
prompt: Union[str, List[int]],
@@ -1380,24 +1424,20 @@ def logit_bias_processor(
13801424
"top_logprobs": [top_logprob],
13811425
}
13821426
returned_tokens += 1
1383-
yield {
1384-
"id": completion_id,
1385-
"object": "text_completion",
1386-
"created": created,
1387-
"model": model_name,
1388-
"choices": [
1389-
{
1390-
"text": self.detokenize(
1391-
[token],
1392-
prev_tokens=prompt_tokens
1393-
+ completion_tokens[:returned_tokens],
1394-
).decode("utf-8", errors="ignore"),
1395-
"index": 0,
1396-
"logprobs": logprobs_or_none,
1397-
"finish_reason": None,
1398-
}
1399-
],
1400-
}
1427+
yield self._create_chunk(
1428+
completion_id=completion_id,
1429+
created=created,
1430+
model_name=model_name,
1431+
text=self.detokenize(
1432+
[token],
1433+
prev_tokens=prompt_tokens
1434+
+ completion_tokens[:returned_tokens],
1435+
).decode("utf-8", errors="ignore"),
1436+
logprobs_or_none=logprobs_or_none,
1437+
index=0,
1438+
finish_reason=None,
1439+
usage=None,
1440+
)
14011441
else:
14021442
while len(remaining_tokens) > 0:
14031443
decode_success = False
@@ -1426,20 +1466,16 @@ def logit_bias_processor(
14261466
remaining_tokens = remaining_tokens[i:]
14271467
returned_tokens += i
14281468

1429-
yield {
1430-
"id": completion_id,
1431-
"object": "text_completion",
1432-
"created": created,
1433-
"model": model_name,
1434-
"choices": [
1435-
{
1436-
"text": ts,
1437-
"index": 0,
1438-
"logprobs": None,
1439-
"finish_reason": None,
1440-
}
1441-
],
1442-
}
1469+
yield self._create_chunk(
1470+
completion_id=completion_id,
1471+
created=created,
1472+
model_name=model_name,
1473+
text=ts,
1474+
logprobs_or_none=None,
1475+
index=0,
1476+
finish_reason=None,
1477+
usage=None,
1478+
)
14431479

14441480
if len(completion_tokens) >= max_tokens:
14451481
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
@@ -1518,54 +1554,51 @@ def logit_bias_processor(
15181554
if token_end_position == end - 1:
15191555
break
15201556
returned_tokens += 1
1521-
yield {
1522-
"id": completion_id,
1523-
"object": "text_completion",
1524-
"created": created,
1525-
"model": model_name,
1526-
"choices": [
1527-
{
1528-
"text": last_text[
1529-
: len(last_text) - (token_end_position - end)
1530-
].decode("utf-8", errors="ignore"),
1531-
"index": 0,
1532-
"logprobs": logprobs_or_none,
1533-
"finish_reason": None,
1534-
}
1535-
],
1536-
}
1557+
yield self._create_chunk(
1558+
completion_id=completion_id,
1559+
created=created,
1560+
model_name=model_name,
1561+
text=last_text[
1562+
: len(last_text) - (token_end_position - end)
1563+
].decode("utf-8", errors="ignore"),
1564+
logprobs_or_none=logprobs_or_none,
1565+
index=0,
1566+
finish_reason=None,
1567+
usage=None,
1568+
)
15371569
break
15381570
returned_tokens += 1
1539-
yield {
1540-
"id": completion_id,
1541-
"object": "text_completion",
1542-
"created": created,
1543-
"model": model_name,
1544-
"choices": [
1545-
{
1546-
"text": self.detokenize([token]).decode(
1547-
"utf-8", errors="ignore"
1548-
),
1549-
"index": 0,
1550-
"logprobs": logprobs_or_none,
1551-
"finish_reason": None,
1552-
}
1553-
],
1554-
}
1555-
yield {
1556-
"id": completion_id,
1557-
"object": "text_completion",
1558-
"created": created,
1559-
"model": model_name,
1560-
"choices": [
1561-
{
1562-
"text": "",
1563-
"index": 0,
1564-
"logprobs": None,
1565-
"finish_reason": finish_reason,
1566-
}
1567-
],
1571+
yield self._create_chunk(
1572+
completion_id=completion_id,
1573+
created=created,
1574+
model_name=model_name,
1575+
text=self.detokenize([token]).decode(
1576+
"utf-8", errors="ignore"
1577+
),
1578+
logprobs_or_none=logprobs_or_none,
1579+
index=0,
1580+
finish_reason=None,
1581+
usage=None,
1582+
)
1583+
1584+
# Final streaming chunk with both finish_reason and usage
1585+
usage = {
1586+
"prompt_tokens": len(prompt_tokens),
1587+
"completion_tokens": returned_tokens,
1588+
"total_tokens": len(prompt_tokens) + returned_tokens,
15681589
}
1590+
1591+
yield self._create_chunk(
1592+
completion_id=completion_id,
1593+
created=created,
1594+
model_name=model_name,
1595+
text="",
1596+
logprobs_or_none=None,
1597+
index=0,
1598+
finish_reason=finish_reason,
1599+
usage=usage,
1600+
)
1601+
15691602
if self.cache:
15701603
if self.verbose:
15711604
print("Llama._create_completion: cache save", file=sys.stderr)

llama_cpp/llama_chat_format.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def _convert_text_completion_chunks_to_chat(
350350
"finish_reason": chunk["choices"][0]["finish_reason"],
351351
}
352352
],
353+
"usage": chunk.get("usage") if "usage" in chunk else None,
353354
}
354355

355356

@@ -434,7 +435,7 @@ def _stream_response_to_function_stream(
434435
created = chunk["created"]
435436
model = chunk["model"]
436437
tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"]
437-
yield {
438+
response = {
438439
"id": id_,
439440
"object": "chat.completion.chunk",
440441
"created": created,
@@ -453,7 +454,11 @@ def _stream_response_to_function_stream(
453454
}
454455
],
455456
}
456-
yield {
457+
if "usage" in chunk:
458+
response["usage"] = chunk["usage"]
459+
yield response
460+
461+
response = {
457462
"id": "chat" + chunk["id"],
458463
"object": "chat.completion.chunk",
459464
"created": chunk["created"],
@@ -487,10 +492,14 @@ def _stream_response_to_function_stream(
487492
}
488493
],
489494
}
495+
if "usage" in chunk:
496+
response["usage"] = chunk["usage"]
497+
yield response
490498
first = False
491499
continue
500+
492501
assert tool_id is not None
493-
yield {
502+
response = {
494503
"id": "chat" + chunk["id"],
495504
"object": "chat.completion.chunk",
496505
"created": chunk["created"],
@@ -522,9 +531,12 @@ def _stream_response_to_function_stream(
522531
}
523532
],
524533
}
534+
if "usage" in chunk:
535+
response["usage"] = chunk["usage"]
536+
yield response
525537

526538
if id_ is not None and created is not None and model is not None:
527-
yield {
539+
response = {
528540
"id": id_,
529541
"object": "chat.completion.chunk",
530542
"created": created,
@@ -543,6 +555,9 @@ def _stream_response_to_function_stream(
543555
}
544556
],
545557
}
558+
if "usage" in chunk:
559+
response["usage"] = chunk["usage"]
560+
yield response
546561

547562
return _stream_response_to_function_stream(chunks)
548563

@@ -2123,6 +2138,7 @@ def generate_streaming(tools, functions, function_call, prompt):
21232138
},
21242139
}
21252140
],
2141+
usage=chunk["usage"] if "usage" in chunk else None,
21262142
)
21272143
first = False
21282144
if tools is not None:
@@ -2163,6 +2179,7 @@ def generate_streaming(tools, functions, function_call, prompt):
21632179
},
21642180
}
21652181
],
2182+
usage=chunk["usage"] if "usage" in chunk else None,
21662183
)
21672184
# Yield tool_call/function_call stop message
21682185
yield llama_types.CreateChatCompletionStreamResponse(
@@ -2185,6 +2202,7 @@ def generate_streaming(tools, functions, function_call, prompt):
21852202
},
21862203
}
21872204
],
2205+
usage=chunk["usage"] if "usage" in chunk else None,
21882206
)
21892207
# If "auto" or no tool_choice/function_call
21902208
elif isinstance(function_call, str) and function_call == "auto":
@@ -2220,6 +2238,7 @@ def generate_streaming(tools, functions, function_call, prompt):
22202238
"finish_reason": None,
22212239
}
22222240
],
2241+
usage=chunk["usage"] if "usage" in chunk else None,
22232242
)
22242243
else:
22252244
prompt += f"{function_name}\n<|content|>"
@@ -2265,6 +2284,7 @@ def generate_streaming(tools, functions, function_call, prompt):
22652284
},
22662285
}
22672286
],
2287+
usage=chunk["usage"] if "usage" in chunk else None,
22682288
)
22692289
# Generate content
22702290
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
@@ -2302,6 +2322,7 @@ def generate_streaming(tools, functions, function_call, prompt):
23022322
},
23032323
}
23042324
],
2325+
usage=chunk["usage"] if "usage" in chunk else None,
23052326
)
23062327
is_end = False
23072328
elif chunk["choices"][0]["text"] == "\n":
@@ -2331,6 +2352,7 @@ def generate_streaming(tools, functions, function_call, prompt):
23312352
},
23322353
}
23332354
],
2355+
usage=chunk["usage"] if "usage" in chunk else None,
23342356
)
23352357
# Check whether the model wants to generate another turn
23362358
if (
@@ -2363,6 +2385,7 @@ def generate_streaming(tools, functions, function_call, prompt):
23632385
"finish_reason": "stop",
23642386
}
23652387
],
2388+
usage=chunk["usage"] if "usage" in chunk else None,
23662389
)
23672390
break
23682391
else:
@@ -2412,6 +2435,7 @@ def generate_streaming(tools, functions, function_call, prompt):
24122435
},
24132436
}
24142437
],
2438+
usage=chunk["usage"] if "usage" in chunk else None,
24152439
)
24162440
prompt += completion_text.strip()
24172441
grammar = None
@@ -2451,6 +2475,7 @@ def generate_streaming(tools, functions, function_call, prompt):
24512475
},
24522476
}
24532477
],
2478+
usage=chunk["usage"] if "usage" in chunk else None,
24542479
)
24552480
break
24562481

llama_cpp/llama_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ class ChatCompletionStreamResponseChoice(TypedDict):
154154
finish_reason: Optional[Literal["stop", "length", "tool_calls", "function_call"]]
155155
logprobs: NotRequired[Optional[ChatCompletionLogprobs]]
156156

157-
158157
class CreateChatCompletionStreamResponse(TypedDict):
159158
id: str
160159
model: str
161160
object: Literal["chat.completion.chunk"]
162161
created: int
163162
choices: List[ChatCompletionStreamResponseChoice]
163+
usage: NotRequired[CompletionUsage]
164164

165165

166166
class ChatCompletionFunctions(TypedDict):

0 commit comments

Comments
 (0)