@@ -251,7 +251,7 @@ def __init__(self, config):
251251 def forward (self , inputs , seq_len ):
252252 sort_lens , sort_idx = torch .sort (seq_len , dim = 0 , descending = True )
253253 inputs = inputs [sort_idx ]
254- inputs = nn .utils .rnn .pack_padded_sequence (inputs , sort_lens , batch_first = self .batch_first )
254+ inputs = nn .utils .rnn .pack_padded_sequence (inputs , sort_lens . cpu () , batch_first = self .batch_first )
255255 output , hx = self .encoder (inputs , None ) # -> [N,L,C]
256256 output , _ = nn .utils .rnn .pad_packed_sequence (output , batch_first = self .batch_first )
257257 _ , unsort_idx = torch .sort (sort_idx , dim = 0 , descending = False )
@@ -316,7 +316,7 @@ def forward(self, inputs, seq_len):
316316 max_len = inputs .size (1 )
317317 sort_lens , sort_idx = torch .sort (seq_len , dim = 0 , descending = True )
318318 inputs = inputs [sort_idx ]
319- inputs = nn .utils .rnn .pack_padded_sequence (inputs , sort_lens , batch_first = True )
319+ inputs = nn .utils .rnn .pack_padded_sequence (inputs , sort_lens . cpu () , batch_first = True )
320320 output , _ = self ._lstm_forward (inputs , None )
321321 _ , unsort_idx = torch .sort (sort_idx , dim = 0 , descending = False )
322322 output = output [:, unsort_idx ]
0 commit comments