Skip to content
52 changes: 45 additions & 7 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,21 @@
--end-layer 28
"""

import argparse
import multiprocessing
import os
import tempfile
import threading

from parallax.p2p.server import ServerState, launch_p2p_server
from parallax.server.executor import Executor
from parallax.server.executor import (
Executor,
run_executor_process,
stop_executor_process,
)
from parallax.server.http_server import launch_http_server, stop_http_server
from parallax.server.server_args import parse_args
from parallax.utils.utils import fetch_model_from_hf, initialize_nccl_port
from parallax_utils.ascii_anime import display_parallax_join
from parallax_utils.logging_config import get_logger, set_log_level
from parallax_utils.version_check import check_latest_release
Expand All @@ -35,6 +41,7 @@
gradient_server = None
http_server_process = None
executor = None
executor_procs = []
try:
args = parse_args()
set_log_level(args.log_level)
Expand All @@ -43,28 +50,31 @@
args.send_to_peer_addr = f"ipc://{tempfile.NamedTemporaryFile().name}"
args.executor_input_ipc = f"ipc://{tempfile.NamedTemporaryFile().name}"
args.executor_output_ipc = f"ipc://{tempfile.NamedTemporaryFile().name}"
if args.nccl_port is None:
args.nccl_port = initialize_nccl_port()

# Silence tokenizer warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

logger.debug(f"executor_input_addr: {args.executor_input_ipc}")
logger.debug(f"executor_output_addr: {args.executor_output_ipc}")
logger.debug(f"nccl_port: {args.nccl_port}")
if args.scheduler_addr is None:
if args.log_level != "DEBUG":
display_parallax_join(args.model_path)
check_latest_release()

config = fetch_model_from_hf(args.model_path)
# only launch http server on head node
if args.start_layer == 0:
http_server_process = launch_http_server(args)
executor = Executor.create_from_args(args)
launch_p2p_server(
initial_peers=args.initial_peers,
scheduler_addr=args.scheduler_addr,
relay_servers=args.relay_servers,
pp_start_layer=args.start_layer,
pp_end_layer=args.end_layer,
hidden_layers=executor.config.get("num_hidden_layers"),
hidden_layers=config.get("num_hidden_layers"),
tcp_port=args.tcp_port,
udp_port=args.udp_port,
dht_prefix=args.dht_prefix,
Expand All @@ -81,7 +91,18 @@
)
if gradient_server is not None:
gradient_server.status = ServerState.READY
executor.run_loop()
tp_rank_range = range(args.tp_size)
for tp_rank in tp_rank_range:
args_copy = argparse.Namespace(**vars(args))
args_copy.tp_rank = tp_rank
proc = multiprocessing.Process(
target=run_executor_process,
args=(args_copy,),
)
proc.start()
executor_procs.append(proc)
for executor_process in executor_procs:
executor_process.join()
else:
gradient_server = launch_p2p_server(
initial_peers=args.initial_peers,
Expand All @@ -107,7 +128,10 @@
args.start_layer = gradient_server.block_start_index
args.end_layer = gradient_server.block_end_index
args.model_path = gradient_server.model_name
# Hard code for mlx-community models
# TODO: Implement inter-process communication to enable TP.
# For scheduler mode, currently only support tp_rank=0
args.tp_rank = 0

logger.debug(
f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}"
)
Expand Down Expand Up @@ -174,13 +198,27 @@
except Exception as e:
logger.exception(e)
finally:
t = None
thread_pool = []

# Shutdown http server
if http_server_process is not None:
t = threading.Thread(target=stop_http_server, args=(http_server_process,))
t.start()
thread_pool.append(t)

# Shutdown gradient server
if gradient_server is not None:
gradient_server.shutdown()

# Shutdown executor subprocess for scheduler mode
for executor_process in executor_procs:
t = threading.Thread(target=stop_executor_process, args=(executor_process,))
t.start()
thread_pool.append(t)

# Shutdown executor main process for non-scheduler mode
if executor is not None:
executor.shutdown()
if t is not None:

for t in thread_pool:
t.join()
Loading