Skip to content

Commit 5150dac

Browse files
zucchini-nlpmolbap
andauthored
Fix helper fn for new processor config format (#42085)
* fix the helper fn for new processor config format * change the priority order * maybe we need to explicitly load and then decide * Apply suggestions from code review Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * use helper fn for json decoding --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
1 parent 27c3807 commit 5150dac

File tree

9 files changed

+307
-104
lines changed

9 files changed

+307
-104
lines changed

src/transformers/feature_extraction_utils.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
is_torch_dtype,
4040
logging,
4141
requires_backends,
42+
safe_load_json_file,
4243
)
4344
from .utils.hub import cached_file
4445

@@ -427,35 +428,42 @@ def get_feature_extractor_dict(
427428
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
428429
if os.path.isfile(pretrained_model_name_or_path):
429430
resolved_feature_extractor_file = pretrained_model_name_or_path
431+
resolved_processor_file = None
430432
is_local = True
431433
elif is_remote_url(pretrained_model_name_or_path):
432434
feature_extractor_file = pretrained_model_name_or_path
435+
resolved_processor_file = None
433436
resolved_feature_extractor_file = download_url(pretrained_model_name_or_path)
434437
else:
435438
feature_extractor_file = FEATURE_EXTRACTOR_NAME
436439
try:
437440
# Load from local folder or from cache or download from model Hub and cache
438-
resolved_feature_extractor_files = [
439-
resolved_file
440-
for filename in [feature_extractor_file, PROCESSOR_NAME]
441-
if (
442-
resolved_file := cached_file(
443-
pretrained_model_name_or_path,
444-
filename=filename,
445-
cache_dir=cache_dir,
446-
force_download=force_download,
447-
proxies=proxies,
448-
local_files_only=local_files_only,
449-
subfolder=subfolder,
450-
token=token,
451-
user_agent=user_agent,
452-
revision=revision,
453-
_raise_exceptions_for_missing_entries=False,
454-
)
455-
)
456-
is not None
457-
]
458-
resolved_feature_extractor_file = resolved_feature_extractor_files[0]
441+
resolved_processor_file = cached_file(
442+
pretrained_model_name_or_path,
443+
filename=PROCESSOR_NAME,
444+
cache_dir=cache_dir,
445+
force_download=force_download,
446+
proxies=proxies,
447+
local_files_only=local_files_only,
448+
token=token,
449+
user_agent=user_agent,
450+
revision=revision,
451+
subfolder=subfolder,
452+
_raise_exceptions_for_missing_entries=False,
453+
)
454+
resolved_feature_extractor_file = cached_file(
455+
pretrained_model_name_or_path,
456+
filename=feature_extractor_file,
457+
cache_dir=cache_dir,
458+
force_download=force_download,
459+
proxies=proxies,
460+
local_files_only=local_files_only,
461+
token=token,
462+
user_agent=user_agent,
463+
revision=revision,
464+
subfolder=subfolder,
465+
_raise_exceptions_for_missing_entries=False,
466+
)
459467
except OSError:
460468
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
461469
# the original exception.
@@ -469,19 +477,24 @@ def get_feature_extractor_dict(
469477
f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
470478
)
471479

472-
try:
473-
# Load feature_extractor dict
474-
with open(resolved_feature_extractor_file, encoding="utf-8") as reader:
475-
text = reader.read()
476-
feature_extractor_dict = json.loads(text)
477-
if "audio_processor" in feature_extractor_dict:
478-
feature_extractor_dict = feature_extractor_dict["audio_processor"]
479-
else:
480-
feature_extractor_dict = feature_extractor_dict.get("feature_extractor", feature_extractor_dict)
480+
# Load feature_extractor dict. Priority goes as (nested config if found -> image processor config)
481+
# We are downloading both configs because almost all models have a `processor_config.json` but
482+
# not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
483+
feature_extractor_dict = None
484+
if resolved_processor_file is not None:
485+
processor_dict = safe_load_json_file(resolved_processor_file)
486+
if "feature_extractor" in processor_dict or "audio_processor" in processor_dict:
487+
feature_extractor_dict = processor_dict.get("feature_extractor", processor_dict.get("audio_processor"))
488+
489+
if resolved_feature_extractor_file is not None and feature_extractor_dict is None:
490+
feature_extractor_dict = safe_load_json_file(resolved_feature_extractor_file)
481491

482-
except json.JSONDecodeError:
492+
if feature_extractor_dict is None:
483493
raise OSError(
484-
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
494+
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
495+
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
496+
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
497+
f" directory containing a {feature_extractor_file} file"
485498
)
486499

487500
if is_local:

src/transformers/image_processing_base.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
is_offline_mode,
3333
is_remote_url,
3434
logging,
35+
safe_load_json_file,
3536
)
3637
from .utils.hub import cached_file
3738

@@ -280,35 +281,41 @@ def get_image_processor_dict(
280281
image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
281282
if os.path.isfile(pretrained_model_name_or_path):
282283
resolved_image_processor_file = pretrained_model_name_or_path
284+
resolved_processor_file = None
283285
is_local = True
284286
elif is_remote_url(pretrained_model_name_or_path):
285287
image_processor_file = pretrained_model_name_or_path
288+
resolved_processor_file = None
286289
resolved_image_processor_file = download_url(pretrained_model_name_or_path)
287290
else:
288291
image_processor_file = image_processor_filename
289292
try:
290-
# Load from local folder or from cache or download from model Hub and cache
291-
resolved_image_processor_files = [
292-
resolved_file
293-
for filename in [image_processor_file, PROCESSOR_NAME]
294-
if (
295-
resolved_file := cached_file(
296-
pretrained_model_name_or_path,
297-
filename=filename,
298-
cache_dir=cache_dir,
299-
force_download=force_download,
300-
proxies=proxies,
301-
local_files_only=local_files_only,
302-
token=token,
303-
user_agent=user_agent,
304-
revision=revision,
305-
subfolder=subfolder,
306-
_raise_exceptions_for_missing_entries=False,
307-
)
308-
)
309-
is not None
310-
]
311-
resolved_image_processor_file = resolved_image_processor_files[0]
293+
resolved_processor_file = cached_file(
294+
pretrained_model_name_or_path,
295+
filename=PROCESSOR_NAME,
296+
cache_dir=cache_dir,
297+
force_download=force_download,
298+
proxies=proxies,
299+
local_files_only=local_files_only,
300+
token=token,
301+
user_agent=user_agent,
302+
revision=revision,
303+
subfolder=subfolder,
304+
_raise_exceptions_for_missing_entries=False,
305+
)
306+
resolved_image_processor_file = cached_file(
307+
pretrained_model_name_or_path,
308+
filename=image_processor_file,
309+
cache_dir=cache_dir,
310+
force_download=force_download,
311+
proxies=proxies,
312+
local_files_only=local_files_only,
313+
token=token,
314+
user_agent=user_agent,
315+
revision=revision,
316+
subfolder=subfolder,
317+
_raise_exceptions_for_missing_entries=False,
318+
)
312319
except OSError:
313320
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
314321
# the original exception.
@@ -322,16 +329,24 @@ def get_image_processor_dict(
322329
f" directory containing a {image_processor_filename} file"
323330
)
324331

325-
try:
326-
# Load image_processor dict
327-
with open(resolved_image_processor_file, encoding="utf-8") as reader:
328-
text = reader.read()
329-
image_processor_dict = json.loads(text)
330-
image_processor_dict = image_processor_dict.get("image_processor", image_processor_dict)
332+
# Load image_processor dict. Priority goes as (nested config if found -> image processor config)
333+
# We are downloading both configs because almost all models have a `processor_config.json` but
334+
# not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
335+
image_processor_dict = None
336+
if resolved_processor_file is not None:
337+
processor_dict = safe_load_json_file(resolved_processor_file)
338+
if "image_processor" in processor_dict:
339+
image_processor_dict = processor_dict["image_processor"]
340+
341+
if resolved_image_processor_file is not None and image_processor_dict is None:
342+
image_processor_dict = safe_load_json_file(resolved_image_processor_file)
331343

332-
except json.JSONDecodeError:
344+
if image_processor_dict is None:
333345
raise OSError(
334-
f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
346+
f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
347+
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
348+
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
349+
f" directory containing a {image_processor_filename} file"
335350
)
336351

337352
if is_local:

src/transformers/models/auto/feature_extraction_auto.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""AutoFeatureExtractor class."""
1616

1717
import importlib
18-
import json
1918
import os
2019
from collections import OrderedDict
2120
from typing import Optional, Union
@@ -24,7 +23,7 @@
2423
from ...configuration_utils import PreTrainedConfig
2524
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
2625
from ...feature_extraction_utils import FeatureExtractionMixin
27-
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
26+
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, cached_file, logging, safe_load_json_file
2827
from .auto_factory import _LazyAutoMapping
2928
from .configuration_auto import (
3029
CONFIG_MAPPING_NAMES,
@@ -175,9 +174,10 @@ def get_feature_extractor_config(
175174
feature_extractor.save_pretrained("feature-extractor-test")
176175
feature_extractor_config = get_feature_extractor_config("feature-extractor-test")
177176
```"""
178-
resolved_config_file = cached_file(
177+
# Load with a priority given to the nested processor config, if available in repo
178+
resolved_processor_file = cached_file(
179179
pretrained_model_name_or_path,
180-
FEATURE_EXTRACTOR_NAME,
180+
filename=PROCESSOR_NAME,
181181
cache_dir=cache_dir,
182182
force_download=force_download,
183183
proxies=proxies,
@@ -186,16 +186,37 @@ def get_feature_extractor_config(
186186
local_files_only=local_files_only,
187187
_raise_exceptions_for_gated_repo=False,
188188
_raise_exceptions_for_missing_entries=False,
189-
_raise_exceptions_for_connection_errors=False,
190189
)
191-
if resolved_config_file is None:
192-
logger.info(
193-
"Could not locate the feature extractor configuration file, will try to use the model config instead."
194-
)
190+
resolved_feature_extractor_file = cached_file(
191+
pretrained_model_name_or_path,
192+
filename=FEATURE_EXTRACTOR_NAME,
193+
cache_dir=cache_dir,
194+
force_download=force_download,
195+
proxies=proxies,
196+
token=token,
197+
revision=revision,
198+
local_files_only=local_files_only,
199+
_raise_exceptions_for_gated_repo=False,
200+
_raise_exceptions_for_missing_entries=False,
201+
)
202+
203+
# An empty list if none of the possible files is found in the repo
204+
if not resolved_feature_extractor_file and not resolved_processor_file:
205+
logger.info("Could not locate the feature extractor configuration file.")
195206
return {}
196207

197-
with open(resolved_config_file, encoding="utf-8") as reader:
198-
return json.load(reader)
208+
# Load feature_extractor dict. Priority goes as (nested config if found -> feature extractor config)
209+
# We are downloading both configs because almost all models have a `processor_config.json` but
210+
# not all of these are nested. We need to check if it was saved recently as nested or if it is legacy style
211+
feature_extractor_dict = {}
212+
if resolved_processor_file is not None:
213+
processor_dict = safe_load_json_file(resolved_processor_file)
214+
if "feature_extractor" in processor_dict:
215+
feature_extractor_dict = processor_dict["feature_extractor"]
216+
217+
if resolved_feature_extractor_file is not None and feature_extractor_dict is None:
218+
feature_extractor_dict = safe_load_json_file(resolved_feature_extractor_file)
219+
return feature_extractor_dict
199220

200221

201222
class AutoFeatureExtractor:

src/transformers/models/auto/image_processing_auto.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""AutoImageProcessor class."""
1616

1717
import importlib
18-
import json
1918
import os
2019
import warnings
2120
from collections import OrderedDict
@@ -29,12 +28,14 @@
2928
from ...utils import (
3029
CONFIG_NAME,
3130
IMAGE_PROCESSOR_NAME,
31+
PROCESSOR_NAME,
3232
cached_file,
3333
is_timm_config_dict,
3434
is_timm_local_checkpoint,
3535
is_torchvision_available,
3636
is_vision_available,
3737
logging,
38+
safe_load_json_file,
3839
)
3940
from ...utils.import_utils import requires
4041
from .auto_factory import _LazyAutoMapping
@@ -319,9 +320,10 @@ def get_image_processor_config(
319320
image_processor.save_pretrained("image-processor-test")
320321
image_processor_config = get_image_processor_config("image-processor-test")
321322
```"""
322-
resolved_config_file = cached_file(
323+
# Load with a priority given to the nested processor config, if available in repo
324+
resolved_processor_file = cached_file(
323325
pretrained_model_name_or_path,
324-
IMAGE_PROCESSOR_NAME,
326+
filename=PROCESSOR_NAME,
325327
cache_dir=cache_dir,
326328
force_download=force_download,
327329
proxies=proxies,
@@ -330,16 +332,38 @@ def get_image_processor_config(
330332
local_files_only=local_files_only,
331333
_raise_exceptions_for_gated_repo=False,
332334
_raise_exceptions_for_missing_entries=False,
333-
_raise_exceptions_for_connection_errors=False,
334335
)
335-
if resolved_config_file is None:
336-
logger.info(
337-
"Could not locate the image processor configuration file, will try to use the model config instead."
338-
)
336+
resolved_image_processor_file = cached_file(
337+
pretrained_model_name_or_path,
338+
filename=IMAGE_PROCESSOR_NAME,
339+
cache_dir=cache_dir,
340+
force_download=force_download,
341+
proxies=proxies,
342+
token=token,
343+
revision=revision,
344+
local_files_only=local_files_only,
345+
_raise_exceptions_for_gated_repo=False,
346+
_raise_exceptions_for_missing_entries=False,
347+
)
348+
349+
# An empty list if none of the possible files is found in the repo
350+
if not resolved_image_processor_file and not resolved_processor_file:
351+
logger.info("Could not locate the image processor configuration file.")
339352
return {}
340353

341-
with open(resolved_config_file, encoding="utf-8") as reader:
342-
return json.load(reader)
354+
# Load image_processor dict. Priority goes as (nested config if found -> image processor config)
355+
# We are downloading both configs because almost all models have a `processor_config.json` but
356+
# not all of these are nested. We need to check if it was saved recently as nested or if it is legacy style
357+
image_processor_dict = {}
358+
if resolved_processor_file is not None:
359+
processor_dict = safe_load_json_file(resolved_processor_file)
360+
if "image_processor" in processor_dict:
361+
image_processor_dict = processor_dict["image_processor"]
362+
363+
if resolved_image_processor_file is not None and image_processor_dict is None:
364+
image_processor_dict = safe_load_json_file(resolved_image_processor_file)
365+
366+
return image_processor_dict
343367

344368

345369
def _warning_fast_image_processor_available(fast_class):

0 commit comments

Comments
 (0)