|
20 | 20 |
|
21 | 21 | import tqdm |
22 | 22 | import typer |
23 | | -from langchain.chains import LLMChain |
24 | | -from langchain.prompts import PromptTemplate |
| 23 | +from langchain_core.prompts import PromptTemplate |
25 | 24 |
|
26 | 25 | from nemoguardrails import LLMRails |
27 | 26 | from nemoguardrails.actions.llm.utils import llm_call |
@@ -94,19 +93,23 @@ def create_negative_samples(self, dataset): |
94 | 93 | template=create_negatives_template, |
95 | 94 | input_variables=["evidence", "answer"], |
96 | 95 | ) |
97 | | - create_negatives_chain = LLMChain(prompt=create_negatives_prompt, llm=self.llm) |
| 96 | + |
| 97 | + # Bind config parameters to the LLM for generating negative samples |
| 98 | + llm_with_config = self.llm.bind(temperature=0.8, max_tokens=300) |
98 | 99 |
|
99 | 100 | print("Creating negative samples...") |
100 | 101 | for data in tqdm.tqdm(dataset): |
101 | 102 | assert "evidence" in data and "question" in data and "answer" in data |
102 | 103 | evidence = data["evidence"] |
103 | 104 | answer = data["answer"] |
104 | | - negative_answer_result = create_negatives_chain.invoke( |
105 | | - {"evidence": evidence, "answer": answer}, |
106 | | - config={"temperature": 0.8, "max_tokens": 300}, |
| 105 | + |
| 106 | + # Format the prompt and invoke the LLM directly |
| 107 | + formatted_prompt = create_negatives_prompt.format( |
| 108 | + evidence=evidence, answer=answer |
107 | 109 | ) |
108 | | - negative_answer = negative_answer_result["text"] |
109 | | - data["incorrect_answer"] = negative_answer.strip() |
| 110 | + negative_answer = llm_with_config.invoke(formatted_prompt) |
| 111 | + negative_answer_content = negative_answer.content |
| 112 | + data["incorrect_answer"] = negative_answer_content.strip() |
110 | 113 |
|
111 | 114 | return dataset |
112 | 115 |
|
@@ -186,14 +189,16 @@ def run(self): |
186 | 189 | split="negative" |
187 | 190 | ) |
188 | 191 |
|
189 | | - print(f"Positive Accuracy: {pos_num_correct/len(self.dataset) * 100}") |
190 | | - print(f"Negative Accuracy: {neg_num_correct/len(self.dataset) * 100}") |
| 192 | + print(f"Positive Accuracy: {pos_num_correct / len(self.dataset) * 100}") |
| 193 | + print(f"Negative Accuracy: {neg_num_correct / len(self.dataset) * 100}") |
191 | 194 | print( |
192 | | - f"Overall Accuracy: {(pos_num_correct + neg_num_correct)/(2*len(self.dataset))* 100}" |
| 195 | + f"Overall Accuracy: {(pos_num_correct + neg_num_correct) / (2 * len(self.dataset)) * 100}" |
193 | 196 | ) |
194 | 197 |
|
195 | 198 | print("---Time taken per sample:---") |
196 | | - print(f"Ask LLM:\t{(pos_time+neg_time)*1000/(2*len(self.dataset)):.1f}ms") |
| 199 | + print( |
| 200 | + f"Ask LLM:\t{(pos_time + neg_time) * 1000 / (2 * len(self.dataset)):.1f}ms" |
| 201 | + ) |
197 | 202 |
|
198 | 203 | if self.write_outputs: |
199 | 204 | dataset_name = os.path.basename(self.dataset_path).split(".")[0] |
|
0 commit comments