Skip to content

Commit 24443e2

Browse files
remove redundant code, add tokenized flag
1 parent 34c633b commit 24443e2

File tree

1 file changed

+22
-50
lines changed

1 file changed

+22
-50
lines changed

deep_keyphrase/copy_rnn/predict.py

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
# -*- coding: UTF-8 -*-
22
import os
33
import torch
4-
import json
54
from pysenal import read_file, append_jsonlines
6-
from collections import namedtuple, OrderedDict
5+
from deep_keyphrase.base_predictor import BasePredictor
76
from deep_keyphrase.copy_rnn.model import CopyRNN
87
from deep_keyphrase.copy_rnn.beam_search import BeamSearch
9-
from deep_keyphrase.dataloader import KeyphraseDataLoader, RAW_BATCH, TOKENS
8+
from deep_keyphrase.dataloader import KeyphraseDataLoader, RAW_BATCH, TOKENS, INFERENCE_MODE, EVAL_MODE
109
from deep_keyphrase.utils.vocab_loader import load_vocab
1110
from deep_keyphrase.utils.constants import BOS_WORD
1211
from deep_keyphrase.utils.tokenizer import token_char_tokenize
1312

1413

15-
class CopyRnnPredictor(object):
14+
class CopyRnnPredictor(BasePredictor):
1615
def __init__(self, model_info, vocab_info, beam_size, max_target_len, max_src_length):
16+
super().__init__(model_info)
1717
if isinstance(vocab_info, str):
1818
self.vocab2id = load_vocab(vocab_info)
1919
elif isinstance(vocab_info, dict):
@@ -22,7 +22,7 @@ def __init__(self, model_info, vocab_info, beam_size, max_target_len, max_src_le
2222
raise ValueError('vocab info type error')
2323
self.id2vocab = dict(zip(self.vocab2id.values(), self.vocab2id.keys()))
2424
self.config = self.load_config(model_info)
25-
self.model = self.load_model(model_info, self.vocab2id)
25+
self.model = self.load_model(model_info, CopyRNN(self.config, self.vocab2id))
2626
self.model.eval()
2727
self.beam_size = beam_size
2828
self.max_target_len = max_target_len
@@ -34,78 +34,50 @@ def __init__(self, model_info, vocab_info, beam_size, max_target_len, max_src_le
3434
bos_idx=self.vocab2id[BOS_WORD],
3535
args=self.config)
3636

37-
def load_config(self, model_info):
38-
if 'config' not in model_info:
39-
if isinstance(model_info['model'], str):
40-
config_path = os.path.splitext(model_info['model'])[0] + '.json'
41-
else:
42-
raise ValueError('config path is not assigned')
43-
else:
44-
config_info = model_info['config']
45-
if isinstance(config_info, str):
46-
config_path = config_info
47-
else:
48-
return config_info
49-
# json to object
50-
config = json.loads(read_file(config_path),
51-
object_hook=lambda d: namedtuple('X', d.keys())(*d.values()))
52-
return config
53-
54-
def load_model(self, model_info, vocab2id):
55-
if isinstance(model_info['model'], torch.nn.Module):
56-
return model_info['model']
57-
58-
model_path = model_info['model']
59-
if not isinstance(model_path, str):
60-
raise TypeError('model path should be str')
61-
model = CopyRNN(self.config, vocab2id)
62-
if torch.cuda.is_available():
63-
checkpoint = torch.load(model_path)
64-
else:
65-
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
66-
state_dict = OrderedDict()
67-
# avoid error when load parallel trained model
68-
for k, v in checkpoint.items():
69-
if k.startswith('module.'):
70-
k = k[7:]
71-
state_dict[k] = v
72-
model.load_state_dict(state_dict)
73-
if torch.cuda.is_available():
74-
model = model.cuda()
75-
return model
76-
77-
def predict(self, text_list, batch_size=10, delimiter=None):
37+
def predict(self, text_list, batch_size=10, delimiter=None, tokenized=False):
7838
"""
7939
8040
:param text_list:
8141
:param batch_size:
42+
:param delimiter:
43+
:param tokenized:
8244
:return:
8345
"""
8446
self.model.eval()
8547
if len(text_list) < batch_size:
8648
batch_size = len(text_list)
87-
text_list = [{TOKENS: token_char_tokenize(i)} for i in text_list]
49+
50+
if tokenized:
51+
text_list = [{TOKENS: i} for i in text_list]
52+
else:
53+
text_list = [{TOKENS: token_char_tokenize(i)} for i in text_list]
54+
8855
loader = KeyphraseDataLoader(data_source=text_list,
8956
vocab2id=self.vocab2id,
9057
batch_size=batch_size,
9158
max_oov_count=self.config.max_oov_count,
9259
max_src_len=self.max_src_len,
9360
max_target_len=self.max_target_len,
94-
mode='valid')
61+
mode=INFERENCE_MODE)
9562
result = []
9663
for batch in loader:
9764
with torch.no_grad():
9865
result.extend(self.beam_searcher.beam_search(batch, delimiter=delimiter))
9966
return result
10067

101-
def eval_predict(self, src_filename, dest_filename, batch_size, model=None, remove_existed=False):
68+
def eval_predict(self, src_filename, dest_filename, batch_size,
69+
model=None, remove_existed=False,
70+
token_field='tokens', keyphrase_field='keyphrases'):
10271
loader = KeyphraseDataLoader(data_source=src_filename,
10372
vocab2id=self.vocab2id,
10473
batch_size=batch_size,
10574
max_oov_count=self.config.max_oov_count,
10675
max_src_len=self.max_src_len,
10776
max_target_len=self.max_target_len,
108-
mode='valid')
77+
mode=EVAL_MODE,
78+
pre_fetch=True,
79+
token_field=token_field,
80+
keyphrase_field=keyphrase_field)
10981

11082
if os.path.exists(dest_filename):
11183
print('destination filename {} existed'.format(dest_filename))

0 commit comments

Comments
 (0)