Skip to content

Commit bb181d6

Browse files
authored
Ensure that max_completion_tokens=1 in Prefill (#403)
* Make sure that max_completion_tokens=1 in Prefill Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> * Remove/undo setting of max_completion_tokens to 1, for decode Signed-off-by: Shmuel Kallner <kallner@il.ibm.com> --------- Signed-off-by: Shmuel Kallner <kallner@il.ibm.com>
1 parent 8e98c80 commit bb181d6

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

pkg/sidecar/proxy/connector_lmcache.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, pref
4949
ctx := r.Context()
5050
preq := r.Clone(ctx)
5151

52-
completionRequest["max_tokens"] = 1
53-
completionRequest["max_completion_tokens"] = 1
52+
completionRequest[requestFieldMaxTokens] = 1
53+
completionRequest[requestFieldMaxCompletionTokens] = 1
5454

5555
pbody, err := json.Marshal(completionRequest)
5656
if err != nil {

pkg/sidecar/proxy/connector_nixlv2.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
6767
streamValue, streamOk := completionRequest[requestFieldStream]
6868
streamOptionsValue, streamOptionsOk := completionRequest[requestFieldStreamOptions]
6969
maxTokensValue, maxTokensOk := completionRequest[requestFieldMaxTokens]
70+
maxCompletionTokensValue, maxCompletionTokensOk := completionRequest[requestFieldMaxCompletionTokens]
7071

7172
completionRequest[requestFieldKVTransferParams] = map[string]any{
7273
requestFieldDoRemoteDecode: true,
@@ -80,6 +81,7 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
8081
completionRequest[requestFieldStream] = false
8182
delete(completionRequest, requestFieldStreamOptions)
8283
completionRequest[requestFieldMaxTokens] = 1
84+
completionRequest[requestFieldMaxCompletionTokens] = 1
8385

8486
pbody, err := json.Marshal(completionRequest)
8587
if err != nil {
@@ -146,6 +148,10 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
146148
if maxTokensOk {
147149
completionRequest[requestFieldMaxTokens] = maxTokensValue
148150
}
151+
delete(completionRequest, requestFieldMaxCompletionTokens)
152+
if maxCompletionTokensOk {
153+
completionRequest[requestFieldMaxCompletionTokens] = maxCompletionTokensValue
154+
}
149155
completionRequest[requestFieldKVTransferParams] = pKVTransferParams
150156

151157
dbody, err := json.Marshal(completionRequest)

pkg/sidecar/proxy/proxy.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,17 @@ const (
3838
requestHeaderPrefillURL = "x-prefiller-url"
3939
requestHeaderRequestID = "x-request-id"
4040

41-
requestFieldKVTransferParams = "kv_transfer_params"
42-
requestFieldMaxTokens = "max_tokens"
43-
requestFieldDoRemotePrefill = "do_remote_prefill"
44-
requestFieldDoRemoteDecode = "do_remote_decode"
45-
requestFieldRemoteBlockIDs = "remote_block_ids"
46-
requestFieldRemoteEngineID = "remote_engine_id"
47-
requestFieldRemoteHost = "remote_host"
48-
requestFieldRemotePort = "remote_port"
49-
requestFieldStream = "stream"
50-
requestFieldStreamOptions = "stream_options"
41+
requestFieldKVTransferParams = "kv_transfer_params"
42+
requestFieldMaxTokens = "max_tokens"
43+
requestFieldMaxCompletionTokens = "max_completion_tokens"
44+
requestFieldDoRemotePrefill = "do_remote_prefill"
45+
requestFieldDoRemoteDecode = "do_remote_decode"
46+
requestFieldRemoteBlockIDs = "remote_block_ids"
47+
requestFieldRemoteEngineID = "remote_engine_id"
48+
requestFieldRemoteHost = "remote_host"
49+
requestFieldRemotePort = "remote_port"
50+
requestFieldStream = "stream"
51+
requestFieldStreamOptions = "stream_options"
5152

5253
// ConnectorNIXLV2 enables the P/D NIXL v2 protocol
5354
ConnectorNIXLV2 = "nixlv2"

0 commit comments

Comments
 (0)