Skip to content

Commit 715e847

Browse files
Merge pull request #811 from EmbeddedLLM/vllm-chat
Add vllm chat completion endpoints and fix inspect_history
2 parents 7b1e49a + ca4efef commit 715e847

File tree

2 files changed

+65
-27
lines changed

2 files changed

+65
-27
lines changed

dsp/modules/hf_client.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
import shutil
55
import subprocess
6+
from typing import Literal
67

78
# from dsp.modules.adapter import TurboAdapter, DavinciAdapter, LlamaAdapter
89
import backoff
@@ -117,7 +118,7 @@ def send_hftgi_request_v00(arg, **kwargs):
117118

118119

119120
class HFClientVLLM(HFModel):
120-
def __init__(self, model, port, url="http://localhost", **kwargs):
121+
def __init__(self, model, port, model_type: Literal['chat', 'text'] = 'text', url="http://localhost", **kwargs):
121122
super().__init__(model=model, is_client=True)
122123

123124
if isinstance(url, list):
@@ -129,49 +130,88 @@ def __init__(self, model, port, url="http://localhost", **kwargs):
129130
else:
130131
raise ValueError(f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type {type(url)}.")
131132

133+
self.model_type = model_type
132134
self.headers = {"Content-Type": "application/json"}
133135
self.kwargs |= kwargs
136+
# kwargs needs to have model, port and url for the lm.copy() to work properly
137+
self.kwargs.update({
138+
'port': port,
139+
'url': url,
140+
})
134141

135142

136143
def _generate(self, prompt, **kwargs):
137144
kwargs = {**self.kwargs, **kwargs}
138-
139-
payload = {
140-
"model": self.kwargs["model"],
141-
"prompt": prompt,
142-
**kwargs,
143-
}
144-
145145

146146
# Round robin the urls.
147147
url = self.urls.pop(0)
148148
self.urls.append(url)
149+
150+
if self.model_type == "chat":
151+
system_prompt = kwargs.get("system_prompt",None)
152+
messages = [{"role": "user", "content": prompt}]
153+
if system_prompt:
154+
messages.insert(0, {"role": "system", "content": system_prompt})
155+
payload = {
156+
"model": self.kwargs["model"],
157+
"messages": messages,
158+
**kwargs,
159+
}
160+
response = send_hfvllm_chat_request_v00(
161+
f"{url}/v1/chat/completions",
162+
json=payload,
163+
headers=self.headers,
164+
)
165+
166+
try:
167+
json_response = response.json()
168+
completions = json_response["choices"]
169+
response = {
170+
"prompt": prompt,
171+
"choices": [{"text": c["message"]['content']} for c in completions],
172+
}
173+
return response
149174

150-
response = send_hfvllm_request_v00(
151-
f"{url}/v1/completions",
152-
json=payload,
153-
headers=self.headers,
154-
)
155-
156-
try:
157-
json_response = response.json()
158-
completions = json_response["choices"]
159-
response = {
175+
except Exception:
176+
print("Failed to parse JSON response:", response.text)
177+
raise Exception("Received invalid JSON response from server")
178+
else:
179+
payload = {
180+
"model": self.kwargs["model"],
160181
"prompt": prompt,
161-
"choices": [{"text": c["text"]} for c in completions],
182+
**kwargs,
162183
}
163-
return response
184+
185+
response = send_hfvllm_request_v00(
186+
f"{url}/v1/completions",
187+
json=payload,
188+
headers=self.headers,
189+
)
190+
191+
try:
192+
json_response = response.json()
193+
completions = json_response["choices"]
194+
response = {
195+
"prompt": prompt,
196+
"choices": [{"text": c["text"]} for c in completions],
197+
}
198+
return response
164199

165-
except Exception:
166-
print("Failed to parse JSON response:", response.text)
167-
raise Exception("Received invalid JSON response from server")
200+
except Exception:
201+
print("Failed to parse JSON response:", response.text)
202+
raise Exception("Received invalid JSON response from server")
168203

169204

170205
@CacheMemory.cache
171206
def send_hfvllm_request_v00(arg, **kwargs):
172207
return requests.post(arg, **kwargs)
173208

174209

210+
@CacheMemory.cache
211+
def send_hfvllm_chat_request_v00(arg, **kwargs):
212+
return requests.post(arg, **kwargs)
213+
214+
175215
class HFServerTGI:
176216
def __init__(self, user_dir):
177217
self.model_weights_dir = os.path.abspath(os.path.join(os.getcwd(), "text-generation-inference", user_dir))
@@ -438,4 +478,4 @@ def _generate(self, prompt, **kwargs):
438478

439479
@CacheMemory.cache
440480
def send_hfsglang_request_v00(arg, **kwargs):
441-
return requests.post(arg, **kwargs)
481+
return requests.post(arg, **kwargs)

dsp/modules/lm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,11 @@ def inspect_history(self, n: int = 1, skip: int = 0):
6363
if len(printed) >= n:
6464
break
6565

66+
printing_value = ""
6667
for idx, (prompt, choices) in enumerate(reversed(printed)):
67-
printing_value = ""
68-
6968
# skip the first `skip` prompts
7069
if (n - idx - 1) < skip:
7170
continue
72-
7371
printing_value += "\n\n\n"
7472
printing_value += prompt
7573

0 commit comments

Comments
 (0)