Skip to content

Commit 7d5adac

Browse files
Add response token count logic to OpenAI instrumentation. (#1498)
* Add OpenAI token counts. * Add token counts to langchain + openai tests. * Remove unused expected events. * Linting * Add OpenAI token counts. * Add token counts to langchain + openai tests. * Remove unused expected events. * [MegaLinter] Apply linters fixes --------- Co-authored-by: Tim Pansino <timpansino@gmail.com>
1 parent 19a7291 commit 7d5adac

14 files changed

+241
-500
lines changed

newrelic/hooks/mlmodel_openai.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ def create_chat_completion_message_event(
129129
span_id,
130130
trace_id,
131131
response_model,
132-
request_model,
133132
response_id,
134133
request_id,
135134
llm_metadata,
136135
output_message_list,
136+
all_token_counts,
137137
):
138138
settings = transaction.settings if transaction.settings is not None else global_settings()
139139

@@ -153,11 +153,6 @@ def create_chat_completion_message_event(
153153
"request_id": request_id,
154154
"span_id": span_id,
155155
"trace_id": trace_id,
156-
"token_count": (
157-
settings.ai_monitoring.llm_token_count_callback(request_model, message_content)
158-
if settings.ai_monitoring.llm_token_count_callback
159-
else None
160-
),
161156
"role": message.get("role"),
162157
"completion_id": chat_completion_id,
163158
"sequence": index,
@@ -166,6 +161,9 @@ def create_chat_completion_message_event(
166161
"ingest_source": "Python",
167162
}
168163

164+
if all_token_counts:
165+
chat_completion_input_message_dict["token_count"] = 0
166+
169167
if settings.ai_monitoring.record_content.enabled:
170168
chat_completion_input_message_dict["content"] = message_content
171169

@@ -193,11 +191,6 @@ def create_chat_completion_message_event(
193191
"request_id": request_id,
194192
"span_id": span_id,
195193
"trace_id": trace_id,
196-
"token_count": (
197-
settings.ai_monitoring.llm_token_count_callback(response_model, message_content)
198-
if settings.ai_monitoring.llm_token_count_callback
199-
else None
200-
),
201194
"role": message.get("role"),
202195
"completion_id": chat_completion_id,
203196
"sequence": index,
@@ -207,6 +200,9 @@ def create_chat_completion_message_event(
207200
"is_response": True,
208201
}
209202

203+
if all_token_counts:
204+
chat_completion_output_message_dict["token_count"] = 0
205+
210206
if settings.ai_monitoring.record_content.enabled:
211207
chat_completion_output_message_dict["content"] = message_content
212208

@@ -280,15 +276,18 @@ def _record_embedding_success(transaction, embedding_id, linking_metadata, kwarg
280276
else getattr(attribute_response, "organization", None)
281277
)
282278

279+
response_total_tokens = attribute_response.get("usage", {}).get("total_tokens") if response else None
280+
281+
total_tokens = (
282+
settings.ai_monitoring.llm_token_count_callback(response_model, input_)
283+
if settings.ai_monitoring.llm_token_count_callback and input_
284+
else response_total_tokens
285+
)
286+
283287
full_embedding_response_dict = {
284288
"id": embedding_id,
285289
"span_id": span_id,
286290
"trace_id": trace_id,
287-
"token_count": (
288-
settings.ai_monitoring.llm_token_count_callback(response_model, input_)
289-
if settings.ai_monitoring.llm_token_count_callback
290-
else None
291-
),
292291
"request.model": kwargs.get("model") or kwargs.get("engine"),
293292
"request_id": request_id,
294293
"duration": ft.duration * 1000,
@@ -313,6 +312,7 @@ def _record_embedding_success(transaction, embedding_id, linking_metadata, kwarg
313312
"response.headers.ratelimitRemainingRequests": check_rate_limit_header(
314313
response_headers, "x-ratelimit-remaining-requests", True
315314
),
315+
"response.usage.total_tokens": total_tokens,
316316
"vendor": "openai",
317317
"ingest_source": "Python",
318318
}
@@ -475,12 +475,15 @@ def _handle_completion_success(transaction, linking_metadata, completion_id, kwa
475475

476476

477477
def _record_completion_success(transaction, linking_metadata, completion_id, kwargs, ft, response_headers, response):
478+
settings = transaction.settings if transaction.settings is not None else global_settings()
478479
span_id = linking_metadata.get("span.id")
479480
trace_id = linking_metadata.get("trace.id")
481+
480482
try:
481483
if response:
482484
response_model = response.get("model")
483485
response_id = response.get("id")
486+
token_usage = response.get("usage") or {}
484487
output_message_list = []
485488
finish_reason = None
486489
choices = response.get("choices") or []
@@ -494,6 +497,7 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
494497
else:
495498
response_model = kwargs.get("response.model")
496499
response_id = kwargs.get("id")
500+
token_usage = {}
497501
output_message_list = []
498502
finish_reason = kwargs.get("finish_reason")
499503
if "content" in kwargs:
@@ -505,10 +509,44 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
505509
output_message_list = []
506510
request_model = kwargs.get("model") or kwargs.get("engine")
507511

508-
request_id = response_headers.get("x-request-id")
509-
organization = response_headers.get("openai-organization") or getattr(response, "organization", None)
510512
messages = kwargs.get("messages") or [{"content": kwargs.get("prompt"), "role": "user"}]
511513
input_message_list = list(messages)
514+
515+
# Extract token counts from response object
516+
if token_usage:
517+
response_prompt_tokens = token_usage.get("prompt_tokens")
518+
response_completion_tokens = token_usage.get("completion_tokens")
519+
response_total_tokens = token_usage.get("total_tokens")
520+
521+
else:
522+
response_prompt_tokens = None
523+
response_completion_tokens = None
524+
response_total_tokens = None
525+
526+
# Calculate token counts by checking if a callback is registered and if we have the necessary content to pass
527+
# to it. If not, then we use the token counts provided in the response object
528+
input_message_content = " ".join([msg.get("content", "") for msg in input_message_list if msg.get("content")])
529+
prompt_tokens = (
530+
settings.ai_monitoring.llm_token_count_callback(request_model, input_message_content)
531+
if settings.ai_monitoring.llm_token_count_callback and input_message_content
532+
else response_prompt_tokens
533+
)
534+
output_message_content = " ".join([msg.get("content", "") for msg in output_message_list if msg.get("content")])
535+
completion_tokens = (
536+
settings.ai_monitoring.llm_token_count_callback(response_model, output_message_content)
537+
if settings.ai_monitoring.llm_token_count_callback and output_message_content
538+
else response_completion_tokens
539+
)
540+
541+
total_tokens = (
542+
prompt_tokens + completion_tokens if all([prompt_tokens, completion_tokens]) else response_total_tokens
543+
)
544+
545+
all_token_counts = bool(prompt_tokens and completion_tokens and total_tokens)
546+
547+
request_id = response_headers.get("x-request-id")
548+
organization = response_headers.get("openai-organization") or getattr(response, "organization", None)
549+
512550
full_chat_completion_summary_dict = {
513551
"id": completion_id,
514552
"span_id": span_id,
@@ -553,6 +591,12 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
553591
),
554592
"response.number_of_messages": len(input_message_list) + len(output_message_list),
555593
}
594+
595+
if all_token_counts:
596+
full_chat_completion_summary_dict["response.usage.prompt_tokens"] = prompt_tokens
597+
full_chat_completion_summary_dict["response.usage.completion_tokens"] = completion_tokens
598+
full_chat_completion_summary_dict["response.usage.total_tokens"] = total_tokens
599+
556600
llm_metadata = _get_llm_attributes(transaction)
557601
full_chat_completion_summary_dict.update(llm_metadata)
558602
transaction.record_custom_event("LlmChatCompletionSummary", full_chat_completion_summary_dict)
@@ -564,11 +608,11 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
564608
span_id,
565609
trace_id,
566610
response_model,
567-
request_model,
568611
response_id,
569612
request_id,
570613
llm_metadata,
571614
output_message_list,
615+
all_token_counts,
572616
)
573617
except Exception:
574618
_logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, traceback.format_exception(*sys.exc_info()))
@@ -579,6 +623,7 @@ def _record_completion_error(transaction, linking_metadata, completion_id, kwarg
579623
trace_id = linking_metadata.get("trace.id")
580624
request_message_list = kwargs.get("messages", None) or []
581625
notice_error_attributes = {}
626+
582627
try:
583628
if OPENAI_V1:
584629
response = getattr(exc, "response", None)
@@ -643,18 +688,20 @@ def _record_completion_error(transaction, linking_metadata, completion_id, kwarg
643688
output_message_list = []
644689
if "content" in kwargs:
645690
output_message_list = [{"content": kwargs.get("content"), "role": kwargs.get("role")}]
691+
646692
create_chat_completion_message_event(
647693
transaction,
648694
request_message_list,
649695
completion_id,
650696
span_id,
651697
trace_id,
652698
kwargs.get("response.model"),
653-
request_model,
654699
response_id,
655700
request_id,
656701
llm_metadata,
657702
output_message_list,
703+
# We do not record token counts in error cases, so set all_token_counts to True so the pipeline tokenizer does not run
704+
all_token_counts=True,
658705
)
659706
except Exception:
660707
_logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, traceback.format_exception(*sys.exc_info()))

tests/mlmodel_langchain/test_chain.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@
359359
"response.headers.ratelimitResetRequests": "20ms",
360360
"response.headers.ratelimitRemainingTokens": 999992,
361361
"response.headers.ratelimitRemainingRequests": 2999,
362+
"response.usage.total_tokens": 8,
362363
"vendor": "openai",
363364
"ingest_source": "Python",
364365
"input": "[[3923, 374, 220, 17, 489, 220, 19, 30]]",
@@ -382,6 +383,7 @@
382383
"response.headers.ratelimitResetRequests": "20ms",
383384
"response.headers.ratelimitRemainingTokens": 999998,
384385
"response.headers.ratelimitRemainingRequests": 2999,
386+
"response.usage.total_tokens": 1,
385387
"vendor": "openai",
386388
"ingest_source": "Python",
387389
"input": "[[10590]]",
@@ -452,6 +454,9 @@
452454
"response.headers.ratelimitResetRequests": "8.64s",
453455
"response.headers.ratelimitRemainingTokens": 199912,
454456
"response.headers.ratelimitRemainingRequests": 9999,
457+
"response.usage.prompt_tokens": 73,
458+
"response.usage.completion_tokens": 375,
459+
"response.usage.total_tokens": 448,
455460
"response.number_of_messages": 3,
456461
},
457462
],
@@ -467,6 +472,7 @@
467472
"sequence": 0,
468473
"response.model": "gpt-3.5-turbo-0125",
469474
"vendor": "openai",
475+
"token_count": 0,
470476
"ingest_source": "Python",
471477
"content": "You are a generator of quiz questions for a seminar. Use the following pieces of retrieved context to generate 5 multiple choice questions (A,B,C,D) on the subject matter. Use a three sentence maximum and keep the answer concise. Render the output as HTML\n\nWhat is 2 + 4?",
472478
},
@@ -483,6 +489,7 @@
483489
"sequence": 1,
484490
"response.model": "gpt-3.5-turbo-0125",
485491
"vendor": "openai",
492+
"token_count": 0,
486493
"ingest_source": "Python",
487494
"content": "math",
488495
},
@@ -499,6 +506,7 @@
499506
"sequence": 2,
500507
"response.model": "gpt-3.5-turbo-0125",
501508
"vendor": "openai",
509+
"token_count": 0,
502510
"ingest_source": "Python",
503511
"is_response": True,
504512
"content": "```html\n<!DOCTYPE html>\n<html>\n<head>\n <title>Math Quiz</title>\n</head>\n<body>\n <h2>Math Quiz Questions</h2>\n <ol>\n <li>What is the result of 5 + 3?</li>\n <ul>\n <li>A) 7</li>\n <li>B) 8</li>\n <li>C) 9</li>\n <li>D) 10</li>\n </ul>\n <li>What is the product of 6 x 7?</li>\n <ul>\n <li>A) 36</li>\n <li>B) 42</li>\n <li>C) 48</li>\n <li>D) 56</li>\n </ul>\n <li>What is the square root of 64?</li>\n <ul>\n <li>A) 6</li>\n <li>B) 7</li>\n <li>C) 8</li>\n <li>D) 9</li>\n </ul>\n <li>What is the result of 12 / 4?</li>\n <ul>\n <li>A) 2</li>\n <li>B) 3</li>\n <li>C) 4</li>\n <li>D) 5</li>\n </ul>\n <li>What is the sum of 15 + 9?</li>\n <ul>\n <li>A) 22</li>\n <li>B) 23</li>\n <li>C) 24</li>\n <li>D) 25</li>\n </ul>\n </ol>\n</body>\n</html>\n```",

tests/mlmodel_openai/test_chat_completion.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import openai
1616
from testing_support.fixtures import override_llm_token_callback_settings, reset_core_stats_engine, validate_attributes
1717
from testing_support.ml_testing_utils import (
18-
add_token_count_to_events,
18+
add_token_counts_to_chat_events,
1919
disabled_ai_monitoring_record_content_settings,
2020
disabled_ai_monitoring_settings,
2121
disabled_ai_monitoring_streaming_settings,
@@ -55,6 +55,9 @@
5555
"response.organization": "new-relic-nkmd8b",
5656
"request.temperature": 0.7,
5757
"request.max_tokens": 100,
58+
"response.usage.completion_tokens": 11,
59+
"response.usage.total_tokens": 64,
60+
"response.usage.prompt_tokens": 53,
5861
"response.choices.finish_reason": "stop",
5962
"response.headers.llmVersion": "2020-10-01",
6063
"response.headers.ratelimitLimitRequests": 200,
@@ -81,6 +84,7 @@
8184
"role": "system",
8285
"completion_id": None,
8386
"sequence": 0,
87+
"token_count": 0,
8488
"response.model": "gpt-3.5-turbo-0613",
8589
"vendor": "openai",
8690
"ingest_source": "Python",
@@ -99,6 +103,7 @@
99103
"role": "user",
100104
"completion_id": None,
101105
"sequence": 1,
106+
"token_count": 0,
102107
"response.model": "gpt-3.5-turbo-0613",
103108
"vendor": "openai",
104109
"ingest_source": "Python",
@@ -117,6 +122,7 @@
117122
"role": "assistant",
118123
"completion_id": None,
119124
"sequence": 2,
125+
"token_count": 0,
120126
"response.model": "gpt-3.5-turbo-0613",
121127
"vendor": "openai",
122128
"is_response": True,
@@ -172,7 +178,7 @@ def test_openai_chat_completion_sync_no_content(set_trace_info):
172178

173179
@reset_core_stats_engine()
174180
@override_llm_token_callback_settings(llm_token_count_callback)
175-
@validate_custom_events(add_token_count_to_events(chat_completion_recorded_events))
181+
@validate_custom_events(add_token_counts_to_chat_events(chat_completion_recorded_events))
176182
# One summary event, one system message, one user message, and one response message from the assistant
177183
@validate_custom_event_count(count=4)
178184
@validate_transaction_metrics(
@@ -343,7 +349,7 @@ def test_openai_chat_completion_async_no_content(loop, set_trace_info):
343349

344350
@reset_core_stats_engine()
345351
@override_llm_token_callback_settings(llm_token_count_callback)
346-
@validate_custom_events(add_token_count_to_events(chat_completion_recorded_events))
352+
@validate_custom_events(add_token_counts_to_chat_events(chat_completion_recorded_events))
347353
# One summary event, one system message, one user message, and one response message from the assistant
348354
@validate_custom_event_count(count=4)
349355
@validate_transaction_metrics(

0 commit comments

Comments
 (0)