Skip to content

Commit a4b475b

Browse files
author
JONEMI19
committed
add a context manager to handle sigint in evaluate
1 parent 521d130 commit a4b475b

File tree

2 files changed

+124
-33
lines changed

2 files changed

+124
-33
lines changed

dspy/evaluate/evaluate.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import contextlib
2+
import signal
13
import sys
24
import threading
35
import types
@@ -14,18 +16,14 @@
1416
except ImportError:
1517
ipython_display = print
1618

17-
def HTML(x):
19+
def HTML(x) -> str: # noqa: N802
1820
return x
1921

2022

2123
from concurrent.futures import ThreadPoolExecutor, as_completed
2224

23-
from dsp.evaluation.utils import *
24-
25-
"""
26-
TODO: Counting failures and having a max_failure count. When that is exceeded (also just at the end),
27-
we print the number of failures, the first N examples that failed, and the first N exceptions raised.
28-
"""
25+
# TODO: Counting failures and having a max_failure count. When that is exceeded (also just at the end),
26+
# we print the number of failures, the first N examples that failed, and the first N exceptions raised.
2927

3028

3129
class Evaluate:
@@ -49,11 +47,13 @@ def __init__(
4947
self.max_errors = max_errors
5048
self.error_count = 0
5149
self.error_lock = threading.Lock()
50+
self.cancel_jobs = threading.Event()
5251
self.return_outputs = return_outputs
5352

5453
if "display" in _kwargs:
5554
dspy.logger.warning(
56-
"DeprecationWarning: 'display' has been deprecated. To see all information for debugging, use 'dspy.set_log_level('debug')'. In the future this will raise an error.",
55+
"DeprecationWarning: 'display' has been deprecated. To see all information for debugging,"
56+
" use 'dspy.set_log_level('debug')'. In the future this will raise an error.",
5757
)
5858

5959
def _execute_single_thread(self, wrapped_program, devset, display_progress):
@@ -78,19 +78,52 @@ def _execute_multi_thread(self, wrapped_program, devset, num_threads, display_pr
7878
ncorrect = 0
7979
ntotal = 0
8080
reordered_devset = []
81-
82-
with ThreadPoolExecutor(max_workers=num_threads) as executor:
83-
futures = {executor.submit(wrapped_program, idx, arg) for idx, arg in devset}
81+
job_cancelled = "cancelled"
82+
83+
# context manger to handle sigint
84+
@contextlib.contextmanager
85+
def interrupt_handler_manager():
86+
"""Sets the cancel_jobs event when a SIGINT is received."""
87+
default_handler = signal.getsignal(signal.SIGINT)
88+
89+
def interrupt_handler(sig, frame):
90+
self.cancel_jobs.set()
91+
dspy.logger.warning("Received SIGINT. Cancelling evaluation.")
92+
default_handler(sig, frame)
93+
94+
signal.signal(signal.SIGINT, interrupt_handler)
95+
yield
96+
# reset to the default handler
97+
signal.signal(signal.SIGINT, default_handler)
98+
99+
def cancellable_wrapped_program(idx, arg):
100+
# If the cancel_jobs event is set, return the cancelled_job literal
101+
if self.cancel_jobs.is_set():
102+
return None, None, job_cancelled, None
103+
return wrapped_program(idx, arg)
104+
105+
with ThreadPoolExecutor(max_workers=num_threads) as executor, interrupt_handler_manager():
106+
futures = {executor.submit(cancellable_wrapped_program, idx, arg) for idx, arg in devset}
84107
pbar = tqdm.tqdm(total=len(devset), dynamic_ncols=True, disable=not display_progress)
85108

86109
for future in as_completed(futures):
87110
example_idx, example, prediction, score = future.result()
111+
112+
# use the cancelled_job literal to check if the job was cancelled - use "is" not "=="
113+
# in case the prediction is "cancelled" for some reason.
114+
if prediction is job_cancelled:
115+
continue
116+
88117
reordered_devset.append((example_idx, example, prediction, score))
89118
ncorrect += score
90119
ntotal += 1
91120
self._update_progress(pbar, ncorrect, ntotal)
92121
pbar.close()
93122

123+
if self.cancel_jobs.is_set():
124+
dspy.logger.warning("Evaluation was cancelled. The results may be incomplete.")
125+
raise KeyboardInterrupt
126+
94127
return reordered_devset, ncorrect, ntotal
95128

96129
def _update_progress(self, pbar, ncorrect, ntotal):
@@ -175,24 +208,21 @@ def wrapped_program(example_idx, example):
175208
merge_dicts(example, prediction) | {"correct": score} for _, example, prediction, score in predicted_devset
176209
]
177210

178-
df = pd.DataFrame(data)
211+
result_df = pd.DataFrame(data)
179212

180-
# Truncate every cell in the DataFrame
181-
if hasattr(df, "map"): # DataFrame.applymap was renamed to DataFrame.map in Pandas 2.1.0
182-
df = df.map(truncate_cell)
183-
else:
184-
df = df.applymap(truncate_cell)
213+
# Truncate every cell in the DataFrame (DataFrame.applymap was renamed to DataFrame.map in Pandas 2.1.0)
214+
result_df = result_df.map(truncate_cell) if hasattr(result_df, "map") else result_df.applymap(truncate_cell)
185215

186216
# Rename the 'correct' column to the name of the metric object
187217
metric_name = metric.__name__ if isinstance(metric, types.FunctionType) else metric.__class__.__name__
188-
df.rename(columns={"correct": metric_name}, inplace=True)
218+
result_df = result_df.rename(columns={"correct": metric_name})
189219

190220
if display_table:
191221
if isinstance(display_table, int):
192-
df_to_display = df.head(display_table).copy()
193-
truncated_rows = len(df) - display_table
222+
df_to_display = result_df.head(display_table).copy()
223+
truncated_rows = len(result_df) - display_table
194224
else:
195-
df_to_display = df.copy()
225+
df_to_display = result_df.copy()
196226
truncated_rows = 0
197227

198228
styled_df = configure_dataframe_display(df_to_display, metric_name)
@@ -215,15 +245,15 @@ def wrapped_program(example_idx, example):
215245

216246
if return_all_scores and return_outputs:
217247
return round(100 * ncorrect / ntotal, 2), results, [score for *_, score in predicted_devset]
218-
elif return_all_scores:
248+
if return_all_scores:
219249
return round(100 * ncorrect / ntotal, 2), [score for *_, score in predicted_devset]
220-
elif return_outputs:
250+
if return_outputs:
221251
return round(100 * ncorrect / ntotal, 2), results
222252

223253
return round(100 * ncorrect / ntotal, 2)
224254

225255

226-
def merge_dicts(d1, d2):
256+
def merge_dicts(d1, d2) -> dict:
227257
merged = {}
228258
for k, v in d1.items():
229259
if k in d2:
@@ -240,25 +270,22 @@ def merge_dicts(d1, d2):
240270
return merged
241271

242272

243-
def truncate_cell(content):
273+
def truncate_cell(content) -> str:
244274
"""Truncate content of a cell to 25 words."""
245275
words = str(content).split()
246276
if len(words) > 25:
247277
return " ".join(words[:25]) + "..."
248278
return content
249279

250280

251-
def configure_dataframe_display(df, metric_name):
281+
def configure_dataframe_display(df, metric_name) -> pd.DataFrame:
252282
"""Set various pandas display options for DataFrame."""
253283
pd.options.display.max_colwidth = None
254284
pd.set_option("display.max_colwidth", 20) # Adjust the number as needed
255285
pd.set_option("display.width", 400) # Adjust
256286

257-
# df[metric_name] = df[metric_name].apply(lambda x: f'✔️ [{x}]' if x is True else f'❌ [{x}]')
258-
# df.loc[:, metric_name] = df[metric_name].apply(lambda x: f"✔️ [{x}]" if x is True else f"{x}")
259287
df[metric_name] = df[metric_name].apply(lambda x: f"✔️ [{x}]" if x else str(x))
260288

261-
262289
# Return styled DataFrame
263290
return df.style.set_table_styles(
264291
[
@@ -276,4 +303,5 @@ def configure_dataframe_display(df, metric_name):
276303

277304

278305
# FIXME: TODO: The merge_dicts stuff above is way too quick and dirty.
279-
# TODO: the display_table can't handle False but can handle 0! Not sure how it works with True exactly, probably fails too.
306+
# TODO: the display_table can't handle False but can handle 0!
307+
# Not sure how it works with True exactly, probably fails too.

tests/evaluate/test_evaluate.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
1-
import dsp, dspy
1+
import signal
2+
import threading
3+
4+
import pytest
5+
6+
import dsp
7+
import dspy
28
from dspy.evaluate.evaluate import Evaluate
39
from dspy.evaluate.metrics import answer_exact_match
410
from dspy.predict import Predict
511
from dspy.utils.dummies import DummyLM
612

13+
714
def new_example(question, answer):
815
"""Helper function to create a new example."""
916
return dspy.Example(
1017
question=question,
1118
answer=answer,
12-
).with_inputs("question")
19+
).with_inputs("question")
20+
1321

1422
def test_evaluate_initialization():
1523
devset = [new_example("What is 1+1?", "2")]
@@ -23,6 +31,7 @@ def test_evaluate_initialization():
2331
assert ev.num_threads == len(devset)
2432
assert ev.display_progress == False
2533

34+
2635
def test_evaluate_call():
2736
dspy.settings.configure(lm=DummyLM({"What is 1+1?": "2", "What is 2+2?": "4"}))
2837
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
@@ -36,6 +45,60 @@ def test_evaluate_call():
3645
score = ev(program)
3746
assert score == 100.0
3847

48+
49+
def test_multithread_evaluate_call():
50+
dspy.settings.configure(lm=DummyLM({"What is 1+1?": "2", "What is 2+2?": "4"}))
51+
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
52+
program = Predict("question -> answer")
53+
assert program(question="What is 1+1?").answer == "2"
54+
ev = Evaluate(
55+
devset=devset,
56+
metric=answer_exact_match,
57+
display_progress=False,
58+
num_threads=2,
59+
)
60+
score = ev(program)
61+
assert score == 100.0
62+
63+
64+
def test_multi_thread_evaluate_call_cancelled(monkeypatch):
65+
# slow LM that sleeps for 1 second before returning the answer
66+
class SlowLM(DummyLM):
67+
def __call__(self, prompt, **kwargs):
68+
import time
69+
70+
time.sleep(1)
71+
return super().__call__(prompt, **kwargs)
72+
73+
dspy.settings.configure(lm=SlowLM({"What is 1+1?": "2", "What is 2+2?": "4"}))
74+
75+
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
76+
program = Predict("question -> answer")
77+
assert program(question="What is 1+1?").answer == "2"
78+
79+
# spawn a thread that will sleep for .1 seconds then send a KeyboardInterrupt
80+
def sleep_then_interrupt():
81+
import time
82+
83+
time.sleep(0.1)
84+
import os
85+
86+
os.kill(os.getpid(), signal.SIGINT)
87+
88+
input_thread = threading.Thread(target=sleep_then_interrupt)
89+
input_thread.start()
90+
91+
with pytest.raises(KeyboardInterrupt):
92+
ev = Evaluate(
93+
devset=devset,
94+
metric=answer_exact_match,
95+
display_progress=False,
96+
num_threads=2,
97+
)
98+
score = ev(program)
99+
assert score == 100.0
100+
101+
39102
def test_evaluate_call_bad():
40103
dspy.settings.configure(lm=DummyLM({"What is 1+1?": "0", "What is 2+2?": "0"}))
41104
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
@@ -48,6 +111,7 @@ def test_evaluate_call_bad():
48111
score = ev(program)
49112
assert score == 0.0
50113

114+
51115
def test_evaluate_display_table():
52116
devset = [new_example("What is 1+1?", "2")]
53117
ev = Evaluate(
@@ -56,4 +120,3 @@ def test_evaluate_display_table():
56120
display_table=True,
57121
)
58122
assert ev.display_table == True
59-

0 commit comments

Comments
 (0)