diff --git a/rag/generate.py b/rag/generate.py index 331540e..026c42e 100644 --- a/rag/generate.py +++ b/rag/generate.py @@ -135,6 +135,7 @@ def __call__( context_results[lexical_search_k:lexical_search_k] = lexical_context # Rerank + predicted_tag = None if self.reranker: predicted_tag = custom_predict( inputs=[query], classifier=self.reranker, threshold=rerank_threshold @@ -165,6 +166,7 @@ def __call__( "question": query, "sources": sources, "document_ids": document_ids, + "predicted_tag": predicted_tag, "answer": answer, "llm": self.llm, } diff --git a/rag/serve.py b/rag/serve.py index d19d06e..5958947 100644 --- a/rag/serve.py +++ b/rag/serve.py @@ -209,6 +209,7 @@ def produce_streaming_answer(self, query, result): "finished streaming query", query=query, document_ids=result["document_ids"], + predicted_tag=result["predicted_tag"], llm=result["llm"], answer="".join(answer), ) @@ -231,7 +232,7 @@ def stream(self, query: Query) -> StreamingResponse: use_lexical_search=False, lexical_search_k=0, use_reranking=True, - rerank_threshold=0.9, + rerank_threshold=0.5, rerank_k=9, llm="mistralai/Mixtral-8x7B-Instruct-v0.1", sql_dump_fp=Path(os.environ["RAY_ASSISTANT_INDEX"]),