Skip to content

Commit c4b505d

Browse files
ydshiehCyrilvallez
andauthored
Don't convert to safetensors on the fly if the call is from testing (#41194)
* don't convert * disable * Update src/transformers/modeling_utils.py Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> * fix * disable * disable * disable --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
1 parent 01c9e1b commit c4b505d

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

.circleci/create_circleci_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"RUN_PIPELINE_TESTS": False,
3030
# will be adjust in `CircleCIJob.to_dict`.
3131
"RUN_FLAKY": True,
32+
"DISABLE_SAFETENSORS_CONVERSION": True,
3233
}
3334
# Disable the use of {"s": None} as the output is way too long, causing the navigation on CircleCI impractical
3435
COMMON_PYTEST_OPTIONS = {"max-worker-restart": 0, "vvv": None, "rsfE":None}

conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def pytest_configure(config):
9090
config.addinivalue_line("markers", "torch_compile_test: mark test which tests torch compile functionality")
9191
config.addinivalue_line("markers", "torch_export_test: mark test which tests torch export functionality")
9292

93+
os.environ['DISABLE_SAFETENSORS_CONVERSION'] = 'true'
94+
9395

9496
def pytest_collection_modifyitems(items):
9597
for item in items:

src/transformers/modeling_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,12 @@ def _get_resolved_checkpoint_files(
10211021
is_sharded = True
10221022
if not local_files_only and not is_offline_mode():
10231023
if resolved_archive_file is not None:
1024-
if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]:
1024+
# In a CI environment (CircleCI / Github Actions workflow runs) or in a pytest run,
1025+
# we set `DISABLE_SAFETENSORS_CONVERSION=true` to prevent the conversion.
1026+
if (
1027+
filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
1028+
and os.getenv("DISABLE_SAFETENSORS_CONVERSION", None) != "true"
1029+
):
10251030
# If the PyTorch file was found, check if there is a safetensors file on the repository
10261031
# If there is no safetensors file on the repositories, start an auto conversion
10271032
safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME

0 commit comments

Comments
 (0)