3535from langchain .text_splitter import RecursiveCharacterTextSplitter
3636from langchain .callbacks .manager import CallbackManagerForRetrieverRun
3737from langchain .embeddings .sagemaker_endpoint import EmbeddingsContentHandler
38+ from langchain .callbacks .streaming_stdout import StreamingStdOutCallbackHandler
3839from langchain_core .prompts import ChatPromptTemplate , HumanMessagePromptTemplate , SystemMessagePromptTemplate
3940
4041import threading
@@ -313,7 +314,19 @@ class retriever_utils():
313314 length_function = len ,
314315 )
315316 token_limit = 300
317+
318+ @classmethod
319+ def control_streaming_mode (cls , llm , stream = True ):
320+
321+ if stream :
322+ llm .streaming = True
323+ llm .callbacks = [StreamingStdOutCallbackHandler ()]
324+ else :
325+ llm .streaming = False
326+ llm .callbacks = None
316327
328+ return llm
329+
317330 @classmethod
318331 # semantic search based
319332 def get_semantic_similar_docs_by_langchain (cls , ** kwargs ):
@@ -476,9 +489,10 @@ def get_rag_fusion_similar_docs(cls, **kwargs):
476489 llm_text = kwargs ["llm_text" ]
477490 query_augmentation_size = kwargs ["query_augmentation_size" ]
478491 query_transformation_prompt = kwargs ["query_transformation_prompt" ]
479-
492+
493+ llm_text = cls .control_streaming_mode (llm_text , stream = False ) ## trun off llm streaming
480494 generate_queries = query_transformation_prompt | llm_text | StrOutputParser () | (lambda x : x .split ("\n " ))
481-
495+
482496 rag_fusion_query = generate_queries .invoke (
483497 {
484498 "query" : kwargs ["query" ],
@@ -495,6 +509,9 @@ def get_rag_fusion_similar_docs(cls, **kwargs):
495509 print ("\n " )
496510 print ("===== RAG-Fusion Queries =====" )
497511 print (rag_fusion_query )
512+
513+ llm_text = cls .control_streaming_mode (llm_text , stream = True )## trun on llm streaming
514+
498515
499516 tasks = []
500517 for query in rag_fusion_query :
@@ -547,6 +564,7 @@ def _get_hyde_response(query, prompt, llm_text):
547564 hyde_query = kwargs ["hyde_query" ]
548565
549566 tasks = []
567+ llm_text = cls .control_streaming_mode (llm_text , stream = False ) ## trun off llm streaming
550568 for template_type in hyde_query :
551569 hyde_response = partial (
552570 _get_hyde_response ,
@@ -557,10 +575,9 @@ def _get_hyde_response(query, prompt, llm_text):
557575 tasks .append (cls .hyde_pool .apply_async (hyde_response ,))
558576 hyde_answers = [task .get () for task in tasks ]
559577 hyde_answers .insert (0 , query )
560- augmentation = hyde_answers
561-
562578
563579 tasks = []
580+ llm_text = cls .control_streaming_mode (llm_text , stream = True ) ## trun on llm streaming
564581 for hyde_answer in hyde_answers :
565582 semantic_search = partial (
566583 cls .get_semantic_similar_docs ,
@@ -583,10 +600,13 @@ def _get_hyde_response(query, prompt, llm_text):
583600 c = 60 ,
584601 k = kwargs ["k" ],
585602 )
603+
586604 if kwargs ["verbose" ]:
587605 print ("\n " )
588606 print ("===== HyDE Answers =====" )
589607 print (hyde_answers )
608+
609+ augmentation = hyde_answers [1 ]
590610
591611 return similar_docs
592612
0 commit comments