Skip to content

Commit 52579cb

Browse files
Simplify streaming (#124)
1 parent 0843a2c commit 52579cb

File tree

5 files changed

+74
-145
lines changed

5 files changed

+74
-145
lines changed

example-apps/chatbot-rag-app/api/app.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from flask import Flask, jsonify, request, Response
22
from flask_cors import CORS
3-
from queue import Queue
43
from uuid import uuid4
5-
from chat import ask_question, parse_stream_message
6-
import threading
4+
from chat import ask_question
75
import os
86
import sys
97

@@ -23,18 +21,8 @@ def api_chat():
2321
if question is None:
2422
return jsonify({"msg": "Missing question from request JSON"}), 400
2523

26-
stream_queue = Queue()
2724
session_id = request.args.get("session_id", str(uuid4()))
28-
29-
print("Chat session ID: ", session_id)
30-
31-
threading.Thread(
32-
target=ask_question, args=(question, stream_queue, session_id)
33-
).start()
34-
35-
return Response(
36-
parse_stream_message(session_id, stream_queue), mimetype="text/event-stream"
37-
)
25+
return Response(ask_question(question, session_id), mimetype="text/event-stream")
3826

3927

4028
@app.cli.command()
Lines changed: 28 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
1-
from langchain.callbacks.base import BaseCallbackHandler
2-
from langchain.chains import ConversationalRetrievalChain
3-
from langchain.prompts.chat import (
4-
HumanMessagePromptTemplate,
5-
SystemMessagePromptTemplate,
6-
ChatPromptTemplate,
7-
)
8-
from langchain.prompts.prompt import PromptTemplate
91
from langchain.vectorstores import ElasticsearchStore
10-
from queue import Queue
112
from llm_integrations import get_llm
123
from elasticsearch_client import (
134
elasticsearch_client,
145
get_elasticsearch_chat_message_history,
156
)
7+
from flask import render_template, stream_with_context, current_app
168
import json
179
import os
1810

@@ -21,135 +13,50 @@
2113
"ES_INDEX_CHAT_HISTORY", "workplace-app-docs-chat-history"
2214
)
2315
ELSER_MODEL = os.getenv("ELSER_MODEL", ".elser_model_2")
24-
POISON_MESSAGE = "~~~END~~~"
2516
SESSION_ID_TAG = "[SESSION_ID]"
2617
SOURCE_TAG = "[SOURCE]"
2718
DONE_TAG = "[DONE]"
2819

29-
30-
class QueueCallbackHandler(BaseCallbackHandler):
31-
def __init__(
32-
self,
33-
queue: Queue,
34-
):
35-
self.queue = queue
36-
self.in_human_prompt = True
37-
38-
def on_retriever_end(self, documents, *, run_id, parent_run_id=None, **kwargs):
39-
if len(documents) > 0:
40-
for doc in documents:
41-
source = {
42-
"name": doc.metadata["name"],
43-
"page_content": doc.page_content,
44-
"url": doc.metadata["url"],
45-
"icon": doc.metadata["category"],
46-
"updated_at": doc.metadata.get("updated_at", None),
47-
}
48-
self.queue.put(f"{SOURCE_TAG} {json.dumps(source)}")
49-
50-
def on_llm_new_token(self, token, **kwargs):
51-
if not self.in_human_prompt:
52-
self.queue.put(token)
53-
54-
def on_llm_start(
55-
self,
56-
serialized,
57-
prompts,
58-
*,
59-
run_id,
60-
parent_run_id=None,
61-
tags=None,
62-
metadata=None,
63-
**kwargs,
64-
):
65-
self.in_human_prompt = prompts[0].startswith("Human:")
66-
67-
def on_llm_end(self, response, *, run_id, parent_run_id=None, **kwargs):
68-
if not self.in_human_prompt:
69-
self.queue.put(POISON_MESSAGE)
70-
71-
7220
store = ElasticsearchStore(
7321
es_connection=elasticsearch_client,
7422
index_name=INDEX,
7523
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(model_id=ELSER_MODEL),
7624
)
7725

78-
general_system_template = """
79-
Human: Use the following passages to answer the user's question.
80-
Each passage has a SOURCE which is the title of the document. When answering, give the source name of the passages you are answering from, put them in a comma seperated list, prefixed at the start with SOURCES: $sources then print an empty line.
81-
82-
Example:
83-
84-
Question: What is the meaning of life?
85-
Response:
86-
The meaning of life is 42. \n
8726

88-
SOURCES: Hitchhiker's Guide to the Galaxy \n
89-
90-
If you don't know the answer, just say that you don't know, don't try to make up an answer.
91-
92-
----
93-
{context}
94-
----
27+
@stream_with_context
28+
def ask_question(question, session_id):
29+
yield f"data: {SESSION_ID_TAG} {session_id}\n\n"
30+
current_app.logger.debug("Chat session ID: %s", session_id)
9531

96-
"""
97-
general_user_template = "Question: {question}"
98-
qa_prompt = ChatPromptTemplate.from_messages(
99-
[
100-
SystemMessagePromptTemplate.from_template(general_system_template),
101-
HumanMessagePromptTemplate.from_template(general_user_template),
102-
]
103-
)
32+
chat_history = get_elasticsearch_chat_message_history(
33+
INDEX_CHAT_HISTORY, session_id
34+
)
10435

105-
document_prompt = PromptTemplate(
106-
input_variables=["page_content", "name"],
107-
template="""
108-
---
109-
NAME: "{name}"
110-
PASSAGE:
111-
{page_content}
112-
---
113-
""",
114-
)
36+
if len(chat_history.messages) > 0:
37+
# create a condensed question
38+
condense_question_prompt = render_template(
39+
'condense_question_prompt.txt', question=question,
40+
chat_history=chat_history.messages)
41+
question = get_llm().invoke(condense_question_prompt).content
11542

116-
retriever = store.as_retriever()
117-
llm = get_llm()
118-
chat = ConversationalRetrievalChain.from_llm(
119-
llm=llm,
120-
retriever=store.as_retriever(),
121-
return_source_documents=True,
122-
combine_docs_chain_kwargs={"prompt": qa_prompt, "document_prompt": document_prompt},
123-
verbose=True,
124-
)
43+
current_app.logger.debug('Question: %s', question)
12544

45+
docs = store.as_retriever().invoke(question)
46+
for doc in docs:
47+
doc_source = {**doc.metadata, 'page_content': doc.page_content}
48+
current_app.logger.debug('Retrieved document passage from: %s', doc.metadata['name'])
49+
yield f'data: {SOURCE_TAG} {json.dumps(doc_source)}\n\n'
12650

127-
def parse_stream_message(session_id, queue: Queue):
128-
yield f"data: {SESSION_ID_TAG} {session_id}\n\n"
51+
qa_prompt = render_template('rag_prompt.txt', question=question, docs=docs)
12952

130-
message = None
131-
break_out_flag = False
132-
while True:
133-
message = queue.get()
134-
for line in message.splitlines():
135-
if line == POISON_MESSAGE:
136-
break_out_flag = True
137-
break
138-
yield f"data: {line}\n\n"
139-
if break_out_flag:
140-
break
53+
answer = ''
54+
for chunk in get_llm().stream(qa_prompt):
55+
yield f'data: {chunk.content}\n\n'
56+
answer += chunk.content
14157

14258
yield f"data: {DONE_TAG}\n\n"
59+
current_app.logger.debug('Answer: %s', answer)
14360

144-
145-
def ask_question(question, queue, session_id):
146-
chat_history = get_elasticsearch_chat_message_history(
147-
INDEX_CHAT_HISTORY, session_id
148-
)
149-
result = chat(
150-
{"question": question, "chat_history": chat_history.messages},
151-
callbacks=[QueueCallbackHandler(queue)],
152-
)
153-
154-
chat_history.add_user_message(result["question"])
155-
chat_history.add_ai_message(result["answer"])
61+
chat_history.add_user_message(question)
62+
chat_history.add_ai_message(answer)

example-apps/chatbot-rag-app/api/llm_integrations.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55

66
LLM_TYPE = os.getenv("LLM_TYPE", "openai")
77

8-
def init_openai_chat():
8+
def init_openai_chat(temperature):
99
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
10-
return ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True, temperature=0.2)
11-
def init_vertex_chat():
10+
return ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True, temperature=temperature)
11+
def init_vertex_chat(temperature):
1212
VERTEX_PROJECT_ID = os.getenv("VERTEX_PROJECT_ID")
1313
VERTEX_REGION = os.getenv("VERTEX_REGION", "us-central1")
1414
vertexai.init(project=VERTEX_PROJECT_ID, location=VERTEX_REGION)
15-
return ChatVertexAI(streaming=True, temperature=0.2)
16-
def init_azure_chat():
15+
return ChatVertexAI(streaming=True, temperature=temperature)
16+
def init_azure_chat(temperature):
1717
OPENAI_VERSION=os.getenv("OPENAI_VERSION", "2023-05-15")
1818
BASE_URL=os.getenv("OPENAI_BASE_URL")
1919
OPENAI_API_KEY=os.getenv("OPENAI_API_KEY")
@@ -24,8 +24,8 @@ def init_azure_chat():
2424
openai_api_version=OPENAI_VERSION,
2525
openai_api_key=OPENAI_API_KEY,
2626
streaming=True,
27-
temperature=0.2)
28-
def init_bedrock():
27+
temperature=temperature)
28+
def init_bedrock(temperature):
2929
AWS_ACCESS_KEY=os.getenv("AWS_ACCESS_KEY")
3030
AWS_SECRET_KEY=os.getenv("AWS_SECRET_KEY")
3131
AWS_REGION=os.getenv("AWS_REGION")
@@ -35,7 +35,7 @@ def init_bedrock():
3535
client=BEDROCK_CLIENT,
3636
model_id=AWS_MODEL_ID,
3737
streaming=True,
38-
model_kwargs={"temperature":0.2})
38+
model_kwargs={"temperature":temperature})
3939

4040
MAP_LLM_TYPE_TO_CHAT_MODEL = {
4141
"azure": init_azure_chat,
@@ -44,8 +44,8 @@ def init_bedrock():
4444
"vertex": init_vertex_chat,
4545
}
4646

47-
def get_llm():
47+
def get_llm(temperature=0.2):
4848
if not LLM_TYPE in MAP_LLM_TYPE_TO_CHAT_MODEL:
4949
raise Exception("LLM type not found. Please set LLM_TYPE to one of: " + ", ".join(MAP_LLM_TYPE_TO_CHAT_MODEL.keys()) + ".")
5050

51-
return MAP_LLM_TYPE_TO_CHAT_MODEL[LLM_TYPE]()
51+
return MAP_LLM_TYPE_TO_CHAT_MODEL[LLM_TYPE](temperature=temperature)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
2+
3+
Chat history:
4+
{% for dialogue_turn in chat_history -%}
5+
{% if dialogue_turn.type == 'human' %}Question: {{ dialogue_turn.content }}{% elif dialogue_turn.type == 'ai' %}Response: {{ dialogue_turn.content }}{% endif %}
6+
{% endfor -%}
7+
Follow Up Question: {{ question }}
8+
Standalone question:
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
Use the following passages to answer the user's question.
2+
Each passage has a NAME which is the title of the document. When answering, give the source name of the passages you are answering from at the end. Put them in a comma separated list, prefixed with SOURCES:.
3+
4+
Example:
5+
6+
Question: What is the meaning of life?
7+
Response:
8+
The meaning of life is 42.
9+
10+
SOURCES: Hitchhiker's Guide to the Galaxy
11+
12+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
13+
14+
----
15+
16+
{% for doc in docs -%}
17+
---
18+
NAME: {{ doc.metadata.name }}
19+
PASSAGE:
20+
{{ doc.page_content }}
21+
---
22+
23+
{% endfor -%}
24+
----
25+
Question: {{ question }}
26+
Response:

0 commit comments

Comments
 (0)