Skip to content

Commit fe6cdd2

Browse files
committed
Update huggingface model
Update huggingface model Update README.md Update README.md Update README.md Update huggingface model Update huggingface model
1 parent 3bd2e7b commit fe6cdd2

File tree

82 files changed

+20063
-7
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+20063
-7
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ segmentation/convertor/
66
checkpoint_dir/
77
demo/
88
pretrained/
9+
upload.py
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
---
2+
license: mit
3+
pipeline_tag: image-classification
4+
library_name: transformers
5+
tags:
6+
- internimage
7+
- custom_code
8+
datasets:
9+
- ILSVRC/imagenet-1k
10+
---
11+
12+
# InternImage Model Card
13+
14+
## Introduction
15+
16+
InternImage is an advanced vision foundation model developed by researchers from Shanghai AI Laboratory, Tsinghua University, and other institutions. Unlike models based on Transformers, InternImage employs DCNv3 as its core operator. This approach equips the model with dynamic and effective receptive fields required for downstream tasks like object detection and segmentation, while enabling adaptive spatial aggregation.
17+
18+
<div style="text-align: center;"> <img src="https://github.com/OpenGVLab/InternImage/raw/master/docs/figs/arch.png" style="width:60%;" /> </div>
19+
20+
## Performance
21+
22+
- InternImage achieved an impressive Top-1 accuracy of 90.1% on the ImageNet benchmark dataset using only publicly available data for image classification. Apart from two undisclosed models trained with additional datasets by Google and Microsoft, InternImage is the only open-source model that achieves a Top-1 accuracy of over 90.0%, and it is also the largest model in scale worldwide.
23+
- InternImage outperformed all other models worldwide on the COCO object detection benchmark dataset with a remarkable mAP of 65.5, making it the only model that surpasses 65 mAP in the world.
24+
- InternImage also demonstrated world's best performance on 16 other important visual benchmark datasets, covering a wide range of tasks such as classification, detection, and segmentation, making it the top-performing model across multiple domains.
25+
26+
## Released Models
27+
28+
### Open‑Source Visual Pretrained Models
29+
30+
| huggingface name | model name | pretrain | resolution | #param |
31+
| :-------------------------------------------------------------------------------------------: | :------------: | :------------------: | :--------: | :----: |
32+
| [internimage_l_22k_384](https://huggingface.co/OpenGVLab/internimage_l_22k_384) | InternImage-L | IN-22K | 384x384 | 223M |
33+
| [internimage_xl_22k_384](https://huggingface.co/OpenGVLab/internimage_xl_22k_384) | InternImage-XL | IN-22K | 384x384 | 335M |
34+
| [internimage_h_jointto22k_384](https://huggingface.co/OpenGVLab/internimage_h_jointto22k_384) | InternImage-H | Joint 427M -> IN-22K | 384x384 | 1.08B |
35+
| [internimage_g_jointto22k_384](https://huggingface.co/OpenGVLab/internimage_g_jointto22k_384) | InternImage-G | Joint 427M -> IN-22K | 384x384 | 3B |
36+
37+
### ImageNet-1K Image Classification
38+
39+
| huggingface name | model name | pretrain | resolution | acc@1 | #param | FLOPs |
40+
| :---------------------------------------------------------------------------------------: | :------------: | :------------------: | :--------: | :---: | :----: | :---: |
41+
| [internimage_t_1k_224](https://huggingface.co/OpenGVLab/internimage_t_1k_224) | InternImage-T | IN-1K | 224x224 | 83.5 | 30M | 5G |
42+
| [internimage_s_1k_224](https://huggingface.co/OpenGVLab/internimage_s_1k_224) | InternImage-S | IN-1K | 224x224 | 84.2 | 50M | 8G |
43+
| [internimage_b_1k_224](https://huggingface.co/OpenGVLab/internimage_b_1k_224) | InternImage-B | IN-1K | 224x224 | 84.9 | 97M | 16G |
44+
| [internimage_l_22kto1k_384](https://huggingface.co/OpenGVLab/internimage_l_22kto1k_384) | InternImage-L | IN-22K | 384x384 | 87.7 | 223M | 108G |
45+
| [internimage_xl_22kto1k_384](https://huggingface.co/OpenGVLab/internimage_xl_22kto1k_384) | InternImage-XL | IN-22K | 384x384 | 88.0 | 335M | 163G |
46+
| [internimage_h_22kto1k_640](https://huggingface.co/OpenGVLab/internimage_h_22kto1k_640) | InternImage-H | Joint 427M -> IN-22K | 640x640 | 89.6 | 1.08B | 1478G |
47+
| [internimage_g_22kto1k_512](https://huggingface.co/OpenGVLab/internimage_g_22kto1k_512) | InternImage-G | Joint 427M -> IN-22K | 512x512 | 90.1 | 3B | 2700G |
48+
49+
## DCNv3 CUDA Kernel Installation
50+
51+
If you do not install the CUDA version of DCNv3, InternImage will automatically fall back to a PyTorch implementation. However, the CUDA implementation can significantly reduce GPU memory usage and improve inference efficiency.
52+
53+
**Installation Tutorial:**
54+
55+
1. Open your terminal and run:
56+
57+
```bash
58+
git clone https://github.com/OpenGVLab/InternImage.git
59+
cd InternImage/classification/ops_dcnv3
60+
```
61+
62+
2. Make sure you have an available GPU for compilation, then run:
63+
64+
```bash
65+
sh make.sh
66+
```
67+
68+
This will compile the CUDA version of DCNv3. Once installed, InternImage will automatically leverage the optimized CUDA implementation for better performance.
69+
70+
## Usage with Transformers
71+
72+
Below are two usage examples for InternImage with the Transformers framework:
73+
74+
### Example 1: Using InternImage as an Image Backbone
75+
76+
```python
77+
import torch
78+
from PIL import Image
79+
from transformers import AutoModel, CLIPImageProcessor
80+
81+
# Replace 'model_name' with the appropriate model identifier
82+
model_name = "OpenGVLab/internimage_t_1k_224" # example model
83+
84+
# Prepare the image
85+
image_path = 'img.png'
86+
image_processor = CLIPImageProcessor.from_pretrained(model_name)
87+
image = Image.open(image_path)
88+
image = image_processor(images=image, return_tensors='pt').pixel_values
89+
print('image shape:', image.shape)
90+
91+
# Load the model as a backbone
92+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
93+
# 'hidden_states' contains the outputs from the 4 stages of the InternImage backbone
94+
hidden_states = model(image).hidden_states
95+
```
96+
97+
### Example 2: Using InternImage for Image Classification
98+
99+
```python
100+
import torch
101+
from PIL import Image
102+
from transformers import AutoModelForImageClassification, CLIPImageProcessor
103+
104+
# Replace 'model_name' with the appropriate model identifier
105+
model_name = "OpenGVLab/internimage_t_1k_224" # example model
106+
107+
# Prepare the image
108+
image_path = 'img.png'
109+
image_processor = CLIPImageProcessor.from_pretrained(model_name)
110+
image = Image.open(image_path)
111+
image = image_processor(images=image, return_tensors='pt').pixel_values
112+
print('image shape:', image.shape)
113+
114+
# Load the model as an image classifier
115+
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)
116+
logits = model(image).logits
117+
label = torch.argmax(logits, dim=1)
118+
print("Predicted label:", label.item())
119+
```
120+
121+
## Citation
122+
123+
If this work is helpful for your research, please consider citing the following BibTeX entry.
124+
125+
```Bibtex
126+
@inproceedings{wang2023internimage,
127+
title={Internimage: Exploring large-scale vision foundation models with deformable convolutions},
128+
author={Wang, Wenhai and Dai, Jifeng and Chen, Zhe and Huang, Zhenhang and Li, Zhiqi and Zhu, Xizhou and Hu, Xiaowei and Lu, Tong and Lu, Lewei and Li, Hongsheng and others},
129+
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
130+
pages={14408--14419},
131+
year={2023}
132+
}
133+
```
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
{
2+
"_name_or_path": "OpenGVLab/internimage_g_jointto22k_384",
3+
"act_layer": "GELU",
4+
"architectures": [
5+
"InternImageModel"
6+
],
7+
"auto_map": {
8+
"AutoConfig": "configuration_internimage.InternImageConfig",
9+
"AutoModel": "modeling_internimage.InternImageModel",
10+
"AutoModelForImageClassification": "modeling_internimage.InternImageModelForImageClassification"
11+
},
12+
"center_feature_scale": true,
13+
"channels": 512,
14+
"cls_scale": 1.5,
15+
"core_op": "DCNv3",
16+
"depths": [
17+
2,
18+
2,
19+
48,
20+
4
21+
],
22+
"drop_path_rate": 0.0,
23+
"drop_path_type": "linear",
24+
"drop_rate": 0.0,
25+
"dw_kernel_size": 5,
26+
"groups": [
27+
16,
28+
32,
29+
64,
30+
128
31+
],
32+
"layer_scale": null,
33+
"level2_post_norm": true,
34+
"level2_post_norm_block_ids": [
35+
5,
36+
11,
37+
17,
38+
23,
39+
29,
40+
35,
41+
41,
42+
47
43+
],
44+
"mlp_ratio": 4.0,
45+
"model_type": "internimage",
46+
"norm_layer": "LN",
47+
"num_classes": 21841,
48+
"offset_scale": 1.0,
49+
"post_norm": true,
50+
"remove_center": false,
51+
"res_post_norm": false,
52+
"torch_dtype": "float32",
53+
"transformers_version": "4.37.2",
54+
"use_clip_projector": true,
55+
"with_cp": false
56+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# --------------------------------------------------------
2+
# InternImage
3+
# Copyright (c) 2025 OpenGVLab
4+
# Licensed under The MIT License [see LICENSE for details]
5+
# --------------------------------------------------------
6+
7+
from transformers import PretrainedConfig
8+
9+
10+
class InternImageConfig(PretrainedConfig):
11+
r"""
12+
This is the configuration class to store the configuration of a [`~InternImageModel`].
13+
It is used to instantiate an internimage model according to the specified arguments, defining the model
14+
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
15+
the internimage [OpenGVLab/internimage](https://huggingface.co/OpenGVLab/internimage) architecture.
16+
17+
Configuration objects inherit from [`PretrainedConfig`] and can be used
18+
to control the model outputs. Read the documentation from [`PretrainedConfig`]
19+
for more information.
20+
21+
Args:
22+
core_op (`str`, *optional*, defaults to `"DCNv3"`):
23+
Core operation used in the InternImageModel.
24+
depths (`tuple`, *optional*, defaults to `(4, 4, 18, 4)`):
25+
Tuple specifying the depth of layers in the InternImageModel.
26+
groups (`tuple`, *optional*, defaults to `(4, 8, 16, 32)`):
27+
Tuple specifying the group of layers in the InternImageModel.
28+
channels (`int`, *optional*, defaults to `64`):
29+
Number of channels in the InternImageModel.
30+
dw_kernel_size (`int`, *optional*, defaults to `None`):
31+
Kernel size for depthwise convolutions.
32+
layer_scale (`float`, *optional*, defaults to `None`):
33+
Scale of the layers in the model.
34+
offset_scale (`float`, *optional*, defaults to `1.0`):
35+
Offset scale in the model.
36+
mlp_ratio (`float`, *optional*, defaults to `4.0`):
37+
Ratio of mlp layers in the InternImageModel.
38+
post_norm (`bool`, *optional*, defaults to `False`):
39+
Whether to use post normalization in the model.
40+
level2_post_norm (`bool`, *optional*, defaults to `False`):
41+
Whether to use level 2 post normalization.
42+
level2_post_norm_block_ids (`list`, *optional*, defaults to `None`):
43+
Specific block IDs for level 2 post normalization.
44+
center_feature_scale (`bool`, *optional*, defaults to `False`):
45+
Whether to apply center feature scaling.
46+
use_clip_projector (`bool`, *optional*, defaults to `False`):
47+
Whether to use CLIP projector.
48+
remove_center (`bool`, *optional*, defaults to `False`):
49+
Whether to remove center pixels in some operations.
50+
num_classes (`int`, *optional*, defaults to `1000`):
51+
Number of classes for the model output.
52+
drop_rate (`float`, *optional*, defaults to `0.0`):
53+
Dropout rate in the model.
54+
drop_path_rate (`float`, *optional*, defaults to `0.0`):
55+
Dropout path rate in the model.
56+
drop_path_type (`str`, *optional*, defaults to `"linear"`):
57+
Type of dropout path used in the model.
58+
act_layer (`str`, *optional*, defaults to `"GELU"`):
59+
Activation function used in the model.
60+
norm_layer (`str`, *optional*, defaults to `"LN"`):
61+
Normalization layer used in the model.
62+
cls_scale (`float`, *optional*, defaults to `1.5`):
63+
Scale of the classification layer in the model.
64+
with_cp (`bool`, *optional*, defaults to `False`):
65+
Whether to use checkpointing in the model.
66+
"""
67+
model_type = 'internimage'
68+
69+
def __init__(
70+
self,
71+
core_op='DCNv3',
72+
depths=(4, 4, 18, 4),
73+
groups=(4, 8, 16, 32),
74+
channels=64,
75+
dw_kernel_size=None,
76+
layer_scale=None,
77+
offset_scale=1.0,
78+
mlp_ratio=4.0,
79+
post_norm=False,
80+
res_post_norm=False,
81+
level2_post_norm=False,
82+
level2_post_norm_block_ids=None,
83+
center_feature_scale=False,
84+
use_clip_projector=False,
85+
remove_center=False,
86+
num_classes=1000,
87+
drop_rate=0.0,
88+
drop_path_rate=0.0,
89+
drop_path_type='linear',
90+
act_layer='GELU',
91+
norm_layer='LN',
92+
cls_scale=1.5,
93+
with_cp=False,
94+
**kwargs,
95+
):
96+
super().__init__(**kwargs)
97+
98+
# Model configuration parameters
99+
self.core_op = core_op
100+
self.depths = depths
101+
self.groups = groups
102+
self.channels = channels
103+
self.dw_kernel_size = dw_kernel_size
104+
self.layer_scale = layer_scale
105+
self.offset_scale = offset_scale
106+
self.mlp_ratio = mlp_ratio
107+
self.post_norm = post_norm
108+
self.res_post_norm = res_post_norm
109+
self.level2_post_norm = level2_post_norm
110+
self.level2_post_norm_block_ids = level2_post_norm_block_ids
111+
self.center_feature_scale = center_feature_scale
112+
self.use_clip_projector = use_clip_projector
113+
self.remove_center = remove_center
114+
self.num_classes = num_classes
115+
self.drop_rate = drop_rate
116+
self.drop_path_rate = drop_path_rate
117+
self.drop_path_type = drop_path_type
118+
self.act_layer = act_layer
119+
self.norm_layer = norm_layer
120+
self.cls_scale = cls_scale
121+
self.with_cp = with_cp

0 commit comments

Comments
 (0)