Skip to content

Commit da60a87

Browse files
committed
Change pre-generated randome sampled dataset
1 parent b5ec970 commit da60a87

File tree

8 files changed

+132
-202
lines changed

8 files changed

+132
-202
lines changed

bert_pytorch/__main__.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import argparse
2+
3+
from torch.utils.data import DataLoader
4+
5+
from .model import BERT
6+
from .trainer import BERTTrainer
7+
from .dataset import BERTDataset, WordVocab
8+
9+
10+
def train():
11+
parser = argparse.ArgumentParser()
12+
13+
parser.add_argument("-d", "--train_dataset", required=True, type=str)
14+
parser.add_argument("-t", "--test_dataset", type=str, default=None)
15+
parser.add_argument("-v", "--vocab_path", required=True, type=str)
16+
parser.add_argument("-o", "--output_path", required=True, type=str)
17+
18+
parser.add_argument("-hs", "--hidden", type=int, default=256)
19+
parser.add_argument("-l", "--layers", type=int, default=8)
20+
parser.add_argument("-a", "--attn_heads", type=int, default=8)
21+
parser.add_argument("-s", "--seq_len", type=int, default=20)
22+
23+
parser.add_argument("-b", "--batch_size", type=int, default=64)
24+
parser.add_argument("-e", "--epochs", type=int, default=10)
25+
parser.add_argument("-w", "--num_workers", type=int, default=5)
26+
parser.add_argument("-c", "--with_cuda", type=bool, default=True)
27+
parser.add_argument("--log_freq", type=int, default=10)
28+
parser.add_argument("--corpus_lines", type=int, default=None)
29+
30+
parser.add_argument("--lr", type=float, default=1e-3)
31+
parser.add_argument("--adam_weight_decay", type=float, default=0.01)
32+
parser.add_argument("--adam_beta1", type=float, default=0.9)
33+
parser.add_argument("--adam_beta2", type=float, default=0.999)
34+
35+
args = parser.parse_args()
36+
37+
print("Loading Vocab", args.vocab_path)
38+
vocab = WordVocab.load_vocab(args.vocab_path)
39+
print("Vocab Size: ", len(vocab))
40+
41+
print("Loading Train Dataset", args.train_dataset)
42+
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, corpus_lines=args.corpus_lines)
43+
44+
print("Loading Test Dataset", args.test_dataset)
45+
test_dataset = BERTDataset(args.test_dataset, vocab,
46+
seq_len=args.seq_len) if args.test_dataset is not None else None
47+
48+
print("Creating Dataloader")
49+
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
50+
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
51+
if test_dataset is not None else None
52+
53+
print("Building BERT model")
54+
bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
55+
56+
print("Creating BERT Trainer")
57+
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
58+
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
59+
with_cuda=args.with_cuda, log_freq=args.log_freq)
60+
61+
print("Training Start")
62+
for epoch in range(args.epochs):
63+
trainer.train(epoch)
64+
trainer.save(epoch, args.output_path)
65+
66+
if test_data_loader is not None:
67+
trainer.test(epoch)

bert_pytorch/build_dataset.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

bert_pytorch/build_vocab.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

bert_pytorch/dataset/creator.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

bert_pytorch/dataset/dataset.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,32 @@
11
from torch.utils.data import Dataset
22
import tqdm
33
import torch
4+
import random
45

56

67
class BERTDataset(Dataset):
78
def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None):
89
self.vocab = vocab
910
self.seq_len = seq_len
1011

11-
self.datas = []
1212
with open(corpus_path, "r", encoding=encoding) as f:
13-
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
14-
t1, t2, t1_l, t2_l, is_next = line[:-1].split("\t")
15-
t1, t2 = [[int(token) for token in t.split(",")] for t in [t1, t2]]
16-
t1_l, t2_l = [[int(token) for token in label.split(",")] for label in [t1_l, t2_l]]
17-
is_next = int(is_next)
18-
self.datas.append({"t1": t1, "t2": t2, "t1_label": t1_l, "t2_label": t2_l, "is_next": is_next})
13+
self.datas = [line[:-1].split("\t")
14+
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
1915

2016
def __len__(self):
2117
return len(self.datas)
2218

2319
def __getitem__(self, item):
20+
t1, (t2, is_next_label) = self.datas[item][0], self.random_sent(item)
21+
t1_random, t1_label = self.random_word(t1)
22+
t2_random, t2_label = self.random_word(t2)
23+
2424
# [CLS] tag = SOS tag, [SEP] tag = EOS tag
25-
t1 = [self.vocab.sos_index] + self.datas[item]["t1"] + [self.vocab.eos_index]
26-
t2 = self.datas[item]["t2"] + [self.vocab.eos_index]
25+
t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]
26+
t2 = t2_random + [self.vocab.eos_index]
2727

28-
t1_label = [0] + self.datas[item]["t1_label"] + [0]
29-
t2_label = self.datas[item]["t2_label"] + [0]
28+
t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]
29+
t2_label = t2_label + [self.vocab.pad_index]
3030

3131
segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
3232
bert_input = (t1 + t2)[:self.seq_len]
@@ -41,3 +41,37 @@ def __getitem__(self, item):
4141
"is_next": self.datas[item]["is_next"]}
4242

4343
return {key: torch.tensor(value) for key, value in output.items()}
44+
45+
def random_word(self, sentence):
46+
tokens = sentence.split()
47+
output_label = []
48+
49+
for i, token in enumerate(tokens):
50+
prob = random.random()
51+
if prob < 0.15:
52+
# 80% randomly change token to make token
53+
if prob < prob * 0.8:
54+
tokens[i] = self.vocab.mask_index
55+
56+
# 10% randomly change token to random token
57+
elif prob * 0.8 <= prob < prob * 0.9:
58+
tokens[i] = random.randrange(len(self.vocab))
59+
60+
# 10% randomly change token to current token
61+
elif prob >= prob * 0.9:
62+
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
63+
64+
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
65+
66+
else:
67+
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
68+
output_label.append(0)
69+
70+
return tokens, output_label
71+
72+
def random_sent(self, index):
73+
# output_text, label(isNotNext:0, isNext:1)
74+
if random.random() > 0.5:
75+
return self.datas[index][1], 1
76+
else:
77+
return self.datas[random.randrange(len(self.datas))][1], 0

bert_pytorch/dataset/vocab.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,21 @@ def from_seq(self, seq, join=False, with_pad=False):
165165
def load_vocab(vocab_path: str) -> 'WordVocab':
166166
with open(vocab_path, "rb") as f:
167167
return pickle.load(f)
168+
169+
170+
if __name__ == "__main__":
171+
import argparse
172+
173+
parser = argparse.ArgumentParser()
174+
parser.add_argument("-c", "--corpus_path", required=True, type=str)
175+
parser.add_argument("-o", "--output_path", required=True, type=str)
176+
parser.add_argument("-s", "--vocab_size", type=int, default=None)
177+
parser.add_argument("-e", "--encoding", type=str, default="utf-8")
178+
parser.add_argument("-m", "--min_freq", type=int, default=1)
179+
args = parser.parse_args()
180+
181+
with open(args.corpus_path, "r", encoding=args.encoding) as f:
182+
vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq)
183+
184+
print("VOCAB SIZE:", len(vocab))
185+
vocab.save_vocab(args.output_path)

bert_pytorch/train.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

setup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,8 @@ def run(self):
4444
],
4545
entry_points={
4646
'console_scripts': [
47-
'bert = bert_pytorch.train:train',
48-
'bert-dataset = bert_pytorch.build_dataset:build',
49-
'bert-vocab = bert_pytorch.build_vocab:build',
47+
'bert = bert_pytorch.__main__:train',
48+
'bert-vocab = bert_pytorch.dataset.vocab:build',
5049
]
5150
},
5251
cmdclass={

0 commit comments

Comments
 (0)