@@ -164,6 +164,36 @@ def __init__(self,
164164 attention_scores = all_attention_scores )
165165 super ().__init__ (
166166 inputs = self .inputs , outputs = outputs , ** kwargs )
167+ self ._config = dict (
168+ name = self .name ,
169+ word_vocab_size = word_vocab_size ,
170+ word_embed_size = word_embed_size ,
171+ type_vocab_size = type_vocab_size ,
172+ max_sequence_length = max_sequence_length ,
173+ num_blocks = num_blocks ,
174+ hidden_size = hidden_size ,
175+ num_attention_heads = num_attention_heads ,
176+ intermediate_size = intermediate_size ,
177+ intermediate_act_fn = intermediate_act_fn ,
178+ hidden_dropout_prob = hidden_dropout_prob ,
179+ attention_probs_dropout_prob = attention_probs_dropout_prob ,
180+ intra_bottleneck_size = intra_bottleneck_size ,
181+ initializer_range = initializer_range ,
182+ use_bottleneck_attention = use_bottleneck_attention ,
183+ key_query_shared_bottleneck = key_query_shared_bottleneck ,
184+ num_feedforward_networks = num_feedforward_networks ,
185+ normalization_type = normalization_type ,
186+ classifier_activation = classifier_activation ,
187+ input_mask_dtype = input_mask_dtype ,
188+ ** kwargs ,
189+ )
190+
191+ def get_config (self ):
192+ return dict (self ._config )
193+
194+ @classmethod
195+ def from_config (cls , config ):
196+ return cls (** config )
167197
168198 def get_embedding_table (self ):
169199 return self .embedding_layer .word_embedding .embeddings
0 commit comments