@@ -172,6 +172,20 @@ def summarize_messages(llm,history,stored_messages):
172172 return True
173173
174174
175+ def get_total_tokens (model , ai_response ):
176+ if "gemini" in model :
177+ total_tokens = ai_response .response_metadata ['usage_metadata' ]['prompt_token_count' ]
178+ elif "bedrock" in model :
179+ total_tokens = ai_response .response_metadata ['usage' ]['total_tokens' ]
180+ elif "anthropic-claude" in model :
181+ input_tokens = int (ai_response .response_metadata ['usage' ]['input_tokens' ])
182+ output_tokens = int (ai_response .response_metadata ['usage' ]['output_tokens' ])
183+ total_tokens = input_tokens + output_tokens
184+ else :
185+ total_tokens = ai_response .response_metadata ['token_usage' ]['total_tokens' ]
186+ return total_tokens
187+
188+
175189def clear_chat_history (graph ,session_id ):
176190 history = Neo4jChatMessageHistory (
177191 graph = graph ,
@@ -186,7 +200,7 @@ def clear_chat_history(graph,session_id):
186200
187201def setup_chat (model , graph , session_id , retrieval_query ):
188202 start_time = time .time ()
189- if model in ["diffbot" ,"LLM_MODEL_CONFIG_ollama_llama3" , "LLM_MODEL_CONFIG_anthropic-claude-3-5-sonnet" , "LLM_MODEL_CONFIG_bedrock-claude-3-5-sonnet " ]:
203+ if model in ["diffbot" , "LLM_MODEL_CONFIG_ollama_llama3" ]:
190204 model = "openai-gpt-4o"
191205 llm ,model_name = get_llm (model )
192206 logging .info (f"Model called in chat { model } and model version is { model_name } " )
@@ -216,11 +230,8 @@ def process_documents(docs, question, messages, llm,model):
216230 })
217231 result = get_sources_and_chunks (sources , docs )
218232 content = ai_response .content
219-
220- if "gemini" in model :
221- total_tokens = ai_response .response_metadata ['usage_metadata' ]['prompt_token_count' ]
222- else :
223- total_tokens = ai_response .response_metadata ['token_usage' ]['total_tokens' ]
233+ total_tokens = get_total_tokens (model , ai_response )
234+
224235
225236 predict_time = time .time () - start_time
226237 logging .info (f"Final Response predicted in { predict_time :.2f} seconds" )
0 commit comments