Skip to content

Commit d6d5fda

Browse files
authored
【开源实习】Bigbird pegasus模型微调 (#1994)
1 parent 825d93c commit d6d5fda

File tree

3 files changed

+400
-25
lines changed

3 files changed

+400
-25
lines changed
Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,60 @@
1-
# bigbird_pegasus模型微调对比
2-
## train loss
3-
4-
对比微调训练的loss变化
5-
6-
| epoch | mindnlp+mindspore | transformer+torch(4060) |transformer+torch(4060,another time) |
7-
| ----- | ----------------- | ------------------------- |------------------------- |
8-
| 1 | 2.0958 | 8.7301 |5.4650 |
9-
| 2 | 1.969 | 8.1557 |4.6890 |
10-
| 3 | 1.8755 | 7.7516 |4.2572 |
11-
| 4 | 1.8264 | 7.5017 |4.0263 |
12-
| 5 | 1.7349 | 7.2614 |3.9444 |
13-
| 6 | 1.678 | 7.0559 |3.8428 |
14-
| 7 | 1.6937 | 6.8405 |3.7187 |
15-
| 8 | 1.654 | 6.7297 |3.7192 |
16-
| 9 | 1.6365 | 6.7136 |3.5434 |
17-
| 10 | 1.7003 | 6.6279 |3.5881 |
18-
19-
## eval loss
20-
21-
对比评估得分
22-
23-
| epoch | mindnlp+mindspore | transformer+torch(4060) | transformer+torch(4060) |
24-
| ----- | ------------------ | ------------------------- |------------------------- |
25-
| 1 | 2.1257965564727783 | 6.3235931396484375 |4.264792442321777 |
1+
# bigbird_pegasus微调
2+
实现了bigbird_pegasus模型在google/Synthetic-Persona-Chat数据集上的微调实验。
3+
任务链接在https://gitee.com/mindspore/community/issues/IAUPBF
4+
transformers+pytorch+3090的benchmark是自己编写的,仓库位于https://github.com/outbreak-sen/bigbird_pegasus_finetune
5+
更改代码位于llm/finetune/bigbird_prgasus,只包含mindnlp+mindspore的
6+
实验结果如下
7+
## Loss Values 表格
8+
9+
| 序号 | MindNLP | PyTorch |
10+
|------|-----------|---------|
11+
| 1 | 0.1826 | 7.6556 |
12+
| 2 | 0.1614 | 0.5960 |
13+
| 3 | 0.1435 | 0.4145 |
14+
| 4 | 0.1398 | 0.3022 |
15+
| 5 | 0.1344 | 0.2555 |
16+
| 6 | 0.1263 | 0.2357 |
17+
| 7 | 0.1200 | 0.2247 |
18+
| 8 | 0.1147 | 0.2166 |
19+
| 9 | 0.1105 | 0.2107 |
20+
| 10 | 0.1082 | 0.2075 |
21+
22+
## Eval Loss Values 表格
23+
24+
| 序号 | MindNLP | PyTorch |
25+
|------|-----------|---------|
26+
| 1 | 0.2397 | 0.8738 |
27+
| 2 | 0.2451 | 0.4804 |
28+
| 3 | 0.2530 | 0.3490 |
29+
| 4 | 0.2548 | 0.2861 |
30+
| 5 | 0.2595 | 0.2669 |
31+
| 6 | 0.2663 | 0.2612 |
32+
| 7 | 0.2690 | 0.2545 |
33+
| 8 | 0.2755 | 0.2526 |
34+
| 9 | 0.2791 | 0.2519 |
35+
| 10 | 0.2831 | 0.2510 |
36+
| 11 | 0.2831 | 0.2510 |
37+
38+
## 对话测试
39+
40+
问题来自评估数据集的第一个问题
41+
42+
* 问题输入:
43+
44+
Nice to meet you too. What are you interested in?
45+
46+
* mindnlp未微调前的回答:
47+
48+
we present a new method for the detection of rare events, based on the use of time - frequency combs.<n> we show how this technique can be used to detect rare events in a broad range of time - frequency domains.<n> we also show how this technique can be used to study the evolution of the spectrum of rare events. <n> rare events ; amplitude ; phase ; amplitude ; frequency ; time - frequency combs + _ pacs : _<n> 11.30.er, 12.20.fv, 12.20.ds, 12.60.jv, 12.60.jv @xmath0 department of physics and astronomy, iowa state university, ames, ia 50011 + @xmath1 department of physics and astronomy, university of iowa, ames, ia 50011 + @xmath2 department of physics and astronomy, university of iowa, ames, ia 50011 + _ key words : _ rare events ; amplitude ; phase ; frequency ; spectrum ; time - frequency combs + _ pacs : _<n> 11.30.er
49+
50+
* mindnlp微调后的回答:
51+
52+
I'm interested in a lot of things, but I'm especially interested in history and science.
53+
54+
* torch微调前的回答:
55+
56+
we present a new method for the detection of rare events , based on the use of time - frequency combs .<n> we show how this technique can be used to detect rare events in a broad range of time - frequency domains .<n> we also show how this technique can be used to study the evolution of the spectrum of rare events . <n> rare events ; amplitude ; phase ; amplitude ; frequency ; time - frequency combs + _ pacs : _<n> 11.30.er , 12.20.fv , 12.20.ds , 12.60.jv , 12.60.jv @xmath0 department of physics and astronomy , iowa state university , ames , ia 50011 + @xmath1 department of physics and astronomy , university of iowa , ames , ia 50011 + @xmath2 department of physics and astronomy , university of iowa , ames , ia 50011 + _ key words : _ rare events ; amplitude ; phase ; frequency ; spectrum ; time - frequency combs + _ pacs : _<n> 11.30.er
57+
58+
* torch微调后的回答:
59+
60+
how do you like to do for fun?
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
from mindnlp.transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
2+
from mindnlp.engine import Trainer, TrainingArguments
3+
from datasets import load_dataset, load_from_disk
4+
import mindspore as ms
5+
import os
6+
7+
# 设置运行模式和设备
8+
ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend")
9+
# 设置 HF_ENDPOINT 环境变量
10+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
11+
# 加载模型和分词器
12+
print("加载模型和分词器")
13+
model_name = "google/bigbird-pegasus-large-arxiv"
14+
tokenizer = AutoTokenizer.from_pretrained(model_name)
15+
model = BigBirdPegasusForConditionalGeneration.from_pretrained(model_name)
16+
print("模型和分词器加载完成")
17+
input = "Nice to meet you too. What are you interested in?"
18+
print("input question:", input)
19+
input_tokens = tokenizer([input], return_tensors="ms")
20+
output_tokens = model.generate(**input_tokens)
21+
print("output answer:", tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0])
22+
23+
print("加载数据集")
24+
# 定义数据集保存路径
25+
dataset_path = "./Persona_valid_preprocessed"
26+
# 检查是否存在处理好的数据集
27+
if os.path.exists(dataset_path):
28+
# 加载预处理后的数据集
29+
dataset_train = load_from_disk("./Persona_train_preprocessed")
30+
dataset_valid = load_from_disk("./Persona_valid_preprocessed")
31+
else:
32+
dataset = load_dataset("google/Synthetic-Persona-Chat")
33+
print("dataset finished")
34+
print("dataset:", dataset)
35+
print("dataset['train'][0]:", dataset["train"][0])
36+
dataset_train = dataset["train"]
37+
dataset_valid = dataset["validation"]
38+
print("dataset_train:", dataset_train)
39+
print("dataset_train['Best Generated Conversation'][0]:\n",
40+
dataset_train["Best Generated Conversation"][0])
41+
print("dataset_train['user 1 personas'][0]:",
42+
dataset_train["user 1 personas"][0])
43+
print("dataset_train['user 2 personas'][0]:",
44+
dataset_train["user 2 personas"][0])
45+
print("dataset_train.column_names:",
46+
dataset_train.column_names)
47+
# 数据预处理:将对话格式化为上下文-回复对
48+
def format_dialogue(examples):
49+
inputs, targets = [], []
50+
for conversation in examples["Best Generated Conversation"]:
51+
# 将对话按行拆分
52+
lines = conversation.split("\n")
53+
# 将对话拆分为上下文和回复
54+
# print("lines_range:", len(lines) - 1)
55+
for i in range(len(lines) - 1):
56+
context = "\n".join(lines[:i+1]) # 上下文是当前行及之前的所有行
57+
reply = lines[i+1] # 下一行是回复
58+
context = context.replace("User 1: ", "")
59+
context = context.replace("User 2: ", "")
60+
if context.strip() and reply.strip(): # 确保上下文和回复不为空
61+
inputs.append(context.strip())
62+
targets.append(reply.strip())
63+
# print(f"Best Generated Conversation: {len(examples['Best Generated Conversation'])}")
64+
# print(f"user 1 personas: {len(examples['user 1 personas'])}")
65+
# print(f"inputs length: {len(inputs)}, targets length: {len(targets)}")
66+
return {"input": inputs, "target": targets}
67+
68+
# 应用预处理函数
69+
dataset_train = dataset_train.map(format_dialogue, batched=True
70+
, remove_columns=["user 1 personas"
71+
, "user 2 personas"
72+
, "Best Generated Conversation"])
73+
dataset_valid = dataset_valid.map(format_dialogue, batched=True
74+
, remove_columns=["user 1 personas"
75+
, "user 2 personas"
76+
, "Best Generated Conversation"])
77+
# 保存预处理后的数据集
78+
dataset_train.save_to_disk("./Persona_train_preprocessed")
79+
dataset_valid.save_to_disk("./Persona_valid_preprocessed")
80+
print("tokenizer数据集")
81+
# 定义数据集保存路径
82+
dataset_path = "./PersonaTokenized_train_preprocessed"
83+
# 检查是否存在处理好的数据集
84+
if os.path.exists(dataset_path):
85+
# 加载预处理后的数据集
86+
dataset_train_tokenized = load_from_disk("./PersonaTokenized_train_preprocessed")
87+
dataset_valid_tokenized= load_from_disk("./PersonaTokenized_valid_preprocessed")
88+
else:
89+
# 分词处理
90+
def tokenize_function(examples):
91+
model_inputs = tokenizer(
92+
examples["input"],
93+
max_length=128,
94+
truncation=True,
95+
padding="max_length",
96+
)
97+
with tokenizer.as_target_tokenizer():
98+
labels = tokenizer(
99+
examples["target"],
100+
max_length=128,
101+
truncation=True,
102+
padding="max_length",
103+
)
104+
model_inputs["labels"] = labels["input_ids"]#获得"labels" "input_ids" "attention_mask"
105+
return model_inputs
106+
107+
dataset_train_tokenized = dataset_train.map(tokenize_function)
108+
dataset_valid_tokenized = dataset_valid.map(tokenize_function)
109+
dataset_train_tokenized = dataset_train_tokenized.filter(lambda example: len(example["input_ids"]) > 0 and len(example["labels"]) > 0)
110+
dataset_valid_tokenized = dataset_valid_tokenized.filter(lambda example: len(example["input_ids"]) > 0 and len(example["labels"]) > 0)
111+
# 保存预处理后的数据集
112+
dataset_train_tokenized.save_to_disk("./PersonaTokenized_train_preprocessed")
113+
dataset_valid_tokenized.save_to_disk("./PersonaTokenized_valid_preprocessed")
114+
# 计算百分之一的数据量
115+
train_size = len(dataset_train_tokenized)
116+
valid_size = len(dataset_valid_tokenized)
117+
train_subset_size = train_size // 100
118+
valid_subset_size = valid_size // 100
119+
# 使用 select 方法选择前百分之一的数据
120+
dataset_train_tokenized = dataset_train_tokenized.select(range(train_subset_size))
121+
dataset_valid_tokenized = dataset_valid_tokenized.select(range(valid_subset_size))
122+
print("dataset_train_tokenized:",dataset_train_tokenized)
123+
print("dataset_valid_tokenized:",dataset_valid_tokenized)
124+
125+
import numpy as np
126+
def data_generator(dataset):
127+
for item in dataset:
128+
yield (
129+
np.array(item["input_ids"], dtype=np.int32), # input_ids
130+
np.array(item["attention_mask"], dtype=np.int32), # attention_mask
131+
np.array(item["labels"], dtype=np.int32) # label
132+
)
133+
import mindspore.dataset as ds
134+
# 将训练集和验证集转换为 MindSpore 数据集,注意forward函数中label要改成labels
135+
def create_mindspore_dataset(dataset, shuffle=True):
136+
return ds.GeneratorDataset(
137+
source=lambda: data_generator(dataset), # 使用 lambda 包装生成器
138+
column_names=["input_ids", "attention_mask", "labels"],
139+
shuffle=shuffle
140+
)
141+
dataset_train_tokenized = create_mindspore_dataset(dataset_train_tokenized, shuffle=True)
142+
dataset_valid_tokenized = create_mindspore_dataset(dataset_valid_tokenized, shuffle=False)
143+
144+
TOKENS = 20
145+
EPOCHS = 10
146+
BATCH_SIZE = 4
147+
148+
training_args = TrainingArguments(
149+
output_dir='./MindNLP_BigBirdPegasus_persona_finetuned',
150+
overwrite_output_dir=True,
151+
num_train_epochs=EPOCHS,
152+
per_device_train_batch_size=BATCH_SIZE,
153+
per_device_eval_batch_size=BATCH_SIZE,
154+
155+
save_steps=500, # Save checkpoint every 500 steps
156+
save_total_limit=2, # Keep only the last 2 checkpoints
157+
logging_dir="./MindNLP_logs", # Directory for logs
158+
logging_steps=100, # Log every 100 steps
159+
logging_strategy="epoch",
160+
evaluation_strategy="epoch",
161+
eval_steps=500, # Evaluation frequency
162+
warmup_steps=100,
163+
learning_rate=5e-5,
164+
weight_decay=0.01, # Weight decay
165+
)
166+
167+
trainer = Trainer(
168+
model=model,
169+
args=training_args,
170+
train_dataset=dataset_train_tokenized,
171+
eval_dataset=dataset_valid_tokenized
172+
)
173+
print("开始训练")
174+
# 开始训练
175+
trainer.train()
176+
eval_results = trainer.evaluate()
177+
print(f"Evaluation results: {eval_results}")
178+
179+
model.save_pretrained("./MindNLP_BigBirdPegasus_persona_finetuned")
180+
tokenizer.save_pretrained("./MindNLP_BigBirdPegasus_persona_finetuned")
181+
fine_tuned_model = BigBirdPegasusForConditionalGeneration.from_pretrained("./MindNLP_BigBirdPegasus_persona_finetuned")
182+
fine_tuned_tokenizer = AutoTokenizer.from_pretrained("./MindNLP_BigBirdPegasus_persona_finetuned")
183+
# 再次测试对话
184+
print("再次测试对话")
185+
input = "Nice to meet you too. What are you interested in?"
186+
print("input question:", input)
187+
input_tokens = fine_tuned_tokenizer([input], return_tensors="ms")
188+
output_tokens = fine_tuned_model.generate(**input_tokens)
189+
print("output answer:", fine_tuned_tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0])

0 commit comments

Comments
 (0)