Skip to content

Commit e2b7a0c

Browse files
authored
Merge pull request #186 from stanfordnlp/anyscale_client
added support for OpenAI completion API string prompting
2 parents e7f163d + c4f95b5 commit e2b7a0c

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

dsp/modules/hf_client.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,26 +165,41 @@ def __init__(self, model, **kwargs):
165165
**kwargs
166166
}
167167

168-
def _generate(self, prompt, **kwargs):
168+
def _generate(self, prompt, use_chat_api=True, **kwargs):
169169
url = f"{self.api_base}/chat/completions"
170170
kwargs = {**self.kwargs, **kwargs}
171171

172172
temperature = kwargs.get("temperature")
173-
messages = [{"role": "system", "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections."}, {"role": "user", "content": prompt}]
174-
175-
body = {
176-
"model": self.model,
177-
"messages": messages,
178-
"temperature": temperature,
179-
"max_tokens": 150
180-
}
173+
max_tokens = kwargs.get("max_tokens", 150)
174+
175+
if use_chat_api:
176+
messages = [
177+
{"role": "system", "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections."},
178+
{"role": "user", "content": prompt}
179+
]
180+
body = {
181+
"model": self.model,
182+
"messages": messages,
183+
"temperature": temperature,
184+
"max_tokens": max_tokens
185+
}
186+
else:
187+
body = {
188+
"model": self.model,
189+
"prompt": f"[INST]{prompt}[/INST]",
190+
"temperature": temperature,
191+
"max_tokens": max_tokens
192+
}
181193

182194
headers = {"Authorization": f"Bearer {self.token}"}
183195

184196
try:
185197
with self.session.post(url, headers=headers, json=body) as resp:
186198
resp_json = resp.json()
187-
completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")]
199+
if use_chat_api:
200+
completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")]
201+
else:
202+
completions = [resp_json.get('choices', [])[0].get('text', "")]
188203
response = {"prompt": prompt, "choices": [{"text": c} for c in completions]}
189204
return response
190205
except Exception as e:

0 commit comments

Comments
 (0)