@@ -165,12 +165,12 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
165165 else :
166166 output_data = None
167167 cost = self .get_cost_estimate (
168- model = kwargs . get ( " model" ) ,
168+ model = response . model ,
169169 num_input_tokens = response .usage .prompt_tokens ,
170170 num_output_tokens = response .usage .completion_tokens ,
171171 )
172172
173- tracer . add_openai_chat_completion_step_to_trace (
173+ self . _add_to_trace (
174174 end_time = end_time ,
175175 inputs = {
176176 "prompt" : kwargs ["messages" ],
@@ -181,10 +181,9 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
181181 cost = cost ,
182182 prompt_tokens = response .usage .prompt_tokens ,
183183 completion_tokens = response .usage .completion_tokens ,
184- model = kwargs . get ( " model" ) ,
184+ model = response . model ,
185185 model_parameters = kwargs .get ("model_parameters" ),
186186 raw_output = response .model_dump (),
187- provider = "OpenAI" ,
188187 )
189188 # pylint: disable=broad-except
190189 except Exception as e :
@@ -269,7 +268,7 @@ def stream_chunks():
269268 ),
270269 )
271270
272- tracer . add_openai_chat_completion_step_to_trace (
271+ self . _add_to_trace (
273272 end_time = end_time ,
274273 inputs = {
275274 "prompt" : kwargs ["messages" ],
@@ -290,7 +289,6 @@ def stream_chunks():
290289 else None
291290 )
292291 },
293- provider = "OpenAI" ,
294292 )
295293 # pylint: disable=broad-except
296294 except Exception as e :
@@ -318,12 +316,12 @@ def modified_create_completion(*args, **kwargs):
318316 output_data = choices [0 ].text .strip ()
319317 num_of_tokens = int (response .usage .total_tokens / len (prompts ))
320318 cost = self .get_cost_estimate (
321- model = kwargs . get ( " model" ) ,
319+ model = response . model ,
322320 num_input_tokens = response .usage .prompt_tokens ,
323321 num_output_tokens = response .usage .completion_tokens ,
324322 )
325323
326- tracer . add_openai_chat_completion_step_to_trace (
324+ self . _add_to_trace (
327325 end_time = end_time ,
328326 inputs = {
329327 "prompt" : [{"role" : "user" , "content" : input_data }],
@@ -334,10 +332,9 @@ def modified_create_completion(*args, **kwargs):
334332 cost = cost ,
335333 prompt_tokens = response .usage .prompt_tokens ,
336334 completion_tokens = response .usage .completion_tokens ,
337- model = kwargs . get ( " model" ) ,
335+ model = response . model ,
338336 model_parameters = kwargs .get ("model_parameters" ),
339337 raw_output = response .model_dump (),
340- provider = "OpenAI" ,
341338 )
342339 # pylint: disable=broad-except
343340 except Exception as e :
@@ -347,6 +344,13 @@ def modified_create_completion(*args, **kwargs):
347344
348345 return modified_create_completion
349346
347+ def _add_to_trace (self , ** kwargs ) -> None :
348+ """Add a step to the trace."""
349+ tracer .add_openai_chat_completion_step_to_trace (
350+ ** kwargs ,
351+ provider = "OpenAI" ,
352+ )
353+
350354 @staticmethod
351355 def _split_list (lst : List , n_parts : int ) -> List [List ]:
352356 """Split a list into n_parts."""
@@ -486,3 +490,32 @@ def thread_messages_to_prompt(
486490 }
487491 )
488492 return prompt
493+
494+
495+ class AzureOpenAIMonitor (OpenAIMonitor ):
496+ def __init__ (
497+ self ,
498+ client = None ,
499+ ) -> None :
500+ super ().__init__ (client )
501+
502+ @staticmethod
503+ def get_cost_estimate (
504+ num_input_tokens : int , num_output_tokens : int , model : str
505+ ) -> float :
506+ """Returns the cost estimate for a given model and number of tokens."""
507+ if model not in constants .AZURE_OPENAI_COST_PER_TOKEN :
508+ return None
509+ cost_per_token = constants .AZURE_OPENAI_COST_PER_TOKEN [model ]
510+ return (
511+ cost_per_token ["input" ] * num_input_tokens
512+ + cost_per_token ["output" ] * num_output_tokens
513+ )
514+
515+ def _add_to_trace (self , ** kwargs ) -> None :
516+ """Add a step to the trace."""
517+ tracer .add_openai_chat_completion_step_to_trace (
518+ ** kwargs ,
519+ name = "Azure OpenAI Chat Completion" ,
520+ provider = "Azure OpenAI" ,
521+ )
0 commit comments