diff --git a/examples/load_sllm_state.py b/examples/load_sllm_state.py new file mode 100644 index 000000000000..e3dbdb824d08 --- /dev/null +++ b/examples/load_sllm_state.py @@ -0,0 +1,56 @@ +""" +Saves each worker's model state dict directly to a checkpoint, which enables a +fast load path for large tensor-parallel models where each worker only needs to +read its own shard rather than the entire checkpoint. + +Example usage: + +python save_sharded_state.py \ + --model /path/to/load \ + --quantization deepspeedfp \ + --tensor-parallel-size 8 \ + --output /path/to/save + +Then, the model can be loaded with + +llm = LLM( + model="/path/to/save", + load_format="sharded_state", + quantization="deepspeedfp", + tensor_parallel_size=8, +) +""" +import argparse +import dataclasses +import os +import shutil +from pathlib import Path + +from vllm import LLM, EngineArgs + +parser = argparse.ArgumentParser() +EngineArgs.add_cli_args(parser) +parser.add_argument("--output", + "-o", + required=True, + type=str, + help="path to output checkpoint") + +if __name__ == "__main__": + args = parser.parse_args() + # main(args) + + llm = LLM( + model=args.output, + load_format="serverless_llm", + # load_format="sharded_state", + gpu_memory_utilization=0.9, + distributed_executor_backend="mp", + max_model_len = 512, + tensor_parallel_size=args.tensor_parallel_size, + # num_gpu_blocks_override=128, + ) + + input_text = "Explain thread and process in python." + + print(llm.generate(input_text)) diff --git a/examples/save_load_sllm_state.sh b/examples/save_load_sllm_state.sh new file mode 100644 index 000000000000..a33520ef03ff --- /dev/null +++ b/examples/save_load_sllm_state.sh @@ -0,0 +1,9 @@ +CUDA_VISIBLE_DEVICES=0,1 python save_sllm_state.py \ + --model /mnt/raid0sata1/huggingface/hub/models--facebook--opt-125m/snapshots/27dcfa74d334bc871f3234de431e71c6eeba5dd6 \ + --tensor-parallel-size 4 \ + --output /mnt/raid0nvme1/xly/test_data/vllm/opt-125m + +CUDA_VISIBLE_DEVICES=0,1 python load_sllm_state.py \ + --model /home/fuji/.cache/huggingface/hub/models--facebook--opt-1.3b/snapshots/3f5c25d0bc631cb57ac65913f76e22c2dfb61d62 \ + --tensor-parallel-size 2 \ + --output /home/fuji/sllm_models/opt-1.3b \ No newline at end of file diff --git a/examples/save_sllm_state.py b/examples/save_sllm_state.py new file mode 100644 index 000000000000..4070cb7d7fe7 --- /dev/null +++ b/examples/save_sllm_state.py @@ -0,0 +1,92 @@ +""" +Saves each worker's model state dict directly to a checkpoint, which enables a +fast load path for large tensor-parallel models where each worker only needs to +read its own shard rather than the entire checkpoint. + +Example usage: + +python save_sharded_state.py \ + --model /path/to/load \ + --quantization deepspeedfp \ + --tensor-parallel-size 8 \ + --output /path/to/save + +Then, the model can be loaded with + +llm = LLM( + model="/path/to/save", + load_format="sharded_state", + quantization="deepspeedfp", + tensor_parallel_size=8, +) +""" +import argparse +import dataclasses +import os +import shutil +from pathlib import Path + +from vllm import LLM, EngineArgs + +parser = argparse.ArgumentParser() +EngineArgs.add_cli_args(parser) +parser.add_argument("--output", + "-o", + required=True, + type=str, + help="path to output checkpoint") +parser.add_argument("--file-pattern", + type=str, + help="string pattern of saved filenames") +parser.add_argument("--max-file-size", + type=str, + default=5 * 1024**3, + help="max size (in bytes) of each safetensors file") + + +def main(args): + engine_args = EngineArgs.from_cli_args(args) + engine_args.distributed_executor_backend = "mp" + engine_args.gpu_memory_utilization = 0.4 + engine_args.max_seq_len_to_capture = 512 + engine_args.max_model_len = 512 + engine_args.max_num_seqs = 1 + engine_args.num_gpu_blocks_override = 128 + if engine_args.enable_lora: + raise ValueError("Saving with enable_lora=True is not supported!") + model_path = engine_args.model + if not Path(model_path).is_dir(): + raise ValueError("model path must be a local directory") + # Create LLM instance from arguments + print(dataclasses.asdict(engine_args)) + llm = LLM(**dataclasses.asdict(engine_args)) + # Prepare output directory + Path(args.output).mkdir(exist_ok=True) + # Dump worker states to output directory + model_executor = llm.llm_engine.model_executor + model_executor.save_serverless_llm_state(path=args.output, + pattern=args.file_pattern, + max_size=args.max_file_size) + # Copy metadata files to output directory + for file in os.listdir(model_path): + if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): + if os.path.isdir(os.path.join(model_path, file)): + shutil.copytree(os.path.join(model_path, file), + os.path.join(args.output, file)) + else: + shutil.copy(os.path.join(model_path, file), args.output) + +from vllm.distributed import get_tensor_model_parallel_rank +if __name__ == "__main__": + args = parser.parse_args() + main(args) + + # llm = LLM( + # model=args.output, + # load_format="serverless_llm", + # tensor_parallel_size=2, + # ) + + # input_text = "Hello, world!" + + # print(llm.generate(input_text)) diff --git a/vllm/config.py b/vllm/config.py index d9e4a619ee01..9d257858b6f5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -498,6 +498,7 @@ class LoadFormat(str, enum.Enum): TENSORIZER = "tensorizer" SHARDED_STATE = "sharded_state" BITSANDBYTES = "bitsandbytes" + SERVERLESS_LLM = "serverless_llm" @dataclass diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index f7c608af1ad3..e9c09c6f6e1f 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -113,6 +113,17 @@ def save_sharded_state( path=path, pattern=pattern, max_size=max_size) + + def save_serverless_llm_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self._run_workers("save_serverless_llm_state", + path=path, + pattern=pattern, + max_size=max_size) @abstractmethod def _driver_execute_model( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3ad201f4757e..d71da2bac696 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -106,6 +106,16 @@ def check_health(self) -> None: # GPUExecutor will always be healthy as long as # it's running. return + + def save_serverless_llm_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.driver_worker.save_serverless_llm_state( + path=path, pattern=pattern, max_size=max_size + ) class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 06de2fcc1cc7..eb50af261b73 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -14,6 +14,7 @@ import torch from huggingface_hub import HfApi, hf_hub_download from torch import nn +import gc from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, ParallelConfig, @@ -418,7 +419,6 @@ def save_model( tensorizer_config=tensorizer_config, ) - class ShardedStateLoader(BaseModelLoader): """ Model loader that directly loads each worker's model state dict, which @@ -577,6 +577,145 @@ def save_model( ) +class ServerlessLLMLoader(BaseModelLoader): + # DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = ({} if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy()) + # self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}") + + @staticmethod + def _filter_subtensors( + tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + same_storage_groups = collections.defaultdict(list) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + from serverless_llm_store.torch import load_dict + from vllm.distributed import get_tensor_model_parallel_rank + + assert os.path.isdir(model_config.model) + + rank = get_tensor_model_parallel_rank() + + local_model_path = model_config.model + local_model_path = os.path.join(local_model_path, f"rank_{rank}") + + def remove_prefix(path, prefix): + # Normalize the paths to ensure consistency across different platforms + path = os.path.normpath(path) + prefix = os.path.normpath(prefix) + + # Check if the path starts with the prefix + if path.startswith(prefix): + # Return the path without the prefix + return path[len(prefix):].lstrip(os.sep) + + # Return the original path if the prefix doesn't exist + return path + + # vLLM needs a local model path to read model config but + # ServerlessLLM Store requires a global model path as the model ID + storage_path = os.getenv("STORAGE_PATH", "./models") + model_path = remove_prefix(local_model_path, storage_path) + + with set_default_torch_dtype(model_config.dtype): + # with torch.device(device_config.device): + with torch.device("cpu"): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config, + cache_config) + model = model.eval() + # set all parameters to meta device + state_dict = self._filter_subtensors(model.state_dict()) + key_list = list(state_dict.keys()) + + for key, param in model.named_parameters(recurse=True): + if key in key_list: + param.data = torch.empty(1, device="cuda") + gc.collect() + + device_id = torch.cuda.current_device() + device_map = {"": device_id} + # Note: storage path is already included in the local model path + sllm_state_dict = load_dict(model_path, device_map) + + for key, param in model.named_parameters(recurse=True): + if key in key_list: + tensor = sllm_state_dict[key] + param.data = tensor + state_dict.pop(key) + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") + + return model + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.distributed import get_tensor_model_parallel_rank + from serverless_llm_store.torch import save_dict + + rank = get_tensor_model_parallel_rank() + state_dict = ServerlessLLMLoader._filter_subtensors(model.state_dict()) + + # move all tensors to CPU + for key, tensor in state_dict.items(): + state_dict[key] = tensor.cpu().contiguous() + + save_path = os.path.join(path, f"rank_{rank}") + if not os.path.exists(save_path): + os.makedirs(save_path) + + save_dict(state_dict, save_path) + + class BitsAndBytesModelLoader(BaseModelLoader): """Model loader to load model weights with BitAndBytes quantization.""" @@ -826,6 +965,9 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.SHARDED_STATE: return ShardedStateLoader(load_config) + + if load_config.load_format == LoadFormat.SERVERLESS_LLM: + return ServerlessLLMLoader(load_config) if load_config.load_format == LoadFormat.BITSANDBYTES: return BitsAndBytesModelLoader(load_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 476e9ba3bb46..02bbceca44f6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -222,6 +222,20 @@ def save_sharded_state( pattern=pattern, max_size=max_size, ) + + def save_serverless_llm_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.model_executor.model_loader.loader import ServerlessLLMLoader + ServerlessLLMLoader.save_model( + self.model, + path, + pattern=pattern, + max_size=max_size, + ) def save_tensorized_model( self, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7a378a862d0c..b4c73f432f84 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -132,6 +132,18 @@ def save_sharded_state( pattern=pattern, max_size=max_size, ) + + def save_serverless_llm_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.model_runner.save_serverless_llm_state( + path, + pattern=pattern, + max_size=max_size, + ) def save_tensorized_model( self,