@@ -158,7 +158,10 @@ class VllmModelDownloader:
158158 from vllm import LLM
159159 from vllm.config import LoadFormat
160160
161- def _run_writer (input_dir , output_dir ):
161+ # set the model storage path
162+ storage_path = os.getenv(" STORAGE_PATH" , " ./models" )
163+
164+ def _run_writer (input_dir , model_name ):
162165 # load models from the input directory
163166 llm_writer = LLM(
164167 model = input_dir,
@@ -169,10 +172,11 @@ class VllmModelDownloader:
169172 enforce_eager = True ,
170173 max_model_len = 1 ,
171174 )
175+ model_path = os.path.join(storage_path, model_name)
172176 model_executer = llm_writer.llm_engine.model_executor
173177 # save the models in the ServerlessLLM format
174178 model_executer.save_serverless_llm_state(
175- path = output_dir , pattern = pattern, max_size = max_size
179+ path = model_path , pattern = pattern, max_size = max_size
176180 )
177181 for file in os.listdir(input_dir):
178182 # Copy the metadata files into the output directory
@@ -182,48 +186,37 @@ class VllmModelDownloader:
182186 " .safetensors" ,
183187 ):
184188 src_path = os.path.join(input_dir, file )
185- dest_path = os.path.join(output_dir , file )
189+ dest_path = os.path.join(model_path , file )
186190 if os.path.isdir(src_path):
187191 shutil.copytree(src_path, dest_path)
188192 else :
189- shutil.copy(src_path, output_dir )
193+ shutil.copy(src_path, dest_path )
190194 del model_executer
191195 del llm_writer
192196 gc.collect()
193197 if torch.cuda.is_available():
194198 torch.cuda.empty_cache()
195199 torch.cuda.synchronize()
196200
197- # set the model storage path
198- storage_path = os.getenv(" STORAGE_PATH" , " ./models" )
199- model_dir = os.path.join(storage_path, model_name)
200-
201- # create the output directory
202- if os.path.exists(model_dir):
203- print (f " Already exists: { model_dir} " )
204- return
205- os.makedirs(model_dir, exist_ok = True )
206-
207201 try :
208202 with TemporaryDirectory() as cache_dir:
209- # download model from huggingface
203+ # download from huggingface
210204 input_dir = snapshot_download(
211205 model_name,
212206 cache_dir = cache_dir,
213207 allow_patterns = [" *.safetensors" , " *.bin" , " *.json" , " *.txt" ],
214208 )
215- _run_writer(input_dir, model_dir )
209+ _run_writer(input_dir, model_name )
216210 except Exception as e:
217211 print (f " An error occurred while saving the model: { e} " )
218212 # remove the output dir
219- shutil.rmtree(model_dir )
213+ shutil.rmtree(os.path.join(storage_path, model_name) )
220214 raise RuntimeError (
221- f " Failed to save model { model_name} for vllm backend: { e} "
215+ f " Failed to save { model_name} for vllm backend: { e} "
222216 )
223217
224218downloader = VllmModelDownloader()
225219downloader.download_vllm_model(" facebook/opt-1.3b" , " float16" , 1 )
226-
227220```
228221
229222After downloading the model, you can launch the checkpoint store server and load the model in vLLM through ` serverless_llm ` load format.
0 commit comments