Skip to content

Commit ced6231

Browse files
authored
Add missing eval_mm_quality.sh (#3204)
Summary: This is accidently missed from #3133, adding it back Test Plan: tested with the same commands locally ``` sh eval.sh --eval_type mm_quality --model_ids google/gemma-3-12b-it --mm_tasks chartqa --model_type gemma3 --mm_eval_batch_size 32 sh eval.sh --eval_type mm_quality --model_ids google/gemma-3-12b-it --mm_tasks chartqa --model_type gemma3 --mm_eval_batch_size 32 --use_cache sh summarize_results.sh --model_ids google/gemma-3-12b-it ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 2d8a4c1 commit ced6231

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
#!/bin/bash
8+
set -e
9+
source eval_env_checks.sh
10+
check_lmms_eval
11+
12+
usage() {
13+
echo "Usage: $0 --model_id <model_id> --model_type <model_type> [--tasks <tasks> (comma-separated, e.g. mmlu,arc_challenge, default mmlu)] [--use_cache]"
14+
exit 1
15+
}
16+
17+
MODEL_ID_ARRAY=()
18+
MODEL_TYPE=""
19+
TASK_ARRAY=("chartqa") # default can be overwritten by user input
20+
BATCH_SIZE=1
21+
USE_CACHE=false # default: do not use cache
22+
# Parse arguments
23+
while [[ $# -gt 0 ]]; do
24+
case "$1" in
25+
--model_ids)
26+
shift
27+
while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do
28+
MODEL_ID_ARRAY+=("$1")
29+
shift
30+
done
31+
;;
32+
--model_type)
33+
MODEL_TYPE="$2"
34+
shift 2
35+
;;
36+
--batch_size)
37+
BATCH_SIZE="$2"
38+
shift 2
39+
;;
40+
--tasks)
41+
shift
42+
TASK_ARRAY=()
43+
while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do
44+
TASK_ARRAY+=("$1")
45+
shift
46+
done
47+
;;
48+
--use_cache)
49+
USE_CACHE=true
50+
shift
51+
;;
52+
*)
53+
echo "Unknown argument: $1"
54+
usage
55+
exit 1
56+
;;
57+
esac
58+
done
59+
if [[ ${#MODEL_ID_ARRAY[@]} -eq 0 ]]; then
60+
echo "Error: --model_ids is required"
61+
usage
62+
exit 1
63+
fi
64+
if [[ -z "$MODEL_TYPE" ]]; then
65+
echo "Error: --model_type is required"
66+
usage
67+
exit 1
68+
fi
69+
RESULTS_DIR="$(pwd)/mm_quality_eval_results"
70+
for MODEL_ID in "${MODEL_ID_ARRAY[@]}"; do
71+
# Replace all '/' with '_'
72+
SAFE_MODEL_ID="${MODEL_ID//\//_}"
73+
echo "======================== Eval Multi-modal Model Quality $MODLE_ID ======================"
74+
for TASK in "${TASK_ARRAY[@]}"; do
75+
OUTPUT_FILE="$(pwd)/${SAFE_MODEL_ID}_mm_quality_${TASK}.log"
76+
EVAL_CACHE_DB_PREFIX="/tmp/${SAFE_MODEL_ID}_mm_quality_${TASK}"
77+
mkdir -p "${EVAL_CACHE_DB_PREFIX}"
78+
echo "Running multi-modal model quality (accuracy) evaluation for model $MODEL_ID on task $TASK"
79+
80+
MAIN_PORT=12356
81+
LMMS_EVAL_CMD="accelerate launch \
82+
--main_process_port \"$MAIN_PORT\" \
83+
-m lmms_eval \
84+
--model \"$MODEL_TYPE\" \
85+
--model_args \"pretrained=$MODEL_ID\" \
86+
--tasks \"$TASK\" \
87+
--batch_size \"$BATCH_SIZE\" \
88+
--output_path \"$RESULTS_DIR\""
89+
90+
if $USE_CACHE; then
91+
LMMS_EVAL_CMD="$LMMS_EVAL_CMD --use_cache \"$EVAL_CACHE_DB_PREFIX\""
92+
fi
93+
94+
eval "$LMMS_EVAL_CMD" > "$OUTPUT_FILE" 2>&1
95+
echo "Quality eval output for task '$TASK' saved to $OUTPUT_FILE"
96+
done
97+
echo "======================== Eval Model Quality $MODEL_ID End =================="
98+
done

0 commit comments

Comments
 (0)