Skip to content

Commit b01501a

Browse files
authored
[Feat] Toy proxy now supports PD-mixed round-robin scheduling (#316)
* [Feat] Toy proxy now supports PD-mixed round-robin scheduling * [Docs] modify the path of toy_proxy_server.py
1 parent 1827635 commit b01501a

File tree

4 files changed

+172
-76
lines changed

4 files changed

+172
-76
lines changed

docs/source/user-guide/pd-disaggregation/1p1d.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ CUDA_VISIBLE_DEVICES=1 vllm serve /home/models/Qwen2.5-7B-Instruct \
6868
### Run proxy server
6969
Make sure prefill nodes and decode nodes can connect to each other.
7070
```bash
71-
cd vllm-workspace/unified-cache-management/test/
72-
python3 toy_proxy_server.py --host localhost --port 7802 --prefiller-host <prefill-node-ip> --prefiller-port 7800 --decoder-host <decode-node-ip> --decoder-port 7801
71+
cd vllm-workspace/unified-cache-management/ucm/pd
72+
python3 toy_proxy_server.py --pd-disaggregation --host localhost --port 7802 --prefiller-host <prefill-node-ip> --prefiller-port 7800 --decoder-host <decode-node-ip> --decoder-port 7801
7373
```
7474

7575
## Testing and Benchmarking

docs/source/user-guide/pd-disaggregation/npgd.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ CUDA_VISIBLE_DEVICES=0 vllm serve /home/models/Qwen2.5-7B-Instruct \
7777
### Run proxy server
7878
Make sure prefill nodes and decode nodes can connect to each other.
7979
```bash
80-
cd vllm-workspace/unified-cache-management/test/
80+
cd vllm-workspace/unified-cache-management/ucm/pd
8181
python3 toy_proxy_server.py --host localhost --port 7802 --prefiller-host <prefill-node-ip> --prefiller-port 7800 --decoder-host <decode-node-ip> --decoder-port 7801
8282
```
8383

docs/source/user-guide/pd-disaggregation/xpyd.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ CUDA_VISIBLE_DEVICES=3 vllm serve /home/models/Qwen2.5-7B-Instruct \
121121
### Run proxy server
122122
Make sure prefill nodes and decode nodes can connect to each other. the number of prefill/decode hosts should be equal to the number of prefill/decode ports.
123123
```bash
124-
cd vllm-workspace/unified-cache-management/test/
125-
python3 toy_proxy_server.py --host localhost --port 7805 --prefiller-hosts <prefill-node-ip-1> <prefill-node-ip-2> --prefiller-port 7800 7801 --decoder-hosts <decoder-node-ip-1> <decoder-node-ip-2> --decoder-ports 7802 7803
124+
cd vllm-workspace/unified-cache-management/ucm/pd
125+
python3 toy_proxy_server.py --pd-disaggregation --host localhost --port 7805 --prefiller-hosts <prefill-node-ip-1> <prefill-node-ip-2> --prefiller-port 7800 7801 --decoder-hosts <decoder-node-ip-1> <decoder-node-ip-2> --decoder-ports 7802 7803
126126
```
127127

128128
## Testing and Benchmarking

ucm/pd/toy_proxy_server.py

Lines changed: 167 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -17,53 +17,78 @@
1717
@asynccontextmanager
1818
async def lifespan(app: FastAPI):
1919
"""
20-
Lifespan context manager to handle startup and shutdown events.
20+
Lifespan context manager to initialize clients based on mode.
2121
"""
22-
# Startup: Initialize client pools for prefiller and decoder services
2322
app.state.prefill_clients = []
2423
app.state.decode_clients = []
25-
26-
# Create prefill clients
27-
for i, (host, port) in enumerate(global_args.prefiller_instances):
28-
prefiller_base_url = f"http://{host}:{port}/v1"
29-
app.state.prefill_clients.append(
30-
{
31-
"client": httpx.AsyncClient(timeout=None, base_url=prefiller_base_url),
32-
"host": host,
33-
"port": port,
34-
"id": i,
35-
}
24+
app.state.worker_clients = [] # For PD-mixed workers
25+
26+
if global_args.pd_disaggregation:
27+
# === PD disaggregation ===
28+
for i, (host, port) in enumerate(global_args.prefiller_instances):
29+
base_url = f"http://{host}:{port}/v1"
30+
app.state.prefill_clients.append(
31+
{
32+
"client": httpx.AsyncClient(timeout=None, base_url=base_url),
33+
"host": host,
34+
"port": port,
35+
"id": i,
36+
}
37+
)
38+
39+
for i, (host, port) in enumerate(global_args.decoder_instances):
40+
base_url = f"http://{host}:{port}/v1"
41+
app.state.decode_clients.append(
42+
{
43+
"client": httpx.AsyncClient(timeout=None, base_url=base_url),
44+
"host": host,
45+
"port": port,
46+
"id": i,
47+
}
48+
)
49+
50+
app.state.prefill_iterator = itertools.cycle(
51+
range(len(app.state.prefill_clients))
3652
)
37-
38-
# Create decode clients
39-
for i, (host, port) in enumerate(global_args.decoder_instances):
40-
decoder_base_url = f"http://{host}:{port}/v1"
41-
app.state.decode_clients.append(
42-
{
43-
"client": httpx.AsyncClient(timeout=None, base_url=decoder_base_url),
44-
"host": host,
45-
"port": port,
46-
"id": i,
47-
}
53+
app.state.decode_iterator = itertools.cycle(
54+
range(len(app.state.decode_clients))
4855
)
4956

50-
# Initialize round-robin iterators
51-
app.state.prefill_iterator = itertools.cycle(range(len(app.state.prefill_clients)))
52-
app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients)))
57+
print(
58+
f"[PD Mode] Initialized {len(app.state.prefill_clients)} prefillers "
59+
f"and {len(app.state.decode_clients)} decoders."
60+
)
5361

54-
print(
55-
f"Initialized {len(app.state.prefill_clients)} prefill clients "
56-
f"and {len(app.state.decode_clients)} decode clients."
57-
)
62+
else:
63+
# === PD mix ===
64+
for i, (host, port) in enumerate(global_args.worker_instances):
65+
base_url = f"http://{host}:{port}/v1"
66+
app.state.worker_clients.append(
67+
{
68+
"client": httpx.AsyncClient(timeout=None, base_url=base_url),
69+
"host": host,
70+
"port": port,
71+
"id": i,
72+
}
73+
)
74+
75+
app.state.worker_iterator = itertools.cycle(
76+
range(len(app.state.worker_clients))
77+
)
78+
print(
79+
f"[Mixed Mode] Initialized {len(app.state.worker_clients)} PD-mixed workers."
80+
)
5881

5982
yield
6083

61-
# Shutdown: Close all clients
62-
for client_info in app.state.prefill_clients:
63-
await client_info["client"].aclose()
64-
65-
for client_info in app.state.decode_clients:
66-
await client_info["client"].aclose()
84+
# Close all clients
85+
for client_list in [
86+
app.state.prefill_clients,
87+
app.state.decode_clients,
88+
app.state.worker_clients,
89+
]:
90+
for client_info in client_list:
91+
await client_info["client"].aclose()
6792

6893

6994
# Update FastAPI app initialization to use lifespan
@@ -75,6 +100,26 @@ def parse_args():
75100

76101
parser.add_argument("--port", type=int, default=8000)
77102
parser.add_argument("--host", type=str, default="localhost")
103+
parser.add_argument(
104+
"--pd-disaggregation",
105+
action="store_true",
106+
help="Enable PD disaggregation mode (prefill and decode separation)",
107+
)
108+
# For PD mix instances
109+
parser.add_argument(
110+
"--worker-hosts",
111+
"--work-host",
112+
type=str,
113+
nargs="+",
114+
default=["localhost"],
115+
)
116+
parser.add_argument(
117+
"--worker-ports",
118+
"--work-port",
119+
type=int,
120+
nargs="+",
121+
default=[8100],
122+
)
78123

79124
# For prefiller instances
80125
parser.add_argument(
@@ -107,9 +152,15 @@ def parse_args():
107152
if len(args.decoder_hosts) != len(args.decoder_ports):
108153
raise ValueError("Number of decoder hosts must match number of decoder ports")
109154

110-
# Create tuples of (host, port) for each service type
155+
if len(args.worker_hosts) != len(args.worker_ports):
156+
raise ValueError("Number of worker hosts must match number of worker ports")
157+
158+
# Create instance tuples
111159
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
112160
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
161+
args.worker_instances = list(
162+
zip(args.worker_hosts, args.worker_ports)
163+
) # Mixed workers
113164

114165
return args
115166

@@ -120,12 +171,15 @@ def get_next_client(app, service_type: str):
120171
121172
Args:
122173
app: The FastAPI app instance
123-
service_type: Either 'prefill' or 'decode'
174+
service_type: 'worker' 、'prefill' 'decode'
124175
125176
Returns:
126177
The next client to use
127178
"""
128-
if service_type == "prefill":
179+
if service_type == "worker":
180+
worker_idx = next(app.state.worker_iterator)
181+
return app.state.worker_clients[worker_idx]
182+
elif service_type == "prefill":
129183
client_idx = next(app.state.prefill_iterator)
130184
return app.state.prefill_clients[client_idx]
131185
elif service_type == "decode":
@@ -183,37 +237,72 @@ async def _handle_completions(api: str, request: Request):
183237
req_data = await request.json()
184238
request_id = str(uuid.uuid4())
185239

186-
# Get the next prefill client in round-robin fashion
187-
prefill_client_info = get_next_client(request.app, "prefill")
188-
189-
# Send request to prefill service
190-
response = await send_request_to_service(
191-
prefill_client_info, api, req_data, request_id
192-
)
193-
194-
# Extract the needed fields
195-
response_json = response.json()
196-
197-
# Get the next decode client in round-robin fashion
198-
decode_client_info = get_next_client(request.app, "decode")
199-
200-
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
201-
202-
# Stream response from decode service
203-
async def generate_stream():
204-
async for chunk in stream_service_response(
205-
decode_client_info, api, req_data, request_id=request_id
206-
):
207-
yield chunk
208-
209-
return StreamingResponse(generate_stream(), media_type="application/json")
240+
headers = {
241+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
242+
"X-Request-Id": request_id,
243+
}
244+
245+
if global_args.pd_disaggregation:
246+
# === PD disaggregation logic ===
247+
248+
# Step 1: Send request to prefiller (to trigger computation and cache KV)
249+
prefill_client_info = get_next_client(request.app, "prefill")
250+
prefill_req_data = req_data.copy()
251+
prefill_req_data["stream"] = False
252+
prefill_req_data["max_tokens"] = 1
253+
if "stream_options" in prefill_req_data:
254+
del prefill_req_data["stream_options"]
255+
256+
response = await prefill_client_info["client"].post(
257+
api, json=prefill_req_data, headers=headers
258+
)
259+
response.raise_for_status()
260+
261+
# Step 2: Stream full output from decoder
262+
decode_client_info = get_next_client(request.app, "decode")
263+
264+
logger.debug(
265+
"PD-DISAGG: Prefill=%s:%d, Decode=%s:%d",
266+
prefill_client_info["host"],
267+
prefill_client_info["port"],
268+
decode_client_info["host"],
269+
decode_client_info["port"],
270+
)
271+
272+
async def generate_stream():
273+
async for chunk in stream_service_response(
274+
decode_client_info, api, req_data, request_id
275+
):
276+
yield chunk
277+
278+
return StreamingResponse(generate_stream(), media_type="application/json")
279+
280+
else:
281+
# === PD mixed mode: Directly forward the entire stream using round-robin ===
282+
worker_client_info = get_next_client(request.app, "worker")
283+
284+
logger.debug(
285+
"PD-MIXED: Forwarding to %s:%d",
286+
worker_client_info["host"],
287+
worker_client_info["port"],
288+
)
289+
290+
async def generate_stream():
291+
async with worker_client_info["client"].stream(
292+
"POST", api, json=req_data, headers=headers
293+
) as resp:
294+
resp.raise_for_status()
295+
async for chunk in resp.aiter_bytes():
296+
yield chunk
297+
298+
return StreamingResponse(generate_stream(), media_type="application/json")
210299

211300
except Exception as e:
212301
import sys
213302
import traceback
214303

215304
exc_info = sys.exc_info()
216-
print("Error occurred in disagg prefill proxy server" f" - {api} endpoint")
305+
print(f"Error in proxy server - {api} endpoint")
217306
print(e)
218307
print("".join(traceback.format_exception(*exc_info)))
219308
raise
@@ -231,12 +320,19 @@ async def handle_chat_completions(request: Request):
231320

232321
@app.get("/healthcheck")
233322
async def healthcheck():
234-
"""Simple endpoint to check if the server is running."""
235-
return {
236-
"status": "ok",
237-
"prefill_instances": len(app.state.prefill_clients),
238-
"decode_instances": len(app.state.decode_clients),
239-
}
323+
if global_args.pd_disaggregation:
324+
return {
325+
"status": "ok",
326+
"mode": "pd-disaggregation",
327+
"prefill_instances": len(app.state.prefill_clients),
328+
"decode_instances": len(app.state.decode_clients),
329+
}
330+
else:
331+
return {
332+
"status": "ok",
333+
"mode": "pd-mixed",
334+
"worker_instances": len(app.state.worker_clients),
335+
}
240336

241337

242338
if __name__ == "__main__":

0 commit comments

Comments
 (0)