|
| 1 | +from langchain_nvidia_ai_endpoints import NVIDIARerank |
| 2 | +api_key = "<add your api key>" |
| 3 | +from multiagent import HybridRetriever |
| 4 | +import io |
| 5 | +from contextlib import redirect_stdout, redirect_stderr |
| 6 | +from utils import automation |
| 7 | + |
| 8 | + |
| 9 | +class Nodes: |
| 10 | + @staticmethod |
| 11 | + def retrieve(state): |
| 12 | + print("---RETRIEVE---") |
| 13 | + question = state["question"] |
| 14 | + path = state["path"] |
| 15 | + hybrid_retriever_instance = HybridRetriever(path, api_key) |
| 16 | + hybrid_retriever = hybrid_retriever_instance.get_retriever() |
| 17 | + with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()): |
| 18 | + documents = hybrid_retriever.get_relevant_documents(question) |
| 19 | + |
| 20 | + return {"documents": documents, "question": question} |
| 21 | + |
| 22 | + @staticmethod |
| 23 | + def rerank(state): |
| 24 | + print("NVIDIA--RERANKER") |
| 25 | + question = state["question"] |
| 26 | + documents = state["documents"] |
| 27 | + reranker = NVIDIARerank(model="nvidia/llama-3.2-nv-rerankqa-1b-v2", api_key=api_key) |
| 28 | + documents = reranker.compress_documents(query=question, documents=documents) |
| 29 | + return {"documents": documents, "question": question} |
| 30 | + |
| 31 | + @staticmethod |
| 32 | + def generate(state): |
| 33 | + print("GENERATE USING LLM") |
| 34 | + question = state["question"] |
| 35 | + documents = state["documents"] |
| 36 | + |
| 37 | + generation = automation.rag_chain.invoke({"context": documents, "question": question}) |
| 38 | + return {"documents": documents, "question": question, "generation": generation} |
| 39 | + |
| 40 | + @staticmethod |
| 41 | + def grade_documents(state): |
| 42 | + print("CHECKING DOCUMENT RELEVANCE TO QUESTION") |
| 43 | + question = state["question"] |
| 44 | + ret_documents = state["documents"] |
| 45 | + |
| 46 | + filtered_docs = [] |
| 47 | + for doc in ret_documents: |
| 48 | + score = automation.retrieval_grader.invoke( |
| 49 | + {"question": question, "document": doc.page_content} |
| 50 | + ) |
| 51 | + grade = score.binary_score |
| 52 | + if grade == "yes": |
| 53 | + print("---GRADE: DOCUMENT RELEVANT---") |
| 54 | + filtered_docs.append(doc) |
| 55 | + else: |
| 56 | + print("---GRADE: DOCUMENT NOT RELEVANT---") |
| 57 | + return {"documents": filtered_docs, "question": question} |
| 58 | + |
| 59 | + @staticmethod |
| 60 | + def transform_query(state): |
| 61 | + |
| 62 | + print("REWRITE PROMPT") |
| 63 | + question = state["question"] |
| 64 | + documents = state["documents"] |
| 65 | + |
| 66 | + better_question = automation.question_rewriter.invoke({"question": question}) |
| 67 | + print(f"actual query : {question} \n Transformed query:{better_question}") |
| 68 | + return {"documents": documents, "question": better_question} |
0 commit comments