Skip to content

Commit 9c14654

Browse files
committed
Improve support for custom dataset label name/description through HF hub export, via pretrained_cfg
1 parent 1e0b347 commit 9c14654

File tree

4 files changed

+75
-24
lines changed

4 files changed

+75
-24
lines changed

timm/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .constants import *
55
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
66
from .dataset_factory import create_dataset
7-
from .dataset_info import DatasetInfo
7+
from .dataset_info import DatasetInfo, CustomDatasetInfo
88
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
99
from .loader import create_loader
1010
from .mixup import Mixup, FastCollateMixup

timm/data/dataset_info.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Dict, List, Union
2+
from typing import Dict, List, Optional, Union
33

44

55
class DatasetInfo(ABC):
@@ -29,4 +29,45 @@ def index_to_description(self, index: int, detailed: bool = False) -> str:
2929

3030
@abstractmethod
3131
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
32-
pass
32+
pass
33+
34+
35+
class CustomDatasetInfo(DatasetInfo):
36+
""" DatasetInfo that wraps passed values for custom datasets."""
37+
38+
def __init__(
39+
self,
40+
label_names: Union[List[str], Dict[int, str]],
41+
label_descriptions: Optional[Dict[str, str]] = None
42+
):
43+
super().__init__()
44+
assert len(label_names) > 0
45+
self._label_names = label_names # label index => label name mapping
46+
self._label_descriptions = label_descriptions # label name => label description mapping
47+
if self._label_descriptions is not None:
48+
# validate descriptions (label names required)
49+
assert isinstance(self._label_descriptions, dict)
50+
for n in self._label_names:
51+
assert n in self._label_descriptions
52+
53+
def num_classes(self):
54+
return len(self._label_names)
55+
56+
def label_names(self):
57+
return self._label_names
58+
59+
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
60+
return self._label_descriptions
61+
62+
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
63+
if self._label_descriptions:
64+
return self._label_descriptions[label]
65+
return label # return label name itself if a descriptions is not present
66+
67+
def index_to_label_name(self, index) -> str:
68+
assert 0 <= index < len(self._label_names)
69+
return self._label_names[index]
70+
71+
def index_to_description(self, index: int, detailed: bool = False) -> str:
72+
label = self.index_to_label_name(index)
73+
return self.label_name_to_description(label, detailed=detailed)

timm/models/_hub.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.hub import _get_torch_home as get_dir
1717

1818
from timm import __version__
19+
from timm.layers import ClassifierHead, NormMlpClassifierHead
1920
from timm.models._pretrained import filter_pretrained_cfg
2021

2122
try:
@@ -96,7 +97,7 @@ def has_hf_hub(necessary=False):
9697
return _has_hf_hub
9798

9899

99-
def hf_split(hf_id):
100+
def hf_split(hf_id: str):
100101
# FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
101102
rev_split = hf_id.split('@')
102103
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
@@ -127,19 +128,26 @@ def load_model_config_from_hf(model_id: str):
127128
hf_config = {}
128129
hf_config['architecture'] = pretrained_cfg.pop('architecture')
129130
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
130-
if 'labels' in pretrained_cfg:
131-
hf_config['label_name'] = pretrained_cfg.pop('labels')
131+
if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
132+
pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
132133
hf_config['pretrained_cfg'] = pretrained_cfg
133134

134135
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
135136
pretrained_cfg = hf_config['pretrained_cfg']
136137
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
137138
pretrained_cfg['source'] = 'hf-hub'
139+
140+
# model should be created with base config num_classes if its exist
138141
if 'num_classes' in hf_config:
139-
# model should be created with parent num_classes if they exist
140142
pretrained_cfg['num_classes'] = hf_config['num_classes']
141-
model_name = hf_config['architecture']
142143

144+
# label meta-data in base config overrides saved pretrained_cfg on load
145+
if 'label_names' in hf_config:
146+
pretrained_cfg['label_names'] = hf_config.pop('label_names')
147+
if 'label_descriptions' in hf_config:
148+
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
149+
150+
model_name = hf_config['architecture']
143151
return pretrained_cfg, model_name
144152

145153

@@ -150,7 +158,7 @@ def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
150158
return state_dict
151159

152160

153-
def save_config_for_hf(model, config_path, model_config=None):
161+
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
154162
model_config = model_config or {}
155163
hf_config = {}
156164
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
@@ -164,22 +172,22 @@ def save_config_for_hf(model, config_path, model_config=None):
164172

165173
if 'labels' in model_config:
166174
_logger.warning(
167-
"'labels' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
168-
"Using provided 'label' field as 'label_name'.")
169-
model_config['label_name'] = model_config.pop('labels')
175+
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
176+
" Renaming provided 'labels' field to 'label_names'.")
177+
model_config.setdefault('label_names', model_config.pop('labels'))
170178

171-
label_name = model_config.pop('label_name', None)
172-
if label_name:
173-
assert isinstance(label_name, (dict, list, tuple))
179+
label_names = model_config.pop('label_names', None)
180+
if label_names:
181+
assert isinstance(label_names, (dict, list, tuple))
174182
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
175183
# can be a dict id: name if there are id gaps, or tuple/list if no gaps.
176-
hf_config['label_name'] = model_config['label_name']
184+
hf_config['label_names'] = label_names
177185

178-
display_name = model_config.pop('display_name', None)
179-
if display_name:
180-
assert isinstance(display_name, dict)
181-
# map label_name -> user interface display name
182-
hf_config['display_name'] = model_config['display_name']
186+
label_descriptions = model_config.pop('label_descriptions', None)
187+
if label_descriptions:
188+
assert isinstance(label_descriptions, dict)
189+
# maps label names -> descriptions
190+
hf_config['label_descriptions'] = label_descriptions
183191

184192
hf_config['pretrained_cfg'] = pretrained_cfg
185193
hf_config.update(model_config)
@@ -188,7 +196,7 @@ def save_config_for_hf(model, config_path, model_config=None):
188196
json.dump(hf_config, f, indent=2)
189197

190198

191-
def save_for_hf(model, save_directory, model_config=None):
199+
def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None):
192200
assert has_hf_hub(True)
193201
save_directory = Path(save_directory)
194202
save_directory.mkdir(exist_ok=True, parents=True)
@@ -249,7 +257,7 @@ def push_to_hf_hub(
249257
)
250258

251259

252-
def generate_readme(model_card, model_name):
260+
def generate_readme(model_card: dict, model_name: str):
253261
readme_text = "---\n"
254262
readme_text += "tags:\n- image-classification\n- timm\n"
255263
readme_text += "library_tag: timm\n"

timm/models/_pretrained.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ class PretrainedCfg:
3434
mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
3535
std: Tuple[float, ...] = (0.229, 0.224, 0.225)
3636

37-
# head config
37+
# head / classifier config and meta-data
3838
num_classes: int = 1000
3939
label_offset: Optional[int] = None
40+
label_names: Optional[Tuple[str]] = None
41+
label_descriptions: Optional[Dict[str, str]] = None
4042

4143
# model attributes that vary with above or required for pretrained adaptation
4244
pool_size: Optional[Tuple[int, ...]] = None

0 commit comments

Comments
 (0)