@@ -787,6 +787,24 @@ def _prune_heads(self, heads_to_prune):
787787 for layer , heads in heads_to_prune .items ():
788788 self .h [layer ].attn .prune_heads (heads )
789789
790+ @property
791+ def dtype (self ):
792+ """
793+ :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
794+ """
795+ try :
796+ return next (self .parameters ()).dtype
797+ except StopIteration :
798+ # For nn.DataParallel compatibility in PyTorch 1.5
799+
800+ def find_tensor_attributes (module : nn .Module ):
801+ tuples = [(k , v ) for k , v in module .__dict__ .items () if torch .is_tensor (v )]
802+ return tuples
803+
804+ gen = self ._named_members (get_members_fn = find_tensor_attributes )
805+ first_tuple = next (gen )
806+ return first_tuple [1 ].dtype
807+
790808 def forward (self , input_ids , state = None , attention_mask = None , token_type_ids = None , position_ids = None ,
791809 head_mask = None , output_attentions = True ):
792810 """
@@ -836,6 +854,7 @@ def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=Non
836854 # effectively the same as removing these entirely.
837855 # this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469
838856 # attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
857+ attention_mask = attention_mask .to (self .dtype )
839858 attention_mask = (1.0 - attention_mask ) * - 10000.0
840859 # attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0)
841860
0 commit comments