2121from safetensors .torch import load_file , safe_open , save_file
2222from tqdm .auto import tqdm
2323
24+ from vllm import envs
2425from vllm .config import LoadConfig , ModelConfig
2526from vllm .distributed import get_tensor_model_parallel_rank
2627from vllm .logger import init_logger
@@ -95,6 +96,41 @@ def get_lock(model_name_or_path: Union[str, Path],
9596 return lock
9697
9798
99+ def maybe_download_from_modelscope (
100+ model : str ,
101+ revision : Optional [str ] = None ,
102+ download_dir : Optional [str ] = None ,
103+ ignore_patterns : Optional [Union [str , list [str ]]] = None ,
104+ allow_patterns : Optional [Union [list [str ],
105+ str ]] = None ) -> Optional [str ]:
106+ """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
107+
108+ Returns the path to the downloaded model, or None if the model is not
109+ downloaded from ModelScope."""
110+ if envs .VLLM_USE_MODELSCOPE :
111+ # download model from ModelScope hub,
112+ # lazy import so that modelscope is not required for normal use.
113+ # pylint: disable=C.
114+ from modelscope .hub .snapshot_download import snapshot_download
115+
116+ # Use file lock to prevent multiple processes from
117+ # downloading the same model weights at the same time.
118+ with get_lock (model , download_dir ):
119+ if not os .path .exists (model ):
120+ model_path = snapshot_download (
121+ model_id = model ,
122+ cache_dir = download_dir ,
123+ local_files_only = huggingface_hub .constants .HF_HUB_OFFLINE ,
124+ revision = revision ,
125+ ignore_file_pattern = ignore_patterns ,
126+ allow_patterns = allow_patterns ,
127+ )
128+ else :
129+ model_path = model
130+ return model_path
131+ return None
132+
133+
98134def _shared_pointers (tensors ):
99135 ptrs = defaultdict (list )
100136 for k , v in tensors .items ():
@@ -169,7 +205,13 @@ def get_quant_config(model_config: ModelConfig,
169205 # Inflight BNB quantization
170206 if model_config .quantization == "bitsandbytes" :
171207 return quant_cls .from_config ({})
172- is_local = os .path .isdir (model_config .model )
208+ model_name_or_path = maybe_download_from_modelscope (
209+ model_config .model ,
210+ revision = model_config .revision ,
211+ download_dir = load_config .download_dir ,
212+ allow_patterns = ["*.json" ],
213+ ) or model_config .model
214+ is_local = os .path .isdir (model_name_or_path )
173215 if not is_local :
174216 # Download the config files.
175217 with get_lock (model_config .model , load_config .download_dir ):
@@ -182,7 +224,7 @@ def get_quant_config(model_config: ModelConfig,
182224 tqdm_class = DisabledTqdm ,
183225 )
184226 else :
185- hf_folder = model_config . model
227+ hf_folder = model_name_or_path
186228
187229 possible_config_filenames = quant_cls .get_config_filenames ()
188230
0 commit comments