11# -*- coding: UTF-8 -*-
22import torch
33import torch .nn as nn
4+ import torch .nn .functional as F
45from torch .nn .modules .transformer import (TransformerEncoder , TransformerDecoder ,
56 TransformerEncoderLayer , TransformerDecoderLayer )
67from deep_keyphrase .dataloader import TOKENS , TOKENS_LENS , TOKENS_OOV , UNK_WORD , PAD_WORD , OOV_COUNT
78
89
9- def get_position_encoding (input_tensor , position , dim_size ):
10+ def get_position_encoding (input_tensor ):
11+ batch_size , position , dim_size = input_tensor .size ()
1012 assert dim_size % 2 == 0
11- batch_size = len (input_tensor )
1213 num_timescales = dim_size // 2
1314 time_scales = torch .arange (0 , position + 1 , dtype = torch .float ).unsqueeze (1 )
1415 dim_scales = torch .arange (0 , num_timescales , dtype = torch .float ).unsqueeze (0 )
@@ -68,6 +69,7 @@ def __init__(self, embedding, input_dim, head_size,
6869 feed_forward_dim , dropout , num_layers ):
6970 super ().__init__ ()
7071 self .embedding = embedding
72+ self .dropout = dropout
7173 layer = TransformerEncoderLayer (d_model = input_dim ,
7274 nhead = head_size ,
7375 dim_feedforward = feed_forward_dim ,
@@ -76,22 +78,26 @@ def __init__(self, embedding, input_dim, head_size,
7678
7779 def forward (self , src_dict ):
7880 batch_size , max_len = src_dict [TOKENS ].size ()
79- # mask_range = torch.arange(max_len).expand( batch_size, max_len )
80- mask_range = torch . zeros ( batch_size , max_len , dtype = torch . bool )
81+ mask_range = torch .arange (max_len ).unsqueeze ( 0 ). repeat ( batch_size , 1 )
82+
8183 if torch .cuda .is_available ():
8284 mask_range = mask_range .cuda ()
83- # mask = mask_range > src_dict[TOKENS_LENS].unsqueeze(1)
85+ mask = mask_range >= src_dict [TOKENS_LENS ]
86+ # mask = (mask_range > src_dict[TOKENS_LENS].unsqueeze(1)).expand(batch_size, max_len, max_len)
8487 src_embed = self .embedding (src_dict [TOKENS ]).transpose (1 , 0 )
85- # print(src_embed.size(), mask.size())
86- output = self .encoder (src_embed ).transpose (1 , 0 )
87- return output , mask_range
88+ pos_embed = get_position_encoding (src_embed )
89+ src_embed = src_embed + pos_embed
90+ src_embed = F .dropout (src_embed , p = self .dropout , training = self .training )
91+ output = self .encoder (src_embed , src_key_padding_mask = mask ).transpose (1 , 0 )
92+ return output , mask
8893
8994
9095class CopyTransformerDecoder (nn .Module ):
9196 def __init__ (self , embedding , input_dim , vocab2id , head_size , feed_forward_dim ,
9297 dropout , num_layers , target_max_len , max_oov_count ):
9398 super ().__init__ ()
9499 self .embedding = embedding
100+ self .dropout = dropout
95101 self .vocab_size = embedding .num_embeddings
96102 self .vocab2id = vocab2id
97103 layer = TransformerDecoderLayer (d_model = input_dim ,
@@ -124,18 +130,19 @@ def forward(self, prev_output_tokens, prev_decoder_state, position,
124130 # map copied oov tokens to OOV idx to avoid embedding lookup error
125131 prev_output_tokens [prev_output_tokens >= self .vocab_size ] = self .vocab2id [UNK_WORD ]
126132 token_embed = self .embedding (prev_output_tokens )
127- pos_embed = get_position_encoding (token_embed , position , self .input_dim )
133+
134+ pos_embed = get_position_encoding (token_embed )
128135 # B x seq_len x H
129136 src_embed = token_embed + pos_embed
130- # print(token_embed.size(),pos_embed.size())
131- # print(src_embed.size(),copy_state.size())
132137 decoder_input = self .embed_proj (torch .cat ([src_embed , copy_state ], dim = 2 )).transpose (1 , 0 )
133- # print (decoder_input.size() )
138+ decoder_input = F . dropout (decoder_input , p = self . dropout , training = self . training )
134139 decoder_input_mask = torch .triu (torch .ones (self .input_dim , self .input_dim ), 1 )
135140 # B x seq_len x H
136- decoder_output = self .decoder (tgt = decoder_input , memory = encoder_output .transpose (1 , 0 ), )
141+ decoder_output = self .decoder (tgt = decoder_input ,
142+ memory = encoder_output .transpose (1 , 0 ),
143+ memory_key_padding_mask = decoder_input_mask )
137144 decoder_output = decoder_output .transpose (1 , 0 )
138- # tgt_mask=decoder_input_mask, memory_mask=encoder_mask)
145+
139146 # B x 1 x H
140147 decoder_output = decoder_output [:, - 1 :, :]
141148 generation_logits = self .generate_proj (decoder_output ).squeeze (1 )
0 commit comments