Skip to content

Commit abe350c

Browse files
authored
Merge pull request #1043 from tom-doerr/filter_vllm_keywords
Filter vllm keywords
2 parents 5533d39 + 6045e82 commit abe350c

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

dsp/modules/hf_client.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,20 +148,49 @@ def _generate(self, prompt, **kwargs):
148148
# Round robin the urls.
149149
url = self.urls.pop(0)
150150
self.urls.append(url)
151-
151+
152+
list_of_elements_to_allow = [
153+
"n",
154+
"best_of",
155+
"presence_penalty",
156+
"frequency_penalty",
157+
"repetition_penalty",
158+
"temperature",
159+
"top_p",
160+
"top_k",
161+
"min_p",
162+
"seed",
163+
"use_beam_search",
164+
"length_penalty",
165+
"early_stopping",
166+
"stop",
167+
"stop_token_ids",
168+
"include_stop_str_in_output",
169+
"ignore_eos",
170+
"max_tokens",
171+
"min_tokens",
172+
"logprobs",
173+
"prompt_logprobs",
174+
"detokenize",
175+
"skip_special_tokens",
176+
"spaces_between_special_tokens",
177+
"logits_processors",
178+
"truncate_prompt_tokens",
179+
]
180+
req_kwargs = {
181+
k: v for k, v in kwargs.items() if k in list_of_elements_to_allow
182+
}
183+
152184
if self.model_type == "chat":
153185
system_prompt = kwargs.get("system_prompt",None)
154186
messages = [{"role": "user", "content": prompt}]
155187
if system_prompt:
156188
messages.insert(0, {"role": "system", "content": system_prompt})
157189

158-
kwargs.pop("port", None)
159-
kwargs.pop("url", None)
160-
161190
payload = {
162191
"model": self.kwargs["model"],
163192
"messages": messages,
164-
**kwargs,
193+
**req_kwargs,
165194
}
166195
response = send_hfvllm_request_v01_wrapped(
167196
f"{url}/v1/chat/completions",
@@ -184,13 +213,10 @@ def _generate(self, prompt, **kwargs):
184213
print("Failed to parse JSON response:", response.text)
185214
raise Exception("Received invalid JSON response from server")
186215
else:
187-
kwargs.pop("port", None)
188-
kwargs.pop("url", None)
189-
190216
payload = {
191217
"model": self.kwargs["model"],
192218
"prompt": prompt,
193-
**kwargs,
219+
**req_kwargs,
194220
}
195221

196222
response = send_hfvllm_request_v01_wrapped(

0 commit comments

Comments
 (0)