|
| 1 | +import argparse |
| 2 | + |
| 3 | +from torch.utils.data import DataLoader |
| 4 | + |
| 5 | +from .model import BERT |
| 6 | +from .trainer import BERTTrainer |
| 7 | +from .dataset import BERTDataset, WordVocab |
| 8 | + |
| 9 | + |
| 10 | +def train(): |
| 11 | + parser = argparse.ArgumentParser() |
| 12 | + |
| 13 | + parser.add_argument("-d", "--train_dataset", required=True, type=str) |
| 14 | + parser.add_argument("-t", "--test_dataset", type=str, default=None) |
| 15 | + parser.add_argument("-v", "--vocab_path", required=True, type=str) |
| 16 | + parser.add_argument("-o", "--output_path", required=True, type=str) |
| 17 | + |
| 18 | + parser.add_argument("-hs", "--hidden", type=int, default=256) |
| 19 | + parser.add_argument("-l", "--layers", type=int, default=8) |
| 20 | + parser.add_argument("-a", "--attn_heads", type=int, default=8) |
| 21 | + parser.add_argument("-s", "--seq_len", type=int, default=20) |
| 22 | + |
| 23 | + parser.add_argument("-b", "--batch_size", type=int, default=64) |
| 24 | + parser.add_argument("-e", "--epochs", type=int, default=10) |
| 25 | + parser.add_argument("-w", "--num_workers", type=int, default=5) |
| 26 | + parser.add_argument("-c", "--with_cuda", type=bool, default=True) |
| 27 | + parser.add_argument("--log_freq", type=int, default=10) |
| 28 | + parser.add_argument("--corpus_lines", type=int, default=None) |
| 29 | + |
| 30 | + parser.add_argument("--lr", type=float, default=1e-3) |
| 31 | + parser.add_argument("--adam_weight_decay", type=float, default=0.01) |
| 32 | + parser.add_argument("--adam_beta1", type=float, default=0.9) |
| 33 | + parser.add_argument("--adam_beta2", type=float, default=0.999) |
| 34 | + |
| 35 | + args = parser.parse_args() |
| 36 | + |
| 37 | + print("Loading Vocab", args.vocab_path) |
| 38 | + vocab = WordVocab.load_vocab(args.vocab_path) |
| 39 | + print("Vocab Size: ", len(vocab)) |
| 40 | + |
| 41 | + print("Loading Train Dataset", args.train_dataset) |
| 42 | + train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, corpus_lines=args.corpus_lines) |
| 43 | + |
| 44 | + print("Loading Test Dataset", args.test_dataset) |
| 45 | + test_dataset = BERTDataset(args.test_dataset, vocab, |
| 46 | + seq_len=args.seq_len) if args.test_dataset is not None else None |
| 47 | + |
| 48 | + print("Creating Dataloader") |
| 49 | + train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) |
| 50 | + test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \ |
| 51 | + if test_dataset is not None else None |
| 52 | + |
| 53 | + print("Building BERT model") |
| 54 | + bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads) |
| 55 | + |
| 56 | + print("Creating BERT Trainer") |
| 57 | + trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader, |
| 58 | + lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, |
| 59 | + with_cuda=args.with_cuda, log_freq=args.log_freq) |
| 60 | + |
| 61 | + print("Training Start") |
| 62 | + for epoch in range(args.epochs): |
| 63 | + trainer.train(epoch) |
| 64 | + trainer.save(epoch, args.output_path) |
| 65 | + |
| 66 | + if test_data_loader is not None: |
| 67 | + trainer.test(epoch) |
0 commit comments