1717@asynccontextmanager
1818async def lifespan (app : FastAPI ):
1919 """
20- Lifespan context manager to handle startup and shutdown events .
20+ Lifespan context manager to initialize clients based on mode .
2121 """
22- # Startup: Initialize client pools for prefiller and decoder services
2322 app .state .prefill_clients = []
2423 app .state .decode_clients = []
25-
26- # Create prefill clients
27- for i , (host , port ) in enumerate (global_args .prefiller_instances ):
28- prefiller_base_url = f"http://{ host } :{ port } /v1"
29- app .state .prefill_clients .append (
30- {
31- "client" : httpx .AsyncClient (timeout = None , base_url = prefiller_base_url ),
32- "host" : host ,
33- "port" : port ,
34- "id" : i ,
35- }
24+ app .state .worker_clients = [] # For PD-mixed workers
25+
26+ if global_args .pd_disaggregation :
27+ # === PD disaggregation ===
28+ for i , (host , port ) in enumerate (global_args .prefiller_instances ):
29+ base_url = f"http://{ host } :{ port } /v1"
30+ app .state .prefill_clients .append (
31+ {
32+ "client" : httpx .AsyncClient (timeout = None , base_url = base_url ),
33+ "host" : host ,
34+ "port" : port ,
35+ "id" : i ,
36+ }
37+ )
38+
39+ for i , (host , port ) in enumerate (global_args .decoder_instances ):
40+ base_url = f"http://{ host } :{ port } /v1"
41+ app .state .decode_clients .append (
42+ {
43+ "client" : httpx .AsyncClient (timeout = None , base_url = base_url ),
44+ "host" : host ,
45+ "port" : port ,
46+ "id" : i ,
47+ }
48+ )
49+
50+ app .state .prefill_iterator = itertools .cycle (
51+ range (len (app .state .prefill_clients ))
3652 )
37-
38- # Create decode clients
39- for i , (host , port ) in enumerate (global_args .decoder_instances ):
40- decoder_base_url = f"http://{ host } :{ port } /v1"
41- app .state .decode_clients .append (
42- {
43- "client" : httpx .AsyncClient (timeout = None , base_url = decoder_base_url ),
44- "host" : host ,
45- "port" : port ,
46- "id" : i ,
47- }
53+ app .state .decode_iterator = itertools .cycle (
54+ range (len (app .state .decode_clients ))
4855 )
4956
50- # Initialize round-robin iterators
51- app .state .prefill_iterator = itertools .cycle (range (len (app .state .prefill_clients )))
52- app .state .decode_iterator = itertools .cycle (range (len (app .state .decode_clients )))
57+ print (
58+ f"[PD Mode] Initialized { len (app .state .prefill_clients )} prefillers "
59+ f"and { len (app .state .decode_clients )} decoders."
60+ )
5361
54- print (
55- f"Initialized { len (app .state .prefill_clients )} prefill clients "
56- f"and { len (app .state .decode_clients )} decode clients."
57- )
62+ else :
63+ # === PD mix ===
64+ for i , (host , port ) in enumerate (global_args .worker_instances ):
65+ base_url = f"http://{ host } :{ port } /v1"
66+ app .state .worker_clients .append (
67+ {
68+ "client" : httpx .AsyncClient (timeout = None , base_url = base_url ),
69+ "host" : host ,
70+ "port" : port ,
71+ "id" : i ,
72+ }
73+ )
74+
75+ app .state .worker_iterator = itertools .cycle (
76+ range (len (app .state .worker_clients ))
77+ )
78+ print (
79+ f"[Mixed Mode] Initialized { len (app .state .worker_clients )} PD-mixed workers."
80+ )
5881
5982 yield
6083
61- # Shutdown: Close all clients
62- for client_info in app .state .prefill_clients :
63- await client_info ["client" ].aclose ()
64-
65- for client_info in app .state .decode_clients :
66- await client_info ["client" ].aclose ()
84+ # Close all clients
85+ for client_list in [
86+ app .state .prefill_clients ,
87+ app .state .decode_clients ,
88+ app .state .worker_clients ,
89+ ]:
90+ for client_info in client_list :
91+ await client_info ["client" ].aclose ()
6792
6893
6994# Update FastAPI app initialization to use lifespan
@@ -75,6 +100,26 @@ def parse_args():
75100
76101 parser .add_argument ("--port" , type = int , default = 8000 )
77102 parser .add_argument ("--host" , type = str , default = "localhost" )
103+ parser .add_argument (
104+ "--pd-disaggregation" ,
105+ action = "store_true" ,
106+ help = "Enable PD disaggregation mode (prefill and decode separation)" ,
107+ )
108+ # For PD mix instances
109+ parser .add_argument (
110+ "--worker-hosts" ,
111+ "--work-host" ,
112+ type = str ,
113+ nargs = "+" ,
114+ default = ["localhost" ],
115+ )
116+ parser .add_argument (
117+ "--worker-ports" ,
118+ "--work-port" ,
119+ type = int ,
120+ nargs = "+" ,
121+ default = [8100 ],
122+ )
78123
79124 # For prefiller instances
80125 parser .add_argument (
@@ -107,9 +152,15 @@ def parse_args():
107152 if len (args .decoder_hosts ) != len (args .decoder_ports ):
108153 raise ValueError ("Number of decoder hosts must match number of decoder ports" )
109154
110- # Create tuples of (host, port) for each service type
155+ if len (args .worker_hosts ) != len (args .worker_ports ):
156+ raise ValueError ("Number of worker hosts must match number of worker ports" )
157+
158+ # Create instance tuples
111159 args .prefiller_instances = list (zip (args .prefiller_hosts , args .prefiller_ports ))
112160 args .decoder_instances = list (zip (args .decoder_hosts , args .decoder_ports ))
161+ args .worker_instances = list (
162+ zip (args .worker_hosts , args .worker_ports )
163+ ) # Mixed workers
113164
114165 return args
115166
@@ -120,12 +171,15 @@ def get_next_client(app, service_type: str):
120171
121172 Args:
122173 app: The FastAPI app instance
123- service_type: Either ' prefill' or 'decode'
174+ service_type: 'worker' 、' prefill' 、 'decode'
124175
125176 Returns:
126177 The next client to use
127178 """
128- if service_type == "prefill" :
179+ if service_type == "worker" :
180+ worker_idx = next (app .state .worker_iterator )
181+ return app .state .worker_clients [worker_idx ]
182+ elif service_type == "prefill" :
129183 client_idx = next (app .state .prefill_iterator )
130184 return app .state .prefill_clients [client_idx ]
131185 elif service_type == "decode" :
@@ -183,37 +237,72 @@ async def _handle_completions(api: str, request: Request):
183237 req_data = await request .json ()
184238 request_id = str (uuid .uuid4 ())
185239
186- # Get the next prefill client in round-robin fashion
187- prefill_client_info = get_next_client (request .app , "prefill" )
188-
189- # Send request to prefill service
190- response = await send_request_to_service (
191- prefill_client_info , api , req_data , request_id
192- )
193-
194- # Extract the needed fields
195- response_json = response .json ()
196-
197- # Get the next decode client in round-robin fashion
198- decode_client_info = get_next_client (request .app , "decode" )
199-
200- logger .debug ("Using %s %s" , prefill_client_info , decode_client_info )
201-
202- # Stream response from decode service
203- async def generate_stream ():
204- async for chunk in stream_service_response (
205- decode_client_info , api , req_data , request_id = request_id
206- ):
207- yield chunk
208-
209- return StreamingResponse (generate_stream (), media_type = "application/json" )
240+ headers = {
241+ "Authorization" : f"Bearer { os .environ .get ('OPENAI_API_KEY' )} " ,
242+ "X-Request-Id" : request_id ,
243+ }
244+
245+ if global_args .pd_disaggregation :
246+ # === PD disaggregation logic ===
247+
248+ # Step 1: Send request to prefiller (to trigger computation and cache KV)
249+ prefill_client_info = get_next_client (request .app , "prefill" )
250+ prefill_req_data = req_data .copy ()
251+ prefill_req_data ["stream" ] = False
252+ prefill_req_data ["max_tokens" ] = 1
253+ if "stream_options" in prefill_req_data :
254+ del prefill_req_data ["stream_options" ]
255+
256+ response = await prefill_client_info ["client" ].post (
257+ api , json = prefill_req_data , headers = headers
258+ )
259+ response .raise_for_status ()
260+
261+ # Step 2: Stream full output from decoder
262+ decode_client_info = get_next_client (request .app , "decode" )
263+
264+ logger .debug (
265+ "PD-DISAGG: Prefill=%s:%d, Decode=%s:%d" ,
266+ prefill_client_info ["host" ],
267+ prefill_client_info ["port" ],
268+ decode_client_info ["host" ],
269+ decode_client_info ["port" ],
270+ )
271+
272+ async def generate_stream ():
273+ async for chunk in stream_service_response (
274+ decode_client_info , api , req_data , request_id
275+ ):
276+ yield chunk
277+
278+ return StreamingResponse (generate_stream (), media_type = "application/json" )
279+
280+ else :
281+ # === PD mixed mode: Directly forward the entire stream using round-robin ===
282+ worker_client_info = get_next_client (request .app , "worker" )
283+
284+ logger .debug (
285+ "PD-MIXED: Forwarding to %s:%d" ,
286+ worker_client_info ["host" ],
287+ worker_client_info ["port" ],
288+ )
289+
290+ async def generate_stream ():
291+ async with worker_client_info ["client" ].stream (
292+ "POST" , api , json = req_data , headers = headers
293+ ) as resp :
294+ resp .raise_for_status ()
295+ async for chunk in resp .aiter_bytes ():
296+ yield chunk
297+
298+ return StreamingResponse (generate_stream (), media_type = "application/json" )
210299
211300 except Exception as e :
212301 import sys
213302 import traceback
214303
215304 exc_info = sys .exc_info ()
216- print ("Error occurred in disagg prefill proxy server" f" - { api } endpoint" )
305+ print (f "Error in proxy server - { api } endpoint" )
217306 print (e )
218307 print ("" .join (traceback .format_exception (* exc_info )))
219308 raise
@@ -231,12 +320,19 @@ async def handle_chat_completions(request: Request):
231320
232321@app .get ("/healthcheck" )
233322async def healthcheck ():
234- """Simple endpoint to check if the server is running."""
235- return {
236- "status" : "ok" ,
237- "prefill_instances" : len (app .state .prefill_clients ),
238- "decode_instances" : len (app .state .decode_clients ),
239- }
323+ if global_args .pd_disaggregation :
324+ return {
325+ "status" : "ok" ,
326+ "mode" : "pd-disaggregation" ,
327+ "prefill_instances" : len (app .state .prefill_clients ),
328+ "decode_instances" : len (app .state .decode_clients ),
329+ }
330+ else :
331+ return {
332+ "status" : "ok" ,
333+ "mode" : "pd-mixed" ,
334+ "worker_instances" : len (app .state .worker_clients ),
335+ }
240336
241337
242338if __name__ == "__main__" :
0 commit comments