Skip to content

Commit 7abf293

Browse files
authored
Merge pull request #3 from curieo-org/initial-groq-support
adding the inspect_history
2 parents cb6e258 + 2223065 commit 7abf293

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

dsp/modules/groq_client.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33
import json
4-
from typing import Any, Literal, Optional
4+
from typing import Any, Literal, Optional, Required
55
import backoff
66
from groq import Groq, AsyncGroq
77
import groq
@@ -42,15 +42,18 @@ class GroqLM(LM):
4242

4343
def __init__(
4444
self,
45+
api_key: str,
4546
model: str = "mixtral-8x7b-32768",
46-
api_key: Optional[str] = None,
4747
**kwargs,
4848
):
4949
super().__init__(model)
50-
50+
self.provider = "groq"
5151
if api_key:
5252
self.api_key = api_key
5353
self.client = Groq(api_key = api_key)
54+
else:
55+
raise ValueError("api_key is required for groq")
56+
5457

5558
self.kwargs = {
5659
"temperature": 0.0,
@@ -85,7 +88,7 @@ def basic_request(self, prompt: str, **kwargs):
8588

8689
history = {
8790
"prompt": prompt,
88-
"response": response,
91+
"response": response.choices[0].message.content,
8992
"kwargs": kwargs,
9093
"raw_kwargs": raw_kwargs,
9194
}

dsp/modules/lm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def inspect_history(self, n: int = 1, skip: int = 0):
4646

4747
if prompt != last_prompt:
4848

49-
if provider == "clarifai" or provider == "google" or provider == "claude":
49+
if provider == "clarifai" or provider == "google" or provider == "claude" or provider == "groq":
5050
printed.append(
5151
(
5252
prompt,
@@ -82,6 +82,9 @@ def inspect_history(self, n: int = 1, skip: int = 0):
8282
text = ' ' + self._get_choice_text(choices[0]).strip()
8383
elif provider == "clarifai" or provider == "claude" :
8484
text=choices
85+
elif provider == "groq":
86+
# print(choices)
87+
text = ' ' + choices
8588
elif provider == "google":
8689
text = choices[0].parts[0].text
8790
else:

0 commit comments

Comments
 (0)