Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions lmms_eval/tasks/spatialviz/_default_template_yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
dataset_path: PLM-Team/Spatial-Visualization-Benchmark

generation_kwargs:
max_new_tokens: 8192
temperature: 0.0
top_p: 1.0
num_beams: 1
do_sample: false

output_type: generate_until
doc_to_visual: !function utils.spatialviz_doc_to_visual
doc_to_text: !function utils.spatialviz_doc_to_text
doc_to_target: utils.spatialviz_doc_to_target
process_results: !function utils.spatialviz_process_results

metric_list:
- metric: spatialviz_score
aggregation: !function utils.spatialviz_aggregate_results
higher_is_better: true

dataset_kwargs:
token: True
cache_dir: SpatialViz
force_download: true

# metadata:
# strategy: CoT # ['Direct', 'CoT']
# use_lmms_judge: False
5 changes: 5 additions & 0 deletions lmms_eval/tasks/spatialviz/spatialviz.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Official SpatialViz paper use this configuration
dataset_name:
test_split: test
task: "spatialviz_full"
include: _default_template_yaml
145 changes: 145 additions & 0 deletions lmms_eval/tasks/spatialviz/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import os
import re
from collections import defaultdict
from pathlib import Path

import yaml
from huggingface_hub import snapshot_download
from PIL import Image

with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
for i, line in enumerate(raw_data):
# remove function definition since yaml load cannot handle it
if "!function" not in line:
safe_data.append(line)
config = yaml.safe_load("".join(safe_data))

cache_dir = snapshot_download(
repo_id=config["dataset_path"],
repo_type="dataset",
local_dir_use_symlinks=False,
)
Comment on lines +19 to +23
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask does this require the user to unzip the zip file by themselves? I saw that on the hub the image is in zip and video in dataset kwargs is set to false. I think the best practice for this kind of dataset is to push a convert dataset to the hub. Not sure if the cache dir under current settings works automatically. Thanks!



def spatialviz_doc_to_visual(doc):
visual = []

category = doc["Category"]
task = doc["Task"]
level = doc["Level"]
image_id = doc["Image_id"]
image_path = f"{cache_dir}/{category}/{task}/{level}/{image_id}.png"

if os.path.exists(image_path):
image_path = image_path
visual.append(Image.open(image_path).convert("RGB"))
else:
raise FileExistsError(f"video path:{image_path} does not exist.")
return visual


def spatialviz_doc_to_text(doc):
ops = ["A", "B", "C", "D"]
prompt = "You should first provide a reasoning process, then provide a single option(A, B, C or D) as the final answer. The reasoning process and the answer are enclosed within <think></think> and <answer></answer> tags, respectively, i.e., <think>reasoning process</think>, <answer>answer</answer>.\n"
question = doc["Question"]
choices = doc["Choices"]
choice_text = ""

for i, choice in enumerate(choices):
choice_text += ops[i] + ". " + choice + "\n"
text = prompt + "Question: " + question + "\n" + choice_text
return text


def spatialviz_process_results(doc, results):
key_name = "spatialviz_score"
grounded_output = doc["Answer"]
response = results[0]

think_pattern = r"<think>(.*?)</think>"
answer_pattern = r"<answer>(.*?)</answer>"

think_match = re.search(think_pattern, response, re.DOTALL)
answer_match = re.search(answer_pattern, response, re.DOTALL)
if think_match and answer_match:
final_answer = answer_match.group(1).strip()
pred_answer = final_answer.split(".")[0]
op = re.findall(r"[A-D]", pred_answer)

else:
print("No match for think/answer \n")
final_answer_patterns = ["<answer>", "Answer:", "Final answer", "final answer", "Final Answer", "the answer is", "The answer is", "correct answer", "Correct answer", "Correct Answer", "答案" "correct path"]
if len(response) == 1:
op = re.findall(r"[A-D]", response)
else:
for pattern in final_answer_patterns:
if pattern in response:
response = response.split(pattern)[-1].strip()
op = re.findall(r"[A-D]", response.split(".")[0])
break

op = list(set(op))

if len(op) == 1 and grounded_output == op[0].upper():
is_correct = True
else:
is_correct = False

query = spatialviz_doc_to_text(doc)
spatialviz_submission = {"id": doc["Image_id"], "query": query, "gt_content": grounded_output, "pred": response, "category": doc["Category"], "task": doc["Task"], "level": doc["Level"], "is_correct": is_correct}
return {key_name: spatialviz_submission}


def spatialviz_aggregate_results(results):
task_to_eval_samples = defaultdict(list)
category_to_eval_samples = defaultdict(list)
key_to_eval_samples = defaultdict(list)
total_samples = len(results)
total_correct = 0

for sample in results:
task = sample["task"]
category = sample["category"]
level = sample["level"]
key = f"{category}-{task}-{level}"
is_correct = sample["is_correct"]

if is_correct:
total_correct += 1
task_to_eval_samples[task].append(1)
category_to_eval_samples[category].append(1)
key_to_eval_samples[key].append(1)
else:
task_to_eval_samples[task].append(0)
category_to_eval_samples[category].append(0)
key_to_eval_samples[key].append(0)

accuracy = total_correct / total_samples if total_samples > 0 else 0
task_accuracies = {task: sum(scores) / len(scores) for task, scores in task_to_eval_samples.items()}
category_accuracies = {category: sum(scores) / len(scores) for category, scores in category_to_eval_samples.items()}
key_accuracies = {key: sum(scores) / len(scores) for key, scores in key_to_eval_samples.items()}
print(f"{'Total Samples':<20}: {total_samples}")
print(f"{'Total Correct':<20}: {total_correct}")
print(f"{'Overall Accuracy':<20}: {accuracy:.4f}")
print()

print(f"{'Per-Task Accuracy':<40}")
print("-" * 40)
for task, acc in task_accuracies.items():
print(f"{task:<20}: {acc:.4f}")
print()

print(f"{'Per-Category Accuracy':<40}")
print("-" * 40)
for category, acc in category_accuracies.items():
print(f"{category:<20}: {acc:.4f}")
print("=" * 40)

print(f"{'Per-Key Accuracy':<40}")
print("-" * 40)
for key, acc in key_accuracies.items():
print(f"{key:<20}: {acc:.4f}")
print()
return accuracy