@@ -75,12 +75,13 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
7575 """Patches the image and embeds the patches.
7676
7777 Args:
78- image_size: int. Size of the input image (height or width).
79- Assumed to be square.
80- patch_size: int. Size of each image patch.
78+ image_size: (int, int). Size of the input image.
79+ patch_size: (int, int). Size of each image patch.
8180 hidden_dim: int. Dimensionality of the patch embeddings.
8281 num_channels: int. Number of channels in the input image. Defaults to
8382 `3`.
83+ use_class_token: bool. Whether to use class token to be part of
84+ patch embedding. Defaults to `True`.
8485 data_format: str. `"channels_last"` or `"channels_first"`. Defaults to
8586 `None` (which uses `"channels_last"`).
8687 **kwargs: Additional keyword arguments passed to `keras.layers.Layer`
@@ -92,12 +93,15 @@ def __init__(
9293 patch_size ,
9394 hidden_dim ,
9495 num_channels = 3 ,
96+ use_class_token = True ,
97+ use_patch_bias = True ,
9598 data_format = None ,
9699 ** kwargs ,
97100 ):
98101 super ().__init__ (** kwargs )
99- num_patches = (image_size // patch_size ) ** 2
100- num_positions = num_patches + 1
102+ grid_size = tuple ([s // p for s , p in zip (image_size , patch_size )])
103+ num_patches = grid_size [0 ] * grid_size [1 ]
104+ num_positions = num_patches + 1 if use_class_token else num_patches
101105
102106 # === Config ===
103107 self .image_size = image_size
@@ -106,19 +110,22 @@ def __init__(
106110 self .num_channels = num_channels
107111 self .num_patches = num_patches
108112 self .num_positions = num_positions
113+ self .use_class_token = use_class_token
114+ self .use_patch_bias = use_patch_bias
109115 self .data_format = standardize_data_format (data_format )
110116
111117 def build (self , input_shape ):
112- self .class_token = self .add_weight (
113- shape = (
114- 1 ,
115- 1 ,
116- self .hidden_dim ,
117- ),
118- initializer = "random_normal" ,
119- dtype = self .variable_dtype ,
120- name = "class_token" ,
121- )
118+ if self .use_class_token :
119+ self .class_token = self .add_weight (
120+ shape = (
121+ 1 ,
122+ 1 ,
123+ self .hidden_dim ,
124+ ),
125+ initializer = "random_normal" ,
126+ dtype = self .variable_dtype ,
127+ name = "class_token" ,
128+ )
122129 self .patch_embedding = keras .layers .Conv2D (
123130 filters = self .hidden_dim ,
124131 kernel_size = self .patch_size ,
@@ -127,6 +134,7 @@ def build(self, input_shape):
127134 activation = None ,
128135 dtype = self .dtype_policy ,
129136 data_format = self .data_format ,
137+ use_bias = self .use_patch_bias ,
130138 name = "patch_embedding" ,
131139 )
132140 self .patch_embedding .build (input_shape )
@@ -153,10 +161,16 @@ def call(self, inputs):
153161 patch_embeddings = ops .reshape (
154162 patch_embeddings , [embeddings_shape [0 ], - 1 , embeddings_shape [- 1 ]]
155163 )
156- class_token = ops .tile (self .class_token , (embeddings_shape [0 ], 1 , 1 ))
157164 position_embeddings = self .position_embedding (self .position_ids )
158- embeddings = ops .concatenate ([class_token , patch_embeddings ], axis = 1 )
159- return ops .add (embeddings , position_embeddings )
165+
166+ if self .use_class_token :
167+ class_token = ops .tile (
168+ self .class_token , (embeddings_shape [0 ], 1 , 1 )
169+ )
170+ patch_embeddings = ops .concatenate (
171+ [class_token , patch_embeddings ], axis = 1
172+ )
173+ return ops .add (patch_embeddings , position_embeddings )
160174
161175 def compute_output_shape (self , input_shape ):
162176 return (
@@ -175,6 +189,7 @@ def get_config(self):
175189 "num_channels" : self .num_channels ,
176190 "num_patches" : self .num_patches ,
177191 "num_positions" : self .num_positions ,
192+ "use_class_token" : self .use_class_token ,
178193 }
179194 )
180195 return config
0 commit comments