|
1 | | -from langchain.callbacks.base import BaseCallbackHandler |
2 | | -from langchain.chains import ConversationalRetrievalChain |
3 | | -from langchain.prompts.chat import ( |
4 | | - HumanMessagePromptTemplate, |
5 | | - SystemMessagePromptTemplate, |
6 | | - ChatPromptTemplate, |
7 | | -) |
8 | | -from langchain.prompts.prompt import PromptTemplate |
9 | 1 | from langchain.vectorstores import ElasticsearchStore |
10 | | -from queue import Queue |
11 | 2 | from llm_integrations import get_llm |
12 | 3 | from elasticsearch_client import ( |
13 | 4 | elasticsearch_client, |
14 | 5 | get_elasticsearch_chat_message_history, |
15 | 6 | ) |
| 7 | +from flask import render_template, stream_with_context, current_app |
16 | 8 | import json |
17 | 9 | import os |
18 | 10 |
|
|
21 | 13 | "ES_INDEX_CHAT_HISTORY", "workplace-app-docs-chat-history" |
22 | 14 | ) |
23 | 15 | ELSER_MODEL = os.getenv("ELSER_MODEL", ".elser_model_2") |
24 | | -POISON_MESSAGE = "~~~END~~~" |
25 | 16 | SESSION_ID_TAG = "[SESSION_ID]" |
26 | 17 | SOURCE_TAG = "[SOURCE]" |
27 | 18 | DONE_TAG = "[DONE]" |
28 | 19 |
|
29 | | - |
30 | | -class QueueCallbackHandler(BaseCallbackHandler): |
31 | | - def __init__( |
32 | | - self, |
33 | | - queue: Queue, |
34 | | - ): |
35 | | - self.queue = queue |
36 | | - self.in_human_prompt = True |
37 | | - |
38 | | - def on_retriever_end(self, documents, *, run_id, parent_run_id=None, **kwargs): |
39 | | - if len(documents) > 0: |
40 | | - for doc in documents: |
41 | | - source = { |
42 | | - "name": doc.metadata["name"], |
43 | | - "page_content": doc.page_content, |
44 | | - "url": doc.metadata["url"], |
45 | | - "icon": doc.metadata["category"], |
46 | | - "updated_at": doc.metadata.get("updated_at", None), |
47 | | - } |
48 | | - self.queue.put(f"{SOURCE_TAG} {json.dumps(source)}") |
49 | | - |
50 | | - def on_llm_new_token(self, token, **kwargs): |
51 | | - if not self.in_human_prompt: |
52 | | - self.queue.put(token) |
53 | | - |
54 | | - def on_llm_start( |
55 | | - self, |
56 | | - serialized, |
57 | | - prompts, |
58 | | - *, |
59 | | - run_id, |
60 | | - parent_run_id=None, |
61 | | - tags=None, |
62 | | - metadata=None, |
63 | | - **kwargs, |
64 | | - ): |
65 | | - self.in_human_prompt = prompts[0].startswith("Human:") |
66 | | - |
67 | | - def on_llm_end(self, response, *, run_id, parent_run_id=None, **kwargs): |
68 | | - if not self.in_human_prompt: |
69 | | - self.queue.put(POISON_MESSAGE) |
70 | | - |
71 | | - |
72 | 20 | store = ElasticsearchStore( |
73 | 21 | es_connection=elasticsearch_client, |
74 | 22 | index_name=INDEX, |
75 | 23 | strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(model_id=ELSER_MODEL), |
76 | 24 | ) |
77 | 25 |
|
78 | | -general_system_template = """ |
79 | | -Human: Use the following passages to answer the user's question. |
80 | | -Each passage has a SOURCE which is the title of the document. When answering, give the source name of the passages you are answering from, put them in a comma seperated list, prefixed at the start with SOURCES: $sources then print an empty line. |
81 | | -
|
82 | | -Example: |
83 | | -
|
84 | | -Question: What is the meaning of life? |
85 | | -Response: |
86 | | -The meaning of life is 42. \n |
87 | 26 |
|
88 | | -SOURCES: Hitchhiker's Guide to the Galaxy \n |
89 | | -
|
90 | | -If you don't know the answer, just say that you don't know, don't try to make up an answer. |
91 | | -
|
92 | | ----- |
93 | | -{context} |
94 | | ----- |
| 27 | +@stream_with_context |
| 28 | +def ask_question(question, session_id): |
| 29 | + yield f"data: {SESSION_ID_TAG} {session_id}\n\n" |
| 30 | + current_app.logger.debug("Chat session ID: %s", session_id) |
95 | 31 |
|
96 | | -""" |
97 | | -general_user_template = "Question: {question}" |
98 | | -qa_prompt = ChatPromptTemplate.from_messages( |
99 | | - [ |
100 | | - SystemMessagePromptTemplate.from_template(general_system_template), |
101 | | - HumanMessagePromptTemplate.from_template(general_user_template), |
102 | | - ] |
103 | | -) |
| 32 | + chat_history = get_elasticsearch_chat_message_history( |
| 33 | + INDEX_CHAT_HISTORY, session_id |
| 34 | + ) |
104 | 35 |
|
105 | | -document_prompt = PromptTemplate( |
106 | | - input_variables=["page_content", "name"], |
107 | | - template=""" |
108 | | ---- |
109 | | -NAME: "{name}" |
110 | | -PASSAGE: |
111 | | -{page_content} |
112 | | ---- |
113 | | -""", |
114 | | -) |
| 36 | + if len(chat_history.messages) > 0: |
| 37 | + # create a condensed question |
| 38 | + condense_question_prompt = render_template( |
| 39 | + 'condense_question_prompt.txt', question=question, |
| 40 | + chat_history=chat_history.messages) |
| 41 | + question = get_llm().invoke(condense_question_prompt).content |
115 | 42 |
|
116 | | -retriever = store.as_retriever() |
117 | | -llm = get_llm() |
118 | | -chat = ConversationalRetrievalChain.from_llm( |
119 | | - llm=llm, |
120 | | - retriever=store.as_retriever(), |
121 | | - return_source_documents=True, |
122 | | - combine_docs_chain_kwargs={"prompt": qa_prompt, "document_prompt": document_prompt}, |
123 | | - verbose=True, |
124 | | -) |
| 43 | + current_app.logger.debug('Question: %s', question) |
125 | 44 |
|
| 45 | + docs = store.as_retriever().invoke(question) |
| 46 | + for doc in docs: |
| 47 | + doc_source = {**doc.metadata, 'page_content': doc.page_content} |
| 48 | + current_app.logger.debug('Retrieved document passage from: %s', doc.metadata['name']) |
| 49 | + yield f'data: {SOURCE_TAG} {json.dumps(doc_source)}\n\n' |
126 | 50 |
|
127 | | -def parse_stream_message(session_id, queue: Queue): |
128 | | - yield f"data: {SESSION_ID_TAG} {session_id}\n\n" |
| 51 | + qa_prompt = render_template('rag_prompt.txt', question=question, docs=docs) |
129 | 52 |
|
130 | | - message = None |
131 | | - break_out_flag = False |
132 | | - while True: |
133 | | - message = queue.get() |
134 | | - for line in message.splitlines(): |
135 | | - if line == POISON_MESSAGE: |
136 | | - break_out_flag = True |
137 | | - break |
138 | | - yield f"data: {line}\n\n" |
139 | | - if break_out_flag: |
140 | | - break |
| 53 | + answer = '' |
| 54 | + for chunk in get_llm().stream(qa_prompt): |
| 55 | + yield f'data: {chunk.content}\n\n' |
| 56 | + answer += chunk.content |
141 | 57 |
|
142 | 58 | yield f"data: {DONE_TAG}\n\n" |
| 59 | + current_app.logger.debug('Answer: %s', answer) |
143 | 60 |
|
144 | | - |
145 | | -def ask_question(question, queue, session_id): |
146 | | - chat_history = get_elasticsearch_chat_message_history( |
147 | | - INDEX_CHAT_HISTORY, session_id |
148 | | - ) |
149 | | - result = chat( |
150 | | - {"question": question, "chat_history": chat_history.messages}, |
151 | | - callbacks=[QueueCallbackHandler(queue)], |
152 | | - ) |
153 | | - |
154 | | - chat_history.add_user_message(result["question"]) |
155 | | - chat_history.add_ai_message(result["answer"]) |
| 61 | + chat_history.add_user_message(question) |
| 62 | + chat_history.add_ai_message(answer) |
0 commit comments