@@ -71,20 +71,26 @@ def traced_chat_func(*args, **kwargs):
7171 chat_request = chat_details .chat_request
7272 stream = getattr (chat_request , 'is_stream' , False )
7373
74- # Call the original OCI client chat method
74+ # Measure timing around the actual OCI call
75+ start_time = time .time ()
7576 response = chat_func (* args , ** kwargs )
77+ end_time = time .time ()
7678
7779 if stream :
7880 return handle_streaming_chat (
7981 response = response ,
8082 chat_details = chat_details ,
8183 kwargs = kwargs ,
84+ start_time = start_time ,
85+ end_time = end_time ,
8286 )
8387 else :
8488 return handle_non_streaming_chat (
8589 response = response ,
8690 chat_details = chat_details ,
8791 kwargs = kwargs ,
92+ start_time = start_time ,
93+ end_time = end_time ,
8894 )
8995
9096 client .chat = traced_chat_func
@@ -95,6 +101,8 @@ def handle_streaming_chat(
95101 response : Iterator [Any ],
96102 chat_details : Any ,
97103 kwargs : Dict [str , Any ],
104+ start_time : float ,
105+ end_time : float ,
98106) -> Iterator [Any ]:
99107 """Handles the chat method when streaming is enabled.
100108
@@ -116,19 +124,24 @@ def handle_streaming_chat(
116124 chunks = response .data .events (),
117125 chat_details = chat_details ,
118126 kwargs = kwargs ,
127+ start_time = start_time ,
128+ end_time = end_time ,
119129 )
120130
121131
122132def stream_chunks (
123133 chunks : Iterator [Any ],
124134 chat_details : Any ,
125135 kwargs : Dict [str , Any ],
136+ start_time : float ,
137+ end_time : float ,
126138):
127139 """Streams the chunks of the completion and traces the completion."""
128140 collected_output_data = []
129141 collected_function_calls = []
130142 raw_outputs = []
131- start_time = time .time ()
143+ # Use the timing from the actual OCI call (passed as parameter)
144+ # start_time is already provided
132145
133146 # For grouping raw outputs into a more organized structure
134147 streaming_stats = {
@@ -187,6 +200,9 @@ def stream_chunks(
187200 if hasattr (chunk , 'data' ) and hasattr (chunk .data , 'usage' ):
188201 usage = chunk .data .usage
189202 num_of_prompt_tokens = getattr (usage , 'prompt_tokens' , 0 )
203+ else :
204+ # OCI doesn't provide usage info, estimate from chat_details
205+ num_of_prompt_tokens = estimate_prompt_tokens_from_chat_details (chat_details )
190206
191207 if i > 0 :
192208 num_of_completion_tokens = i + 1
@@ -343,6 +359,8 @@ def handle_non_streaming_chat(
343359 response : Any ,
344360 chat_details : Any ,
345361 kwargs : Dict [str , Any ],
362+ start_time : float ,
363+ end_time : float ,
346364) -> Any :
347365 """Handles the chat method when streaming is disabled.
348366
@@ -360,17 +378,15 @@ def handle_non_streaming_chat(
360378 Any
361379 The chat completion response.
362380 """
363- start_time = time .time ()
364- # The response is now passed directly, no need to call chat_func here
365- end_time = time .time () # This will be adjusted after processing
381+ # Use the timing from the actual OCI call (passed as parameters)
382+ # start_time and end_time are already provided
366383
367384 try :
368385 # Parse response and extract data
369386 output_data = parse_non_streaming_output_data (response )
370387 tokens_info = extract_tokens_info (response , chat_details )
371388 model_id = extract_model_id (chat_details )
372389
373- end_time = time .time ()
374390 latency = (end_time - start_time ) * 1000
375391
376392 # Extract additional metadata
@@ -569,6 +585,28 @@ def parse_non_streaming_output_data(response) -> Union[str, Dict[str, Any], None
569585 return str (data )
570586
571587
588+ def estimate_prompt_tokens_from_chat_details (chat_details ) -> int :
589+ """Estimate prompt tokens from chat details when OCI doesn't provide usage info."""
590+ if not chat_details :
591+ return 10 # Fallback estimate
592+
593+ try :
594+ input_text = ""
595+ if hasattr (chat_details , 'chat_request' ) and hasattr (chat_details .chat_request , 'messages' ):
596+ for msg in chat_details .chat_request .messages :
597+ if hasattr (msg , 'content' ) and msg .content :
598+ for content_item in msg .content :
599+ if hasattr (content_item , 'text' ):
600+ input_text += content_item .text + " "
601+
602+ # Rough estimation: ~4 characters per token
603+ estimated_tokens = max (1 , len (input_text ) // 4 )
604+ return estimated_tokens
605+ except Exception as e :
606+ logger .debug ("Error estimating prompt tokens: %s" , e )
607+ return 10 # Fallback estimate
608+
609+
572610def extract_tokens_info (response , chat_details = None ) -> Dict [str , int ]:
573611 """Extract token usage information from the response."""
574612 tokens_info = {"input_tokens" : 0 , "output_tokens" : 0 , "total_tokens" : 0 }
0 commit comments