2626import rai_bench .manipulation_o3de as manipulation_o3de
2727import rai_bench .tool_calling_agent as tool_calling_agent
2828import rai_bench .vlm_benchmark as vlm_benchmark
29- from rai_bench .base_benchmark import ModelSummary , RunSummary
30- from rai_bench .results_processing .data_loading import SUMMARY_FILE_NAME
29+ from rai_bench .base_benchmark import ModelSummary , RunSummary , TasksSummary
30+ from rai_bench .results_processing .data_loading import (
31+ DETAILED_FILE_NAME ,
32+ SUMMARY_FILE_NAME ,
33+ )
3134from rai_bench .utils import (
3235 define_benchmark_logger ,
3336 get_llm_for_benchmark ,
3437 get_llm_model_name ,
3538)
3639
3740REPEATS_SUMMARY_FILE_NAME = "repeats_summary.csv"
41+ TASKS_SUMMARY_FILE_NAME = "tasks_summary.csv"
3842BENCHMARK_SUMMARY = "benchmark_summary.csv"
3943
4044
@@ -151,7 +155,7 @@ def merge_model_repeats_summary(
151155
152156 merged_file = model_dir / REPEATS_SUMMARY_FILE_NAME
153157 with open (merged_file , "w" , newline = "" ) as f :
154- writer = csv .DictWriter (f , fieldnames = RunSummary .model_fields .keys ())
158+ writer = csv .DictWriter (f , fieldnames = ModelSummary .model_fields .keys ())
155159 writer .writeheader ()
156160 writer .writerow (merged_summary .model_dump ())
157161
@@ -174,7 +178,7 @@ def merge_benchmark_summary(
174178 if not bench_dir .exists ():
175179 return
176180
177- all_summaries : List [RunSummary ] = []
181+ all_summaries : List [ModelSummary ] = []
178182 for model_name in model_names :
179183 model_dir = bench_dir / model_name
180184 merged_file = model_dir / REPEATS_SUMMARY_FILE_NAME
@@ -183,19 +187,89 @@ def merge_benchmark_summary(
183187 with open (merged_file , "r" ) as f :
184188 reader = csv .DictReader (f )
185189 for row in reader :
186- all_summaries .append (RunSummary .model_validate (row ))
190+ all_summaries .append (ModelSummary .model_validate (row ))
187191
188192 if not all_summaries :
189193 return
190194
191195 benchmark_summary_file = bench_dir / BENCHMARK_SUMMARY
192196 with open (benchmark_summary_file , "w" , newline = "" ) as f :
193- writer = csv .DictWriter (f , fieldnames = RunSummary .model_fields .keys ())
197+ writer = csv .DictWriter (f , fieldnames = ModelSummary .model_fields .keys ())
194198 writer .writeheader ()
195199 for summary in all_summaries :
196200 writer .writerow (summary .model_dump ())
197201
198202
203+ def merge_tasks_summary (bench_name : str , model_name : str , run_dir : Path ) -> None :
204+ """Merge task results across all repeats for a single model, aggregating by task.
205+
206+ Parameters
207+ ----------
208+ bench_name : str
209+ Name of the benchmark
210+ model_name : str
211+ Name of the model
212+ run_dir : Path
213+ Directory containing the benchmark run results
214+ """
215+ model_dir = run_dir / bench_name / model_name
216+ if not model_dir .exists ():
217+ return
218+
219+ # Collect all task results from all repeats
220+ task_data_by_prompt : Dict [str , Dict [str , List [float ]]] = {}
221+
222+ for repeat_dir in model_dir .iterdir ():
223+ if repeat_dir .is_dir () and repeat_dir .name .isdigit ():
224+ results_file = repeat_dir / DETAILED_FILE_NAME
225+ if results_file .exists ():
226+ # Read detailed results from this repeat
227+ with open (results_file , "r" ) as f :
228+ reader = csv .DictReader (f )
229+ for row in reader :
230+ task_prompt = row ["task_prompt" ]
231+ score = float (row ["score" ])
232+ total_time = float (row ["total_time" ])
233+
234+ if task_prompt not in task_data_by_prompt :
235+ task_data_by_prompt [task_prompt ] = {
236+ "scores" : [],
237+ "times" : [],
238+ }
239+
240+ task_data_by_prompt [task_prompt ]["scores" ].append (score )
241+ task_data_by_prompt [task_prompt ]["times" ].append (total_time )
242+
243+ if not task_data_by_prompt :
244+ return
245+
246+ # Calculate statistics for each task
247+ task_summaries : List [TasksSummary ] = []
248+ for task_prompt , data in task_data_by_prompt .items ():
249+ scores = np .array (data ["scores" ])
250+ times = np .array (data ["times" ])
251+
252+ task_summary = TasksSummary (
253+ model_name = model_name ,
254+ task_prompt = task_prompt ,
255+ avg_success_rate = round (float (scores .mean ()), 3 ),
256+ std_success_rate = round (float (scores .std ()), 3 ),
257+ avg_time = round (float (times .mean ()), 3 ),
258+ std_time = round (float (times .std ()), 3 ),
259+ repeats = len (scores ), # TODO (mkotynia) (extract repeats in another way)
260+ )
261+ task_summaries .append (task_summary )
262+
263+ # Save task summaries to CSV
264+ tasks_summary_file = model_dir / TASKS_SUMMARY_FILE_NAME
265+ with open (tasks_summary_file , "w" , newline = "" ) as f :
266+ if task_summaries :
267+ writer = csv .DictWriter (f , fieldnames = TasksSummary .model_fields .keys ())
268+ writer .writeheader ()
269+ for task_summary in task_summaries :
270+ writer .writerow (task_summary .model_dump ())
271+
272+
199273def test_dual_agents (
200274 multimodal_llms : List [BaseChatModel ],
201275 tool_calling_models : List [BaseChatModel ],
@@ -351,6 +425,7 @@ def test_models(
351425
352426 for model_name in model_names :
353427 merge_model_repeats_summary (bench_conf .name , model_name , run_dir )
428+ merge_tasks_summary (bench_conf .name , model_name , run_dir )
354429
355430 merge_benchmark_summary (bench_conf .name , run_dir , model_names )
356431
0 commit comments