Skip to content

Commit 9795865

Browse files
alfieroddanintelwaschsalzsamet-akcayrajeshgangireddy
authored
feat(model): add DINOv2 official implementation and AnomalyDINO (#3105)
* init commit Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * add cdist instead of euclidean distance, divide by 2 for cosine and improve comments Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * remove redundant euclidean distance Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * Add AnomalyDINO to model list. Also alphabetically re-order some models Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * add precision modifier Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * remove fit comments. small typo of shape dimensions Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * update docs for anomaly dino Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * cleanup docstrings Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * add unit tests for anomalydino. change distance computation from cdist to matmul, work with half tensors Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * add vit/dino implementation (no xformers). implement factory class for generating dinov2. update anomaly_dino to use factory method Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * change dinov2loader to factory method, remove duplicated components from dinomaly Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * add from_name back Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * add tests for vit and dinov2loader Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * update docstrings Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * fix(accelerator): Adding name method in XPUAccelerator (#3108) * Update xpu.py regarind PR #3092 Added the name method to fix an issue related to a newly added feature in lightning 2.5.6 Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> * Update xpu.py Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> * Update xpu.py Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> * Update xpu.py with docstring Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> * Update xpu.py with correct docstring Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> * added name method for XPUAccelerator Signed-off-by: waschsalz <niclas.zschach@icloud.com> --------- Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> Signed-off-by: waschsalz <niclas.zschach@icloud.com> Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * change licesning with meta. Tensor is torch.Tensor. remove __future__. Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> * Update src/anomalib/models/components/dinov2/layers/block.py Signed-off-by: Samet Akcay <smt.akcay@gmail.com> * Update src/anomalib/models/components/dinov2/layers/attention.py Signed-off-by: Samet Akcay <smt.akcay@gmail.com> * Update src/anomalib/models/components/dinov2/layers/dino_head.py Signed-off-by: Samet Akcay <smt.akcay@gmail.com> * Update src/anomalib/models/components/dinov2/layers/drop_path.py Signed-off-by: Samet Akcay <smt.akcay@gmail.com> * Update src/anomalib/models/components/dinov2/layers/layer_scale.py Signed-off-by: Samet Akcay <smt.akcay@gmail.com> * Update src/anomalib/models/components/dinov2/layers/mlp.py Signed-off-by: Samet Akcay <smt.akcay@gmail.com> * Update src/anomalib/models/components/dinov2/layers/patch_embed.py Signed-off-by: Samet Akcay <smt.akcay@gmail.com> * Update src/anomalib/models/components/dinov2/layers/swiglu_ffn.py Signed-off-by: Samet Akcay <smt.akcay@gmail.com> * Update src/anomalib/models/components/dinov2/vision_transformer.py Signed-off-by: Samet Akcay <smt.akcay@gmail.com> --------- Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> Signed-off-by: waschsalz <niclas.zschach@icloud.com> Signed-off-by: Samet Akcay <smt.akcay@gmail.com> Co-authored-by: Niclas <152474825+waschsalz@users.noreply.github.com> Co-authored-by: Samet Akcay <smt.akcay@gmail.com> Co-authored-by: Samet Akcay <samet.akcay@intel.com> Co-authored-by: Rajesh Gangireddy <rajesh.gangireddy@intel.com>
1 parent 1fedbd6 commit 9795865

File tree

31 files changed

+2891
-247
lines changed

31 files changed

+2891
-247
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# AnomalyDINO
2+
3+
```{eval-rst}
4+
.. automodule:: anomalib.models.image.anomaly_dino.lightning_model
5+
:members: AnomalyDINO
6+
:show-inheritance:
7+
```
8+
9+
```{eval-rst}
10+
.. automodule:: anomalib.models.image.anomaly_dino.torch_model
11+
:members: AnomalyDINOModel
12+
:show-inheritance:
13+
```

docs/source/markdown/guides/reference/models/image/index.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
:margin: 1 1 0 0
55
:gutter: 1
66

7+
:::{grid-item-card} {material-regular}`model_training;1.5em` AnomalyDINO
8+
:link: ./anomaly_dino
9+
:link-type: doc
10+
11+
Boosting Patch-based Few-shot Anomaly Detection with DINOv2
12+
:::
13+
714
:::{grid-item-card} {material-regular}`model_training;1.5em` CFA
815
:link: ./cfa
916
:link-type: doc
@@ -157,6 +164,7 @@ WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation
157164
:caption: Data
158165
:hidden:
159166
167+
./anomaly_dino
160168
./cfa
161169
./cflow
162170
./csflow

examples/configs/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ configs/
2121
│ └── visa.yaml
2222
└── model
2323
├── ai_vad.yaml
24+
├── anomaly_dino.yaml
2425
├── cfa.yaml
2526
├── cflow.yaml
2627
├── csflow.yaml
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
model:
2+
class_path: anomalib.models.AnomalyDINO
3+
init_args:
4+
num_neighbours: 1
5+
encoder_name: dinov2_vit_small_14
6+
masking: False
7+
coreset_subsampling: False

src/anomalib/models/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from anomalib.utils.path import convert_snake_to_pascal_case, convert_to_snake_case, convert_to_title_case
5959

6060
from .image import (
61+
AnomalyDINO,
6162
Cfa,
6263
Cflow,
6364
Csflow,
@@ -96,6 +97,8 @@ class UnknownModelError(ModuleNotFoundError):
9697

9798

9899
__all__ = [
100+
"AiVad",
101+
"AnomalyDINO",
99102
"Cfa",
100103
"Cflow",
101104
"Csflow",
@@ -107,6 +110,7 @@ class UnknownModelError(ModuleNotFoundError):
107110
"EfficientAd",
108111
"Fastflow",
109112
"Fre",
113+
"Fuvas",
110114
"Ganomaly",
111115
"Padim",
112116
"Patchcore",
@@ -117,8 +121,6 @@ class UnknownModelError(ModuleNotFoundError):
117121
"UniNet",
118122
"VlmAd",
119123
"WinClip",
120-
"AiVad",
121-
"Fuvas",
122124
]
123125

124126
logger = logging.getLogger(__name__)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
"""Anomalib's Vision Transformer implementation.
6+
7+
References:
8+
https://github.com/facebookresearch/dinov2/blob/main/dinov2/
9+
10+
Classes:
11+
DinoVisionTransformer: DINOv2 implementation.
12+
DinoV2Loader: Loader class to support downloading and loading weights.
13+
"""
14+
15+
# vision transformer
16+
# loader
17+
from .dinov2_loader import DinoV2Loader
18+
from .vision_transformer import (
19+
DinoVisionTransformer,
20+
vit_base,
21+
vit_giant2,
22+
vit_large,
23+
vit_small,
24+
)
25+
26+
__all__ = [
27+
# vision transformer
28+
"DinoVisionTransformer",
29+
"vit_base",
30+
"vit_giant2",
31+
"vit_large",
32+
"vit_small",
33+
# loader
34+
"DinoV2Loader",
35+
]
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Loading pre-trained DINOv2 Vision Transformer models.
5+
6+
This module provides the :class:`DinoV2Loader` class for constructing and loading
7+
pre-trained DINOv2 Vision Transformer models used in the Dinomaly anomaly detection
8+
framework. It supports both standard DINOv2 models and register-token variants, and
9+
allows custom Vision Transformer factories to be supplied.
10+
11+
Example:
12+
>>> from anomalib.models.components.dinov2 import DinoV2Loader
13+
>>> loader = DinoV2Loader()
14+
>>> model = loader.load("dinov2_vit_base_14")
15+
>>> model = loader.load("vit_base_14")
16+
>>> custom_loader = DinoV2Loader(vit_factory=my_custom_vit_module)
17+
>>> model = custom_loader.load("dinov2reg_vit_base_14")
18+
19+
The DINOv2 loader handles:
20+
21+
- Parsing model names and validating architecture types
22+
- Constructing the appropriate Vision Transformer model
23+
- Locating or downloading the corresponding pre-trained weights
24+
- Supporting custom ViT implementations via a pluggable factory
25+
26+
This enables a simple, unified interface for accessing DINOv2-based backbones in
27+
downstream anomaly detection tasks.
28+
"""
29+
30+
import logging
31+
from pathlib import Path
32+
from typing import ClassVar
33+
from urllib.request import urlretrieve
34+
35+
import torch
36+
37+
from anomalib.data.utils import DownloadInfo
38+
from anomalib.data.utils.download import DownloadProgressBar
39+
from anomalib.models.components.dinov2 import vision_transformer as dinov2_models
40+
41+
logger = logging.getLogger(__name__)
42+
43+
MODEL_FACTORIES: dict[str, object] = {
44+
"dinov2": dinov2_models,
45+
"dinov2_reg": dinov2_models,
46+
}
47+
48+
49+
class DinoV2Loader:
50+
"""Simple loader for DINOv2 Vision Transformer models.
51+
52+
Supports loading dinov2, dinov2_reg, and dinomaly model variants across small, base,
53+
and large architectures.
54+
"""
55+
56+
DINOV2_BASE_URL: ClassVar[str] = "https://dl.fbaipublicfiles.com/dinov2"
57+
58+
MODEL_CONFIGS: ClassVar[dict[str, dict[str, int]]] = {
59+
"small": {"embed_dim": 384, "num_heads": 6},
60+
"base": {"embed_dim": 768, "num_heads": 12},
61+
"large": {"embed_dim": 1024, "num_heads": 16},
62+
}
63+
64+
def __init__(
65+
self,
66+
cache_dir: str | Path = "./pre_trained/",
67+
vit_factory: object | None = None,
68+
) -> None:
69+
self.cache_dir = Path(cache_dir)
70+
self.vit_factory = vit_factory
71+
self.cache_dir.mkdir(parents=True, exist_ok=True)
72+
73+
def load(self, model_name: str) -> torch.nn.Module:
74+
"""Load a DINOv2 model by name.
75+
76+
Args:
77+
model_name: Model identifier such as "dinov2_vit_base_14".
78+
79+
Returns:
80+
A fully constructed and weight-loaded PyTorch module.
81+
82+
Raises:
83+
ValueError: If the requested model name is malformed or unsupported.
84+
"""
85+
model_type, architecture, patch_size = self._parse_name(model_name)
86+
model = self.create_model(model_type, architecture, patch_size)
87+
self._load_weights(model, model_type, architecture, patch_size)
88+
89+
logger.info(f"Loaded model: {model_name}")
90+
return model
91+
92+
@classmethod
93+
def from_name(
94+
cls,
95+
model_name: str,
96+
cache_dir: str | Path = "./pre_trained/",
97+
) -> torch.nn.Module:
98+
"""Instantiate a loader and return the requested model."""
99+
loader = cls(cache_dir)
100+
return loader.load(model_name)
101+
102+
def _parse_name(self, name: str) -> tuple[str, str, int]:
103+
"""Parse a model name string into components.
104+
105+
Args:
106+
name: Full model name string.
107+
108+
Returns:
109+
Tuple of (model_type, architecture_name, patch_size).
110+
111+
Raises:
112+
ValueError: If the prefix or architecture is unknown.
113+
"""
114+
parts = name.split("_")
115+
prefix = parts[0]
116+
architecture = parts[-2]
117+
patch_size = int(parts[-1])
118+
119+
if prefix == "dinov2reg":
120+
model_type = "dinov2_reg"
121+
elif prefix == "dinov2":
122+
model_type = "dinov2"
123+
elif prefix == "dinomaly":
124+
model_type = "dinomaly"
125+
else:
126+
msg = f"Unknown model type prefix '{prefix}'."
127+
raise ValueError(msg)
128+
129+
if architecture not in self.MODEL_CONFIGS:
130+
msg = f"Invalid architecture '{architecture}'. Expected one of: {list(self.MODEL_CONFIGS)}"
131+
raise ValueError(
132+
msg,
133+
)
134+
135+
return model_type, architecture, patch_size
136+
137+
def create_model(self, model_type: str, architecture: str, patch_size: int) -> torch.nn.Module:
138+
"""Create a Vision Transformer model.
139+
140+
Args:
141+
model_type: Normalized model family name (e.g., "dinov2", "dinov2_reg").
142+
architecture: Architecture size (e.g., "small", "base", "large").
143+
patch_size: ViT patch size.
144+
145+
Returns:
146+
Instantiated Vision Transformer model.
147+
148+
Raises:
149+
ValueError: If no matching constructor exists.
150+
"""
151+
model_kwargs = {
152+
"patch_size": patch_size,
153+
"img_size": 518,
154+
"block_chunks": 0,
155+
"init_values": 1e-8,
156+
"interpolate_antialias": False,
157+
"interpolate_offset": 0.1,
158+
}
159+
160+
if model_type == "dinov2_reg":
161+
model_kwargs["num_register_tokens"] = 4
162+
163+
# If user supplied a custom ViT module, use it
164+
module = self.vit_factory if self.vit_factory is not None else MODEL_FACTORIES[model_type]
165+
166+
ctor = getattr(module, f"vit_{architecture}", None)
167+
if ctor is None:
168+
msg = f"No constructor vit_{architecture} in module {module}"
169+
raise ValueError(msg)
170+
171+
return ctor(**model_kwargs)
172+
173+
def _load_weights(
174+
self,
175+
model: torch.nn.Module,
176+
model_type: str,
177+
architecture: str,
178+
patch_size: int,
179+
) -> None:
180+
"""Load pre-trained weights from disk, downloading them if necessary."""
181+
weight_path = self._get_weight_path(model_type, architecture, patch_size)
182+
183+
if not weight_path.exists():
184+
self._download_weights(model_type, architecture, patch_size)
185+
186+
# Using weights_only=True for safety mitigation (see Anomalib PR #2729)
187+
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True) # nosec B614
188+
model.load_state_dict(state_dict, strict=False)
189+
190+
def _get_weight_path(
191+
self,
192+
model_type: str,
193+
architecture: str,
194+
patch_size: int,
195+
) -> Path:
196+
"""Return the expected local path for downloaded weights."""
197+
arch_code = architecture[0]
198+
199+
if model_type == "dinov2_reg":
200+
filename = f"dinov2_vit{arch_code}{patch_size}_reg4_pretrain.pth"
201+
else:
202+
filename = f"dinov2_vit{arch_code}{patch_size}_pretrain.pth"
203+
204+
return self.cache_dir / filename
205+
206+
def _download_weights(
207+
self,
208+
model_type: str,
209+
architecture: str,
210+
patch_size: int,
211+
) -> None:
212+
"""Download DINOv2 weight files using Anomalib's standardized utilities."""
213+
weight_path = self._get_weight_path(model_type, architecture, patch_size)
214+
arch_code = architecture[0]
215+
216+
model_dir = f"dinov2_vit{arch_code}{patch_size}"
217+
url = f"{self.DINOV2_BASE_URL}/{model_dir}/{weight_path.name}"
218+
219+
download_info = DownloadInfo(
220+
name=f"DINOv2 {model_type} {architecture} weights",
221+
url=url,
222+
hashsum="", # DINOv2 publishes no official hash
223+
filename=weight_path.name,
224+
)
225+
226+
logger.info(
227+
f"Downloading DINOv2 weights: {weight_path.name} to {self.cache_dir}",
228+
)
229+
230+
self.cache_dir.mkdir(parents=True, exist_ok=True)
231+
232+
with DownloadProgressBar(
233+
unit="B",
234+
unit_scale=True,
235+
miniters=1,
236+
desc=download_info.name,
237+
) as progress_bar:
238+
# nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected # noqa: ERA001, E501
239+
urlretrieve( # noqa: S310 # nosec B310
240+
url=url,
241+
filename=weight_path,
242+
reporthook=progress_bar.update_to,
243+
)

0 commit comments

Comments
 (0)