Skip to content

Commit 9a02451

Browse files
LLMs for chat (#500)
* Added token usage * modified token usage
1 parent b3764e3 commit 9a02451

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

backend/src/QA_integration_new.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ def summarize_messages(llm,history,stored_messages):
172172
return True
173173

174174

175+
def get_total_tokens(model, ai_response):
176+
if "gemini" in model:
177+
total_tokens = ai_response.response_metadata['usage_metadata']['prompt_token_count']
178+
elif "bedrock" in model:
179+
total_tokens = ai_response.response_metadata['usage']['total_tokens']
180+
elif "anthropic-claude" in model:
181+
input_tokens = int(ai_response.response_metadata['usage']['input_tokens'])
182+
output_tokens = int(ai_response.response_metadata['usage']['output_tokens'])
183+
total_tokens = input_tokens + output_tokens
184+
else:
185+
total_tokens = ai_response.response_metadata['token_usage']['total_tokens']
186+
return total_tokens
187+
188+
175189
def clear_chat_history(graph,session_id):
176190
history = Neo4jChatMessageHistory(
177191
graph=graph,
@@ -186,7 +200,7 @@ def clear_chat_history(graph,session_id):
186200

187201
def setup_chat(model, graph, session_id, retrieval_query):
188202
start_time = time.time()
189-
if model in ["diffbot","LLM_MODEL_CONFIG_ollama_llama3","LLM_MODEL_CONFIG_anthropic-claude-3-5-sonnet","LLM_MODEL_CONFIG_bedrock-claude-3-5-sonnet"]:
203+
if model in ["diffbot", "LLM_MODEL_CONFIG_ollama_llama3"]:
190204
model = "openai-gpt-4o"
191205
llm,model_name = get_llm(model)
192206
logging.info(f"Model called in chat {model} and model version is {model_name}")
@@ -216,11 +230,8 @@ def process_documents(docs, question, messages, llm,model):
216230
})
217231
result = get_sources_and_chunks(sources, docs)
218232
content = ai_response.content
219-
220-
if "gemini" in model:
221-
total_tokens = ai_response.response_metadata['usage_metadata']['prompt_token_count']
222-
else:
223-
total_tokens = ai_response.response_metadata['token_usage']['total_tokens']
233+
total_tokens = get_total_tokens(model, ai_response)
234+
224235

225236
predict_time = time.time() - start_time
226237
logging.info(f"Final Response predicted in {predict_time:.2f} seconds")

0 commit comments

Comments
 (0)