@@ -885,6 +885,54 @@ def no_weight_decay(self):
885885 return {'freqs' }
886886
887887
888+ @torch .fx .wrap
889+ @register_notrace_function
890+ def make_coords_dinov3 (
891+ height : int ,
892+ width : int ,
893+ normalize_coords : str = 'separate' ,
894+ grid_indexing : str = 'ij' ,
895+ grid_offset : float = 0. ,
896+ device : torch .device = 'cpu' ,
897+ dtype : torch .dtype = torch .float32 ,
898+ ) -> torch .Tensor :
899+ """Make coordinate grid matching offset and normalization of original.
900+ Returns: coords with shape (HW, 2) in [-1, 1].
901+ """
902+ # 0.5-centered indices with optional offset
903+ coords_h = torch .arange (0.5 , height , device = device , dtype = dtype ) + grid_offset
904+ coords_w = torch .arange (0.5 , width , device = device , dtype = dtype ) + grid_offset
905+
906+ # Normalization denominators
907+ if normalize_coords == "max" :
908+ denom = float (max (height , width ))
909+ h_denom = denom
910+ w_denom = denom
911+ elif normalize_coords == "min" :
912+ denom = float (min (height , width ))
913+ h_denom = denom
914+ w_denom = denom
915+ elif normalize_coords == "separate" :
916+ h_denom = float (height )
917+ w_denom = float (width )
918+ else :
919+ raise ValueError (f"Unknown normalize_coords: { normalize_coords } " )
920+
921+ # Normalize to [0, 1]
922+ coords_h = coords_h / h_denom
923+ coords_w = coords_w / w_denom
924+
925+ # Create grid then map to [-1, 1]
926+ if grid_indexing == "xy" :
927+ grid_w , grid_h = torch .meshgrid (coords_w , coords_h , indexing = "xy" )
928+ coords = torch .stack ([grid_h , grid_w ], dim = - 1 ) # (H, W, 2) -> (h, w order)
929+ else :
930+ coords = torch .stack (torch .meshgrid (coords_h , coords_w , indexing = "ij" ), dim = - 1 ) # (H, W, 2)
931+ coords = coords .flatten (0 , 1 ) # (HW, 2)
932+ coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1]
933+ return coords
934+
935+
888936class RotaryEmbeddingDinoV3 (nn .Module ):
889937 """RoPE for timm DinoV3 port, numerically matching original.
890938
@@ -960,49 +1008,6 @@ def _compute_periods(self, device: torch.device = 'cpu', dtype: torch.dtype = to
9601008
9611009 return periods
9621010
963- def _make_coords (
964- self ,
965- height : int ,
966- width : int ,
967- device : torch .device = 'cpu' ,
968- dtype : torch .dtype = torch .float32 ,
969- ) -> torch .Tensor :
970- """Make coordinate grid matching offset and normalization of original.
971- Returns: coords with shape (HW, 2) in [-1, 1].
972- """
973- # 0.5-centered indices with optional offset
974- coords_h = torch .arange (0.5 , height , device = device , dtype = dtype ) + self .grid_offset
975- coords_w = torch .arange (0.5 , width , device = device , dtype = dtype ) + self .grid_offset
976-
977- # Normalization denominators
978- if self .normalize_coords == "max" :
979- denom = float (max (height , width ))
980- h_denom = denom
981- w_denom = denom
982- elif self .normalize_coords == "min" :
983- denom = float (min (height , width ))
984- h_denom = denom
985- w_denom = denom
986- elif self .normalize_coords == "separate" :
987- h_denom = float (height )
988- w_denom = float (width )
989- else :
990- raise ValueError (f"Unknown normalize_coords: { self .normalize_coords } " )
991-
992- # Normalize to [0, 1]
993- coords_h = coords_h / h_denom
994- coords_w = coords_w / w_denom
995-
996- # Create grid then map to [-1, 1]
997- if self .grid_indexing == "xy" :
998- grid_w , grid_h = torch .meshgrid (coords_w , coords_h , indexing = "xy" )
999- coords = torch .stack ([grid_h , grid_w ], dim = - 1 ) # (H, W, 2) -> (h, w order)
1000- else :
1001- coords = torch .stack (torch .meshgrid (coords_h , coords_w , indexing = "ij" ), dim = - 1 ) # (H, W, 2)
1002- coords = coords .flatten (0 , 1 ) # (HW, 2)
1003- coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1]
1004- return coords
1005-
10061011 def _apply_coord_augs (self , coords : torch .Tensor ) -> torch .Tensor :
10071012 """Apply shift/jitter/rescale train time augmentations."""
10081013 if not self .training or not self .aug_active :
@@ -1063,7 +1068,12 @@ def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tenso
10631068
10641069 def _create_embed (self , feat_shape : List [int ], no_aug : bool = False ) -> torch .Tensor :
10651070 H , W = feat_shape
1066- coords = self ._make_coords (H , W ) # (HW, 2)
1071+ coords = make_coords_dinov3 (
1072+ H , W ,
1073+ normalize_coords = self .normalize_coords ,
1074+ grid_indexing = self .grid_indexing ,
1075+ grid_offset = self .grid_offset
1076+ ) # (HW, 2)
10671077 if not no_aug :
10681078 coords = self ._apply_coord_augs (coords )
10691079 sin , cos = self ._get_pos_embed_from_coords (coords ) # 2 * (HW, dim)
0 commit comments