diff --git a/dsp/modules/hf_client.py b/dsp/modules/hf_client.py index bd2061e7a2..caea2f5e1e 100644 --- a/dsp/modules/hf_client.py +++ b/dsp/modules/hf_client.py @@ -113,26 +113,41 @@ def __init__(self, model, **kwargs): **kwargs } - def _generate(self, prompt, **kwargs): + def _generate(self, prompt, use_chat_api=True, **kwargs): url = f"{self.api_base}/chat/completions" kwargs = {**self.kwargs, **kwargs} temperature = kwargs.get("temperature") - messages = [{"role": "system", "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections."}, {"role": "user", "content": prompt}] - - body = { - "model": self.model, - "messages": messages, - "temperature": temperature, - "max_tokens": 150 - } + max_tokens = kwargs.get("max_tokens", 150) + + if use_chat_api: + messages = [ + {"role": "system", "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections."}, + {"role": "user", "content": prompt} + ] + body = { + "model": self.model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens + } + else: + body = { + "model": self.model, + "prompt": f"[INST]{prompt}[/INST]", + "temperature": temperature, + "max_tokens": max_tokens + } headers = {"Authorization": f"Bearer {self.token}"} try: with self.session.post(url, headers=headers, json=body) as resp: resp_json = resp.json() - completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")] + if use_chat_api: + completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")] + else: + completions = [resp_json.get('choices', [])[0].get('text', "")] response = {"prompt": prompt, "choices": [{"text": c} for c in completions]} return response except Exception as e: