Skip to content

Commit 057fa63

Browse files
committed
解决dytpe的问题
1 parent dc280fa commit 057fa63

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

fastNLP/modules/encoder/bert.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,24 @@ def __init__(self, config, *inputs, **kwargs):
464464
logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.')
465465
self.apply(self.init_bert_weights)
466466

467+
@property
468+
def dtype(self):
469+
"""
470+
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
471+
"""
472+
try:
473+
return next(self.parameters()).dtype
474+
except StopIteration:
475+
# For nn.DataParallel compatibility in PyTorch 1.5
476+
477+
def find_tensor_attributes(module: nn.Module):
478+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
479+
return tuples
480+
481+
gen = self._named_members(get_members_fn=find_tensor_attributes)
482+
first_tuple = next(gen)
483+
return first_tuple[1].dtype
484+
467485
def init_bert_weights(self, module):
468486
r""" Initialize the weights.
469487
"""
@@ -510,6 +528,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_al
510528
# effectively the same as removing these entirely.
511529
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469
512530
# extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
531+
extended_attention_mask = extended_attention_mask.to(self.dtype)
513532
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
514533

515534
embedding_output = self.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)

fastNLP/modules/encoder/gpt2.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,24 @@ def _prune_heads(self, heads_to_prune):
787787
for layer, heads in heads_to_prune.items():
788788
self.h[layer].attn.prune_heads(heads)
789789

790+
@property
791+
def dtype(self):
792+
"""
793+
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
794+
"""
795+
try:
796+
return next(self.parameters()).dtype
797+
except StopIteration:
798+
# For nn.DataParallel compatibility in PyTorch 1.5
799+
800+
def find_tensor_attributes(module: nn.Module):
801+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
802+
return tuples
803+
804+
gen = self._named_members(get_members_fn=find_tensor_attributes)
805+
first_tuple = next(gen)
806+
return first_tuple[1].dtype
807+
790808
def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None,
791809
head_mask=None, output_attentions=True):
792810
"""
@@ -836,6 +854,7 @@ def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=Non
836854
# effectively the same as removing these entirely.
837855
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469
838856
# attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
857+
attention_mask = attention_mask.to(self.dtype)
839858
attention_mask = (1.0 - attention_mask) * -10000.0
840859
# attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0)
841860

0 commit comments

Comments
 (0)