Skip to content

Commit c4f95b5

Browse files
added support for OpenAI completion API string prompting
1 parent 5f0aa61 commit c4f95b5

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

dsp/modules/hf_client.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,26 +113,41 @@ def __init__(self, model, **kwargs):
113113
**kwargs
114114
}
115115

116-
def _generate(self, prompt, **kwargs):
116+
def _generate(self, prompt, use_chat_api=True, **kwargs):
117117
url = f"{self.api_base}/chat/completions"
118118
kwargs = {**self.kwargs, **kwargs}
119119

120120
temperature = kwargs.get("temperature")
121-
messages = [{"role": "system", "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections."}, {"role": "user", "content": prompt}]
122-
123-
body = {
124-
"model": self.model,
125-
"messages": messages,
126-
"temperature": temperature,
127-
"max_tokens": 150
128-
}
121+
max_tokens = kwargs.get("max_tokens", 150)
122+
123+
if use_chat_api:
124+
messages = [
125+
{"role": "system", "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections."},
126+
{"role": "user", "content": prompt}
127+
]
128+
body = {
129+
"model": self.model,
130+
"messages": messages,
131+
"temperature": temperature,
132+
"max_tokens": max_tokens
133+
}
134+
else:
135+
body = {
136+
"model": self.model,
137+
"prompt": f"[INST]{prompt}[/INST]",
138+
"temperature": temperature,
139+
"max_tokens": max_tokens
140+
}
129141

130142
headers = {"Authorization": f"Bearer {self.token}"}
131143

132144
try:
133145
with self.session.post(url, headers=headers, json=body) as resp:
134146
resp_json = resp.json()
135-
completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")]
147+
if use_chat_api:
148+
completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")]
149+
else:
150+
completions = [resp_json.get('choices', [])[0].get('text', "")]
136151
response = {"prompt": prompt, "choices": [{"text": c} for c in completions]}
137152
return response
138153
except Exception as e:

0 commit comments

Comments
 (0)