Skip to content

Commit 8c06637

Browse files
bongwoobakhhk7734
andauthored
Add SGLang Connector for Prefill/Decode Disaggregation (migrated from llm-d-routing-sidecar#64) (#456)
* add sglang connector Co-authored-by: Hyeonki Hong <hhk7734@gmail.com> Signed-off-by: bongwoobak <bongwoobak@gmail.com> * fix error log Signed-off-by: bongwoobak <bongwoobak@gmail.com> * refactor: address code review feedback for SGLang connector - Move bootstrap port initialization to init() for performance - Remove prefill host validation (handled by proxy layer) - Move handler creation outside goroutine for proper error handling Signed-off-by: bongwoobak <bongwoobak@gmail.com> * refactor: simplify getBootstrapHost to return only hostname Remove unnecessary port return. Signed-off-by: bongwoobak <bongwoobak@gmail.com> * Update cmd/pd-sidecar/main.go Co-authored-by: Hyeonki Hong <hhk7734@gmail.com> Signed-off-by: bongwoobak <66110096+bongwoobak@users.noreply.github.com> * Update cmd/pd-sidecar/main.go Co-authored-by: Hyeonki Hong <hhk7734@gmail.com> Signed-off-by: bongwoobak <66110096+bongwoobak@users.noreply.github.com> * fix: add missing strings import in main.go Signed-off-by: bongwoobak <bongwoobak@gmail.com> --------- Signed-off-by: bongwoobak <bongwoobak@gmail.com> Signed-off-by: bongwoobak <66110096+bongwoobak@users.noreply.github.com> Co-authored-by: Hyeonki Hong <hhk7734@gmail.com>
1 parent e549a9f commit 8c06637

File tree

3 files changed

+184
-3
lines changed

3 files changed

+184
-3
lines changed

cmd/pd-sidecar/main.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"flag"
2121
"net/url"
2222
"os"
23+
"strings"
2324

2425
"k8s.io/klog/v2"
2526
ctrl "sigs.k8s.io/controller-runtime"
@@ -28,11 +29,20 @@ import (
2829
"github.com/llm-d/llm-d-inference-scheduler/pkg/sidecar/version"
2930
)
3031

32+
var (
33+
// supportedConnectors defines all valid P/D connector types
34+
supportedConnectors = []string{
35+
proxy.ConnectorNIXLV2,
36+
proxy.ConnectorLMCache,
37+
proxy.ConnectorSGLang,
38+
}
39+
)
40+
3141
func main() {
3242
port := flag.String("port", "8000", "the port the sidecar is listening on")
3343
vLLMPort := flag.String("vllm-port", "8001", "the port vLLM is listening on")
3444
vLLMDataParallelSize := flag.Int("data-parallel-size", 1, "the vLLM DATA-PARALLEL-SIZE value")
35-
connector := flag.String("connector", "nixlv2", "the P/D connector being used. Either nixl, nixlv2 or lmcache")
45+
connector := flag.String("connector", proxy.ConnectorNIXLV2, "the P/D connector being used. Supported: "+strings.Join(supportedConnectors, ", "))
3646
prefillerUseTLS := flag.Bool("prefiller-use-tls", false, "whether to use TLS when sending requests to prefillers")
3747
decoderUseTLS := flag.Bool("decoder-use-tls", false, "whether to use TLS when sending requests to the decoder")
3848
prefillerInsecureSkipVerify := flag.Bool("prefiller-tls-insecure-skip-verify", false, "configures the proxy to skip TLS verification for requests to prefiller")
@@ -57,8 +67,16 @@ func main() {
5767

5868
logger.Info("Proxy starting", "Built on", version.BuildRef, "From Git SHA", version.CommitSHA)
5969

60-
if *connector != proxy.ConnectorNIXLV2 && *connector != proxy.ConnectorLMCache {
61-
logger.Info("Error: --connector must either be 'nixlv2' or 'lmcache'")
70+
// Validate connector
71+
isValidConnector := false
72+
for _, validConnector := range supportedConnectors {
73+
if *connector == validConnector {
74+
isValidConnector = true
75+
break
76+
}
77+
}
78+
if !isValidConnector {
79+
logger.Info("Error: --connector must be one of: " + strings.Join(supportedConnectors, ", "))
6280
return
6381
}
6482
logger.Info("p/d connector validated", "connector", connector)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
Copyright 2025 The llm-d Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package proxy
18+
19+
import (
20+
"bytes"
21+
"encoding/json"
22+
"fmt"
23+
"io"
24+
"math/rand"
25+
"net/http"
26+
"os"
27+
"strconv"
28+
"strings"
29+
"time"
30+
)
31+
32+
var (
33+
sglangBootstrapPort int
34+
)
35+
36+
func init() {
37+
// Default SGLang bootstrap port
38+
sglangBootstrapPort = 8998
39+
40+
// Override from environment variable if set
41+
if portStr := os.Getenv("SGLANG_BOOTSTRAP_PORT"); portStr != "" {
42+
if port, err := strconv.Atoi(portStr); err == nil {
43+
sglangBootstrapPort = port
44+
}
45+
}
46+
}
47+
48+
func (s *Server) runSGLangProtocol(w http.ResponseWriter, r *http.Request, prefillPodHostPort string) {
49+
s.logger.V(4).Info("running SGLang protocol", "url", prefillPodHostPort)
50+
51+
// Make Request
52+
requestData, err := s.parseSGLangRequest(r)
53+
54+
if err != nil {
55+
if err := errorJSONInvalid(err, w); err != nil {
56+
s.logger.Error(err, "failed to send error response to client")
57+
}
58+
return
59+
}
60+
61+
roomID := s.generateSGLangRoomID()
62+
63+
// Inject bootstrap info for both prefill and decode
64+
bootstrapInfo := s.addSGLangBootstrapInfo(requestData, prefillPodHostPort, roomID)
65+
66+
body, err := json.Marshal(bootstrapInfo)
67+
if err != nil {
68+
if err := errorJSONInvalid(err, w); err != nil {
69+
s.logger.Error(err, "failed to send error response to client")
70+
}
71+
return
72+
}
73+
74+
// Send concurrent prefill and decode requests
75+
s.sendSGLangConcurrentRequests(w, r, body, prefillPodHostPort)
76+
}
77+
78+
func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Request, body []byte, prefillHost string) {
79+
// Create separate requests for prefill and decode
80+
prefillReq := cloneWithJSONBody(r, body)
81+
decodeReq := cloneWithJSONBody(r, body)
82+
83+
prefillHandler, err := s.prefillerProxyHandler(prefillHost)
84+
if err != nil {
85+
if err := errorBadGateway(err, w); err != nil {
86+
s.logger.Error(err, "failed to send error response to client")
87+
}
88+
return
89+
}
90+
91+
// Send prefill request asynchronously
92+
go func() {
93+
pw := &bufferedResponseWriter{}
94+
prefillHandler.ServeHTTP(pw, prefillReq)
95+
s.logger.V(5).Info("prefill request completed", "status", pw.statusCode)
96+
}()
97+
98+
// Send decode request synchronously
99+
s.decoderProxy.ServeHTTP(w, decodeReq)
100+
}
101+
102+
func cloneWithJSONBody(r *http.Request, body []byte) *http.Request {
103+
req := r.Clone(r.Context())
104+
req.Body = io.NopCloser(bytes.NewReader(body))
105+
req.ContentLength = int64(len(body))
106+
return req
107+
}
108+
109+
func (s *Server) addSGLangBootstrapInfo(requestData map[string]interface{}, prefillHostPort string, roomID int64) map[string]interface{} {
110+
modifiedRequest := make(map[string]interface{})
111+
for k, v := range requestData {
112+
modifiedRequest[k] = v
113+
}
114+
115+
// Generate bootstrap host from prefill host
116+
bootstrapHost := s.getBootstrapHost(prefillHostPort)
117+
118+
// Add bootstrap information
119+
modifiedRequest[requestFieldBootstrapHost] = bootstrapHost
120+
modifiedRequest[requestFieldBootstrapPort] = sglangBootstrapPort
121+
modifiedRequest[requestFieldBootstrapRoom] = roomID
122+
123+
s.logger.V(5).Info("bootstrap info added",
124+
"bootstrap_host", bootstrapHost,
125+
"bootstrap_port", sglangBootstrapPort,
126+
"bootstrap_room", roomID)
127+
128+
return modifiedRequest
129+
}
130+
131+
func (s *Server) parseSGLangRequest(r *http.Request) (map[string]interface{}, error) {
132+
body, err := io.ReadAll(r.Body)
133+
if err != nil {
134+
return nil, fmt.Errorf("failed to read request body: %w", err)
135+
}
136+
137+
var requestData map[string]interface{}
138+
if err := json.Unmarshal(body, &requestData); err != nil {
139+
return nil, fmt.Errorf("failed to parse request body: %w", err)
140+
}
141+
142+
return requestData, nil
143+
}
144+
145+
func (s *Server) generateSGLangRoomID() int64 {
146+
return time.Now().UnixNano() + int64(rand.Intn(1000))
147+
}
148+
149+
func (s *Server) getBootstrapHost(prefillHostPort string) string {
150+
// Extract hostname from prefill host
151+
parts := strings.Split(prefillHostPort, ":")
152+
return parts[0]
153+
}

pkg/sidecar/proxy/proxy.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,19 @@ const (
4646
requestFieldStream = "stream"
4747
requestFieldStreamOptions = "stream_options"
4848

49+
// SGLang bootstrap fields
50+
requestFieldBootstrapHost = "bootstrap_host"
51+
requestFieldBootstrapPort = "bootstrap_port"
52+
requestFieldBootstrapRoom = "bootstrap_room"
53+
4954
// ConnectorNIXLV2 enables the P/D NIXL v2 protocol
5055
ConnectorNIXLV2 = "nixlv2"
5156

5257
// ConnectorLMCache enables (now deprecated) P/D LMCache protocol
5358
ConnectorLMCache = "lmcache"
59+
60+
// ConnectorSGLang enables SGLang P/D disaggregation protocol
61+
ConnectorSGLang = "sglang"
5462
)
5563

5664
// Config represents the proxy server configuration
@@ -108,6 +116,8 @@ func NewProxy(port string, decodeURL *url.URL, config Config) *Server {
108116
switch config.Connector {
109117
case ConnectorLMCache:
110118
server.runConnectorProtocol = server.runLMCacheProtocol
119+
case ConnectorSGLang:
120+
server.runConnectorProtocol = server.runSGLangProtocol
111121
case ConnectorNIXLV2:
112122
fallthrough
113123
default:

0 commit comments

Comments
 (0)