1010def train ():
1111 parser = argparse .ArgumentParser ()
1212
13- parser .add_argument ("-c" , "--train_dataset" , required = True , type = str )
14- parser .add_argument ("-t" , "--test_dataset" , type = str , default = None )
15- parser .add_argument ("-v" , "--vocab_path" , required = True , type = str )
16- parser .add_argument ("-o" , "--output_path" , required = True , type = str )
17-
18- parser .add_argument ("-hs" , "--hidden" , type = int , default = 256 )
19- parser .add_argument ("-l" , "--layers" , type = int , default = 8 )
20- parser .add_argument ("-a" , "--attn_heads" , type = int , default = 8 )
21- parser .add_argument ("-s" , "--seq_len" , type = int , default = 20 )
22-
23- parser .add_argument ("-b" , "--batch_size" , type = int , default = 64 )
24- parser .add_argument ("-e" , "--epochs" , type = int , default = 10 )
25- parser .add_argument ("-w" , "--num_workers" , type = int , default = 5 )
26- parser .add_argument ("--with_cuda" , type = bool , default = True )
27- parser .add_argument ("--log_freq" , type = int , default = 10 )
28- parser .add_argument ("--corpus_lines" , type = int , default = None )
29-
30- parser .add_argument ("--lr" , type = float , default = 1e-3 )
31- parser .add_argument ("--adam_weight_decay" , type = float , default = 0.01 )
32- parser .add_argument ("--adam_beta1" , type = float , default = 0.9 )
33- parser .add_argument ("--adam_beta2" , type = float , default = 0.999 )
13+ parser .add_argument ("-c" , "--train_dataset" , required = True , type = str , help = "train dataset for train bert" )
14+ parser .add_argument ("-t" , "--test_dataset" , type = str , default = None , help = "test set for evaluate train set" )
15+ parser .add_argument ("-v" , "--vocab_path" , required = True , type = str , help = "built vocab model path with bert-vocab" )
16+ parser .add_argument ("-o" , "--output_path" , required = True , type = str , help = "ex)output/bert.model" )
17+
18+ parser .add_argument ("-hs" , "--hidden" , type = int , default = 256 , help = "hidden size of transformer model" )
19+ parser .add_argument ("-l" , "--layers" , type = int , default = 8 , help = "number of layers" )
20+ parser .add_argument ("-a" , "--attn_heads" , type = int , default = 8 , help = "number of attention heads" )
21+ parser .add_argument ("-s" , "--seq_len" , type = int , default = 20 , help = "maximum sequence len" )
22+
23+ parser .add_argument ("-b" , "--batch_size" , type = int , default = 64 , help = "number of batch_size" )
24+ parser .add_argument ("-e" , "--epochs" , type = int , default = 10 , help = "number of epochs" )
25+ parser .add_argument ("-w" , "--num_workers" , type = int , default = 5 , help = "dataloader worker size" )
26+
27+ parser .add_argument ("--with_cuda" , type = bool , default = True , help = "training with CUDA: true, or false" )
28+ parser .add_argument ("--log_freq" , type = int , default = 10 , help = "printing loss every n iter: setting n" )
29+ parser .add_argument ("--corpus_lines" , type = int , default = None , help = "total number of lines in corpus" )
30+ parser .add_argument ("--cuda_devices" , type = int , nargs = '+' , default = None , help = "CUDA device ids" )
31+ parser .add_argument ("--on_memory" , type = bool , default = True , help = "Loading on memory: true or false" )
32+
33+ parser .add_argument ("--lr" , type = float , default = 1e-3 , help = "learning rate of adam" )
34+ parser .add_argument ("--adam_weight_decay" , type = float , default = 0.01 , help = "weight_decay of adam" )
35+ parser .add_argument ("--adam_beta1" , type = float , default = 0.9 , help = "adam first beta value" )
36+ parser .add_argument ("--adam_beta2" , type = float , default = 0.999 , help = "adam first beta value" )
3437
3538 args = parser .parse_args ()
3639
@@ -39,11 +42,12 @@ def train():
3942 print ("Vocab Size: " , len (vocab ))
4043
4144 print ("Loading Train Dataset" , args .train_dataset )
42- train_dataset = BERTDataset (args .train_dataset , vocab , seq_len = args .seq_len , corpus_lines = args .corpus_lines )
45+ train_dataset = BERTDataset (args .train_dataset , vocab , seq_len = args .seq_len ,
46+ corpus_lines = args .corpus_lines , on_memory = args .on_memory )
4347
4448 print ("Loading Test Dataset" , args .test_dataset )
45- test_dataset = BERTDataset (args .test_dataset , vocab ,
46- seq_len = args . seq_len ) if args .test_dataset is not None else None
49+ test_dataset = BERTDataset (args .test_dataset , vocab , seq_len = args . seq_len , on_memory = args . on_memory ) \
50+ if args .test_dataset is not None else None
4751
4852 print ("Creating Dataloader" )
4953 train_data_loader = DataLoader (train_dataset , batch_size = args .batch_size , num_workers = args .num_workers )
@@ -56,7 +60,7 @@ def train():
5660 print ("Creating BERT Trainer" )
5761 trainer = BERTTrainer (bert , len (vocab ), train_dataloader = train_data_loader , test_dataloader = test_data_loader ,
5862 lr = args .lr , betas = (args .adam_beta1 , args .adam_beta2 ), weight_decay = args .adam_weight_decay ,
59- with_cuda = args .with_cuda , log_freq = args .log_freq )
63+ with_cuda = args .with_cuda , cuda_devices = args . cuda_devices , log_freq = args .log_freq )
6064
6165 print ("Training Start" )
6266 for epoch in range (args .epochs ):
0 commit comments