Skip to content

Commit 19a7291

Browse files
Add response token count logic to Gemini instrumentation. (#1486)
* Add response token count logic to Gemini instrumentation. * Update token counting util functions. * Linting * Add response token count logic to Gemini instrumentation. * Update token counting util functions. * [MegaLinter] Apply linters fixes * Bump tests. --------- Co-authored-by: Tim Pansino <timpansino@gmail.com>
1 parent 27f357c commit 19a7291

File tree

6 files changed

+139
-193
lines changed

6 files changed

+139
-193
lines changed

newrelic/hooks/mlmodel_gemini.py

Lines changed: 101 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -175,20 +175,24 @@ def _record_embedding_success(transaction, embedding_id, linking_metadata, kwarg
175175
embedding_content = str(embedding_content)
176176
request_model = kwargs.get("model")
177177

178+
embedding_token_count = (
179+
settings.ai_monitoring.llm_token_count_callback(request_model, embedding_content)
180+
if settings.ai_monitoring.llm_token_count_callback
181+
else None
182+
)
183+
178184
full_embedding_response_dict = {
179185
"id": embedding_id,
180186
"span_id": span_id,
181187
"trace_id": trace_id,
182-
"token_count": (
183-
settings.ai_monitoring.llm_token_count_callback(request_model, embedding_content)
184-
if settings.ai_monitoring.llm_token_count_callback
185-
else None
186-
),
187188
"request.model": request_model,
188189
"duration": ft.duration * 1000,
189190
"vendor": "gemini",
190191
"ingest_source": "Python",
191192
}
193+
if embedding_token_count:
194+
full_embedding_response_dict["response.usage.total_tokens"] = embedding_token_count
195+
192196
if settings.ai_monitoring.record_content.enabled:
193197
full_embedding_response_dict["input"] = embedding_content
194198

@@ -300,15 +304,13 @@ def _record_generation_error(transaction, linking_metadata, completion_id, kwarg
300304
"Unable to parse input message to Gemini LLM. Message content and role will be omitted from "
301305
"corresponding LlmChatCompletionMessage event. "
302306
)
307+
# Extract the input message content and role from the input message if it exists
308+
input_message_content, input_role = _parse_input_message(input_message) if input_message else (None, None)
303309

304-
generation_config = kwargs.get("config")
305-
if generation_config:
306-
request_temperature = getattr(generation_config, "temperature", None)
307-
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
308-
else:
309-
request_temperature = None
310-
request_max_tokens = None
310+
# Extract data from generation config object
311+
request_temperature, request_max_tokens = _extract_generation_config(kwargs)
311312

313+
# Prepare error attributes
312314
notice_error_attributes = {
313315
"http.statusCode": getattr(exc, "code", None),
314316
"error.message": getattr(exc, "message", None),
@@ -348,15 +350,17 @@ def _record_generation_error(transaction, linking_metadata, completion_id, kwarg
348350

349351
create_chat_completion_message_event(
350352
transaction,
351-
input_message,
353+
input_message_content,
354+
input_role,
352355
completion_id,
353356
span_id,
354357
trace_id,
355358
# Passing the request model as the response model here since we do not have access to a response model
356359
request_model,
357-
request_model,
358360
llm_metadata,
359361
output_message_list,
362+
# We do not record token counts in error cases, so set all_token_counts to True so the pipeline tokenizer does not run
363+
all_token_counts=True,
360364
)
361365
except Exception:
362366
_logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True)
@@ -377,6 +381,7 @@ def _handle_generation_success(transaction, linking_metadata, completion_id, kwa
377381

378382

379383
def _record_generation_success(transaction, linking_metadata, completion_id, kwargs, ft, response):
384+
settings = transaction.settings or global_settings()
380385
span_id = linking_metadata.get("span.id")
381386
trace_id = linking_metadata.get("trace.id")
382387
try:
@@ -385,12 +390,14 @@ def _record_generation_success(transaction, linking_metadata, completion_id, kwa
385390
# finish_reason is an enum, so grab just the stringified value from it to report
386391
finish_reason = response.get("candidates")[0].get("finish_reason").value
387392
output_message_list = [response.get("candidates")[0].get("content")]
393+
token_usage = response.get("usage_metadata") or {}
388394
else:
389395
# Set all values to NoneTypes since we cannot access them through kwargs or another method that doesn't
390396
# require the response object
391397
response_model = None
392398
output_message_list = []
393399
finish_reason = None
400+
token_usage = {}
394401

395402
request_model = kwargs.get("model")
396403

@@ -412,13 +419,44 @@ def _record_generation_success(transaction, linking_metadata, completion_id, kwa
412419
"corresponding LlmChatCompletionMessage event. "
413420
)
414421

415-
generation_config = kwargs.get("config")
416-
if generation_config:
417-
request_temperature = getattr(generation_config, "temperature", None)
418-
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
422+
input_message_content, input_role = _parse_input_message(input_message) if input_message else (None, None)
423+
424+
# Parse output message content
425+
# This list should have a length of 1 to represent the output message
426+
# Parse the message text out to pass to any registered token counting callback
427+
output_message_content = output_message_list[0].get("parts")[0].get("text") if output_message_list else None
428+
429+
# Extract token counts from response object
430+
if token_usage:
431+
response_prompt_tokens = token_usage.get("prompt_token_count")
432+
response_completion_tokens = token_usage.get("candidates_token_count")
433+
response_total_tokens = token_usage.get("total_token_count")
434+
419435
else:
420-
request_temperature = None
421-
request_max_tokens = None
436+
response_prompt_tokens = None
437+
response_completion_tokens = None
438+
response_total_tokens = None
439+
440+
# Calculate token counts by checking if a callback is registered and if we have the necessary content to pass
441+
# to it. If not, then we use the token counts provided in the response object
442+
prompt_tokens = (
443+
settings.ai_monitoring.llm_token_count_callback(request_model, input_message_content)
444+
if settings.ai_monitoring.llm_token_count_callback and input_message_content
445+
else response_prompt_tokens
446+
)
447+
completion_tokens = (
448+
settings.ai_monitoring.llm_token_count_callback(response_model, output_message_content)
449+
if settings.ai_monitoring.llm_token_count_callback and output_message_content
450+
else response_completion_tokens
451+
)
452+
total_tokens = (
453+
prompt_tokens + completion_tokens if all([prompt_tokens, completion_tokens]) else response_total_tokens
454+
)
455+
456+
all_token_counts = bool(prompt_tokens and completion_tokens and total_tokens)
457+
458+
# Extract generation config
459+
request_temperature, request_max_tokens = _extract_generation_config(kwargs)
422460

423461
full_chat_completion_summary_dict = {
424462
"id": completion_id,
@@ -438,66 +476,78 @@ def _record_generation_success(transaction, linking_metadata, completion_id, kwa
438476
"response.number_of_messages": 1 + len(output_message_list),
439477
}
440478

479+
if all_token_counts:
480+
full_chat_completion_summary_dict["response.usage.prompt_tokens"] = prompt_tokens
481+
full_chat_completion_summary_dict["response.usage.completion_tokens"] = completion_tokens
482+
full_chat_completion_summary_dict["response.usage.total_tokens"] = total_tokens
483+
441484
llm_metadata = _get_llm_attributes(transaction)
442485
full_chat_completion_summary_dict.update(llm_metadata)
443486
transaction.record_custom_event("LlmChatCompletionSummary", full_chat_completion_summary_dict)
444487

445488
create_chat_completion_message_event(
446489
transaction,
447-
input_message,
490+
input_message_content,
491+
input_role,
448492
completion_id,
449493
span_id,
450494
trace_id,
451495
response_model,
452-
request_model,
453496
llm_metadata,
454497
output_message_list,
498+
all_token_counts,
455499
)
456500
except Exception:
457501
_logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True)
458502

459503

504+
def _parse_input_message(input_message):
505+
# The input_message will be a string if generate_content was called directly. In this case, we don't have
506+
# access to the role, so we default to user since this was an input message
507+
if isinstance(input_message, str):
508+
return input_message, "user"
509+
# The input_message will be a Google Content type if send_message was called, so we parse out the message
510+
# text and role (which should be "user")
511+
elif isinstance(input_message, google.genai.types.Content):
512+
return input_message.parts[0].text, input_message.role
513+
else:
514+
return None, None
515+
516+
517+
def _extract_generation_config(kwargs):
518+
generation_config = kwargs.get("config")
519+
if generation_config:
520+
request_temperature = getattr(generation_config, "temperature", None)
521+
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
522+
else:
523+
request_temperature = None
524+
request_max_tokens = None
525+
526+
return request_temperature, request_max_tokens
527+
528+
460529
def create_chat_completion_message_event(
461530
transaction,
462-
input_message,
531+
input_message_content,
532+
input_role,
463533
chat_completion_id,
464534
span_id,
465535
trace_id,
466536
response_model,
467-
request_model,
468537
llm_metadata,
469538
output_message_list,
539+
all_token_counts,
470540
):
471541
try:
472542
settings = transaction.settings or global_settings()
473543

474-
if input_message:
475-
# The input_message will be a string if generate_content was called directly. In this case, we don't have
476-
# access to the role, so we default to user since this was an input message
477-
if isinstance(input_message, str):
478-
input_message_content = input_message
479-
input_role = "user"
480-
# The input_message will be a Google Content type if send_message was called, so we parse out the message
481-
# text and role (which should be "user")
482-
elif isinstance(input_message, google.genai.types.Content):
483-
input_message_content = input_message.parts[0].text
484-
input_role = input_message.role
485-
# Set input data to NoneTypes to ensure token_count callback is not called
486-
else:
487-
input_message_content = None
488-
input_role = None
489-
544+
if input_message_content:
490545
message_id = str(uuid.uuid4())
491546

492547
chat_completion_input_message_dict = {
493548
"id": message_id,
494549
"span_id": span_id,
495550
"trace_id": trace_id,
496-
"token_count": (
497-
settings.ai_monitoring.llm_token_count_callback(request_model, input_message_content)
498-
if settings.ai_monitoring.llm_token_count_callback and input_message_content
499-
else None
500-
),
501551
"role": input_role,
502552
"completion_id": chat_completion_id,
503553
# The input message will always be the first message in our request/ response sequence so this will
@@ -507,6 +557,8 @@ def create_chat_completion_message_event(
507557
"vendor": "gemini",
508558
"ingest_source": "Python",
509559
}
560+
if all_token_counts:
561+
chat_completion_input_message_dict["token_count"] = 0
510562

511563
if settings.ai_monitoring.record_content.enabled:
512564
chat_completion_input_message_dict["content"] = input_message_content
@@ -523,7 +575,7 @@ def create_chat_completion_message_event(
523575

524576
# Add one to the index to account for the single input message so our sequence value is accurate for
525577
# the output message
526-
if input_message:
578+
if input_message_content:
527579
index += 1
528580

529581
message_id = str(uuid.uuid4())
@@ -532,11 +584,6 @@ def create_chat_completion_message_event(
532584
"id": message_id,
533585
"span_id": span_id,
534586
"trace_id": trace_id,
535-
"token_count": (
536-
settings.ai_monitoring.llm_token_count_callback(response_model, message_content)
537-
if settings.ai_monitoring.llm_token_count_callback
538-
else None
539-
),
540587
"role": message.get("role"),
541588
"completion_id": chat_completion_id,
542589
"sequence": index,
@@ -546,6 +593,9 @@ def create_chat_completion_message_event(
546593
"is_response": True,
547594
}
548595

596+
if all_token_counts:
597+
chat_completion_output_message_dict["token_count"] = 0
598+
549599
if settings.ai_monitoring.record_content.enabled:
550600
chat_completion_output_message_dict["content"] = message_content
551601

tests/mlmodel_gemini/test_embeddings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import google.genai
1616
from testing_support.fixtures import override_llm_token_callback_settings, reset_core_stats_engine, validate_attributes
1717
from testing_support.ml_testing_utils import (
18-
add_token_count_to_events,
18+
add_token_count_to_embedding_events,
1919
disabled_ai_monitoring_record_content_settings,
2020
disabled_ai_monitoring_settings,
2121
events_sans_content,
@@ -93,7 +93,7 @@ def test_gemini_embedding_sync_no_content(gemini_dev_client, set_trace_info):
9393

9494
@reset_core_stats_engine()
9595
@override_llm_token_callback_settings(llm_token_count_callback)
96-
@validate_custom_events(add_token_count_to_events(embedding_recorded_events))
96+
@validate_custom_events(add_token_count_to_embedding_events(embedding_recorded_events))
9797
@validate_custom_event_count(count=1)
9898
@validate_transaction_metrics(
9999
name="test_embeddings:test_gemini_embedding_sync_with_token_count",
@@ -177,7 +177,7 @@ def test_gemini_embedding_async_no_content(gemini_dev_client, loop, set_trace_in
177177

178178
@reset_core_stats_engine()
179179
@override_llm_token_callback_settings(llm_token_count_callback)
180-
@validate_custom_events(add_token_count_to_events(embedding_recorded_events))
180+
@validate_custom_events(add_token_count_to_embedding_events(embedding_recorded_events))
181181
@validate_custom_event_count(count=1)
182182
@validate_transaction_metrics(
183183
name="test_embeddings:test_gemini_embedding_async_with_token_count",

0 commit comments

Comments
 (0)