Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 246 additions & 0 deletions examples/intel_hpu/bench_gsm8k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Metric evaluation for Fastdeploy + ERNIE-4.5-Turbo"""
# adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py
import argparse
import ast
import json
import re
import time
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import requests
from tqdm import tqdm

INVALID = -9999999


def call_generate(prompt, **kwargs):
"""
Generates response based on the input prompt.

Args:
prompt (str): The input prompt text.
**kwargs: Keyword arguments, including server IP address and port number.

Returns:
str: The response generated based on the prompt.

"""
url = f"http://{kwargs['ip']}:{kwargs['port']}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"messages": [
{
"role": "user",
"content": prompt,
}
],
"temperature": 0.6,
"max_tokens": 2047,
"top_p": 0.95,
"do_sample": True,
}

response = requests.post(url, headers=headers, data=json.dumps(data))
out = response.json()
return out["choices"][0]["message"]["content"]


def get_one_example(lines, i, include_answer):
"""
Retrieves a question-answer example from the given list of text lines.

Args:
lines (list of dict): A list of question-answer pairs.
i (int): The index of the question-answer pair to retrieve from lines.
include_answer (bool): Whether to include the answer in the returned string.

Returns:
str: A formatted question-answer string in the format "Question: <question>\nAnswer: <answer>".

"""
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret


def get_few_shot_examples(lines, k):
"""
Selects k examples from the given list of text lines and concatenates them into a single string.

Args:
lines (list): A list containing text lines.
k (int): The number of examples to select.

Returns:
str: A string composed of k examples, separated by two newline characters.
"""
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret


def get_answer_value(answer_str):
"""
Extracts numerical values from an answer string and returns them.

Args:
answer_str (str): The string containing the answer.

Returns:
The extracted numerical value; returns "INVALID" if extraction fails.
"""
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID


def read_jsonl(filename: str):
"""
Reads a JSONL file.

Args:
filename (str): Path to the JSONL file.

Yields:
dict: A dictionary object corresponding to each line in the JSONL file.
"""
with open(filename) as fin:
for line in fin:
if line.startswith("#"):
continue
yield json.loads(line)


def main(args):
"""
Process inputs and generate answers by calling the model in parallel using a thread pool.

Args:
args (argparse.Namespace):
- num_questions (int): Number of questions to process.
- num_shots (int): Number of few-shot learning examples.
- ip (str): IP address of the model service.
- port (int): Port number of the model service.
- parallel (int): Number of questions to process in parallel.
- result_file (str): File path to store the results.

Returns:
None

"""
# Read data
filename = "test.jsonl"

lines = list(read_jsonl(filename))

# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)

questions = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)

states = [None] * len(labels)

# Use thread pool
def get_one_answer(i):
answer = call_generate(
prompt=few_shot_examples + questions[i],
# stop=["Question", "Assistant:", "<|separator|>"],
ip=args.ip,
port=args.port,
)
states[i] = answer

tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)

latency = time.time() - tic
preds = []

with open(args.acc_log, "w") as fout:
for i in range(len(states)):
preds.append(get_answer_value(states[i]))
answer = get_answer_value(states[i])
fout.write("\n################################################################\n")
fout.write("-----------prompt--------------\n")
fout.write(f"{few_shot_examples + questions[i]}\n")
fout.write("-----------answer--------------\n")
fout.write(f"answer= {states[i]}\n")
fout.write("-----------accuracy--------------\n")
fout.write(f"Correct={answer==labels[i]}, pred={answer}, label={labels[i]} \n")

# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)

# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")

with open(args.result_file, "a") as fout:
value = {
"task": "gsm8k",
"backend": "paddlepaddle",
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default="127.0.0.1")
parser.add_argument("--port", type=str, default="8188")
parser.add_argument("--num-shots", type=int, default=10)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=1319)
parser.add_argument("--result-file", type=str, default="result.jsonl")
parser.add_argument("--parallel", type=int, default=1)
parser.add_argument("--acc-log", type=str, default="accuracy.log")
args = parser.parse_args()
main(args)
72 changes: 72 additions & 0 deletions examples/intel_hpu/benchmark_paddle_hpu_cli.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/bin/bash

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# set -x

model="ERNIE-4.5-21B-A3B-Paddle"
model_log_name="ERNIE-4.5-21B-A3B-Paddle"
model_yaml="yaml/eb45-21b-a3b-32k-bf16.yaml"
# model="ERNIE-4.5-300B-A47B-Paddle"
# model_log_name="ERNIE-4.5-300B-A47B-Paddle"
# model_yaml="yaml/eb45-300b-a47b-32k-bf16.yaml"

export SERVER_PORT=8188
export no_proxy=localhost,127.0.0.1,0.0.0.0,10.0.0.0/8,192.168.1.0/24

input_lengths=(1024 2048)
output_lengths=(1024)
batch_sizes=(1 2 4 8 16 32 64 128)

workspace=$(pwd)
cd $workspace
log_home=$workspace/benchmark_fastdeploy_logs/$(TZ='Asia/Shanghai' date '+WW%V')_$(TZ='Asia/Shanghai' date +%F-%H-%M-%S)_${model_log_name}_FixedLen

mkdir -p ${log_home}

for input_length in "${input_lengths[@]}"
do
for output_length in "${output_lengths[@]}"
do
for batch_size in "${batch_sizes[@]}"
do
> log/hpu_model_runner_profile.log
num_prompts=$(( batch_size * 3))
log_name_prefix="benchmarkdata_${model_log_name}_inputlength_${input_length}_outputlength_${output_length}_batchsize_${batch_size}_numprompts_${num_prompts}"
log_name=${log_name_prefix}_$(TZ='Asia/Shanghai' date +%F-%H-%M-%S)
echo "running benchmark with input length ${input_length}, output length ${output_length}, batch size ${batch_size}, log name ${log_name}"
cmd="python ../../benchmarks/benchmark_serving.py \
--backend openai-chat \
--model $model \
--endpoint /v1/chat/completions \
--host 0.0.0.0 \
--port ${SERVER_PORT} \
--dataset-name random \
--random-input-len ${input_length} \
--random-output-len ${output_length} \
--random-range-ratio 0 \
--hyperparameter-path ../../benchmarks/${model_yaml} \
--percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \
--metric-percentiles 80,95,99,99.9,99.95,99.99 \
--num-prompts ${num_prompts} \
--max-concurrency ${batch_size} \
--ignore-eos"
echo $cmd | tee -a ${log_home}/${log_name}.log
eval $cmd >> ${log_home}/${log_name}.log 2>&1

cp log/hpu_model_runner_profile.log ${log_home}/${log_name}_profile.log
done
done
done
64 changes: 64 additions & 0 deletions examples/intel_hpu/benchmark_paddle_hpu_cli_sharegpt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/bin/bash

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# set -x

model="ERNIE-4.5-21B-A3B-Paddle"
model_log_name="ERNIE-4.5-21B-A3B-Paddle"
model_yaml="yaml/eb45-21b-a3b-32k-bf16.yaml"
# model="ERNIE-4.5-300B-A47B-Paddle"
# model_log_name="ERNIE-4.5-300B-A47B-Paddle"
# model_yaml="yaml/eb45-300b-a47b-32k-bf16.yaml"
export SERVER_PORT=8188
export no_proxy=.intel.com,intel.com,localhost,127.0.0.1,0.0.0.0,10.0.0.0/8,192.168.1.0/24

CARD_NUM=$1

if [[ "$CARD_NUM" == "1" ]]; then
batch_size=128
else
batch_size=64
fi

num_prompts=2000

workspace=$(pwd)
cd $workspace
log_home=$workspace/benchmark_fastdeploy_logs/$(TZ='Asia/Shanghai' date '+WW%V')_$(TZ='Asia/Shanghai' date +%F-%H-%M-%S)_${model_log_name}

mkdir -p ${log_home}

log_name_prefix="benchmarkdata_${model_log_name}_sharegpt"
log_name=${log_name_prefix}_$(TZ='Asia/Shanghai' date +%F-%H-%M-%S)
echo "running benchmark with sharegpt log name ${log_name}"
cmd="python ../../benchmarks/benchmark_serving.py \
--backend openai-chat \
--model $model \
--endpoint /v1/chat/completions \
--host 0.0.0.0 \
--port ${SERVER_PORT} \
--dataset-name EBChat \
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json \
--hyperparameter-path ../../benchmarks/${model_yaml} \
--percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \
--metric-percentiles 80,95,99,99.9,99.95,99.99 \
--max-concurrency ${batch_size} \
--num-prompts ${num_prompts} \
--sharegpt-output-len 4096 \
--save-result "
echo $cmd | tee -a ${log_home}/${log_name}.log
eval $cmd >> ${log_home}/${log_name}.log 2>&1
cp log/hpu_model_runner_profile.log ${log_home}/${log_name}_profile.log
Loading
Loading