Skip to content

Commit 9c47a69

Browse files
committed
Add filtering vllm arguments
1 parent 9c1fff9 commit 9c47a69

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

dsp/modules/hf_client.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,39 @@ def _generate(self, prompt, **kwargs):
148148
# Round robin the urls.
149149
url = self.urls.pop(0)
150150
self.urls.append(url)
151-
151+
152+
list_of_elements_to_allow = [
153+
"n",
154+
"best_of",
155+
"presence_penalty",
156+
"frequency_penalty",
157+
"repetition_penalty",
158+
"temperature",
159+
"top_p",
160+
"top_k",
161+
"min_p",
162+
"seed",
163+
"use_beam_search",
164+
"length_penalty",
165+
"early_stopping",
166+
"stop",
167+
"stop_token_ids",
168+
"include_stop_str_in_output",
169+
"ignore_eos",
170+
"max_tokens",
171+
"min_tokens",
172+
"logprobs",
173+
"prompt_logprobs",
174+
"detokenize",
175+
"skip_special_tokens",
176+
"spaces_between_special_tokens",
177+
"logits_processors",
178+
"truncate_prompt_tokens",
179+
]
180+
req_kwargs = {
181+
k: v for k, v in kwargs.items() if k in list_of_elements_to_allow
182+
}
183+
152184
if self.model_type == "chat":
153185
system_prompt = kwargs.get("system_prompt",None)
154186
messages = [{"role": "user", "content": prompt}]
@@ -161,7 +193,7 @@ def _generate(self, prompt, **kwargs):
161193
payload = {
162194
"model": self.kwargs["model"],
163195
"messages": messages,
164-
**kwargs,
196+
**req_kwargs,
165197
}
166198
response = send_hfvllm_request_v01_wrapped(
167199
f"{url}/v1/chat/completions",
@@ -190,7 +222,7 @@ def _generate(self, prompt, **kwargs):
190222
payload = {
191223
"model": self.kwargs["model"],
192224
"prompt": prompt,
193-
**kwargs,
225+
**req_kwargs,
194226
}
195227

196228
response = send_hfvllm_request_v01_wrapped(

0 commit comments

Comments
 (0)