|
| 1 | +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Metric evaluation for Fastdeploy + ERNIE-4.5-Turbo""" |
| 16 | +# adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py |
| 17 | +import argparse |
| 18 | +import ast |
| 19 | +import json |
| 20 | +import re |
| 21 | +import time |
| 22 | +from concurrent.futures import ThreadPoolExecutor |
| 23 | + |
| 24 | +import numpy as np |
| 25 | +import requests |
| 26 | +from tqdm import tqdm |
| 27 | + |
| 28 | +INVALID = -9999999 |
| 29 | + |
| 30 | + |
| 31 | +def call_generate(prompt, **kwargs): |
| 32 | + """ |
| 33 | + Generates response based on the input prompt. |
| 34 | +
|
| 35 | + Args: |
| 36 | + prompt (str): The input prompt text. |
| 37 | + **kwargs: Keyword arguments, including server IP address and port number. |
| 38 | +
|
| 39 | + Returns: |
| 40 | + str: The response generated based on the prompt. |
| 41 | +
|
| 42 | + """ |
| 43 | + url = f"http://{kwargs['ip']}:{kwargs['port']}/v1/chat/completions" |
| 44 | + headers = {"Content-Type": "application/json"} |
| 45 | + data = { |
| 46 | + "messages": [ |
| 47 | + { |
| 48 | + "role": "user", |
| 49 | + "content": prompt, |
| 50 | + } |
| 51 | + ], |
| 52 | + "temperature": 0.6, |
| 53 | + "max_tokens": 2047, |
| 54 | + "top_p": 0.95, |
| 55 | + "do_sample": True, |
| 56 | + } |
| 57 | + |
| 58 | + response = requests.post(url, headers=headers, data=json.dumps(data)) |
| 59 | + out = response.json() |
| 60 | + return out["choices"][0]["message"]["content"] |
| 61 | + |
| 62 | + |
| 63 | +def get_one_example(lines, i, include_answer): |
| 64 | + """ |
| 65 | + Retrieves a question-answer example from the given list of text lines. |
| 66 | +
|
| 67 | + Args: |
| 68 | + lines (list of dict): A list of question-answer pairs. |
| 69 | + i (int): The index of the question-answer pair to retrieve from lines. |
| 70 | + include_answer (bool): Whether to include the answer in the returned string. |
| 71 | +
|
| 72 | + Returns: |
| 73 | + str: A formatted question-answer string in the format "Question: <question>\nAnswer: <answer>". |
| 74 | +
|
| 75 | + """ |
| 76 | + ret = "Question: " + lines[i]["question"] + "\nAnswer:" |
| 77 | + if include_answer: |
| 78 | + ret += " " + lines[i]["answer"] |
| 79 | + return ret |
| 80 | + |
| 81 | + |
| 82 | +def get_few_shot_examples(lines, k): |
| 83 | + """ |
| 84 | + Selects k examples from the given list of text lines and concatenates them into a single string. |
| 85 | +
|
| 86 | + Args: |
| 87 | + lines (list): A list containing text lines. |
| 88 | + k (int): The number of examples to select. |
| 89 | +
|
| 90 | + Returns: |
| 91 | + str: A string composed of k examples, separated by two newline characters. |
| 92 | + """ |
| 93 | + ret = "" |
| 94 | + for i in range(k): |
| 95 | + ret += get_one_example(lines, i, True) + "\n\n" |
| 96 | + return ret |
| 97 | + |
| 98 | + |
| 99 | +def get_answer_value(answer_str): |
| 100 | + """ |
| 101 | + Extracts numerical values from an answer string and returns them. |
| 102 | +
|
| 103 | + Args: |
| 104 | + answer_str (str): The string containing the answer. |
| 105 | +
|
| 106 | + Returns: |
| 107 | + The extracted numerical value; returns "INVALID" if extraction fails. |
| 108 | + """ |
| 109 | + answer_str = answer_str.replace(",", "") |
| 110 | + numbers = re.findall(r"\d+", answer_str) |
| 111 | + if len(numbers) < 1: |
| 112 | + return INVALID |
| 113 | + try: |
| 114 | + return ast.literal_eval(numbers[-1]) |
| 115 | + except SyntaxError: |
| 116 | + return INVALID |
| 117 | + |
| 118 | + |
| 119 | +def read_jsonl(filename: str): |
| 120 | + """ |
| 121 | + Reads a JSONL file. |
| 122 | +
|
| 123 | + Args: |
| 124 | + filename (str): Path to the JSONL file. |
| 125 | +
|
| 126 | + Yields: |
| 127 | + dict: A dictionary object corresponding to each line in the JSONL file. |
| 128 | + """ |
| 129 | + with open(filename) as fin: |
| 130 | + for line in fin: |
| 131 | + if line.startswith("#"): |
| 132 | + continue |
| 133 | + yield json.loads(line) |
| 134 | + |
| 135 | + |
| 136 | +def main(args): |
| 137 | + """ |
| 138 | + Process inputs and generate answers by calling the model in parallel using a thread pool. |
| 139 | +
|
| 140 | + Args: |
| 141 | + args (argparse.Namespace): |
| 142 | + - num_questions (int): Number of questions to process. |
| 143 | + - num_shots (int): Number of few-shot learning examples. |
| 144 | + - ip (str): IP address of the model service. |
| 145 | + - port (int): Port number of the model service. |
| 146 | + - parallel (int): Number of questions to process in parallel. |
| 147 | + - result_file (str): File path to store the results. |
| 148 | +
|
| 149 | + Returns: |
| 150 | + None |
| 151 | +
|
| 152 | + """ |
| 153 | + # Read data |
| 154 | + filename = "test.jsonl" |
| 155 | + |
| 156 | + lines = list(read_jsonl(filename)) |
| 157 | + |
| 158 | + # Construct prompts |
| 159 | + num_questions = args.num_questions |
| 160 | + num_shots = args.num_shots |
| 161 | + few_shot_examples = get_few_shot_examples(lines, num_shots) |
| 162 | + |
| 163 | + questions = [] |
| 164 | + labels = [] |
| 165 | + for i in range(len(lines[:num_questions])): |
| 166 | + questions.append(get_one_example(lines, i, False)) |
| 167 | + labels.append(get_answer_value(lines[i]["answer"])) |
| 168 | + assert all(l != INVALID for l in labels) |
| 169 | + |
| 170 | + states = [None] * len(labels) |
| 171 | + |
| 172 | + # Use thread pool |
| 173 | + def get_one_answer(i): |
| 174 | + answer = call_generate( |
| 175 | + prompt=few_shot_examples + questions[i], |
| 176 | + # stop=["Question", "Assistant:", "<|separator|>"], |
| 177 | + ip=args.ip, |
| 178 | + port=args.port, |
| 179 | + ) |
| 180 | + states[i] = answer |
| 181 | + |
| 182 | + tic = time.time() |
| 183 | + if args.parallel == 1: |
| 184 | + for i in tqdm(range(len(questions))): |
| 185 | + get_one_answer(i) |
| 186 | + else: |
| 187 | + with ThreadPoolExecutor(args.parallel) as executor: |
| 188 | + list( |
| 189 | + tqdm( |
| 190 | + executor.map(get_one_answer, list(range(len(questions)))), |
| 191 | + total=len(questions), |
| 192 | + ) |
| 193 | + ) |
| 194 | + |
| 195 | + latency = time.time() - tic |
| 196 | + preds = [] |
| 197 | + |
| 198 | + with open(args.acc_log, "w") as fout: |
| 199 | + for i in range(len(states)): |
| 200 | + preds.append(get_answer_value(states[i])) |
| 201 | + answer = get_answer_value(states[i]) |
| 202 | + fout.write("\n################################################################\n") |
| 203 | + fout.write("-----------prompt--------------\n") |
| 204 | + fout.write(f"{few_shot_examples + questions[i]}\n") |
| 205 | + fout.write("-----------answer--------------\n") |
| 206 | + fout.write(f"answer= {states[i]}\n") |
| 207 | + fout.write("-----------accuracy--------------\n") |
| 208 | + fout.write(f"Correct={answer==labels[i]}, pred={answer}, label={labels[i]} \n") |
| 209 | + |
| 210 | + # Compute accuracy |
| 211 | + acc = np.mean(np.array(preds) == np.array(labels)) |
| 212 | + invalid = np.mean(np.array(preds) == INVALID) |
| 213 | + |
| 214 | + # Print results |
| 215 | + print(f"Accuracy: {acc:.3f}") |
| 216 | + print(f"Invalid: {invalid:.3f}") |
| 217 | + print(f"Latency: {latency:.3f} s") |
| 218 | + |
| 219 | + with open(args.result_file, "a") as fout: |
| 220 | + value = { |
| 221 | + "task": "gsm8k", |
| 222 | + "backend": "paddlepaddle", |
| 223 | + "num_gpus": 1, |
| 224 | + "latency": round(latency, 3), |
| 225 | + "accuracy": round(acc, 3), |
| 226 | + "num_requests": args.num_questions, |
| 227 | + "other": { |
| 228 | + "num_questions": args.num_questions, |
| 229 | + "parallel": args.parallel, |
| 230 | + }, |
| 231 | + } |
| 232 | + fout.write(json.dumps(value) + "\n") |
| 233 | + |
| 234 | + |
| 235 | +if __name__ == "__main__": |
| 236 | + parser = argparse.ArgumentParser() |
| 237 | + parser.add_argument("--ip", type=str, default="127.0.0.1") |
| 238 | + parser.add_argument("--port", type=str, default="8188") |
| 239 | + parser.add_argument("--num-shots", type=int, default=10) |
| 240 | + parser.add_argument("--data-path", type=str, default="test.jsonl") |
| 241 | + parser.add_argument("--num-questions", type=int, default=1319) |
| 242 | + parser.add_argument("--result-file", type=str, default="result.jsonl") |
| 243 | + parser.add_argument("--parallel", type=int, default=1) |
| 244 | + parser.add_argument("--acc-log", type=str, default="accuracy.log") |
| 245 | + args = parser.parse_args() |
| 246 | + main(args) |
0 commit comments