diff --git a/src/backend/main.py b/src/backend/main.py index 899df926..43dc1912 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -62,6 +62,27 @@ async def scheduler_init(raw_request: Request): status_code=200, ) +@app.post("/weight/refit") +async def weight_refit(raw_request: Request): + request_data = await raw_request.json() + status = scheduler_manage.weight_refit(request_data) + if status: + return JSONResponse( + content={ + "type": "weight_refit", + "data": None, + }, + status_code=200, + ) + else: + return JSONResponse( + content={ + "type": "weight_refit", + "data": "Sever not ready", + }, + status_code=500, + ) + @app.get("/node/join/command") async def node_join_command(): diff --git a/src/backend/server/rpc_connection_handler.py b/src/backend/server/rpc_connection_handler.py index 50185d86..7ba70810 100644 --- a/src/backend/server/rpc_connection_handler.py +++ b/src/backend/server/rpc_connection_handler.py @@ -74,7 +74,12 @@ def node_update(self, message): new_rtt_to_nodes=node.rtt_to_nodes, is_active=node.is_active, ) - return {} + res = {} + if self.scheduler.refit_request: + if node.node_id not in self.scheduler.refit_set: + res = self.scheduler.refit_request + self.scheduler.refit_set.add(node.node_id) + return res except Exception as e: logger.exception(f"node_update error: {e}") return {} diff --git a/src/backend/server/scheduler_manage.py b/src/backend/server/scheduler_manage.py index 582a9dc4..674791ab 100644 --- a/src/backend/server/scheduler_manage.py +++ b/src/backend/server/scheduler_manage.py @@ -47,6 +47,7 @@ def __init__( self.stubs = {} self.is_local_network = False + def run(self, model_name, init_nodes_num, is_local_network=True): """ Start the scheduler and the P2P service for RPC handling. @@ -70,6 +71,16 @@ def run(self, model_name, init_nodes_num, is_local_network=True): block_end_index=1, ) + def weight_refit(self, request_data): + """ + Trigger weight refit on every nodes. + """ + if self.scheduler is None: + return False + self.scheduler.refit_request = request_data + self.scheduler.refit_set = set() + return True + def is_running(self): """ Returns True if the scheduler is running, False otherwise. diff --git a/src/parallax/launch.py b/src/parallax/launch.py index a68e16d7..0acbc59b 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -85,6 +85,7 @@ recv_from_peer_addr=args.recv_from_peer_addr, send_to_peer_addr=args.send_to_peer_addr, model_name=args.model_path, + enable_weight_refit=args.enable_weight_refit, max_batch_size=args.max_batch_size, max_sequence_length=args.max_sequence_length, ) diff --git a/src/parallax/p2p/server.py b/src/parallax/p2p/server.py index 772a8a77..cdcee051 100644 --- a/src/parallax/p2p/server.py +++ b/src/parallax/p2p/server.py @@ -11,6 +11,7 @@ import enum import json import logging +import os import threading import time from typing import List, Optional @@ -156,6 +157,16 @@ def rpc_abort( logger.exception(f"Error in rpc_abort: {e}") return forward_pb2.AbortResponse() + def ipc_weight_refit( + self, + refit_weight_path: str, + ): + try: + with self._recv_from_peer_lock: + self.recv_from_peer.send_multipart([b"refit", refit_weight_path.encode("ascii")]) + except Exception as e: + logger.exception(f"Error in ipc_weight_refit: {e}") + @rpc_stream_iter def chat_completion( self, @@ -207,6 +218,7 @@ def __init__( announce_maddrs: List[str] = [], notify_url: str = None, model_name: Optional[str] = None, + enable_weight_refit: Optional[bool] = False, max_batch_size: Optional[int] = None, max_sequence_length: Optional[int] = None, ): @@ -224,6 +236,8 @@ def __init__( self.http_port = http_port self.notify_url = notify_url self.model_name = model_name + self.enable_weight_refit = enable_weight_refit + self.last_refit_time = 0 self.max_batch_size = max_batch_size self.max_sequence_length = max_sequence_length self.prefix_id = f"{dht_prefix}_announce" @@ -539,6 +553,74 @@ def group_requests_by_next_peer(requests: List[forward_pb2.Req]): logger.exception(f"Error in handle_request: {e}") time.sleep(1) + def check_and_run_weight_refit(self, message): + """ + Check and trigger weight refit process. + Received message is a Dict which at least contains: + time_stamp: float, indicating weight refit trigger time. + cid: List[str], cid list. + index_map: Dict[str], key(weight_name): value(cid) + """ + # step1. Check weight refit trigger message + time_stamp = message.get("time_stamp", None) + cid = message.get("cid", None) + index_map = message.get("index_map", None) + if time_stamp is None or cid is None: + return + if self.last_refit_time >= time_stamp: + # Weight already updated + return + + # step2. Download needed weight files from lattica + download_cid_set = set() + layer_key_prefix = "model.layers" + for key in index_map: + is_needed = False + if self.block_start_index == 0: + if "embed_tokens" in key: + is_needed = True + elif self.block_end_index == self.hidden_layers: + if "embed_tokens" in key: + is_needed = True + if "model.norm" in key: + is_needed = True + if "lm_head" in key: + is_needed = True + if layer_key_prefix in key: + try: + parts = key.split(".") + layer_idx = int(parts[2]) + if self.block_start_index <= layer_idx < self.block_end_index: + is_needed = True + except (ValueError, IndexError): + continue + if is_needed: + download_cid_set.add(index_map.get(key)) + + # step3. save weight to disk + weight_dir = os.path.join("/tmp", str(time_stamp)) + while download_cid_set: + cid = download_cid_set.pop() + try: + logger.info(f"Start downloading refit weight {cid}") + raw_data = self.lattica.get_block(cid) + except Exception: + try: + providers = self.lattica.get_providers(cid) + self.lattica.with_bootstraps(providers) + download_cid_set.add(cid) + continue + except Exception as e: + raise RuntimeError(f"Failed to get block: {e}") + file_name = cid + ".safetensors" + file_name = os.path.join(weight_dir, file_name) + with open(file_name, "wb") as f: + f.write_file(raw_data) + + # step4. send ipc message to update weight + self.connection_handler.ipc_weight_refit(weight_dir) + self.last_refit_time = time_stamp + def start_node_announcer(self): """Start a thread that regularly announces this module's presence on DHT""" @@ -548,7 +630,11 @@ def _announcer_thread(): # Announce the range ID try: if self.scheduler_peer_id is not None: - self.scheduler_stub.node_update(self.get_node_info(is_update=True)) + response = self.scheduler_stub.node_update( + self.get_node_info(is_update=True) + ) + if self.enable_weight_refit: + self.check_and_run_weight_refit(response) else: self.lattica.store( key=self.prefix_id, @@ -661,6 +747,7 @@ def launch_p2p_server( recv_from_peer_addr: str, send_to_peer_addr: str, model_name: Optional[str], + enable_weight_refit: Optional[bool] = False, max_batch_size: Optional[int] = None, max_sequence_length: Optional[int] = None, ): @@ -687,6 +774,7 @@ def launch_p2p_server( http_port=http_port, notify_url=notify_url, model_name=model_name, + enable_weight_refit=enable_weight_refit, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length, ) diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index dd0b819e..1487b030 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -19,7 +19,7 @@ import argparse import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import mlx.core as mx import torch @@ -86,6 +86,8 @@ def __init__( kv_block_size: int = 64, kv_cache_memory_fraction: float = 0.8, enable_prefix_cache: Optional[bool] = False, + # Weight Refit + enable_weight_refit: Optional[bool] = False, # Communication Configs # P2P Communication Configs send_to_peer_addr: Optional[str] = None, @@ -162,6 +164,7 @@ def __init__( self.qk_nope_head_dim = self.config.get("qk_nope_head_dim", None) self.qk_rope_head_dim = self.config.get("qk_rope_head_dim", None) self.enable_prefix_cache = enable_prefix_cache + self.enable_weight_refit = enable_weight_refit self.linear_key_head_dim = self.config.get("linear_key_head_dim", None) self.linear_value_head_dim = self.config.get("linear_value_head_dim", None) self.linear_conv_kernel_dim = self.config.get("linear_conv_kernel_dim", None) @@ -279,8 +282,21 @@ def create_from_args(cls, args: argparse.Namespace): """Create executor from command line arguments.""" return cls(**create_executor_config(args)) + def check_and_refit_weight(self, refit_weight_path: str): + if refit_weight_path == "": + return + if self.device == "cuda": + from parallax.sglang.model_runner import refit_sgl_model + + refit_sgl_model(self.model_runner, refit_weight_path) + else: + self.shard_loader.update_weight_from_disk(self.model_shard, refit_weight_path) + def recv_requests_from_http(self) -> List[Request]: - """Receives requests from http frontend""" + """ + Receives requests from http frontend. + Also receives refit requests for weight update. + """ recv_reqs = [] while True: try: @@ -304,9 +320,10 @@ def recv_requests_from_http(self) -> List[Request]: logger.debug(f"Received {len(recv_reqs)} HTTP requests") return recv_reqs - def recv_requests_from_peer(self) -> List[Request]: + def recv_requests_from_peer(self) -> Tuple[List[Request], str]: """Receives requests from the RPC server.""" recv_reqs = [] + refit_weight_path = "" while True: try: recv_req = self.recv_from_peer_socket.recv_multipart(zmq.NOBLOCK) @@ -326,6 +343,8 @@ def recv_requests_from_peer(self) -> List[Request]: abort_request.ParseFromString(recv_req[1]) recv_req = proto_to_abort_request(abort_request) recv_reqs.extend(recv_req) + elif recv_req[0] == b"refit": + refit_weight_path = recv_req[1].decode("ascii") else: raise ValueError(f"Unknown request type: {recv_req[0]}") # First peer is responsible for tokenization @@ -342,7 +361,7 @@ def recv_requests_from_peer(self) -> List[Request]: logger.exception(f"Error receiving or deserializing request: {e}") if recv_reqs: logger.debug(f"Received {len(recv_reqs)} peer requests") - return recv_reqs + return recv_reqs, refit_weight_path def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, Any]: """ @@ -1105,12 +1124,14 @@ def run_loop(self): ) while True: # 1. Ingest new requests from the http frontend - if self.is_first_peer: + if self.is_first_peer or self.enable_weight_refit: http_requests = self.recv_requests_from_http() self._handle_input_requests(http_requests) # 2. Ingest new requests from the RPC server - incoming_requests = self.recv_requests_from_peer() + incoming_requests, refit_weight_path = self.recv_requests_from_peer() + if self.enable_weight_refit: + self.check_and_refit_weight(refit_weight_path) self._handle_input_requests(incoming_requests) # 3. Send finished batch to next peer @@ -1232,5 +1253,6 @@ def create_executor_config(args: argparse.Namespace): "executor_output_ipc_addr": args.executor_output_ipc, "attention_backend": args.attention_backend, "moe_runner_backend": args.moe_runner_backend, + "enable_weight_refit": args.enable_weight_refit, } return config diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 7b8eead8..9139e749 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -99,6 +99,11 @@ def parse_args() -> argparse.Namespace: "--enable-prefix-cache", action="store_true", help="Enable prefix cache reuse" ) + # Weight refit configuration + parser.add_argument( + "--enable-weight-refit", action="store_true", help="Enable runtime weight refit" + ) + # Scheduler configuration parser.add_argument( "--max-batch-size", diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index f15250d6..7b230bc2 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -46,6 +46,7 @@ def __init__( self.model_path_str = model_path_or_hf_repo self.start_layer = start_layer self.end_layer = end_layer + self.config = None self.register_block_class() def register_block_class(self): @@ -95,10 +96,10 @@ def load( A tuple containing the loaded sharded MLX model and its configuration dictionary. """ model_path, _ = get_model_path(self.model_path_str) - config = load_config(model_path) - tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) + self.config = load_config(model_path) + tokenizer = load_tokenizer(model_path, eos_token_ids=self.config.get("eos_token_id", None)) - architectures = config.get("architectures", None) + architectures = self.config.get("architectures", None) if architectures is None: raise ValueError("architectures not found in config.json") if len(architectures) != 1: @@ -108,13 +109,13 @@ def load( if block_class is None: raise ValueError(f"block_class not found for architecture: {architecture}") - num_hidden_layers = config.get("num_hidden_layers", 0) + num_hidden_layers = self.config.get("num_hidden_layers", 0) current_start_layer = self.start_layer if self.start_layer is not None else 0 current_end_layer = self.end_layer if self.end_layer is not None else num_hidden_layers # We need the model object to know its structure and which layers it owns. # This part mirrors the logic from the provided utils.py to get model_args. - model_type = config.get("model_type") + model_type = self.config.get("model_type") if model_type == "kimi_k2": model_type = "deepseek_v3" if not model_type: @@ -122,11 +123,11 @@ def load( try: arch_module = importlib.import_module(f"mlx_lm.models.{model_type}") model_args_class = getattr(arch_module, "ModelArgs") - model_args = model_args_class.from_dict(config) + model_args = model_args_class.from_dict(self.config) except (ImportError, AttributeError) as e: raise ValueError(f"Failed to load architecture for model_type '{model_type}'.") from e - dtype = getattr(mx, config.get("torch_dtype", "bfloat16")) + dtype = getattr(mx, self.config.get("torch_dtype", "bfloat16")) # Extract the base model name from model_id_original if it's a repo ID model_id = self.model_path_str @@ -170,7 +171,9 @@ def load( ): is_needed = True remapped_key = key.replace("model.", "", 1) - if model_shard.is_last_shard and config.get("tie_word_embeddings", False): + if model_shard.is_last_shard and self.config.get( + "tie_word_embeddings", False + ): shard_weights["lm_head.weight"] = mx.array(f.get_tensor(key)) elif model_shard.is_last_shard: if "model.norm" in key: @@ -180,7 +183,7 @@ def load( is_needed = True remapped_key = key elif ( - config.get("tie_word_embeddings", False) + self.config.get("tie_word_embeddings", False) and "embed" in key and key.startswith("model.embed_tokens") ): @@ -204,12 +207,12 @@ def load( if is_needed: shard_weights[remapped_key] = mx.array(f.get_tensor(key)) - if (quantization := config.get("quantization", None)) is not None: + if (quantization := self.config.get("quantization", None)) is not None: logger.info("Model is quantized. Applying quantization parameters...") def class_predicate(p, m): # Handle custom per-layer quantizations from the config - qcfg = config.get("quantization", {}) + qcfg = self.config.get("quantization", {}) # Direct key (Parallax remapped keys usually drop the 'model.' prefix) if p in qcfg: override = qcfg[p] @@ -251,4 +254,110 @@ def class_predicate(p, m): current_end_layer, mx.get_active_memory() / 1024**3, ) - return model_shard, config, tokenizer + return model_shard, self.config, tokenizer + + def update_weight_from_disk(self, model_shard: nn.Module, refit_weight_path: str): + """Runtime weight refit from disk""" + weight_files = glob.glob(refit_weight_path + "/*.safetensors") + if not weight_files: + raise FileNotFoundError(f"No safetensors found in {refit_weight_path}") + + logger.info(f"Begin refit weight from path: {refit_weight_path}") + shard_weights = {} + layer_key_prefix = "model.layers" # Common prefix + + for wf in weight_files: + # For bf16 models, we need torch tensors as a bridge + with safetensors.safe_open(wf, framework="pt") as f: + for key in f.keys(): + is_needed = False + remapped_key = None + + # Check if the key belongs to the shard and remap it + if ( + model_shard.is_first_shard + and "embed_tokens" in key + and key.startswith("model.") + ): + is_needed = True + remapped_key = key.replace("model.", "", 1) + if model_shard.is_last_shard and self.config.get( + "tie_word_embeddings", False + ): + shard_weights["lm_head.weight"] = mx.array(f.get_tensor(key)) + elif model_shard.is_last_shard: + if "model.norm" in key: + is_needed = True + remapped_key = key.replace("model.", "", 1) + if "lm_head" in key: + is_needed = True + remapped_key = key + elif ( + self.config.get("tie_word_embeddings", False) + and "embed" in key + and key.startswith("model.embed_tokens") + ): + # TODO: we don't need load lm_head in this case + # as we will pass hidden_states to FirstPeer + # see request.py for details + is_needed = True + remapped_key = "lm_head.weight" + if layer_key_prefix in key: + try: + parts = key.split(".") + layer_idx = int(parts[2]) + if self.start_layer <= layer_idx < self.end_layer: + is_needed = True + local_layer_idx = layer_idx - self.start_layer + remapped_key = f"layers.{local_layer_idx}.{'.'.join(parts[3:])}" + except (ValueError, IndexError): + continue + + # If the key is needed, load only that tensor from the file + if is_needed: + shard_weights[remapped_key] = mx.array(f.get_tensor(key)) + + if (quantization := self.config.get("quantization", None)) is not None: + logger.info("Model is quantized. Applying quantization parameters...") + + def class_predicate(p, m): + # Handle custom per-layer quantizations from the config + qcfg = self.config.get("quantization", {}) + # Direct key (Parallax remapped keys usually drop the 'model.' prefix) + if p in qcfg: + override = qcfg[p] + if isinstance(override, dict): + logger.debug( + f"[quantize] Using override for '{p}': bits={override.get('bits')} group_size={override.get('group_size')}" + ) + return override + # Allow config keys that still include the original 'model.' prefix (as in mlx-lm) + prefixed = f"model.{p}" + if prefixed in qcfg: + override = qcfg[prefixed] + if isinstance(override, dict): + logger.debug( + f"[quantize] Using override for '{prefixed}' (mapped to '{p}'): bits={override.get('bits')} group_size={override.get('group_size')}" + ) + return override + if not hasattr(m, "to_quantized"): + return False + # Handle legacy models by checking if quantized weights exist + return f"{p}.scales" in shard_weights + + nn.quantize( + model_shard, + group_size=quantization["group_size"], + bits=quantization["bits"], + mode=quantization.get("mode", "affine"), + class_predicate=class_predicate, + ) + + model_shard.load_weights(list(shard_weights.items()), strict=False) + mx.eval(model_shard.parameters()) + model_shard.eval() + logger.info( + "Successfully updated model shard from %s, memory usage: %.3f GB", + refit_weight_path, + mx.get_active_memory() / 1024**3, + ) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 1e0b4d9f..2f9e9ecb 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -574,9 +574,9 @@ def initialize_sgl_model_runner( model_config.hf_config.tie_word_embeddings = False model_config.hf_config.start_layer = start_layer model_config.hf_config.end_layer = end_layer - print("Model config:", model_config) - print("model_start_layer:", model_config.hf_config.start_layer) - print("model_end_layer:", model_config.hf_config.end_layer) + logger.debug(f"SGLang Model config: {model_config}") + logger.debug(f"SGLang model_start_layer: {model_config.hf_config.start_layer}") + logger.debug(f"SGLang model_end_layer: {model_config.hf_config.end_layer}") model_runner = ParallaxModelRunner( model_config=model_config, mem_fraction_static=kv_cache_memory_fraction, @@ -593,3 +593,12 @@ def initialize_sgl_model_runner( pp_end_layer=end_layer, ) return model_runner, config, tokenizer + + +def refit_sgl_model( + model_runner: ParallaxModelRunner, + refit_weight_path: str, +): + """Runtime weight refit from disk""" + logger.info(f"Begin refit weight from path: {refit_weight_path}") + model_runner.update_weights_from_disk(model_path=refit_weight_path, load_format="auto") diff --git a/src/scheduling/scheduler.py b/src/scheduling/scheduler.py index 6d143c10..a2f9b18c 100644 --- a/src/scheduling/scheduler.py +++ b/src/scheduling/scheduler.py @@ -105,6 +105,10 @@ def __init__( ) self._node_assigned_request_count: Dict[str, int] = {} + # Weight refit + self.refit_request = {} + self.refit_set = set() + # Eager bootstrap for initial allocation if enough nodes are present try: if len(self.nodes) >= self.min_nodes_bootstrapping: @@ -519,3 +523,6 @@ def stop(self) -> None: self._wake_event.set() with self._node_count_cv: self._node_count_cv.notify_all() + + def trigger_weight_refit(self): + """Trigger weight refit event for all nodes""" diff --git a/tests/test_server_args.py b/tests/test_server_args.py index deaf5f42..701e340d 100644 --- a/tests/test_server_args.py +++ b/tests/test_server_args.py @@ -84,6 +84,7 @@ def test_create_config(self): micro_batch_ratio=2, scheduler_wait_ms=500, enable_prefix_cache=True, + enable_weight_refit=True, executor_input_ipc="///ipc/1", executor_output_ipc="///ipc/2", attention_backend="torch_native",