@@ -136,6 +136,9 @@ def _get_modified_create_chat_completion(self) -> callable:
136136 def modified_create_chat_completion (* args , ** kwargs ) -> str :
137137 stream = kwargs .get ("stream" , False )
138138
139+ # Pop the reserved Openlayer kwargs
140+ inference_id = kwargs .pop ("inference_id" , None )
141+
139142 if not stream :
140143 start_time = time .time ()
141144 response = self .create_chat_completion (* args , ** kwargs )
@@ -169,21 +172,26 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
169172 num_input_tokens = response .usage .prompt_tokens ,
170173 num_output_tokens = response .usage .completion_tokens ,
171174 )
172-
173- self ._add_to_trace (
174- end_time = end_time ,
175- inputs = {
175+ trace_args = {
176+ "end_time" : end_time ,
177+ "inputs" : {
176178 "prompt" : kwargs ["messages" ],
177179 },
178- output = output_data ,
179- latency = (end_time - start_time ) * 1000 ,
180- tokens = response .usage .total_tokens ,
181- cost = cost ,
182- prompt_tokens = response .usage .prompt_tokens ,
183- completion_tokens = response .usage .completion_tokens ,
184- model = response .model ,
185- model_parameters = kwargs .get ("model_parameters" ),
186- raw_output = response .model_dump (),
180+ "output" : output_data ,
181+ "latency" : (end_time - start_time ) * 1000 ,
182+ "tokens" : response .usage .total_tokens ,
183+ "cost" : cost ,
184+ "prompt_tokens" : response .usage .prompt_tokens ,
185+ "completion_tokens" : response .usage .completion_tokens ,
186+ "model" : response .model ,
187+ "model_parameters" : kwargs .get ("model_parameters" ),
188+ "raw_output" : response .model_dump (),
189+ }
190+ if inference_id :
191+ trace_args ["id" ] = str (inference_id )
192+
193+ self ._add_to_trace (
194+ ** trace_args ,
187195 )
188196 # pylint: disable=broad-except
189197 except Exception as e :
@@ -267,28 +275,33 @@ def stream_chunks():
267275 else 0
268276 ),
269277 )
270-
271- self ._add_to_trace (
272- end_time = end_time ,
273- inputs = {
278+ trace_args = {
279+ "end_time" : end_time ,
280+ "inputs" : {
274281 "prompt" : kwargs ["messages" ],
275282 },
276- output = output_data ,
277- latency = latency ,
278- tokens = num_of_completion_tokens ,
279- cost = completion_cost ,
280- prompt_tokens = None ,
281- completion_tokens = num_of_completion_tokens ,
282- model = kwargs .get ("model" ),
283- model_parameters = kwargs .get ("model_parameters" ),
284- raw_output = raw_outputs ,
285- metadata = {
283+ " output" : output_data ,
284+ " latency" : latency ,
285+ " tokens" : num_of_completion_tokens ,
286+ " cost" : completion_cost ,
287+ " prompt_tokens" : None ,
288+ " completion_tokens" : num_of_completion_tokens ,
289+ " model" : kwargs .get ("model" ),
290+ " model_parameters" : kwargs .get ("model_parameters" ),
291+ " raw_output" : raw_outputs ,
292+ " metadata" : {
286293 "timeToFirstToken" : (
287294 (first_token_time - start_time ) * 1000
288295 if first_token_time
289296 else None
290297 )
291298 },
299+ }
300+ if inference_id :
301+ trace_args ["id" ] = str (inference_id )
302+
303+ self ._add_to_trace (
304+ ** trace_args ,
292305 )
293306 # pylint: disable=broad-except
294307 except Exception as e :
0 commit comments