|
52 | 52 | }, |
53 | 53 | { |
54 | 54 | "cell_type": "code", |
55 | | - "execution_count": null, |
| 55 | + "execution_count": 5, |
56 | 56 | "metadata": {}, |
57 | 57 | "outputs": [], |
58 | 58 | "source": [ |
|
79 | 79 | }, |
80 | 80 | { |
81 | 81 | "cell_type": "code", |
82 | | - "execution_count": null, |
| 82 | + "execution_count": 6, |
83 | 83 | "metadata": {}, |
84 | 84 | "outputs": [], |
85 | 85 | "source": [ |
|
94 | 94 | }, |
95 | 95 | { |
96 | 96 | "cell_type": "code", |
97 | | - "execution_count": null, |
| 97 | + "execution_count": 7, |
98 | 98 | "metadata": {}, |
99 | 99 | "outputs": [], |
100 | 100 | "source": [ |
|
117 | 117 | }, |
118 | 118 | { |
119 | 119 | "cell_type": "code", |
120 | | - "execution_count": null, |
| 120 | + "execution_count": 8, |
| 121 | + "metadata": {}, |
| 122 | + "outputs": [], |
| 123 | + "source": [ |
| 124 | + "def all_queries_distinct(prev_queries):\n", |
| 125 | + " query_distinct = True\n", |
| 126 | + " for i, query in enumerate(prev_queries):\n", |
| 127 | + " if validate_query_distinction_local(prev_queries[:i], query) == False:\n", |
| 128 | + " query_distinct = False\n", |
| 129 | + " break\n", |
| 130 | + " return query_distinct" |
| 131 | + ] |
| 132 | + }, |
| 133 | + { |
| 134 | + "cell_type": "code", |
| 135 | + "execution_count": 9, |
121 | 136 | "metadata": {}, |
122 | 137 | "outputs": [], |
123 | 138 | "source": [ |
|
130 | 145 | " self.generate_answer = dspy.ChainOfThought(GenerateAnswer)\n", |
131 | 146 | " self.max_hops = max_hops\n", |
132 | 147 | "\n", |
| 148 | + " # for evaluating assertions only\n", |
| 149 | + " self.passed_suggestions = 0\n", |
| 150 | + "\n", |
133 | 151 | " def forward(self, question):\n", |
134 | 152 | " context = []\n", |
| 153 | + " prev_queries = [question]\n", |
| 154 | + "\n", |
135 | 155 | " for hop in range(self.max_hops):\n", |
136 | 156 | " query = self.generate_query[hop](context=context, question=question).query\n", |
| 157 | + " prev_queries.append(query)\n", |
137 | 158 | " passages = self.retrieve(query).passages\n", |
138 | 159 | " context = deduplicate(context + passages)\n", |
| 160 | + " \n", |
| 161 | + " if all_queries_distinct(prev_queries):\n", |
| 162 | + " self.passed_suggestions += 1\n", |
| 163 | + " \n", |
139 | 164 | " pred = self.generate_answer(context=context, question=question)\n", |
140 | 165 | " pred = dspy.Prediction(context=context, answer=pred.answer)\n", |
141 | 166 | " return pred" |
142 | 167 | ] |
143 | 168 | }, |
144 | 169 | { |
145 | 170 | "cell_type": "code", |
146 | | - "execution_count": null, |
| 171 | + "execution_count": 10, |
147 | 172 | "metadata": {}, |
148 | 173 | "outputs": [], |
149 | 174 | "source": [ |
|
155 | 180 | " self.generate_answer = dspy.ChainOfThought(GenerateAnswer)\n", |
156 | 181 | " self.max_hops = max_hops\n", |
157 | 182 | "\n", |
| 183 | + " # for evaluating assertions only\n", |
| 184 | + " self.passed_suggestions = 0\n", |
| 185 | + "\n", |
158 | 186 | " def forward(self, question):\n", |
159 | 187 | " context = []\n", |
160 | 188 | " prev_queries = [question]\n", |
|
176 | 204 | " prev_queries.append(query)\n", |
177 | 205 | " passages = self.retrieve(query).passages\n", |
178 | 206 | " context = deduplicate(context + passages)\n", |
| 207 | + " \n", |
| 208 | + " if all_queries_distinct(prev_queries):\n", |
| 209 | + " self.passed_suggestions += 1\n", |
179 | 210 | "\n", |
180 | 211 | " pred = self.generate_answer(context=context, question=question)\n", |
181 | 212 | " pred = dspy.Prediction(context=context, answer=pred.answer)\n", |
|
184 | 215 | }, |
185 | 216 | { |
186 | 217 | "cell_type": "code", |
187 | | - "execution_count": null, |
| 218 | + "execution_count": 11, |
188 | 219 | "metadata": {}, |
189 | 220 | "outputs": [], |
190 | 221 | "source": [ |
191 | | - "evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=25, display_progress=True, display_table=False)" |
| 222 | + "evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=10, display_progress=True, display_table=False)" |
192 | 223 | ] |
193 | 224 | }, |
194 | 225 | { |
195 | 226 | "cell_type": "code", |
196 | | - "execution_count": null, |
| 227 | + "execution_count": 12, |
197 | 228 | "metadata": {}, |
198 | 229 | "outputs": [], |
199 | 230 | "source": [ |
200 | 231 | "def evaluate(module):\n", |
| 232 | + " module.passed_suggestions = 0\n", |
| 233 | + "\n", |
201 | 234 | " retrieval_score = evaluate_on_hotpotqa(\n", |
202 | 235 | " module, metric=gold_passages_retrieved\n", |
203 | 236 | " )\n", |
204 | 237 | " \n", |
| 238 | + " suggestions_score = module.passed_suggestions / len(devset) * 100\n", |
| 239 | + "\n", |
205 | 240 | " accuracy_score = evaluate_on_hotpotqa(\n", |
206 | 241 | " module, metric=dspy.evaluate.answer_exact_match\n", |
207 | 242 | " )\n", |
208 | 243 | "\n", |
| 244 | + " print(f\"## Suggestions Score: {suggestions_score}\")\n", |
| 245 | + "\n", |
209 | 246 | " print(f\"## Retrieval Score: {retrieval_score}\")\n", |
210 | 247 | " print(f\"## Accuracy Score: {accuracy_score}\")" |
211 | 248 | ] |
|
232 | 269 | "evaluate(baleen_with_assertions)" |
233 | 270 | ] |
234 | 271 | }, |
| 272 | + { |
| 273 | + "cell_type": "code", |
| 274 | + "execution_count": null, |
| 275 | + "metadata": {}, |
| 276 | + "outputs": [], |
| 277 | + "source": [ |
| 278 | + "max_bootstrapped_demos = 2" |
| 279 | + ] |
| 280 | + }, |
235 | 281 | { |
236 | 282 | "cell_type": "code", |
237 | 283 | "execution_count": null, |
|
242 | 288 | "baleen = SimplifiedBaleen()\n", |
243 | 289 | "teleprompter = BootstrapFewShotWithRandomSearch(\n", |
244 | 290 | " metric=validate_context_and_answer_and_hops,\n", |
245 | | - " max_bootstrapped_demos=2,\n", |
| 291 | + " max_bootstrapped_demos=max_bootstrapped_demos,\n", |
246 | 292 | " num_candidate_programs=6,\n", |
247 | 293 | ")\n", |
248 | 294 | "\n", |
249 | | - "compiled_baleen = teleprompter.compile(student = baleen, teacher = baleen, trainset = trainset, valset = devset)\n", |
| 295 | + "compiled_baleen = teleprompter.compile(student = SimplifiedBaleen(), teacher = SimplifiedBaleen(), trainset = trainset, valset = devset)\n", |
250 | 296 | "evaluate(compiled_baleen)" |
251 | 297 | ] |
252 | 298 | }, |
|
260 | 306 | "baleen = SimplifiedBaleen()\n", |
261 | 307 | "teleprompter = BootstrapFewShotWithRandomSearch(\n", |
262 | 308 | " metric=validate_context_and_answer_and_hops,\n", |
263 | | - " max_bootstrapped_demos=2,\n", |
| 309 | + " max_bootstrapped_demos=max_bootstrapped_demos,\n", |
264 | 310 | " num_candidate_programs=6,\n", |
265 | 311 | ")\n", |
266 | 312 | "compiled_baleen = teleprompter.compile(\n", |
|
270 | 316 | " ),\n", |
271 | 317 | " teacher=baleen,\n", |
272 | 318 | " trainset=trainset,\n", |
273 | | - " valset=devset[:100]\n", |
| 319 | + " valset=devset\n", |
274 | 320 | ")\n", |
275 | | - "evaluate(compiled_baleen)" |
| 321 | + "evaluate(compiled_baleen)\n" |
276 | 322 | ] |
277 | 323 | } |
278 | 324 | ], |
|
292 | 338 | "name": "python", |
293 | 339 | "nbconvert_exporter": "python", |
294 | 340 | "pygments_lexer": "ipython3", |
295 | | - "version": "3.10.13" |
| 341 | + "version": "3.11.5" |
296 | 342 | } |
297 | 343 | }, |
298 | 344 | "nbformat": 4, |
|
0 commit comments