diff --git a/lmms_eval/tasks/spatialviz/_default_template_yaml b/lmms_eval/tasks/spatialviz/_default_template_yaml new file mode 100644 index 000000000..d59d3e396 --- /dev/null +++ b/lmms_eval/tasks/spatialviz/_default_template_yaml @@ -0,0 +1,28 @@ +dataset_path: PLM-Team/Spatial-Visualization-Benchmark + +generation_kwargs: + max_new_tokens: 8192 + temperature: 0.0 + top_p: 1.0 + num_beams: 1 + do_sample: false + +output_type: generate_until +doc_to_visual: !function utils.spatialviz_doc_to_visual +doc_to_text: !function utils.spatialviz_doc_to_text +doc_to_target: utils.spatialviz_doc_to_target +process_results: !function utils.spatialviz_process_results + +metric_list: + - metric: spatialviz_score + aggregation: !function utils.spatialviz_aggregate_results + higher_is_better: true + +dataset_kwargs: + token: True + cache_dir: SpatialViz + force_download: true + +# metadata: +# strategy: CoT # ['Direct', 'CoT'] +# use_lmms_judge: False \ No newline at end of file diff --git a/lmms_eval/tasks/spatialviz/spatialviz.yaml b/lmms_eval/tasks/spatialviz/spatialviz.yaml new file mode 100644 index 000000000..4ad96ee07 --- /dev/null +++ b/lmms_eval/tasks/spatialviz/spatialviz.yaml @@ -0,0 +1,5 @@ +# Official SpatialViz paper use this configuration +dataset_name: +test_split: test +task: "spatialviz_full" +include: _default_template_yaml diff --git a/lmms_eval/tasks/spatialviz/utils.py b/lmms_eval/tasks/spatialviz/utils.py new file mode 100644 index 000000000..fa248f451 --- /dev/null +++ b/lmms_eval/tasks/spatialviz/utils.py @@ -0,0 +1,145 @@ +import os +import re +from collections import defaultdict +from pathlib import Path + +import yaml +from huggingface_hub import snapshot_download +from PIL import Image + +with open(Path(__file__).parent / "_default_template_yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) + config = yaml.safe_load("".join(safe_data)) + +cache_dir = snapshot_download( + repo_id=config["dataset_path"], + repo_type="dataset", + local_dir_use_symlinks=False, +) + + +def spatialviz_doc_to_visual(doc): + visual = [] + + category = doc["Category"] + task = doc["Task"] + level = doc["Level"] + image_id = doc["Image_id"] + image_path = f"{cache_dir}/{category}/{task}/{level}/{image_id}.png" + + if os.path.exists(image_path): + image_path = image_path + visual.append(Image.open(image_path).convert("RGB")) + else: + raise FileExistsError(f"video path:{image_path} does not exist.") + return visual + + +def spatialviz_doc_to_text(doc): + ops = ["A", "B", "C", "D"] + prompt = "You should first provide a reasoning process, then provide a single option(A, B, C or D) as the final answer. The reasoning process and the answer are enclosed within and tags, respectively, i.e., reasoning process, answer.\n" + question = doc["Question"] + choices = doc["Choices"] + choice_text = "" + + for i, choice in enumerate(choices): + choice_text += ops[i] + ". " + choice + "\n" + text = prompt + "Question: " + question + "\n" + choice_text + return text + + +def spatialviz_process_results(doc, results): + key_name = "spatialviz_score" + grounded_output = doc["Answer"] + response = results[0] + + think_pattern = r"(.*?)" + answer_pattern = r"(.*?)" + + think_match = re.search(think_pattern, response, re.DOTALL) + answer_match = re.search(answer_pattern, response, re.DOTALL) + if think_match and answer_match: + final_answer = answer_match.group(1).strip() + pred_answer = final_answer.split(".")[0] + op = re.findall(r"[A-D]", pred_answer) + + else: + print("No match for think/answer \n") + final_answer_patterns = ["", "Answer:", "Final answer", "final answer", "Final Answer", "the answer is", "The answer is", "correct answer", "Correct answer", "Correct Answer", "答案" "correct path"] + if len(response) == 1: + op = re.findall(r"[A-D]", response) + else: + for pattern in final_answer_patterns: + if pattern in response: + response = response.split(pattern)[-1].strip() + op = re.findall(r"[A-D]", response.split(".")[0]) + break + + op = list(set(op)) + + if len(op) == 1 and grounded_output == op[0].upper(): + is_correct = True + else: + is_correct = False + + query = spatialviz_doc_to_text(doc) + spatialviz_submission = {"id": doc["Image_id"], "query": query, "gt_content": grounded_output, "pred": response, "category": doc["Category"], "task": doc["Task"], "level": doc["Level"], "is_correct": is_correct} + return {key_name: spatialviz_submission} + + +def spatialviz_aggregate_results(results): + task_to_eval_samples = defaultdict(list) + category_to_eval_samples = defaultdict(list) + key_to_eval_samples = defaultdict(list) + total_samples = len(results) + total_correct = 0 + + for sample in results: + task = sample["task"] + category = sample["category"] + level = sample["level"] + key = f"{category}-{task}-{level}" + is_correct = sample["is_correct"] + + if is_correct: + total_correct += 1 + task_to_eval_samples[task].append(1) + category_to_eval_samples[category].append(1) + key_to_eval_samples[key].append(1) + else: + task_to_eval_samples[task].append(0) + category_to_eval_samples[category].append(0) + key_to_eval_samples[key].append(0) + + accuracy = total_correct / total_samples if total_samples > 0 else 0 + task_accuracies = {task: sum(scores) / len(scores) for task, scores in task_to_eval_samples.items()} + category_accuracies = {category: sum(scores) / len(scores) for category, scores in category_to_eval_samples.items()} + key_accuracies = {key: sum(scores) / len(scores) for key, scores in key_to_eval_samples.items()} + print(f"{'Total Samples':<20}: {total_samples}") + print(f"{'Total Correct':<20}: {total_correct}") + print(f"{'Overall Accuracy':<20}: {accuracy:.4f}") + print() + + print(f"{'Per-Task Accuracy':<40}") + print("-" * 40) + for task, acc in task_accuracies.items(): + print(f"{task:<20}: {acc:.4f}") + print() + + print(f"{'Per-Category Accuracy':<40}") + print("-" * 40) + for category, acc in category_accuracies.items(): + print(f"{category:<20}: {acc:.4f}") + print("=" * 40) + + print(f"{'Per-Key Accuracy':<40}") + print("-" * 40) + for key, acc in key_accuracies.items(): + print(f"{key:<20}: {acc:.4f}") + print() + return accuracy