@@ -127,8 +127,7 @@ def __init__(self, vocab2id, embedding, hidden_size,
127127 hidden_size = hidden_size ,
128128 num_layers = self .num_layers ,
129129 bidirectional = bidirectional ,
130- batch_first = True
131- )
130+ batch_first = True )
132131
133132 def forward (self , src_dict ):
134133 """
@@ -145,8 +144,7 @@ def forward(self, src_dict):
145144 total_length = src_embed .size (1 )
146145 packed_src_embed = nn .utils .rnn .pack_padded_sequence (src_embed ,
147146 src_lengths ,
148- batch_first = True ,
149- enforce_sorted = False )
147+ batch_first = True )
150148 state_size = [self .num_layers , batch_size , self .hidden_size ]
151149 if self .bidirectional :
152150 state_size [0 ] *= 2
@@ -237,6 +235,15 @@ def forward_copyrnn(self,
237235 prev_output_tokens ,
238236 prev_rnn_state ,
239237 src_dict ):
238+ """
239+
240+ :param encoder_output_dict:
241+ :param prev_context_state:
242+ :param prev_output_tokens:
243+ :param prev_rnn_state:
244+ :param src_dict:
245+ :return:
246+ """
240247 src_tokens = src_dict [TOKENS ]
241248 src_tokens_with_oov = src_dict [TOKENS_OOV ]
242249 batch_size = len (src_tokens )
@@ -258,7 +265,6 @@ def forward_copyrnn(self,
258265 if self .input_feeding :
259266 prev_context_state = prev_context_state .unsqueeze (1 )
260267 decoder_input = torch .cat ([src_embed , prev_context_state , copy_state ], dim = 2 )
261- # print(decoder_input.size())
262268 else :
263269 decoder_input = torch .cat ([src_embed , copy_state ], dim = 2 )
264270 decoder_input = F .dropout (decoder_input , p = self .dropout , training = self .training )
@@ -323,6 +329,14 @@ def get_attn_read_input(self, encoder_output, prev_context_state,
323329 return copy_state
324330
325331 def get_copy_score (self , encoder_out , src_tokens_with_oov , decoder_output , encoder_output_mask ):
332+ """
333+
334+ :param encoder_out:
335+ :param src_tokens_with_oov:
336+ :param decoder_output:
337+ :param encoder_output_mask:
338+ :return:
339+ """
326340 # copy_score: B x L
327341 copy_score_in_seq = torch .bmm (torch .tanh (self .copy_proj (encoder_out )),
328342 decoder_output .permute (0 , 2 , 1 )).squeeze (2 )
0 commit comments