2626
2727def trace_oci_genai (
2828 client : "GenerativeAiInferenceClient" ,
29+ estimate_tokens : bool = True ,
2930) -> "GenerativeAiInferenceClient" :
3031 """Patch the OCI Generative AI client to trace chat completions.
3132
@@ -47,6 +48,9 @@ def trace_oci_genai(
4748 ----------
4849 client : GenerativeAiInferenceClient
4950 The OCI Generative AI client to patch.
51+ estimate_tokens : bool, optional
52+ Whether to estimate token counts when not provided by the OCI response.
53+ Defaults to True. When False, token fields will be None if not available.
5054
5155 Returns
5256 -------
@@ -84,6 +88,7 @@ def traced_chat_func(*args, **kwargs):
8488 kwargs = kwargs ,
8589 start_time = start_time ,
8690 end_time = end_time ,
91+ estimate_tokens = estimate_tokens ,
8792 )
8893 else :
8994 return handle_non_streaming_chat (
@@ -92,6 +97,7 @@ def traced_chat_func(*args, **kwargs):
9297 kwargs = kwargs ,
9398 start_time = start_time ,
9499 end_time = end_time ,
100+ estimate_tokens = estimate_tokens ,
95101 )
96102
97103 client .chat = traced_chat_func
@@ -104,6 +110,7 @@ def handle_streaming_chat(
104110 kwargs : Dict [str , Any ],
105111 start_time : float ,
106112 end_time : float ,
113+ estimate_tokens : bool = True ,
107114) -> Iterator [Any ]:
108115 """Handles the chat method when streaming is enabled.
109116
@@ -127,6 +134,7 @@ def handle_streaming_chat(
127134 kwargs = kwargs ,
128135 start_time = start_time ,
129136 end_time = end_time ,
137+ estimate_tokens = estimate_tokens ,
130138 )
131139
132140
@@ -136,6 +144,7 @@ def stream_chunks(
136144 kwargs : Dict [str , Any ],
137145 start_time : float ,
138146 end_time : float ,
147+ estimate_tokens : bool = True ,
139148):
140149 """Streams the chunks of the completion and traces the completion."""
141150 collected_output_data = []
@@ -164,15 +173,18 @@ def stream_chunks(
164173 usage = chunk .data .usage
165174 num_of_prompt_tokens = getattr (usage , "prompt_tokens" , 0 )
166175 else :
167- # OCI doesn't provide usage info, estimate from chat_details
168- num_of_prompt_tokens = estimate_prompt_tokens_from_chat_details (chat_details )
176+ # OCI doesn't provide usage info, estimate from chat_details if enabled
177+ if estimate_tokens :
178+ num_of_prompt_tokens = estimate_prompt_tokens_from_chat_details (chat_details )
179+ else :
180+ num_of_prompt_tokens = None
169181
170182 # Store first chunk sample (only for debugging)
171183 if hasattr (chunk , "data" ):
172184 chunk_samples .append ({"index" : 0 , "type" : "first" })
173185
174- # Update completion tokens count
175- if i > 0 :
186+ # Update completion tokens count (estimation based)
187+ if i > 0 and estimate_tokens :
176188 num_of_completion_tokens = i + 1
177189
178190 # Fast content extraction - optimized for performance
@@ -208,8 +220,11 @@ def stream_chunks(
208220 # chat_details is passed directly as parameter
209221 model_id = extract_model_id (chat_details )
210222
211- # Calculate total tokens
212- total_tokens = (num_of_prompt_tokens or 0 ) + (num_of_completion_tokens or 0 )
223+ # Calculate total tokens - handle None values properly
224+ if estimate_tokens :
225+ total_tokens = (num_of_prompt_tokens or 0 ) + (num_of_completion_tokens or 0 )
226+ else :
227+ total_tokens = None if num_of_prompt_tokens is None and num_of_completion_tokens is None else ((num_of_prompt_tokens or 0 ) + (num_of_completion_tokens or 0 ))
213228
214229 # Simplified metadata - only essential timing info
215230 metadata = {
@@ -222,8 +237,8 @@ def stream_chunks(
222237 output = output_data ,
223238 latency = latency ,
224239 tokens = total_tokens ,
225- prompt_tokens = num_of_prompt_tokens or 0 ,
226- completion_tokens = num_of_completion_tokens or 0 ,
240+ prompt_tokens = num_of_prompt_tokens ,
241+ completion_tokens = num_of_completion_tokens ,
227242 model = model_id ,
228243 model_parameters = get_model_parameters (chat_details ),
229244 raw_output = {
@@ -251,6 +266,7 @@ def handle_non_streaming_chat(
251266 kwargs : Dict [str , Any ],
252267 start_time : float ,
253268 end_time : float ,
269+ estimate_tokens : bool = True ,
254270) -> Any :
255271 """Handles the chat method when streaming is disabled.
256272
@@ -274,7 +290,7 @@ def handle_non_streaming_chat(
274290 try :
275291 # Parse response and extract data
276292 output_data = parse_non_streaming_output_data (response )
277- tokens_info = extract_tokens_info (response , chat_details )
293+ tokens_info = extract_tokens_info (response , chat_details , estimate_tokens )
278294 model_id = extract_model_id (chat_details )
279295
280296 latency = (end_time - start_time ) * 1000
@@ -287,9 +303,9 @@ def handle_non_streaming_chat(
287303 inputs = extract_inputs_from_chat_details (chat_details ),
288304 output = output_data ,
289305 latency = latency ,
290- tokens = tokens_info .get ("total_tokens" , 0 ),
291- prompt_tokens = tokens_info .get ("input_tokens" , 0 ),
292- completion_tokens = tokens_info .get ("output_tokens" , 0 ),
306+ tokens = tokens_info .get ("total_tokens" ),
307+ prompt_tokens = tokens_info .get ("input_tokens" ),
308+ completion_tokens = tokens_info .get ("output_tokens" ),
293309 model = model_id ,
294310 model_parameters = get_model_parameters (chat_details ),
295311 raw_output = response .data .__dict__ if hasattr (response , "data" ) else response .__dict__ ,
@@ -472,10 +488,10 @@ def parse_non_streaming_output_data(response) -> Union[str, Dict[str, Any], None
472488 return str (data )
473489
474490
475- def estimate_prompt_tokens_from_chat_details (chat_details ) -> int :
491+ def estimate_prompt_tokens_from_chat_details (chat_details ) -> Optional [ int ] :
476492 """Estimate prompt tokens from chat details when OCI doesn't provide usage info."""
477493 if not chat_details :
478- return 10 # Fallback estimate
494+ return None
479495
480496 try :
481497 input_text = ""
@@ -491,72 +507,107 @@ def estimate_prompt_tokens_from_chat_details(chat_details) -> int:
491507 return estimated_tokens
492508 except Exception as e :
493509 logger .debug ("Error estimating prompt tokens: %s" , e )
494- return 10 # Fallback estimate
510+ return None
495511
496512
497- def extract_tokens_info (response , chat_details = None ) -> Dict [str , int ]:
498- """Extract token usage information from the response."""
499- tokens_info = {"input_tokens" : 0 , "output_tokens" : 0 , "total_tokens" : 0 }
513+ def extract_tokens_info (response , chat_details = None , estimate_tokens : bool = True ) -> Dict [str , Optional [int ]]:
514+ """Extract token usage information from the response.
515+
516+ Handles both CohereChatResponse and GenericChatResponse types from OCI.
517+
518+ Parameters
519+ ----------
520+ response : Any
521+ The OCI chat response object (CohereChatResponse or GenericChatResponse)
522+ chat_details : Any, optional
523+ The chat details for token estimation if needed
524+ estimate_tokens : bool, optional
525+ Whether to estimate tokens when not available in response. Defaults to True.
526+
527+ Returns
528+ -------
529+ Dict[str, Optional[int]]
530+ Dictionary with token counts. Values can be None if unavailable and estimation disabled.
531+ """
532+ tokens_info = {"input_tokens" : None , "output_tokens" : None , "total_tokens" : None }
500533
501534 try :
502- # First, try the standard locations for token usage
535+ # Extract token usage from OCI response (handles both CohereChatResponse and GenericChatResponse)
503536 if hasattr (response , "data" ):
504- # Check multiple possible locations for usage info
505- usage_locations = [
506- getattr (response .data , "usage" , None ),
507- getattr (getattr (response .data , "chat_response" , None ), "usage" , None ),
508- ]
509-
510- for usage in usage_locations :
511- if usage is not None :
512- tokens_info ["input_tokens" ] = getattr (usage , "prompt_tokens" , 0 )
513- tokens_info ["output_tokens" ] = getattr (usage , "completion_tokens" , 0 )
514- tokens_info ["total_tokens" ] = tokens_info ["input_tokens" ] + tokens_info ["output_tokens" ]
515- logger .debug ("Found token usage info: %s" , tokens_info )
516- return tokens_info
517-
518- # If no usage info found, estimate based on text length
519- # This is common for OCI which doesn't return token counts
520- logger .debug ("No token usage found in response, estimating from text length" )
537+ usage = None
538+
539+ # For CohereChatResponse: response.data.usage
540+ if hasattr (response .data , "usage" ):
541+ usage = response .data .usage
542+ # For GenericChatResponse: response.data.chat_response.usage
543+ elif hasattr (response .data , "chat_response" ) and hasattr (response .data .chat_response , "usage" ):
544+ usage = response .data .chat_response .usage
545+
546+ if usage is not None :
547+ # Extract tokens from usage object
548+ prompt_tokens = getattr (usage , "prompt_tokens" , None )
549+ completion_tokens = getattr (usage , "completion_tokens" , None )
550+ total_tokens = getattr (usage , "total_tokens" , None )
551+
552+ tokens_info ["input_tokens" ] = prompt_tokens
553+ tokens_info ["output_tokens" ] = completion_tokens
554+ tokens_info ["total_tokens" ] = total_tokens or (
555+ (prompt_tokens + completion_tokens ) if prompt_tokens is not None and completion_tokens is not None else None
556+ )
557+ logger .debug ("Found token usage info: %s" , tokens_info )
558+ return tokens_info
521559
522- # Estimate input tokens from chat_details
523- if chat_details :
560+ # If no usage info found, estimate based on text length only if estimation is enabled
561+ if estimate_tokens :
562+ logger .debug ("No token usage found in response, estimating from text length" )
563+
564+ # Estimate input tokens from chat_details
565+ if chat_details :
566+ try :
567+ input_text = ""
568+ if hasattr (chat_details , "chat_request" ) and hasattr (chat_details .chat_request , "messages" ):
569+ for msg in chat_details .chat_request .messages :
570+ if hasattr (msg , "content" ) and msg .content :
571+ for content_item in msg .content :
572+ if hasattr (content_item , "text" ):
573+ input_text += content_item .text + " "
574+
575+ # Rough estimation: ~4 characters per token
576+ estimated_input_tokens = max (1 , len (input_text ) // 4 )
577+ tokens_info ["input_tokens" ] = estimated_input_tokens
578+ except Exception as e :
579+ logger .debug ("Error estimating input tokens: %s" , e )
580+ tokens_info ["input_tokens" ] = None
581+
582+ # Estimate output tokens from response
524583 try :
525- input_text = ""
526- if hasattr (chat_details , "chat_request" ) and hasattr (chat_details .chat_request , "messages" ):
527- for msg in chat_details .chat_request .messages :
528- if hasattr (msg , "content" ) and msg .content :
529- for content_item in msg .content :
530- if hasattr (content_item , "text" ):
531- input_text += content_item .text + " "
532-
533- # Rough estimation: ~4 characters per token
534- estimated_input_tokens = max (1 , len (input_text ) // 4 )
535- tokens_info ["input_tokens" ] = estimated_input_tokens
584+ output_text = parse_non_streaming_output_data (response )
585+ if isinstance (output_text , str ):
586+ # Rough estimation: ~4 characters per token
587+ estimated_output_tokens = max (1 , len (output_text ) // 4 )
588+ tokens_info ["output_tokens" ] = estimated_output_tokens
589+ else :
590+ tokens_info ["output_tokens" ] = None
536591 except Exception as e :
537- logger .debug ("Error estimating input tokens: %s" , e )
538- tokens_info ["input_tokens " ] = 10 # Fallback estimate
592+ logger .debug ("Error estimating output tokens: %s" , e )
593+ tokens_info ["output_tokens " ] = None
539594
540- # Estimate output tokens from response
541- try :
542- output_text = parse_non_streaming_output_data (response )
543- if isinstance (output_text , str ):
544- # Rough estimation: ~4 characters per token
545- estimated_output_tokens = max (1 , len (output_text ) // 4 )
546- tokens_info ["output_tokens" ] = estimated_output_tokens
595+ # Calculate total tokens only if we have estimates
596+ if tokens_info ["input_tokens" ] is not None and tokens_info ["output_tokens" ] is not None :
597+ tokens_info ["total_tokens" ] = tokens_info ["input_tokens" ] + tokens_info ["output_tokens" ]
598+ elif tokens_info ["input_tokens" ] is not None or tokens_info ["output_tokens" ] is not None :
599+ tokens_info ["total_tokens" ] = (tokens_info ["input_tokens" ] or 0 ) + (tokens_info ["output_tokens" ] or 0 )
547600 else :
548- tokens_info ["output_tokens" ] = 5 # Fallback estimate
549- except Exception as e :
550- logger .debug ("Error estimating output tokens: %s" , e )
551- tokens_info ["output_tokens" ] = 5 # Fallback estimate
552-
553- tokens_info ["total_tokens" ] = tokens_info ["input_tokens" ] + tokens_info ["output_tokens" ]
554- logger .debug ("Estimated token usage: %s" , tokens_info )
601+ tokens_info ["total_tokens" ] = None
602+
603+ logger .debug ("Estimated token usage: %s" , tokens_info )
604+ else :
605+ logger .debug ("No token usage found in response and estimation disabled, returning None values" )
555606
556607 except Exception as e :
557608 logger .debug ("Error extracting/estimating token info: %s" , e )
558- # Provide minimal fallback estimates
559- tokens_info = {"input_tokens" : 10 , "output_tokens" : 5 , "total_tokens" : 15 }
609+ # Always return None values on exceptions (no more fallback values)
610+ tokens_info = {"input_tokens" : None , "output_tokens" : None , "total_tokens" : None }
560611
561612 return tokens_info
562613
0 commit comments