Skip to content

Commit 1a49c56

Browse files
committed
Remove hard-coded min/max period values that were used for testing, clarify comment
1 parent 91e8dc5 commit 1a49c56

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

timm/layers/pos_embed_sincos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -898,8 +898,8 @@ def __init__(
898898
self,
899899
dim: int,
900900
temperature: Optional[float] = 100.0,
901-
min_period: Optional[float] = 0.5,
902-
max_period: Optional[float] = 90.,
901+
min_period: Optional[float] = None,
902+
max_period: Optional[float] = None,
903903
feat_shape: Optional[List[int]] = None,
904904
ref_feat_shape: Optional[List[int]] = None,
905905
normalize_coords: str = "separate", # 'min', 'max', 'separate'
@@ -957,8 +957,8 @@ def _compute_periods(self, device='cpu', dtype=torch.float32) -> torch.Tensor:
957957
exponents = 2.0 * torch.arange(dim, device=device, dtype=dtype) / (self.dim // 2)
958958
periods = self.temperature ** exponents
959959

960-
# NOTE: original has periods downcast to bfloat16 in persistent buffers, so loaded models
961-
# BTW orignal and timm might differ a bit here
960+
# NOTE: The original dinv3 model weights have periods downcast to bfloat16 in persistent buffers,
961+
# loaded models will differ a bit vs timm as periods is not persistent and generated in float32 by default
962962

963963
return periods
964964

0 commit comments

Comments
 (0)