1616 from torch .hub import _get_torch_home as get_dir
1717
1818from timm import __version__
19+ from timm .layers import ClassifierHead , NormMlpClassifierHead
1920from timm .models ._pretrained import filter_pretrained_cfg
2021
2122try :
@@ -96,7 +97,7 @@ def has_hf_hub(necessary=False):
9697 return _has_hf_hub
9798
9899
99- def hf_split (hf_id ):
100+ def hf_split (hf_id : str ):
100101 # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
101102 rev_split = hf_id .split ('@' )
102103 assert 0 < len (rev_split ) <= 2 , 'hf_hub id should only contain one @ character to identify revision.'
@@ -127,19 +128,26 @@ def load_model_config_from_hf(model_id: str):
127128 hf_config = {}
128129 hf_config ['architecture' ] = pretrained_cfg .pop ('architecture' )
129130 hf_config ['num_features' ] = pretrained_cfg .pop ('num_features' , None )
130- if 'labels' in pretrained_cfg :
131- hf_config [ 'label_name ' ] = pretrained_cfg .pop ('labels' )
131+ if 'labels' in pretrained_cfg : # deprecated name for 'label_names'
132+ pretrained_cfg [ 'label_names ' ] = pretrained_cfg .pop ('labels' )
132133 hf_config ['pretrained_cfg' ] = pretrained_cfg
133134
134135 # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
135136 pretrained_cfg = hf_config ['pretrained_cfg' ]
136137 pretrained_cfg ['hf_hub_id' ] = model_id # insert hf_hub id for pretrained weight load during model creation
137138 pretrained_cfg ['source' ] = 'hf-hub'
139+
140+ # model should be created with base config num_classes if its exist
138141 if 'num_classes' in hf_config :
139- # model should be created with parent num_classes if they exist
140142 pretrained_cfg ['num_classes' ] = hf_config ['num_classes' ]
141- model_name = hf_config ['architecture' ]
142143
144+ # label meta-data in base config overrides saved pretrained_cfg on load
145+ if 'label_names' in hf_config :
146+ pretrained_cfg ['label_names' ] = hf_config .pop ('label_names' )
147+ if 'label_descriptions' in hf_config :
148+ pretrained_cfg ['label_descriptions' ] = hf_config .pop ('label_descriptions' )
149+
150+ model_name = hf_config ['architecture' ]
143151 return pretrained_cfg , model_name
144152
145153
@@ -150,7 +158,7 @@ def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
150158 return state_dict
151159
152160
153- def save_config_for_hf (model , config_path , model_config = None ):
161+ def save_config_for_hf (model , config_path : str , model_config : Optional [ dict ] = None ):
154162 model_config = model_config or {}
155163 hf_config = {}
156164 pretrained_cfg = filter_pretrained_cfg (model .pretrained_cfg , remove_source = True , remove_null = True )
@@ -164,22 +172,22 @@ def save_config_for_hf(model, config_path, model_config=None):
164172
165173 if 'labels' in model_config :
166174 _logger .warning (
167- "'labels' as a config field for timm models is deprecated. Please use 'label_name ' and 'display_name'. "
168- "Using provided 'label ' field as 'label_name '." )
169- model_config [ 'label_name' ] = model_config .pop ('labels' )
175+ "'labels' as a config field for is deprecated. Please use 'label_names ' and 'label_descriptions'. "
176+ " Renaming provided 'labels ' field to 'label_names '." )
177+ model_config . setdefault ( 'label_names' , model_config .pop ('labels' ) )
170178
171- label_name = model_config .pop ('label_name ' , None )
172- if label_name :
173- assert isinstance (label_name , (dict , list , tuple ))
179+ label_names = model_config .pop ('label_names ' , None )
180+ if label_names :
181+ assert isinstance (label_names , (dict , list , tuple ))
174182 # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
175183 # can be a dict id: name if there are id gaps, or tuple/list if no gaps.
176- hf_config ['label_name ' ] = model_config [ 'label_name' ]
184+ hf_config ['label_names ' ] = label_names
177185
178- display_name = model_config .pop ('display_name ' , None )
179- if display_name :
180- assert isinstance (display_name , dict )
181- # map label_name -> user interface display name
182- hf_config ['display_name ' ] = model_config [ 'display_name' ]
186+ label_descriptions = model_config .pop ('label_descriptions ' , None )
187+ if label_descriptions :
188+ assert isinstance (label_descriptions , dict )
189+ # maps label names -> descriptions
190+ hf_config ['label_descriptions ' ] = label_descriptions
183191
184192 hf_config ['pretrained_cfg' ] = pretrained_cfg
185193 hf_config .update (model_config )
@@ -188,7 +196,7 @@ def save_config_for_hf(model, config_path, model_config=None):
188196 json .dump (hf_config , f , indent = 2 )
189197
190198
191- def save_for_hf (model , save_directory , model_config = None ):
199+ def save_for_hf (model , save_directory : str , model_config : Optional [ dict ] = None ):
192200 assert has_hf_hub (True )
193201 save_directory = Path (save_directory )
194202 save_directory .mkdir (exist_ok = True , parents = True )
@@ -249,7 +257,7 @@ def push_to_hf_hub(
249257 )
250258
251259
252- def generate_readme (model_card , model_name ):
260+ def generate_readme (model_card : dict , model_name : str ):
253261 readme_text = "---\n "
254262 readme_text += "tags:\n - image-classification\n - timm\n "
255263 readme_text += "library_tag: timm\n "
0 commit comments