1+ import contextlib
2+ import signal
13import sys
24import threading
35import types
1416except ImportError :
1517 ipython_display = print
1618
17- def HTML (x ):
19+ def HTML (x ) -> str : # noqa: N802
1820 return x
1921
2022
2123from concurrent .futures import ThreadPoolExecutor , as_completed
2224
23- from dsp .evaluation .utils import *
24-
25- """
26- TODO: Counting failures and having a max_failure count. When that is exceeded (also just at the end),
27- we print the number of failures, the first N examples that failed, and the first N exceptions raised.
28- """
25+ # TODO: Counting failures and having a max_failure count. When that is exceeded (also just at the end),
26+ # we print the number of failures, the first N examples that failed, and the first N exceptions raised.
2927
3028
3129class Evaluate :
@@ -49,11 +47,13 @@ def __init__(
4947 self .max_errors = max_errors
5048 self .error_count = 0
5149 self .error_lock = threading .Lock ()
50+ self .cancel_jobs = threading .Event ()
5251 self .return_outputs = return_outputs
5352
5453 if "display" in _kwargs :
5554 dspy .logger .warning (
56- "DeprecationWarning: 'display' has been deprecated. To see all information for debugging, use 'dspy.set_log_level('debug')'. In the future this will raise an error." ,
55+ "DeprecationWarning: 'display' has been deprecated. To see all information for debugging,"
56+ " use 'dspy.set_log_level('debug')'. In the future this will raise an error." ,
5757 )
5858
5959 def _execute_single_thread (self , wrapped_program , devset , display_progress ):
@@ -78,19 +78,52 @@ def _execute_multi_thread(self, wrapped_program, devset, num_threads, display_pr
7878 ncorrect = 0
7979 ntotal = 0
8080 reordered_devset = []
81-
82- with ThreadPoolExecutor (max_workers = num_threads ) as executor :
83- futures = {executor .submit (wrapped_program , idx , arg ) for idx , arg in devset }
81+ job_cancelled = "cancelled"
82+
83+ # context manger to handle sigint
84+ @contextlib .contextmanager
85+ def interrupt_handler_manager ():
86+ """Sets the cancel_jobs event when a SIGINT is received."""
87+ default_handler = signal .getsignal (signal .SIGINT )
88+
89+ def interrupt_handler (sig , frame ):
90+ self .cancel_jobs .set ()
91+ dspy .logger .warning ("Received SIGINT. Cancelling evaluation." )
92+ default_handler (sig , frame )
93+
94+ signal .signal (signal .SIGINT , interrupt_handler )
95+ yield
96+ # reset to the default handler
97+ signal .signal (signal .SIGINT , default_handler )
98+
99+ def cancellable_wrapped_program (idx , arg ):
100+ # If the cancel_jobs event is set, return the cancelled_job literal
101+ if self .cancel_jobs .is_set ():
102+ return None , None , job_cancelled , None
103+ return wrapped_program (idx , arg )
104+
105+ with ThreadPoolExecutor (max_workers = num_threads ) as executor , interrupt_handler_manager ():
106+ futures = {executor .submit (cancellable_wrapped_program , idx , arg ) for idx , arg in devset }
84107 pbar = tqdm .tqdm (total = len (devset ), dynamic_ncols = True , disable = not display_progress )
85108
86109 for future in as_completed (futures ):
87110 example_idx , example , prediction , score = future .result ()
111+
112+ # use the cancelled_job literal to check if the job was cancelled - use "is" not "=="
113+ # in case the prediction is "cancelled" for some reason.
114+ if prediction is job_cancelled :
115+ continue
116+
88117 reordered_devset .append ((example_idx , example , prediction , score ))
89118 ncorrect += score
90119 ntotal += 1
91120 self ._update_progress (pbar , ncorrect , ntotal )
92121 pbar .close ()
93122
123+ if self .cancel_jobs .is_set ():
124+ dspy .logger .warning ("Evaluation was cancelled. The results may be incomplete." )
125+ raise KeyboardInterrupt
126+
94127 return reordered_devset , ncorrect , ntotal
95128
96129 def _update_progress (self , pbar , ncorrect , ntotal ):
@@ -175,24 +208,21 @@ def wrapped_program(example_idx, example):
175208 merge_dicts (example , prediction ) | {"correct" : score } for _ , example , prediction , score in predicted_devset
176209 ]
177210
178- df = pd .DataFrame (data )
211+ result_df = pd .DataFrame (data )
179212
180- # Truncate every cell in the DataFrame
181- if hasattr (df , "map" ): # DataFrame.applymap was renamed to DataFrame.map in Pandas 2.1.0
182- df = df .map (truncate_cell )
183- else :
184- df = df .applymap (truncate_cell )
213+ # Truncate every cell in the DataFrame (DataFrame.applymap was renamed to DataFrame.map in Pandas 2.1.0)
214+ result_df = result_df .map (truncate_cell ) if hasattr (result_df , "map" ) else result_df .applymap (truncate_cell )
185215
186216 # Rename the 'correct' column to the name of the metric object
187217 metric_name = metric .__name__ if isinstance (metric , types .FunctionType ) else metric .__class__ .__name__
188- df .rename (columns = {"correct" : metric_name }, inplace = True )
218+ result_df = result_df .rename (columns = {"correct" : metric_name })
189219
190220 if display_table :
191221 if isinstance (display_table , int ):
192- df_to_display = df .head (display_table ).copy ()
193- truncated_rows = len (df ) - display_table
222+ df_to_display = result_df .head (display_table ).copy ()
223+ truncated_rows = len (result_df ) - display_table
194224 else :
195- df_to_display = df .copy ()
225+ df_to_display = result_df .copy ()
196226 truncated_rows = 0
197227
198228 styled_df = configure_dataframe_display (df_to_display , metric_name )
@@ -215,15 +245,15 @@ def wrapped_program(example_idx, example):
215245
216246 if return_all_scores and return_outputs :
217247 return round (100 * ncorrect / ntotal , 2 ), results , [score for * _ , score in predicted_devset ]
218- elif return_all_scores :
248+ if return_all_scores :
219249 return round (100 * ncorrect / ntotal , 2 ), [score for * _ , score in predicted_devset ]
220- elif return_outputs :
250+ if return_outputs :
221251 return round (100 * ncorrect / ntotal , 2 ), results
222252
223253 return round (100 * ncorrect / ntotal , 2 )
224254
225255
226- def merge_dicts (d1 , d2 ):
256+ def merge_dicts (d1 , d2 ) -> dict :
227257 merged = {}
228258 for k , v in d1 .items ():
229259 if k in d2 :
@@ -240,25 +270,22 @@ def merge_dicts(d1, d2):
240270 return merged
241271
242272
243- def truncate_cell (content ):
273+ def truncate_cell (content ) -> str :
244274 """Truncate content of a cell to 25 words."""
245275 words = str (content ).split ()
246276 if len (words ) > 25 :
247277 return " " .join (words [:25 ]) + "..."
248278 return content
249279
250280
251- def configure_dataframe_display (df , metric_name ):
281+ def configure_dataframe_display (df , metric_name ) -> pd . DataFrame :
252282 """Set various pandas display options for DataFrame."""
253283 pd .options .display .max_colwidth = None
254284 pd .set_option ("display.max_colwidth" , 20 ) # Adjust the number as needed
255285 pd .set_option ("display.width" , 400 ) # Adjust
256286
257- # df[metric_name] = df[metric_name].apply(lambda x: f'✔️ [{x}]' if x is True else f'❌ [{x}]')
258- # df.loc[:, metric_name] = df[metric_name].apply(lambda x: f"✔️ [{x}]" if x is True else f"{x}")
259287 df [metric_name ] = df [metric_name ].apply (lambda x : f"✔️ [{ x } ]" if x else str (x ))
260288
261-
262289 # Return styled DataFrame
263290 return df .style .set_table_styles (
264291 [
@@ -276,4 +303,5 @@ def configure_dataframe_display(df, metric_name):
276303
277304
278305# FIXME: TODO: The merge_dicts stuff above is way too quick and dirty.
279- # TODO: the display_table can't handle False but can handle 0! Not sure how it works with True exactly, probably fails too.
306+ # TODO: the display_table can't handle False but can handle 0!
307+ # Not sure how it works with True exactly, probably fails too.
0 commit comments