Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mindone/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@
Dinov2WithRegistersModel,
Dinov2WithRegistersPreTrainedModel,
)
from .models.dinov3_convnext import DINOv3ConvNextModel, DINOv3ConvNextPreTrainedModel
from .models.dinov3_vit import DINOv3ViTModel, DINOv3ViTPreTrainedModel, DINOv3ViTImageProcessorFast
from .models.distilbert import (
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
Expand Down
2 changes: 2 additions & 0 deletions mindone/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
diffllama,
dinov2,
dinov2_with_registers,
dinov3_convnext,
dinov3_vit,
distilbert,
dpr,
dpt,
Expand Down
4 changes: 4 additions & 0 deletions mindone/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
("diffllama", "DiffLlamaConfig"),
("dinov2", "Dinov2Config"),
("dinov2_with_registers", "Dinov2WithRegistersConfig"),
("dinov3_convnext", "DINOv3ConvNextConfig"),
("dinov3_vit", "DINOv3ViTConfig"),
("deit", "DeiTConfig"),
("distilbert", "DistilBertConfig"),
("dpr", "DPRConfig"),
Expand Down Expand Up @@ -355,6 +357,8 @@
("diffllama", "DiffLlama"),
("dinov2", "DINOv2"),
("dinov2_with_registers", "DINOv2 with Registers"),
("dinov3_convnext", "DINOv3 ConvNext"),
("dinov3_vit", "DINOv3 ViT"),
("distilbert", "DistilBERT"),
("dpr", "DPR"),
("dpt", "DPT"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
("depth_anything", ("DPTImageProcessor",)),
("depth_pro", ("DepthProImageProcessor",)),
("dinov2", ("BitImageProcessor",)),
("dinov3_vit", ("DINOv3ViTImageProcessorFast",)),
("dpt", ("DPTImageProcessor",)),
("efficientnet", ("EfficientNetImageProcessor",)),
("flava", ("FlavaImageProcessor",)),
Expand Down
4 changes: 4 additions & 0 deletions mindone/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
("diffllama", "DiffLlamaModel"),
("dinov2", "Dinov2Model"),
("dinov2_with_registers", "Dinov2WithRegistersModel"),
("dinov3_convnext", "DINOv3ConvNextModel"),
("dinov3_vit", "DINOv3ViTModel"),
("distilbert", "DistilBertModel"),
("dpr", "DPRQuestionEncoder"),
("dpt", "DPTModel"),
Expand Down Expand Up @@ -514,6 +516,8 @@
("depth_pro", "DepthProModel"),
("dinov2", "Dinov2Model"),
("dinov2_with_registers", "Dinov2WithRegistersModel"),
("dinov3_convnext", "DINOv3ConvNextModel"),
("dinov3_vit", "DINOv3ViTModel"),
("dpt", "DPTModel"),
("efficientnet", "EfficientNetModel"),
("focalnet", "FocalNetModel"),
Expand Down
17 changes: 17 additions & 0 deletions mindone/transformers/models/dinov3_convnext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling_dinov3_convnext import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Wildcard imports (from ... import *) are discouraged by PEP 8 as they make it unclear which names are present in the namespace. It's better to explicitly import the required names. Based on __all__ in modeling_dinov3_convnext.py, you should import DINOv3ConvNextModel and DINOv3ConvNextPreTrainedModel.

Suggested change
from .modeling_dinov3_convnext import *
from .modeling_dinov3_convnext import DINOv3ConvNextModel, DINOv3ConvNextPreTrainedModel

Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# coding=utf-8
# Copyright 2025 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MindSpore ConvNext model."""

from typing import Optional

import numpy as np
import mindspore as ms
from mindspore import mint, nn

from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention,
)
from ...modeling_utils import PreTrainedModel
from transformers.models.dinov3_convnext.configuration_dinov3_convnext import DINOv3ConvNextConfig


# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input: ms.Tensor, drop_prob: float = 0.0, training: bool = False) -> ms.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + mint.rand(shape, dtype=input.dtype, )
random_tensor.floor_() # binarize
output = input.div(keep_prob) * random_tensor
return output


# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->DINOv3ConvNext
class DINOv3ConvNextDropPath(nn.Cell):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob

def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training)

def extra_repr(self) -> str:
return f"p={self.drop_prob}"


class DINOv3ConvNextLayerNorm(mint.nn.LayerNorm):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
"""

def __init__(self, *args, data_format="channels_last", **kwargs):
super().__init__(*args, **kwargs)
if data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError(f"Unsupported data format: {data_format}")
self.data_format = data_format

def construct(self, features: ms.Tensor) -> ms.Tensor:
"""
Args:
features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
"""
if self.data_format == "channels_first":
features = features.permute(0, 2, 3, 1)
features = super().construct(features)
features = features.permute(0, 3, 1, 2)
else:
features = super().construct(features)
return features


class DINOv3ConvNextLayer(nn.Cell):
"""This corresponds to the `Block` class in the original implementation.

There are two equivalent implementations:
1) DwConv, LayerNorm (channels_first), Conv, GELU, Conv (all in (N, C, H, W) format)
2) DwConv, Permute, LayerNorm (channels_last), Linear, GELU, Linear, Permute

The authors used (2) as they find it slightly faster in PyTorch.

Args:
config ([`DINOv3ConvNextConfig`]):
Model config.
channels (`int`):
Number of input (and output) channels.
drop_path (`float`):
Drop path rate. Default: 0.0.
"""

def __init__(self, config: DINOv3ConvNextConfig, channels: int, drop_path: float = 0.0):
super().__init__()
self.depthwise_conv = mint.nn.Conv2d(channels, channels, kernel_size=7, padding=3, groups=channels)
self.layer_norm = DINOv3ConvNextLayerNorm(channels, eps=config.layer_norm_eps)
self.pointwise_conv1 = mint.nn.Linear(channels, 4 * channels) # can be seen as a 1x1 conv
self.activation_fn = ACT2FN[config.hidden_act]
self.pointwise_conv2 = mint.nn.Linear(4 * channels, channels) # can be seen as a 1x1 conv
self.gamma = ms.Parameter(mint.full((channels,), config.layer_scale_init_value), requires_grad=True)
self.drop_path = DINOv3ConvNextDropPath(drop_path) if drop_path > 0.0 else mint.nn.Identity()

def construct(self, features: ms.Tensor) -> ms.Tensor:
"""
Args:
features: Tensor of shape (batch_size, channels, height, width)
"""
residual = features
features = self.depthwise_conv(features)
features = features.permute(0, 2, 3, 1) # to channels last
features = self.layer_norm(features)
features = self.pointwise_conv1(features)
features = self.activation_fn(features)
features = self.pointwise_conv2(features)
features = features * self.gamma
features = features.permute(0, 3, 1, 2) # back to channels first
features = residual + self.drop_path(features)
return features


class DINOv3ConvNextStage(nn.Cell):
""" """

def __init__(self, config: DINOv3ConvNextConfig, stage_idx: int):
super().__init__()

in_channels = config.hidden_sizes[stage_idx - 1] if stage_idx > 0 else config.num_channels
out_channels = config.hidden_sizes[stage_idx]

if stage_idx == 0:
self.downsample_layers = nn.CellList(
[
mint.nn.Conv2d(config.num_channels, out_channels, kernel_size=4, stride=4),
DINOv3ConvNextLayerNorm(out_channels, eps=config.layer_norm_eps, data_format="channels_first"),
]
)
else:
self.downsample_layers = nn.CellList(
[
DINOv3ConvNextLayerNorm(in_channels, eps=config.layer_norm_eps, data_format="channels_first"),
mint.nn.Conv2d(in_channels, out_channels, kernel_size=2, stride=2),
]
)

num_stage_layers = config.depths[stage_idx]
num_previous_layers = sum(config.depths[:stage_idx])
num_total_layers = sum(config.depths)
drop_path_rates = np.linspace(0, config.drop_path_rate, num_total_layers).tolist()

self.layers = nn.CellList(
[
DINOv3ConvNextLayer(config, channels=out_channels, drop_path=drop_path_rates[i])
for i in range(num_previous_layers, num_previous_layers + num_stage_layers)
]
)

def construct(self, features: ms.Tensor) -> ms.Tensor:
"""
Args:
features: Tensor of shape (batch_size, channels, height, width)
"""
for layer in self.downsample_layers:
features = layer(features)
for layer in self.layers:
features = layer(features)
return features


class DINOv3ConvNextPreTrainedModel(PreTrainedModel):
config: DINOv3ConvNextConfig
base_model_prefix = "dinov3_convnext"
main_input_name = "pixel_values"
_no_split_modules = ["DINOv3ConvNextLayer"]

def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (mint.nn.Linear, mint.nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (mint.nn.LayerNorm, DINOv3ConvNextLayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, DINOv3ConvNextLayer):
if module.gamma is not None:
module.gamma.data.fill_(self.config.layer_scale_init_value)
Comment on lines +196 to +209
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The weight initialization method _init_weights uses PyTorch-style in-place modification on .data, which is not supported for mindspore.Parameter. You should use helper functions like normal_, zeros_, and constant_ from mindone.models.utils to initialize the parameters correctly. Please also add from mindone.models.utils import constant_, normal_, zeros_ to the imports at the top of the file.

Suggested change
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (mint.nn.Linear, mint.nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (mint.nn.LayerNorm, DINOv3ConvNextLayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, DINOv3ConvNextLayer):
if module.gamma is not None:
module.gamma.data.fill_(self.config.layer_scale_init_value)
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (mint.nn.Linear, mint.nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
zeros_(module.bias)
elif isinstance(module, (mint.nn.LayerNorm, DINOv3ConvNextLayerNorm)):
zeros_(module.bias)
constant_(module.weight, 1.0)
elif isinstance(module, DINOv3ConvNextLayer):
if module.gamma is not None:
constant_(module.gamma, self.config.layer_scale_init_value)



class DINOv3ConvNextModel(DINOv3ConvNextPreTrainedModel):
def __init__(self, config: DINOv3ConvNextConfig):
super().__init__(config)
self.config = config
self.stages = nn.CellList([DINOv3ConvNextStage(config, stage_idx) for stage_idx in range(config.num_stages)])
self.layer_norm = mint.nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps) # final norm layer
self.pool = mint.nn.AdaptiveAvgPool2d(1)
self.post_init()

def construct(
self, pixel_values: ms.Tensor, output_hidden_states: Optional[bool] = None
) -> BaseModelOutputWithPoolingAndNoAttention:
hidden_states = pixel_values

output_hidden_states = output_hidden_states or self.config.output_hidden_states
all_hidden_states = [hidden_states] if output_hidden_states else []

for stage in self.stages:
hidden_states = stage(hidden_states)

# store intermediate stage outputs
if output_hidden_states:
all_hidden_states.append(hidden_states)

# make global representation, a.k.a [CLS] token
pooled_output = self.pool(hidden_states)

# (batch_size, channels, height, width) -> (batch_size, height * width, channels)
pooled_output = pooled_output.flatten(2).transpose(1, 2)
hidden_states = hidden_states.flatten(2).transpose(1, 2)

# concat "cls" and "patch tokens" as (batch_size, 1 + height * width, channels)
hidden_states = mint.cat([pooled_output, hidden_states], dim=1)
hidden_states = self.layer_norm(hidden_states)

return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=hidden_states,
pooler_output=hidden_states[:, 0],
hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
)


__all__ = ["DINOv3ConvNextModel", "DINOv3ConvNextPreTrainedModel"]
18 changes: 18 additions & 0 deletions mindone/transformers/models/dinov3_vit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/transformers
# with modifications to run transformers on mindspore.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .image_processing_dinov3_vit_fast import DINOv3ViTImageProcessorFast
from .modeling_dinov3_vit import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Wildcard imports (from ... import *) are discouraged by PEP 8 as they make it unclear which names are present in the namespace. It's better to explicitly import the required names. Based on __all__ in modeling_dinov3_vit.py, you should import DINOv3ViTModel and DINOv3ViTPreTrainedModel.

Suggested change
from .modeling_dinov3_vit import *
from .modeling_dinov3_vit import DINOv3ViTModel, DINOv3ViTPreTrainedModel

Loading