Skip to content

Commit a22d860

Browse files
Potabkskyloevil
authored andcommitted
[Bugfix] Fix get_quant_config when using modelscope (vllm-project#24421)
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent e1ca633 commit a22d860

File tree

2 files changed

+48
-36
lines changed

2 files changed

+48
-36
lines changed

vllm/model_executor/model_loader/default_loader.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,19 @@
77
from collections.abc import Generator, Iterable
88
from typing import Optional, cast
99

10-
import huggingface_hub
1110
import torch
1211
from torch import nn
1312
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
1413

15-
from vllm import envs
1614
from vllm.config import LoadConfig, ModelConfig
1715
from vllm.logger import init_logger
1816
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
1917
from vllm.model_executor.model_loader.weight_utils import (
2018
download_safetensors_index_file_from_hf, download_weights_from_hf,
2119
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
22-
filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator,
23-
pt_weights_iterator, safetensors_weights_iterator)
20+
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
21+
np_cache_weights_iterator, pt_weights_iterator,
22+
safetensors_weights_iterator)
2423
from vllm.platforms import current_platform
2524

2625
logger = init_logger(__name__)
@@ -57,35 +56,6 @@ def __init__(self, load_config: LoadConfig):
5756
raise ValueError(f"Model loader extra config is not supported for "
5857
f"load format {load_config.load_format}")
5958

60-
def _maybe_download_from_modelscope(
61-
self, model: str, revision: Optional[str]) -> Optional[str]:
62-
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
63-
64-
Returns the path to the downloaded model, or None if the model is not
65-
downloaded from ModelScope."""
66-
if envs.VLLM_USE_MODELSCOPE:
67-
# download model from ModelScope hub,
68-
# lazy import so that modelscope is not required for normal use.
69-
# pylint: disable=C.
70-
from modelscope.hub.snapshot_download import snapshot_download
71-
72-
# Use file lock to prevent multiple processes from
73-
# downloading the same model weights at the same time.
74-
with get_lock(model, self.load_config.download_dir):
75-
if not os.path.exists(model):
76-
model_path = snapshot_download(
77-
model_id=model,
78-
cache_dir=self.load_config.download_dir,
79-
local_files_only=huggingface_hub.constants.
80-
HF_HUB_OFFLINE,
81-
revision=revision,
82-
ignore_file_pattern=self.load_config.ignore_patterns,
83-
)
84-
else:
85-
model_path = model
86-
return model_path
87-
return None
88-
8959
def _prepare_weights(
9060
self,
9161
model_name_or_path: str,
@@ -96,7 +66,7 @@ def _prepare_weights(
9666
"""Prepare weights for the model.
9767
9868
If the model is not local, it will be downloaded."""
99-
model_name_or_path = (self._maybe_download_from_modelscope(
69+
model_name_or_path = (maybe_download_from_modelscope(
10070
model_name_or_path, revision) or model_name_or_path)
10171

10272
is_local = os.path.isdir(model_name_or_path)

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from safetensors.torch import load_file, safe_open, save_file
2222
from tqdm.auto import tqdm
2323

24+
from vllm import envs
2425
from vllm.config import LoadConfig, ModelConfig
2526
from vllm.distributed import get_tensor_model_parallel_rank
2627
from vllm.logger import init_logger
@@ -95,6 +96,41 @@ def get_lock(model_name_or_path: Union[str, Path],
9596
return lock
9697

9798

99+
def maybe_download_from_modelscope(
100+
model: str,
101+
revision: Optional[str] = None,
102+
download_dir: Optional[str] = None,
103+
ignore_patterns: Optional[Union[str, list[str]]] = None,
104+
allow_patterns: Optional[Union[list[str],
105+
str]] = None) -> Optional[str]:
106+
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
107+
108+
Returns the path to the downloaded model, or None if the model is not
109+
downloaded from ModelScope."""
110+
if envs.VLLM_USE_MODELSCOPE:
111+
# download model from ModelScope hub,
112+
# lazy import so that modelscope is not required for normal use.
113+
# pylint: disable=C.
114+
from modelscope.hub.snapshot_download import snapshot_download
115+
116+
# Use file lock to prevent multiple processes from
117+
# downloading the same model weights at the same time.
118+
with get_lock(model, download_dir):
119+
if not os.path.exists(model):
120+
model_path = snapshot_download(
121+
model_id=model,
122+
cache_dir=download_dir,
123+
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
124+
revision=revision,
125+
ignore_file_pattern=ignore_patterns,
126+
allow_patterns=allow_patterns,
127+
)
128+
else:
129+
model_path = model
130+
return model_path
131+
return None
132+
133+
98134
def _shared_pointers(tensors):
99135
ptrs = defaultdict(list)
100136
for k, v in tensors.items():
@@ -169,7 +205,13 @@ def get_quant_config(model_config: ModelConfig,
169205
# Inflight BNB quantization
170206
if model_config.quantization == "bitsandbytes":
171207
return quant_cls.from_config({})
172-
is_local = os.path.isdir(model_config.model)
208+
model_name_or_path = maybe_download_from_modelscope(
209+
model_config.model,
210+
revision=model_config.revision,
211+
download_dir=load_config.download_dir,
212+
allow_patterns=["*.json"],
213+
) or model_config.model
214+
is_local = os.path.isdir(model_name_or_path)
173215
if not is_local:
174216
# Download the config files.
175217
with get_lock(model_config.model, load_config.download_dir):
@@ -182,7 +224,7 @@ def get_quant_config(model_config: ModelConfig,
182224
tqdm_class=DisabledTqdm,
183225
)
184226
else:
185-
hf_folder = model_config.model
227+
hf_folder = model_name_or_path
186228

187229
possible_config_filenames = quant_cls.get_config_filenames()
188230

0 commit comments

Comments
 (0)