Skip to content

Commit 7641d37

Browse files
committed
Revert back changes to intrinsic metrics
1 parent 5799ac9 commit 7641d37

File tree

1 file changed

+60
-14
lines changed

1 file changed

+60
-14
lines changed

examples/qa/hotpot/hotpotqa_with_assertions.ipynb

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
},
5353
{
5454
"cell_type": "code",
55-
"execution_count": null,
55+
"execution_count": 5,
5656
"metadata": {},
5757
"outputs": [],
5858
"source": [
@@ -79,7 +79,7 @@
7979
},
8080
{
8181
"cell_type": "code",
82-
"execution_count": null,
82+
"execution_count": 6,
8383
"metadata": {},
8484
"outputs": [],
8585
"source": [
@@ -94,7 +94,7 @@
9494
},
9595
{
9696
"cell_type": "code",
97-
"execution_count": null,
97+
"execution_count": 7,
9898
"metadata": {},
9999
"outputs": [],
100100
"source": [
@@ -117,7 +117,22 @@
117117
},
118118
{
119119
"cell_type": "code",
120-
"execution_count": null,
120+
"execution_count": 8,
121+
"metadata": {},
122+
"outputs": [],
123+
"source": [
124+
"def all_queries_distinct(prev_queries):\n",
125+
" query_distinct = True\n",
126+
" for i, query in enumerate(prev_queries):\n",
127+
" if validate_query_distinction_local(prev_queries[:i], query) == False:\n",
128+
" query_distinct = False\n",
129+
" break\n",
130+
" return query_distinct"
131+
]
132+
},
133+
{
134+
"cell_type": "code",
135+
"execution_count": 9,
121136
"metadata": {},
122137
"outputs": [],
123138
"source": [
@@ -130,20 +145,30 @@
130145
" self.generate_answer = dspy.ChainOfThought(GenerateAnswer)\n",
131146
" self.max_hops = max_hops\n",
132147
"\n",
148+
" # for evaluating assertions only\n",
149+
" self.passed_suggestions = 0\n",
150+
"\n",
133151
" def forward(self, question):\n",
134152
" context = []\n",
153+
" prev_queries = [question]\n",
154+
"\n",
135155
" for hop in range(self.max_hops):\n",
136156
" query = self.generate_query[hop](context=context, question=question).query\n",
157+
" prev_queries.append(query)\n",
137158
" passages = self.retrieve(query).passages\n",
138159
" context = deduplicate(context + passages)\n",
160+
" \n",
161+
" if all_queries_distinct(prev_queries):\n",
162+
" self.passed_suggestions += 1\n",
163+
" \n",
139164
" pred = self.generate_answer(context=context, question=question)\n",
140165
" pred = dspy.Prediction(context=context, answer=pred.answer)\n",
141166
" return pred"
142167
]
143168
},
144169
{
145170
"cell_type": "code",
146-
"execution_count": null,
171+
"execution_count": 10,
147172
"metadata": {},
148173
"outputs": [],
149174
"source": [
@@ -155,6 +180,9 @@
155180
" self.generate_answer = dspy.ChainOfThought(GenerateAnswer)\n",
156181
" self.max_hops = max_hops\n",
157182
"\n",
183+
" # for evaluating assertions only\n",
184+
" self.passed_suggestions = 0\n",
185+
"\n",
158186
" def forward(self, question):\n",
159187
" context = []\n",
160188
" prev_queries = [question]\n",
@@ -176,6 +204,9 @@
176204
" prev_queries.append(query)\n",
177205
" passages = self.retrieve(query).passages\n",
178206
" context = deduplicate(context + passages)\n",
207+
" \n",
208+
" if all_queries_distinct(prev_queries):\n",
209+
" self.passed_suggestions += 1\n",
179210
"\n",
180211
" pred = self.generate_answer(context=context, question=question)\n",
181212
" pred = dspy.Prediction(context=context, answer=pred.answer)\n",
@@ -184,28 +215,34 @@
184215
},
185216
{
186217
"cell_type": "code",
187-
"execution_count": null,
218+
"execution_count": 11,
188219
"metadata": {},
189220
"outputs": [],
190221
"source": [
191-
"evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=25, display_progress=True, display_table=False)"
222+
"evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=10, display_progress=True, display_table=False)"
192223
]
193224
},
194225
{
195226
"cell_type": "code",
196-
"execution_count": null,
227+
"execution_count": 12,
197228
"metadata": {},
198229
"outputs": [],
199230
"source": [
200231
"def evaluate(module):\n",
232+
" module.passed_suggestions = 0\n",
233+
"\n",
201234
" retrieval_score = evaluate_on_hotpotqa(\n",
202235
" module, metric=gold_passages_retrieved\n",
203236
" )\n",
204237
" \n",
238+
" suggestions_score = module.passed_suggestions / len(devset) * 100\n",
239+
"\n",
205240
" accuracy_score = evaluate_on_hotpotqa(\n",
206241
" module, metric=dspy.evaluate.answer_exact_match\n",
207242
" )\n",
208243
"\n",
244+
" print(f\"## Suggestions Score: {suggestions_score}\")\n",
245+
"\n",
209246
" print(f\"## Retrieval Score: {retrieval_score}\")\n",
210247
" print(f\"## Accuracy Score: {accuracy_score}\")"
211248
]
@@ -232,6 +269,15 @@
232269
"evaluate(baleen_with_assertions)"
233270
]
234271
},
272+
{
273+
"cell_type": "code",
274+
"execution_count": null,
275+
"metadata": {},
276+
"outputs": [],
277+
"source": [
278+
"max_bootstrapped_demos = 2"
279+
]
280+
},
235281
{
236282
"cell_type": "code",
237283
"execution_count": null,
@@ -242,11 +288,11 @@
242288
"baleen = SimplifiedBaleen()\n",
243289
"teleprompter = BootstrapFewShotWithRandomSearch(\n",
244290
" metric=validate_context_and_answer_and_hops,\n",
245-
" max_bootstrapped_demos=2,\n",
291+
" max_bootstrapped_demos=max_bootstrapped_demos,\n",
246292
" num_candidate_programs=6,\n",
247293
")\n",
248294
"\n",
249-
"compiled_baleen = teleprompter.compile(student = baleen, teacher = baleen, trainset = trainset, valset = devset)\n",
295+
"compiled_baleen = teleprompter.compile(student = SimplifiedBaleen(), teacher = SimplifiedBaleen(), trainset = trainset, valset = devset)\n",
250296
"evaluate(compiled_baleen)"
251297
]
252298
},
@@ -260,7 +306,7 @@
260306
"baleen = SimplifiedBaleen()\n",
261307
"teleprompter = BootstrapFewShotWithRandomSearch(\n",
262308
" metric=validate_context_and_answer_and_hops,\n",
263-
" max_bootstrapped_demos=2,\n",
309+
" max_bootstrapped_demos=max_bootstrapped_demos,\n",
264310
" num_candidate_programs=6,\n",
265311
")\n",
266312
"compiled_baleen = teleprompter.compile(\n",
@@ -270,9 +316,9 @@
270316
" ),\n",
271317
" teacher=baleen,\n",
272318
" trainset=trainset,\n",
273-
" valset=devset[:100]\n",
319+
" valset=devset\n",
274320
")\n",
275-
"evaluate(compiled_baleen)"
321+
"evaluate(compiled_baleen)\n"
276322
]
277323
}
278324
],
@@ -292,7 +338,7 @@
292338
"name": "python",
293339
"nbconvert_exporter": "python",
294340
"pygments_lexer": "ipython3",
295-
"version": "3.10.13"
341+
"version": "3.11.5"
296342
}
297343
},
298344
"nbformat": 4,

0 commit comments

Comments
 (0)