11# -*- coding: UTF-8 -*-
22import os
33import torch
4- import json
54from pysenal import read_file , append_jsonlines
6- from collections import namedtuple , OrderedDict
5+ from deep_keyphrase . base_predictor import BasePredictor
76from deep_keyphrase .copy_rnn .model import CopyRNN
87from deep_keyphrase .copy_rnn .beam_search import BeamSearch
9- from deep_keyphrase .dataloader import KeyphraseDataLoader , RAW_BATCH , TOKENS
8+ from deep_keyphrase .dataloader import KeyphraseDataLoader , RAW_BATCH , TOKENS , INFERENCE_MODE , EVAL_MODE
109from deep_keyphrase .utils .vocab_loader import load_vocab
1110from deep_keyphrase .utils .constants import BOS_WORD
1211from deep_keyphrase .utils .tokenizer import token_char_tokenize
1312
1413
15- class CopyRnnPredictor (object ):
14+ class CopyRnnPredictor (BasePredictor ):
1615 def __init__ (self , model_info , vocab_info , beam_size , max_target_len , max_src_length ):
16+ super ().__init__ (model_info )
1717 if isinstance (vocab_info , str ):
1818 self .vocab2id = load_vocab (vocab_info )
1919 elif isinstance (vocab_info , dict ):
@@ -22,7 +22,7 @@ def __init__(self, model_info, vocab_info, beam_size, max_target_len, max_src_le
2222 raise ValueError ('vocab info type error' )
2323 self .id2vocab = dict (zip (self .vocab2id .values (), self .vocab2id .keys ()))
2424 self .config = self .load_config (model_info )
25- self .model = self .load_model (model_info , self .vocab2id )
25+ self .model = self .load_model (model_info , CopyRNN ( self .config , self . vocab2id ) )
2626 self .model .eval ()
2727 self .beam_size = beam_size
2828 self .max_target_len = max_target_len
@@ -34,78 +34,50 @@ def __init__(self, model_info, vocab_info, beam_size, max_target_len, max_src_le
3434 bos_idx = self .vocab2id [BOS_WORD ],
3535 args = self .config )
3636
37- def load_config (self , model_info ):
38- if 'config' not in model_info :
39- if isinstance (model_info ['model' ], str ):
40- config_path = os .path .splitext (model_info ['model' ])[0 ] + '.json'
41- else :
42- raise ValueError ('config path is not assigned' )
43- else :
44- config_info = model_info ['config' ]
45- if isinstance (config_info , str ):
46- config_path = config_info
47- else :
48- return config_info
49- # json to object
50- config = json .loads (read_file (config_path ),
51- object_hook = lambda d : namedtuple ('X' , d .keys ())(* d .values ()))
52- return config
53-
54- def load_model (self , model_info , vocab2id ):
55- if isinstance (model_info ['model' ], torch .nn .Module ):
56- return model_info ['model' ]
57-
58- model_path = model_info ['model' ]
59- if not isinstance (model_path , str ):
60- raise TypeError ('model path should be str' )
61- model = CopyRNN (self .config , vocab2id )
62- if torch .cuda .is_available ():
63- checkpoint = torch .load (model_path )
64- else :
65- checkpoint = torch .load (model_path , map_location = torch .device ('cpu' ))
66- state_dict = OrderedDict ()
67- # avoid error when load parallel trained model
68- for k , v in checkpoint .items ():
69- if k .startswith ('module.' ):
70- k = k [7 :]
71- state_dict [k ] = v
72- model .load_state_dict (state_dict )
73- if torch .cuda .is_available ():
74- model = model .cuda ()
75- return model
76-
77- def predict (self , text_list , batch_size = 10 , delimiter = None ):
37+ def predict (self , text_list , batch_size = 10 , delimiter = None , tokenized = False ):
7838 """
7939
8040 :param text_list:
8141 :param batch_size:
42+ :param delimiter:
43+ :param tokenized:
8244 :return:
8345 """
8446 self .model .eval ()
8547 if len (text_list ) < batch_size :
8648 batch_size = len (text_list )
87- text_list = [{TOKENS : token_char_tokenize (i )} for i in text_list ]
49+
50+ if tokenized :
51+ text_list = [{TOKENS : i } for i in text_list ]
52+ else :
53+ text_list = [{TOKENS : token_char_tokenize (i )} for i in text_list ]
54+
8855 loader = KeyphraseDataLoader (data_source = text_list ,
8956 vocab2id = self .vocab2id ,
9057 batch_size = batch_size ,
9158 max_oov_count = self .config .max_oov_count ,
9259 max_src_len = self .max_src_len ,
9360 max_target_len = self .max_target_len ,
94- mode = 'valid' )
61+ mode = INFERENCE_MODE )
9562 result = []
9663 for batch in loader :
9764 with torch .no_grad ():
9865 result .extend (self .beam_searcher .beam_search (batch , delimiter = delimiter ))
9966 return result
10067
101- def eval_predict (self , src_filename , dest_filename , batch_size , model = None , remove_existed = False ):
68+ def eval_predict (self , src_filename , dest_filename , batch_size ,
69+ model = None , remove_existed = False ,
70+ token_field = 'tokens' , keyphrase_field = 'keyphrases' ):
10271 loader = KeyphraseDataLoader (data_source = src_filename ,
10372 vocab2id = self .vocab2id ,
10473 batch_size = batch_size ,
10574 max_oov_count = self .config .max_oov_count ,
10675 max_src_len = self .max_src_len ,
10776 max_target_len = self .max_target_len ,
108- mode = 'valid' )
77+ mode = EVAL_MODE ,
78+ pre_fetch = True ,
79+ token_field = token_field ,
80+ keyphrase_field = keyphrase_field )
10981
11082 if os .path .exists (dest_filename ):
11183 print ('destination filename {} existed' .format (dest_filename ))
0 commit comments