Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,27 @@ async def scheduler_init(raw_request: Request):
status_code=200,
)

@app.post("/weight/refit")
async def weight_refit(raw_request: Request):
request_data = await raw_request.json()
status = scheduler_manage.weight_refit(request_data)
if status:
return JSONResponse(
content={
"type": "weight_refit",
"data": None,
},
status_code=200,
)
else:
return JSONResponse(
content={
"type": "weight_refit",
"data": "Sever not ready",
},
status_code=500,
)


@app.get("/node/join/command")
async def node_join_command():
Expand Down
7 changes: 6 additions & 1 deletion src/backend/server/rpc_connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ def node_update(self, message):
new_rtt_to_nodes=node.rtt_to_nodes,
is_active=node.is_active,
)
return {}
res = {}
if self.scheduler.refit_request:
if node.node_id not in self.scheduler.refit_set:
res = self.scheduler.refit_request
self.scheduler.refit_set.add(node.node_id)
return res
except Exception as e:
logger.exception(f"node_update error: {e}")
return {}
Expand Down
11 changes: 11 additions & 0 deletions src/backend/server/scheduler_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
self.stubs = {}
self.is_local_network = False


def run(self, model_name, init_nodes_num, is_local_network=True):
"""
Start the scheduler and the P2P service for RPC handling.
Expand All @@ -70,6 +71,16 @@ def run(self, model_name, init_nodes_num, is_local_network=True):
block_end_index=1,
)

def weight_refit(self, request_data):
"""
Trigger weight refit on every nodes.
"""
if self.scheduler is None:
return False
self.scheduler.refit_request = request_data
self.scheduler.refit_set = set()
return True

def is_running(self):
"""
Returns True if the scheduler is running, False otherwise.
Expand Down
1 change: 1 addition & 0 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
recv_from_peer_addr=args.recv_from_peer_addr,
send_to_peer_addr=args.send_to_peer_addr,
model_name=args.model_path,
enable_weight_refit=args.enable_weight_refit,
max_batch_size=args.max_batch_size,
max_sequence_length=args.max_sequence_length,
)
Expand Down
90 changes: 89 additions & 1 deletion src/parallax/p2p/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import enum
import json
import logging
import os
import threading
import time
from typing import List, Optional
Expand Down Expand Up @@ -156,6 +157,16 @@ def rpc_abort(
logger.exception(f"Error in rpc_abort: {e}")
return forward_pb2.AbortResponse()

def ipc_weight_refit(
self,
refit_weight_path: str,
):
try:
with self._recv_from_peer_lock:
self.recv_from_peer.send_multipart([b"refit", refit_weight_path.encode("ascii")])
except Exception as e:
logger.exception(f"Error in ipc_weight_refit: {e}")

@rpc_stream_iter
def chat_completion(
self,
Expand Down Expand Up @@ -207,6 +218,7 @@ def __init__(
announce_maddrs: List[str] = [],
notify_url: str = None,
model_name: Optional[str] = None,
enable_weight_refit: Optional[bool] = False,
max_batch_size: Optional[int] = None,
max_sequence_length: Optional[int] = None,
):
Expand All @@ -224,6 +236,8 @@ def __init__(
self.http_port = http_port
self.notify_url = notify_url
self.model_name = model_name
self.enable_weight_refit = enable_weight_refit
self.last_refit_time = 0
self.max_batch_size = max_batch_size
self.max_sequence_length = max_sequence_length
self.prefix_id = f"{dht_prefix}_announce"
Expand Down Expand Up @@ -539,6 +553,74 @@ def group_requests_by_next_peer(requests: List[forward_pb2.Req]):
logger.exception(f"Error in handle_request: {e}")
time.sleep(1)

def check_and_run_weight_refit(self, message):
"""
Check and trigger weight refit process.
Received message is a Dict which at least contains:
time_stamp: float, indicating weight refit trigger time.
cid: List[str], cid list.
index_map: Dict[str], key(weight_name): value(cid)
"""
# step1. Check weight refit trigger message
time_stamp = message.get("time_stamp", None)
cid = message.get("cid", None)
index_map = message.get("index_map", None)
if time_stamp is None or cid is None:
return
if self.last_refit_time >= time_stamp:
# Weight already updated
return

# step2. Download needed weight files from lattica
download_cid_set = set()
layer_key_prefix = "model.layers"
for key in index_map:
is_needed = False
if self.block_start_index == 0:
if "embed_tokens" in key:
is_needed = True
elif self.block_end_index == self.hidden_layers:
if "embed_tokens" in key:
is_needed = True
if "model.norm" in key:
is_needed = True
if "lm_head" in key:
is_needed = True
if layer_key_prefix in key:
try:
parts = key.split(".")
layer_idx = int(parts[2])
if self.block_start_index <= layer_idx < self.block_end_index:
is_needed = True
except (ValueError, IndexError):
continue
if is_needed:
download_cid_set.add(index_map.get(key))

# step3. save weight to disk
weight_dir = os.path.join("/tmp", str(time_stamp))
while download_cid_set:
cid = download_cid_set.pop()
try:
logger.info(f"Start downloading refit weight {cid}")
raw_data = self.lattica.get_block(cid)
except Exception:
try:
providers = self.lattica.get_providers(cid)
self.lattica.with_bootstraps(providers)
download_cid_set.add(cid)
continue
except Exception as e:
raise RuntimeError(f"Failed to get block: {e}")
file_name = cid + ".safetensors"
file_name = os.path.join(weight_dir, file_name)
with open(file_name, "wb") as f:
f.write_file(raw_data)

# step4. send ipc message to update weight
self.connection_handler.ipc_weight_refit(weight_dir)
self.last_refit_time = time_stamp

def start_node_announcer(self):
"""Start a thread that regularly announces this module's presence on DHT"""

Expand All @@ -548,7 +630,11 @@ def _announcer_thread():
# Announce the range ID
try:
if self.scheduler_peer_id is not None:
self.scheduler_stub.node_update(self.get_node_info(is_update=True))
response = self.scheduler_stub.node_update(
self.get_node_info(is_update=True)
)
if self.enable_weight_refit:
self.check_and_run_weight_refit(response)
else:
self.lattica.store(
key=self.prefix_id,
Expand Down Expand Up @@ -661,6 +747,7 @@ def launch_p2p_server(
recv_from_peer_addr: str,
send_to_peer_addr: str,
model_name: Optional[str],
enable_weight_refit: Optional[bool] = False,
max_batch_size: Optional[int] = None,
max_sequence_length: Optional[int] = None,
):
Expand All @@ -687,6 +774,7 @@ def launch_p2p_server(
http_port=http_port,
notify_url=notify_url,
model_name=model_name,
enable_weight_refit=enable_weight_refit,
max_batch_size=max_batch_size,
max_sequence_length=max_sequence_length,
)
Expand Down
34 changes: 28 additions & 6 deletions src/parallax/server/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import argparse
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import mlx.core as mx
import torch
Expand Down Expand Up @@ -86,6 +86,8 @@ def __init__(
kv_block_size: int = 64,
kv_cache_memory_fraction: float = 0.8,
enable_prefix_cache: Optional[bool] = False,
# Weight Refit
enable_weight_refit: Optional[bool] = False,
# Communication Configs
# P2P Communication Configs
send_to_peer_addr: Optional[str] = None,
Expand Down Expand Up @@ -162,6 +164,7 @@ def __init__(
self.qk_nope_head_dim = self.config.get("qk_nope_head_dim", None)
self.qk_rope_head_dim = self.config.get("qk_rope_head_dim", None)
self.enable_prefix_cache = enable_prefix_cache
self.enable_weight_refit = enable_weight_refit
self.linear_key_head_dim = self.config.get("linear_key_head_dim", None)
self.linear_value_head_dim = self.config.get("linear_value_head_dim", None)
self.linear_conv_kernel_dim = self.config.get("linear_conv_kernel_dim", None)
Expand Down Expand Up @@ -279,8 +282,21 @@ def create_from_args(cls, args: argparse.Namespace):
"""Create executor from command line arguments."""
return cls(**create_executor_config(args))

def check_and_refit_weight(self, refit_weight_path: str):
if refit_weight_path == "":
return
if self.device == "cuda":
from parallax.sglang.model_runner import refit_sgl_model

refit_sgl_model(self.model_runner, refit_weight_path)
else:
self.shard_loader.update_weight_from_disk(self.model_shard, refit_weight_path)

def recv_requests_from_http(self) -> List[Request]:
"""Receives requests from http frontend"""
"""
Receives requests from http frontend.
Also receives refit requests for weight update.
"""
recv_reqs = []
while True:
try:
Expand All @@ -304,9 +320,10 @@ def recv_requests_from_http(self) -> List[Request]:
logger.debug(f"Received {len(recv_reqs)} HTTP requests")
return recv_reqs

def recv_requests_from_peer(self) -> List[Request]:
def recv_requests_from_peer(self) -> Tuple[List[Request], str]:
"""Receives requests from the RPC server."""
recv_reqs = []
refit_weight_path = ""
while True:
try:
recv_req = self.recv_from_peer_socket.recv_multipart(zmq.NOBLOCK)
Expand All @@ -326,6 +343,8 @@ def recv_requests_from_peer(self) -> List[Request]:
abort_request.ParseFromString(recv_req[1])
recv_req = proto_to_abort_request(abort_request)
recv_reqs.extend(recv_req)
elif recv_req[0] == b"refit":
refit_weight_path = recv_req[1].decode("ascii")
else:
raise ValueError(f"Unknown request type: {recv_req[0]}")
# First peer is responsible for tokenization
Expand All @@ -342,7 +361,7 @@ def recv_requests_from_peer(self) -> List[Request]:
logger.exception(f"Error receiving or deserializing request: {e}")
if recv_reqs:
logger.debug(f"Received {len(recv_reqs)} peer requests")
return recv_reqs
return recv_reqs, refit_weight_path

def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -1105,12 +1124,14 @@ def run_loop(self):
)
while True:
# 1. Ingest new requests from the http frontend
if self.is_first_peer:
if self.is_first_peer or self.enable_weight_refit:
http_requests = self.recv_requests_from_http()
self._handle_input_requests(http_requests)

# 2. Ingest new requests from the RPC server
incoming_requests = self.recv_requests_from_peer()
incoming_requests, refit_weight_path = self.recv_requests_from_peer()
if self.enable_weight_refit:
self.check_and_refit_weight(refit_weight_path)
self._handle_input_requests(incoming_requests)

# 3. Send finished batch to next peer
Expand Down Expand Up @@ -1232,5 +1253,6 @@ def create_executor_config(args: argparse.Namespace):
"executor_output_ipc_addr": args.executor_output_ipc,
"attention_backend": args.attention_backend,
"moe_runner_backend": args.moe_runner_backend,
"enable_weight_refit": args.enable_weight_refit,
}
return config
5 changes: 5 additions & 0 deletions src/parallax/server/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def parse_args() -> argparse.Namespace:
"--enable-prefix-cache", action="store_true", help="Enable prefix cache reuse"
)

# Weight refit configuration
parser.add_argument(
"--enable-weight-refit", action="store_true", help="Enable runtime weight refit"
)

# Scheduler configuration
parser.add_argument(
"--max-batch-size",
Expand Down
Loading