@@ -129,11 +129,11 @@ def create_chat_completion_message_event(
129129 span_id ,
130130 trace_id ,
131131 response_model ,
132- request_model ,
133132 response_id ,
134133 request_id ,
135134 llm_metadata ,
136135 output_message_list ,
136+ all_token_counts ,
137137):
138138 settings = transaction .settings if transaction .settings is not None else global_settings ()
139139
@@ -153,11 +153,6 @@ def create_chat_completion_message_event(
153153 "request_id" : request_id ,
154154 "span_id" : span_id ,
155155 "trace_id" : trace_id ,
156- "token_count" : (
157- settings .ai_monitoring .llm_token_count_callback (request_model , message_content )
158- if settings .ai_monitoring .llm_token_count_callback
159- else None
160- ),
161156 "role" : message .get ("role" ),
162157 "completion_id" : chat_completion_id ,
163158 "sequence" : index ,
@@ -166,6 +161,9 @@ def create_chat_completion_message_event(
166161 "ingest_source" : "Python" ,
167162 }
168163
164+ if all_token_counts :
165+ chat_completion_input_message_dict ["token_count" ] = 0
166+
169167 if settings .ai_monitoring .record_content .enabled :
170168 chat_completion_input_message_dict ["content" ] = message_content
171169
@@ -193,11 +191,6 @@ def create_chat_completion_message_event(
193191 "request_id" : request_id ,
194192 "span_id" : span_id ,
195193 "trace_id" : trace_id ,
196- "token_count" : (
197- settings .ai_monitoring .llm_token_count_callback (response_model , message_content )
198- if settings .ai_monitoring .llm_token_count_callback
199- else None
200- ),
201194 "role" : message .get ("role" ),
202195 "completion_id" : chat_completion_id ,
203196 "sequence" : index ,
@@ -207,6 +200,9 @@ def create_chat_completion_message_event(
207200 "is_response" : True ,
208201 }
209202
203+ if all_token_counts :
204+ chat_completion_output_message_dict ["token_count" ] = 0
205+
210206 if settings .ai_monitoring .record_content .enabled :
211207 chat_completion_output_message_dict ["content" ] = message_content
212208
@@ -280,15 +276,18 @@ def _record_embedding_success(transaction, embedding_id, linking_metadata, kwarg
280276 else getattr (attribute_response , "organization" , None )
281277 )
282278
279+ response_total_tokens = attribute_response .get ("usage" , {}).get ("total_tokens" ) if response else None
280+
281+ total_tokens = (
282+ settings .ai_monitoring .llm_token_count_callback (response_model , input_ )
283+ if settings .ai_monitoring .llm_token_count_callback and input_
284+ else response_total_tokens
285+ )
286+
283287 full_embedding_response_dict = {
284288 "id" : embedding_id ,
285289 "span_id" : span_id ,
286290 "trace_id" : trace_id ,
287- "token_count" : (
288- settings .ai_monitoring .llm_token_count_callback (response_model , input_ )
289- if settings .ai_monitoring .llm_token_count_callback
290- else None
291- ),
292291 "request.model" : kwargs .get ("model" ) or kwargs .get ("engine" ),
293292 "request_id" : request_id ,
294293 "duration" : ft .duration * 1000 ,
@@ -313,6 +312,7 @@ def _record_embedding_success(transaction, embedding_id, linking_metadata, kwarg
313312 "response.headers.ratelimitRemainingRequests" : check_rate_limit_header (
314313 response_headers , "x-ratelimit-remaining-requests" , True
315314 ),
315+ "response.usage.total_tokens" : total_tokens ,
316316 "vendor" : "openai" ,
317317 "ingest_source" : "Python" ,
318318 }
@@ -475,12 +475,15 @@ def _handle_completion_success(transaction, linking_metadata, completion_id, kwa
475475
476476
477477def _record_completion_success (transaction , linking_metadata , completion_id , kwargs , ft , response_headers , response ):
478+ settings = transaction .settings if transaction .settings is not None else global_settings ()
478479 span_id = linking_metadata .get ("span.id" )
479480 trace_id = linking_metadata .get ("trace.id" )
481+
480482 try :
481483 if response :
482484 response_model = response .get ("model" )
483485 response_id = response .get ("id" )
486+ token_usage = response .get ("usage" ) or {}
484487 output_message_list = []
485488 finish_reason = None
486489 choices = response .get ("choices" ) or []
@@ -494,6 +497,7 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
494497 else :
495498 response_model = kwargs .get ("response.model" )
496499 response_id = kwargs .get ("id" )
500+ token_usage = {}
497501 output_message_list = []
498502 finish_reason = kwargs .get ("finish_reason" )
499503 if "content" in kwargs :
@@ -505,10 +509,44 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
505509 output_message_list = []
506510 request_model = kwargs .get ("model" ) or kwargs .get ("engine" )
507511
508- request_id = response_headers .get ("x-request-id" )
509- organization = response_headers .get ("openai-organization" ) or getattr (response , "organization" , None )
510512 messages = kwargs .get ("messages" ) or [{"content" : kwargs .get ("prompt" ), "role" : "user" }]
511513 input_message_list = list (messages )
514+
515+ # Extract token counts from response object
516+ if token_usage :
517+ response_prompt_tokens = token_usage .get ("prompt_tokens" )
518+ response_completion_tokens = token_usage .get ("completion_tokens" )
519+ response_total_tokens = token_usage .get ("total_tokens" )
520+
521+ else :
522+ response_prompt_tokens = None
523+ response_completion_tokens = None
524+ response_total_tokens = None
525+
526+ # Calculate token counts by checking if a callback is registered and if we have the necessary content to pass
527+ # to it. If not, then we use the token counts provided in the response object
528+ input_message_content = " " .join ([msg .get ("content" , "" ) for msg in input_message_list if msg .get ("content" )])
529+ prompt_tokens = (
530+ settings .ai_monitoring .llm_token_count_callback (request_model , input_message_content )
531+ if settings .ai_monitoring .llm_token_count_callback and input_message_content
532+ else response_prompt_tokens
533+ )
534+ output_message_content = " " .join ([msg .get ("content" , "" ) for msg in output_message_list if msg .get ("content" )])
535+ completion_tokens = (
536+ settings .ai_monitoring .llm_token_count_callback (response_model , output_message_content )
537+ if settings .ai_monitoring .llm_token_count_callback and output_message_content
538+ else response_completion_tokens
539+ )
540+
541+ total_tokens = (
542+ prompt_tokens + completion_tokens if all ([prompt_tokens , completion_tokens ]) else response_total_tokens
543+ )
544+
545+ all_token_counts = bool (prompt_tokens and completion_tokens and total_tokens )
546+
547+ request_id = response_headers .get ("x-request-id" )
548+ organization = response_headers .get ("openai-organization" ) or getattr (response , "organization" , None )
549+
512550 full_chat_completion_summary_dict = {
513551 "id" : completion_id ,
514552 "span_id" : span_id ,
@@ -553,6 +591,12 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
553591 ),
554592 "response.number_of_messages" : len (input_message_list ) + len (output_message_list ),
555593 }
594+
595+ if all_token_counts :
596+ full_chat_completion_summary_dict ["response.usage.prompt_tokens" ] = prompt_tokens
597+ full_chat_completion_summary_dict ["response.usage.completion_tokens" ] = completion_tokens
598+ full_chat_completion_summary_dict ["response.usage.total_tokens" ] = total_tokens
599+
556600 llm_metadata = _get_llm_attributes (transaction )
557601 full_chat_completion_summary_dict .update (llm_metadata )
558602 transaction .record_custom_event ("LlmChatCompletionSummary" , full_chat_completion_summary_dict )
@@ -564,11 +608,11 @@ def _record_completion_success(transaction, linking_metadata, completion_id, kwa
564608 span_id ,
565609 trace_id ,
566610 response_model ,
567- request_model ,
568611 response_id ,
569612 request_id ,
570613 llm_metadata ,
571614 output_message_list ,
615+ all_token_counts ,
572616 )
573617 except Exception :
574618 _logger .warning (RECORD_EVENTS_FAILURE_LOG_MESSAGE , traceback .format_exception (* sys .exc_info ()))
@@ -579,6 +623,7 @@ def _record_completion_error(transaction, linking_metadata, completion_id, kwarg
579623 trace_id = linking_metadata .get ("trace.id" )
580624 request_message_list = kwargs .get ("messages" , None ) or []
581625 notice_error_attributes = {}
626+
582627 try :
583628 if OPENAI_V1 :
584629 response = getattr (exc , "response" , None )
@@ -643,18 +688,20 @@ def _record_completion_error(transaction, linking_metadata, completion_id, kwarg
643688 output_message_list = []
644689 if "content" in kwargs :
645690 output_message_list = [{"content" : kwargs .get ("content" ), "role" : kwargs .get ("role" )}]
691+
646692 create_chat_completion_message_event (
647693 transaction ,
648694 request_message_list ,
649695 completion_id ,
650696 span_id ,
651697 trace_id ,
652698 kwargs .get ("response.model" ),
653- request_model ,
654699 response_id ,
655700 request_id ,
656701 llm_metadata ,
657702 output_message_list ,
703+ # We do not record token counts in error cases, so set all_token_counts to True so the pipeline tokenizer does not run
704+ all_token_counts = True ,
658705 )
659706 except Exception :
660707 _logger .warning (RECORD_EVENTS_FAILURE_LOG_MESSAGE , traceback .format_exception (* sys .exc_info ()))
0 commit comments