@@ -108,22 +108,25 @@ def call(self, inputs):
108108class Encoder (layers .Layer ):
109109 """Transformer Encoder."""
110110
111- def __init__ (self ,
112- num_layers ,
113- mlp_dim ,
114- num_heads ,
115- dropout_rate = 0.1 ,
116- attention_dropout_rate = 0.1 ,
117- kernel_regularizer = None ,
118- inputs_positions = None ,
119- init_stochastic_depth_rate = 0.0 ,
120- kernel_initializer = 'glorot_uniform' ,
121- add_pos_embed = True ,
122- pos_embed_origin_shape = None ,
123- pos_embed_target_shape = None ,
124- layer_scale_init_value = 0.0 ,
125- transformer_partition_dims = None ,
126- ** kwargs ):
111+ def __init__ (
112+ self ,
113+ num_layers ,
114+ mlp_dim ,
115+ num_heads ,
116+ dropout_rate = 0.1 ,
117+ attention_dropout_rate = 0.1 ,
118+ kernel_regularizer = None ,
119+ inputs_positions = None ,
120+ init_stochastic_depth_rate = 0.0 ,
121+ kernel_initializer = 'glorot_uniform' ,
122+ add_pos_embed = True ,
123+ pos_embed_origin_shape = None ,
124+ pos_embed_target_shape = None ,
125+ layer_scale_init_value = 0.0 ,
126+ transformer_partition_dims = None ,
127+ output_attention_scores = False ,
128+ ** kwargs ,
129+ ):
127130 super ().__init__ (** kwargs )
128131 self ._num_layers = num_layers
129132 self ._mlp_dim = mlp_dim
@@ -139,6 +142,7 @@ def __init__(self,
139142 self ._pos_embed_target_shape = pos_embed_target_shape
140143 self ._layer_scale_init_value = layer_scale_init_value
141144 self ._transformer_partition_dims = transformer_partition_dims
145+ self ._output_attention_scores = output_attention_scores
142146
143147 def build (self , input_shape ):
144148 if self ._add_pos_embed :
@@ -163,10 +167,13 @@ def build(self, input_shape):
163167 kernel_initializer = self ._kernel_initializer ,
164168 norm_first = True ,
165169 stochastic_depth_drop_rate = nn_layers .get_stochastic_depth_rate (
166- self ._init_stochastic_depth_rate , i + 1 , self ._num_layers ),
170+ self ._init_stochastic_depth_rate , i + 1 , self ._num_layers
171+ ),
167172 norm_epsilon = 1e-6 ,
168173 layer_scale_init_value = self ._layer_scale_init_value ,
169- transformer_partition_dims = self ._transformer_partition_dims )
174+ transformer_partition_dims = self ._transformer_partition_dims ,
175+ return_attention_scores = self ._output_attention_scores ,
176+ )
170177 self ._encoder_layers .append (encoder_layer )
171178 self ._norm = layers .LayerNormalization (epsilon = 1e-6 )
172179 super ().build (input_shape )
@@ -177,9 +184,16 @@ def call(self, inputs, training=None):
177184 x = self ._pos_embed (x , inputs_positions = self ._inputs_positions )
178185 x = self ._dropout (x , training = training )
179186
187+ attention_scores = None # Needed to suppress undefined-variable warning.
180188 for encoder_layer in self ._encoder_layers :
181- x = encoder_layer (x , training = training )
189+ if self ._output_attention_scores :
190+ x , attention_scores = encoder_layer (x , training = training )
191+ else :
192+ x = encoder_layer (x , training = training )
182193 x = self ._norm (x )
194+
195+ if self ._output_attention_scores :
196+ return x , attention_scores
183197 return x
184198
185199 def get_config (self ):
@@ -199,6 +213,7 @@ def get_config(self):
199213 'pos_embed_target_shape' : self ._pos_embed_target_shape ,
200214 'layer_scale_init_value' : self ._layer_scale_init_value ,
201215 'transformer_partition_dims' : self ._transformer_partition_dims ,
216+ 'output_attention_scores' : self ._output_attention_scores ,
202217 }
203218 config .update (updates )
204219 return config
@@ -227,6 +242,7 @@ def __init__(
227242 pos_embed_shape : Optional [Tuple [int , int ]] = None ,
228243 layer_scale_init_value : float = 0.0 ,
229244 transformer_partition_dims : Optional [Tuple [int , int , int , int ]] = None ,
245+ output_attention_scores : bool = False ,
230246 ):
231247 """VisionTransformer initialization function."""
232248 self ._mlp_dim = mlp_dim
@@ -265,20 +281,29 @@ def __init__(
265281 if pooler == 'token' :
266282 x = TokenLayer (name = 'cls' )(x )
267283
268- x = Encoder (
284+ encoder_output = Encoder (
269285 num_layers = num_layers ,
270286 mlp_dim = mlp_dim ,
271287 num_heads = num_heads ,
272288 dropout_rate = dropout_rate ,
273289 attention_dropout_rate = attention_dropout_rate ,
274290 kernel_regularizer = kernel_regularizer ,
275- kernel_initializer = 'glorot_uniform' if original_init else dict (
276- class_name = 'TruncatedNormal' , config = dict (stddev = .02 )),
291+ kernel_initializer = 'glorot_uniform'
292+ if original_init
293+ else dict (class_name = 'TruncatedNormal' , config = dict (stddev = 0.02 )),
277294 init_stochastic_depth_rate = init_stochastic_depth_rate ,
278295 pos_embed_origin_shape = pos_embed_shape ,
279296 pos_embed_target_shape = pos_embed_target_shape ,
280- layer_scale_init_value = layer_scale_init_value )(
281- x )
297+ layer_scale_init_value = layer_scale_init_value ,
298+ output_attention_scores = output_attention_scores ,
299+ )(x )
300+
301+ endpoints = {}
302+ if output_attention_scores :
303+ x , attention_scores = encoder_output
304+ endpoints ['attention_scores' ] = attention_scores
305+ else :
306+ x = encoder_output
282307
283308 if pooler == 'token' :
284309 output_feature = x [:, 1 :]
@@ -292,7 +317,6 @@ def __init__(
292317 else :
293318 raise ValueError (f'unrecognized pooler type: { pooler } ' )
294319
295- endpoints = {}
296320 if output_2d_feature_maps :
297321 # Use the closest feature level.
298322 feat_level = round (math .log2 (patch_size ))
@@ -376,4 +400,6 @@ def build_vit(input_specs,
376400 output_2d_feature_maps = backbone_cfg .output_2d_feature_maps ,
377401 layer_scale_init_value = backbone_cfg .layer_scale_init_value ,
378402 pos_embed_shape = backbone_cfg .pos_embed_shape ,
379- transformer_partition_dims = backbone_cfg .transformer_partition_dims )
403+ transformer_partition_dims = backbone_cfg .transformer_partition_dims ,
404+ output_attention_scores = backbone_cfg .output_attention_scores ,
405+ )
0 commit comments