33import re
44import shutil
55import subprocess
6+ from typing import Literal
67
78# from dsp.modules.adapter import TurboAdapter, DavinciAdapter, LlamaAdapter
89import backoff
@@ -117,7 +118,7 @@ def send_hftgi_request_v00(arg, **kwargs):
117118
118119
119120class HFClientVLLM (HFModel ):
120- def __init__ (self , model , port , url = "http://localhost" , ** kwargs ):
121+ def __init__ (self , model , port , model_type : Literal [ 'chat' , 'text' ] = 'text' , url = "http://localhost" , ** kwargs ):
121122 super ().__init__ (model = model , is_client = True )
122123
123124 if isinstance (url , list ):
@@ -129,49 +130,88 @@ def __init__(self, model, port, url="http://localhost", **kwargs):
129130 else :
130131 raise ValueError (f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type { type (url )} ." )
131132
133+ self .model_type = model_type
132134 self .headers = {"Content-Type" : "application/json" }
133135 self .kwargs |= kwargs
136+ # kwargs needs to have model, port and url for the lm.copy() to work properly
137+ self .kwargs .update ({
138+ 'port' : port ,
139+ 'url' : url ,
140+ })
134141
135142
136143 def _generate (self , prompt , ** kwargs ):
137144 kwargs = {** self .kwargs , ** kwargs }
138-
139- payload = {
140- "model" : self .kwargs ["model" ],
141- "prompt" : prompt ,
142- ** kwargs ,
143- }
144-
145145
146146 # Round robin the urls.
147147 url = self .urls .pop (0 )
148148 self .urls .append (url )
149+
150+ if self .model_type == "chat" :
151+ system_prompt = kwargs .get ("system_prompt" ,None )
152+ messages = [{"role" : "user" , "content" : prompt }]
153+ if system_prompt :
154+ messages .insert (0 , {"role" : "system" , "content" : system_prompt })
155+ payload = {
156+ "model" : self .kwargs ["model" ],
157+ "messages" : messages ,
158+ ** kwargs ,
159+ }
160+ response = send_hfvllm_chat_request_v00 (
161+ f"{ url } /v1/chat/completions" ,
162+ json = payload ,
163+ headers = self .headers ,
164+ )
165+
166+ try :
167+ json_response = response .json ()
168+ completions = json_response ["choices" ]
169+ response = {
170+ "prompt" : prompt ,
171+ "choices" : [{"text" : c ["message" ]['content' ]} for c in completions ],
172+ }
173+ return response
149174
150- response = send_hfvllm_request_v00 (
151- f"{ url } /v1/completions" ,
152- json = payload ,
153- headers = self .headers ,
154- )
155-
156- try :
157- json_response = response .json ()
158- completions = json_response ["choices" ]
159- response = {
175+ except Exception :
176+ print ("Failed to parse JSON response:" , response .text )
177+ raise Exception ("Received invalid JSON response from server" )
178+ else :
179+ payload = {
180+ "model" : self .kwargs ["model" ],
160181 "prompt" : prompt ,
161- "choices" : [{ "text" : c [ "text" ]} for c in completions ] ,
182+ ** kwargs ,
162183 }
163- return response
184+
185+ response = send_hfvllm_request_v00 (
186+ f"{ url } /v1/completions" ,
187+ json = payload ,
188+ headers = self .headers ,
189+ )
190+
191+ try :
192+ json_response = response .json ()
193+ completions = json_response ["choices" ]
194+ response = {
195+ "prompt" : prompt ,
196+ "choices" : [{"text" : c ["text" ]} for c in completions ],
197+ }
198+ return response
164199
165- except Exception :
166- print ("Failed to parse JSON response:" , response .text )
167- raise Exception ("Received invalid JSON response from server" )
200+ except Exception :
201+ print ("Failed to parse JSON response:" , response .text )
202+ raise Exception ("Received invalid JSON response from server" )
168203
169204
170205@CacheMemory .cache
171206def send_hfvllm_request_v00 (arg , ** kwargs ):
172207 return requests .post (arg , ** kwargs )
173208
174209
210+ @CacheMemory .cache
211+ def send_hfvllm_chat_request_v00 (arg , ** kwargs ):
212+ return requests .post (arg , ** kwargs )
213+
214+
175215class HFServerTGI :
176216 def __init__ (self , user_dir ):
177217 self .model_weights_dir = os .path .abspath (os .path .join (os .getcwd (), "text-generation-inference" , user_dir ))
@@ -438,4 +478,4 @@ def _generate(self, prompt, **kwargs):
438478
439479@CacheMemory .cache
440480def send_hfsglang_request_v00 (arg , ** kwargs ):
441- return requests .post (arg , ** kwargs )
481+ return requests .post (arg , ** kwargs )
0 commit comments