Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions src/transformers/models/auto/feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ...configuration_utils import PreTrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...feature_extraction_utils import FeatureExtractionMixin
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, cached_file, logging
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
Expand Down Expand Up @@ -167,27 +167,40 @@ def get_feature_extractor_config(
feature_extractor.save_pretrained("feature-extractor-test")
feature_extractor_config = get_feature_extractor_config("feature-extractor-test")
```"""
resolved_config_file = cached_file(
pretrained_model_name_or_path,
FEATURE_EXTRACTOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if resolved_config_file is None:
resolved_config_files = [
resolved_file
for filename in [FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME]
if (
resolved_file := cached_file(
pretrained_model_name_or_path,
filename=filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
)
is not None
]
if resolved_config_files is None:
logger.info(
"Could not locate the feature extractor configuration file, will try to use the model config instead."
)
return {}

resolved_config_file = resolved_config_files[0]
with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
feature_extractor_dict = json.load(reader)
if "audio_processor" in feature_extractor_dict:
feature_extractor_dict = feature_extractor_dict["audio_processor"]
else:
feature_extractor_dict = feature_extractor_dict.get("feature_extractor", feature_extractor_dict)
return feature_extractor_dict


class AutoFeatureExtractor:
Expand Down
42 changes: 27 additions & 15 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...utils import (
CONFIG_NAME,
IMAGE_PROCESSOR_NAME,
PROCESSOR_NAME,
cached_file,
is_timm_config_dict,
is_timm_local_checkpoint,
Expand Down Expand Up @@ -305,27 +306,38 @@ def get_image_processor_config(
image_processor.save_pretrained("image-processor-test")
image_processor_config = get_image_processor_config("image-processor-test")
```"""
resolved_config_file = cached_file(
pretrained_model_name_or_path,
IMAGE_PROCESSOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if resolved_config_file is None:
resolved_config_files = [
resolved_file
for filename in [IMAGE_PROCESSOR_NAME, PROCESSOR_NAME]
if (
resolved_file := cached_file(
pretrained_model_name_or_path,
filename=filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
)
is not None
]
# An empty list if none of the possible files is found in the repo
if not resolved_config_files:
logger.info(
"Could not locate the image processor configuration file, will try to use the model config instead."
)
return {}

resolved_config_file = resolved_config_files[0]
with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
image_processor_dict = json.load(reader)
image_processor_dict = image_processor_dict.get("image_processor", image_processor_dict)
return image_processor_dict


def _warning_fast_image_processor_available(fast_class):
Expand Down
39 changes: 26 additions & 13 deletions src/transformers/models/auto/video_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# Build the list of all video processors
from ...configuration_utils import PreTrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import CONFIG_NAME, VIDEO_PROCESSOR_NAME, cached_file, is_torchvision_available, logging
from ...utils import CONFIG_NAME, PROCESSOR_NAME, VIDEO_PROCESSOR_NAME, cached_file, is_torchvision_available, logging
from ...utils.import_utils import requires
from ...video_processing_utils import BaseVideoProcessor
from .auto_factory import _LazyAutoMapping
Expand Down Expand Up @@ -167,24 +167,37 @@ def get_video_processor_config(
video_processor.save_pretrained("video-processor-test")
video_processor = get_video_processor_config("video-processor-test")
```"""
resolved_config_file = cached_file(
pretrained_model_name_or_path,
VIDEO_PROCESSOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
)
if resolved_config_file is None:
resolved_config_files = [
resolved_file
for filename in [VIDEO_PROCESSOR_NAME, PROCESSOR_NAME]
if (
resolved_file := cached_file(
pretrained_model_name_or_path,
filename=filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
)
is not None
]
if resolved_config_files is None:
logger.info(
"Could not locate the video processor configuration file, will try to use the model config instead."
)
return {}

resolved_config_file = resolved_config_files[0]
with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
video_processor_dict = json.load(reader)
video_processor_dict = video_processor_dict.get("video_processor", video_processor_dict)
return video_processor_dict


@requires(backends=("vision", "torchvision"))
Expand Down
49 changes: 49 additions & 0 deletions tests/models/auto/test_processor_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,22 @@
AutoFeatureExtractor,
AutoProcessor,
AutoTokenizer,
BaseVideoProcessor,
BertTokenizer,
FeatureExtractionMixin,
ImageProcessingMixin,
LlamaTokenizer,
LlavaOnevisionVideoProcessor,
LlavaProcessor,
ProcessorMixin,
SiglipImageProcessor,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
)
from transformers.models.auto.feature_extraction_auto import get_feature_extractor_config
from transformers.models.auto.image_processing_auto import get_image_processor_config
from transformers.models.auto.video_processing_auto import get_video_processor_config
from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test
from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE
from transformers.utils import (
Expand Down Expand Up @@ -107,6 +114,48 @@ def test_processor_from_local_directory_from_extractor_config(self):

self.assertIsInstance(processor, Wav2Vec2Processor)

def test_subcomponent_get_config_dict__saved_as_nested_config(self):
"""
Tests that we can get config dict of a subcomponents of a processor,
even if they were saved as nested dict in `processor_config.json`
"""
# Test feature extractor first
with tempfile.TemporaryDirectory() as tmpdirname:
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
processor.save_pretrained(tmpdirname)

config_dict_1 = get_feature_extractor_config(tmpdirname)
feature_extractor_1 = Wav2Vec2FeatureExtractor(**config_dict_1)
self.assertIsInstance(feature_extractor_1, Wav2Vec2FeatureExtractor)

config_dict_2, _ = FeatureExtractionMixin.get_feature_extractor_dict(tmpdirname)
feature_extractor_2 = Wav2Vec2FeatureExtractor(**config_dict_2)
self.assertIsInstance(feature_extractor_2, Wav2Vec2FeatureExtractor)
self.assertEqual(config_dict_1, config_dict_2)

# Test image and video processors next
with tempfile.TemporaryDirectory() as tmpdirname:
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
processor.save_pretrained(tmpdirname)

config_dict_1 = get_image_processor_config(tmpdirname)
image_processor_1 = SiglipImageProcessor(**config_dict_1)
self.assertIsInstance(image_processor_1, SiglipImageProcessor)

config_dict_2, _ = ImageProcessingMixin.get_image_processor_dict(tmpdirname)
image_processor_2 = SiglipImageProcessor(**config_dict_2)
self.assertIsInstance(image_processor_2, SiglipImageProcessor)
self.assertEqual(config_dict_1, config_dict_2)

config_dict_1 = get_video_processor_config(tmpdirname)
video_processor_1 = LlavaOnevisionVideoProcessor(**config_dict_1)
self.assertIsInstance(video_processor_1, LlavaOnevisionVideoProcessor)

config_dict_2, _ = BaseVideoProcessor.get_video_processor_dict(tmpdirname)
video_processor_2 = LlavaOnevisionVideoProcessor(**config_dict_2)
self.assertIsInstance(video_processor_2, LlavaOnevisionVideoProcessor)
self.assertEqual(config_dict_1, config_dict_2)

def test_processor_from_processor_class(self):
with tempfile.TemporaryDirectory() as tmpdirname:
feature_extractor = Wav2Vec2FeatureExtractor()
Expand Down
Loading