Skip to content

Commit de89ca0

Browse files
adapt for onnx format
1 parent 10a4e68 commit de89ca0

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

deep_keyphrase/copy_rnn/model.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ def __init__(self, vocab2id, embedding, hidden_size,
127127
hidden_size=hidden_size,
128128
num_layers=self.num_layers,
129129
bidirectional=bidirectional,
130-
batch_first=True
131-
)
130+
batch_first=True)
132131

133132
def forward(self, src_dict):
134133
"""
@@ -145,8 +144,7 @@ def forward(self, src_dict):
145144
total_length = src_embed.size(1)
146145
packed_src_embed = nn.utils.rnn.pack_padded_sequence(src_embed,
147146
src_lengths,
148-
batch_first=True,
149-
enforce_sorted=False)
147+
batch_first=True)
150148
state_size = [self.num_layers, batch_size, self.hidden_size]
151149
if self.bidirectional:
152150
state_size[0] *= 2
@@ -237,6 +235,15 @@ def forward_copyrnn(self,
237235
prev_output_tokens,
238236
prev_rnn_state,
239237
src_dict):
238+
"""
239+
240+
:param encoder_output_dict:
241+
:param prev_context_state:
242+
:param prev_output_tokens:
243+
:param prev_rnn_state:
244+
:param src_dict:
245+
:return:
246+
"""
240247
src_tokens = src_dict[TOKENS]
241248
src_tokens_with_oov = src_dict[TOKENS_OOV]
242249
batch_size = len(src_tokens)
@@ -258,7 +265,6 @@ def forward_copyrnn(self,
258265
if self.input_feeding:
259266
prev_context_state = prev_context_state.unsqueeze(1)
260267
decoder_input = torch.cat([src_embed, prev_context_state, copy_state], dim=2)
261-
# print(decoder_input.size())
262268
else:
263269
decoder_input = torch.cat([src_embed, copy_state], dim=2)
264270
decoder_input = F.dropout(decoder_input, p=self.dropout, training=self.training)
@@ -323,6 +329,14 @@ def get_attn_read_input(self, encoder_output, prev_context_state,
323329
return copy_state
324330

325331
def get_copy_score(self, encoder_out, src_tokens_with_oov, decoder_output, encoder_output_mask):
332+
"""
333+
334+
:param encoder_out:
335+
:param src_tokens_with_oov:
336+
:param decoder_output:
337+
:param encoder_output_mask:
338+
:return:
339+
"""
326340
# copy_score: B x L
327341
copy_score_in_seq = torch.bmm(torch.tanh(self.copy_proj(encoder_out)),
328342
decoder_output.permute(0, 2, 1)).squeeze(2)

0 commit comments

Comments
 (0)