Skip to content

Commit 294bc0b

Browse files
Merge pull request #88 from seungwon2/master
[Fix] : fix HyDE RAG fusion error
2 parents 9585cd5 + 5baea2c commit 294bc0b

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

genai/aws-gen-ai-kr/utils/rag_summit.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from langchain.text_splitter import RecursiveCharacterTextSplitter
3636
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
3737
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
38+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
3839
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
3940

4041
import threading
@@ -313,7 +314,19 @@ class retriever_utils():
313314
length_function=len,
314315
)
315316
token_limit = 300
317+
318+
@classmethod
319+
def control_streaming_mode(cls, llm, stream=True):
320+
321+
if stream:
322+
llm.streaming = True
323+
llm.callbacks = [StreamingStdOutCallbackHandler()]
324+
else:
325+
llm.streaming = False
326+
llm.callbacks = None
316327

328+
return llm
329+
317330
@classmethod
318331
# semantic search based
319332
def get_semantic_similar_docs_by_langchain(cls, **kwargs):
@@ -476,9 +489,10 @@ def get_rag_fusion_similar_docs(cls, **kwargs):
476489
llm_text = kwargs["llm_text"]
477490
query_augmentation_size = kwargs["query_augmentation_size"]
478491
query_transformation_prompt = kwargs["query_transformation_prompt"]
479-
492+
493+
llm_text = cls.control_streaming_mode(llm_text, stream=False) ## trun off llm streaming
480494
generate_queries = query_transformation_prompt | llm_text | StrOutputParser() | (lambda x: x.split("\n"))
481-
495+
482496
rag_fusion_query = generate_queries.invoke(
483497
{
484498
"query": kwargs["query"],
@@ -495,6 +509,9 @@ def get_rag_fusion_similar_docs(cls, **kwargs):
495509
print("\n")
496510
print("===== RAG-Fusion Queries =====")
497511
print(rag_fusion_query)
512+
513+
llm_text = cls.control_streaming_mode(llm_text, stream=True)## trun on llm streaming
514+
498515

499516
tasks = []
500517
for query in rag_fusion_query:
@@ -547,6 +564,7 @@ def _get_hyde_response(query, prompt, llm_text):
547564
hyde_query = kwargs["hyde_query"]
548565

549566
tasks = []
567+
llm_text = cls.control_streaming_mode(llm_text, stream=False) ## trun off llm streaming
550568
for template_type in hyde_query:
551569
hyde_response = partial(
552570
_get_hyde_response,
@@ -557,10 +575,9 @@ def _get_hyde_response(query, prompt, llm_text):
557575
tasks.append(cls.hyde_pool.apply_async(hyde_response,))
558576
hyde_answers = [task.get() for task in tasks]
559577
hyde_answers.insert(0, query)
560-
augmentation = hyde_answers
561-
562578

563579
tasks = []
580+
llm_text = cls.control_streaming_mode(llm_text, stream=True) ## trun on llm streaming
564581
for hyde_answer in hyde_answers:
565582
semantic_search = partial(
566583
cls.get_semantic_similar_docs,
@@ -583,10 +600,13 @@ def _get_hyde_response(query, prompt, llm_text):
583600
c=60,
584601
k=kwargs["k"],
585602
)
603+
586604
if kwargs["verbose"]:
587605
print("\n")
588606
print("===== HyDE Answers =====")
589607
print(hyde_answers)
608+
609+
augmentation = hyde_answers[1]
590610

591611
return similar_docs
592612

0 commit comments

Comments
 (0)