Skip to content

Commit d61e15b

Browse files
authored
Merge pull request #1 from mogith-pn/clarifai-dspy-integration
TT-3084-clarifai-dspy-integration
2 parents 8c6ba34 + e318a2d commit d61e15b

File tree

5 files changed

+211
-7
lines changed

5 files changed

+211
-7
lines changed

dsp/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .cohere import *
77
from .sbert import *
88
from .pyserini import *
9+
from .clarifai import *
910

1011
from .hf_client import HFClientTGI
1112
from .hf_client import Anyscale

dsp/modules/clarifai.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""Clarifai LM integration"""
2+
from typing import Any, Optional
3+
4+
from dsp.modules.lm import LM
5+
6+
try:
7+
from clarifai.client.model import Model
8+
except ImportError as err:
9+
raise ImportError("ClarifaiLLM requires `pip install clarifai`.") from err
10+
11+
12+
class ClarifaiLLM(LM):
13+
"""Integration to call models hosted in clarifai platform.
14+
15+
Args:
16+
model (str, optional): Clarifai URL of the model. Defaults to "Mistral-7B-Instruct".
17+
api_key (Optional[str], optional): CLARIFAI_PAT token. Defaults to None.
18+
**kwargs: Additional arguments to pass to the API provider.
19+
Example:
20+
import dspy
21+
dspy.configure(lm=dspy.Clarifai(model=MODEL_URL,
22+
api_key=CLARIFAI_PAT,
23+
inference_params={"max_tokens":100,'temperature':0.6}))
24+
"""
25+
26+
def __init__(
27+
self,
28+
model: str = "https://clarifai.com/mistralai/completion/models/mistral-7B-Instruct",
29+
api_key: Optional[str] = None,
30+
**kwargs,
31+
):
32+
super().__init__(model)
33+
34+
self.provider = "clarifai"
35+
self.pat = api_key
36+
self._model = Model(url=model, pat=api_key)
37+
self.kwargs = {"n": 1, **kwargs}
38+
self.history: list[dict[str, Any]] = []
39+
self.kwargs["temperature"] = (
40+
self.kwargs["inference_params"]["temperature"]
41+
if "inference_params" in self.kwargs
42+
and "temperature" in self.kwargs["inference_params"]
43+
else 0.0
44+
)
45+
self.kwargs["max_tokens"] = (
46+
self.kwargs["inference_params"]["max_tokens"]
47+
if "inference_params" in self.kwargs
48+
and "max_tokens" in self.kwargs["inference_params"]
49+
else 150
50+
)
51+
52+
def basic_request(self, prompt, **kwargs):
53+
params = (
54+
self.kwargs["inference_params"] if "inference_params" in self.kwargs else {}
55+
)
56+
response = (
57+
self._model.predict_by_bytes(
58+
input_bytes=prompt.encode(encoding="utf-8"),
59+
input_type="text",
60+
inference_params=params,
61+
)
62+
.outputs[0]
63+
.data.text.raw
64+
)
65+
kwargs = {**self.kwargs, **kwargs}
66+
history = {
67+
"prompt": prompt,
68+
"response": response,
69+
"kwargs": kwargs,
70+
}
71+
self.history.append(history)
72+
return response
73+
74+
def request(self, prompt: str, **kwargs):
75+
return self.basic_request(prompt, **kwargs)
76+
77+
def __call__(
78+
self,
79+
prompt: str,
80+
only_completed: bool = True,
81+
return_sorted: bool = False,
82+
**kwargs,
83+
):
84+
assert only_completed, "for now"
85+
assert return_sorted is False, "for now"
86+
87+
n = kwargs.pop("n", 1)
88+
completions = []
89+
90+
for i in range(n):
91+
response = self.request(prompt, **kwargs)
92+
completions.append(response)
93+
94+
return completions

dsp/modules/lm.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,23 @@ def inspect_history(self, n: int = 1, skip: int = 0):
4545
prompt = x["prompt"]
4646

4747
if prompt != last_prompt:
48-
printed.append(
49-
(
50-
prompt,
51-
x["response"].generations
52-
if provider == "cohere"
53-
else x["response"]["choices"],
48+
49+
if provider=="clarifai":
50+
printed.append(
51+
(
52+
prompt,
53+
x['response']
54+
)
55+
)
56+
else:
57+
printed.append(
58+
(
59+
prompt,
60+
x["response"].generations
61+
if provider == "cohere"
62+
else x["response"]["choices"],
63+
)
5464
)
55-
)
5665

5766
last_prompt = prompt
5867

@@ -71,6 +80,8 @@ def inspect_history(self, n: int = 1, skip: int = 0):
7180
text = choices[0].text
7281
elif provider == "openai" or provider == "ollama":
7382
text = ' ' + self._get_choice_text(choices[0]).strip()
83+
elif provider == "clarifai":
84+
text=choices
7485
else:
7586
text = choices[0]["text"]
7687
self.print_green(text, end="")

dspy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
OpenAI = dsp.GPT3
1919
ColBERTv2 = dsp.ColBERTv2
2020
Pyserini = dsp.PyseriniRetriever
21+
Clarifai = dsp.ClarifaiLLM
2122

2223
HFClientTGI = dsp.HFClientTGI
2324
HFClientVLLM = HFClientVLLM

dspy/retrieve/clarifai_rm.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""Clarifai as retriver to retrieve hits"""
2+
import os
3+
from concurrent.futures import ThreadPoolExecutor
4+
from typing import List, Optional, Union
5+
6+
import requests
7+
8+
import dspy
9+
from dsp.utils import dotdict
10+
11+
try:
12+
from clarifai.client.search import Search
13+
except ImportError as err:
14+
raise ImportError(
15+
"Clarifai is not installed. Install it using `pip install clarifai`"
16+
) from err
17+
18+
19+
class ClarifaiRM(dspy.Retrieve):
20+
"""
21+
Retrieval module uses clarifai to return the Top K relevant pasages for the given query.
22+
Assuming that you have ingested the source documents into clarifai App, where it is indexed and stored.
23+
24+
Args:
25+
clarifai_user_id (str): Clarifai unique user_id.
26+
clarfiai_app_id (str): Clarifai App ID, where the documents are stored.
27+
clarifai_pat (str): Clarifai PAT key.
28+
k (int): Top K documents to retrieve.
29+
30+
Examples:
31+
TODO
32+
"""
33+
34+
def __init__(
35+
self,
36+
clarifai_user_id: str,
37+
clarfiai_app_id: str,
38+
clarifai_pat: Optional[str] = None,
39+
k: int = 3,
40+
):
41+
self.app_id = clarfiai_app_id
42+
self.user_id = clarifai_user_id
43+
self.pat = (
44+
clarifai_pat if clarifai_pat is not None else os.environ["CLARIFAI_PAT"]
45+
)
46+
self.k = k
47+
self.clarifai_search = Search(
48+
user_id=self.user_id, app_id=self.app_id, top_k=k, pat=self.pat
49+
)
50+
super().__init__(k=k)
51+
52+
def retrieve_hits(self, hits):
53+
header = {"Authorization": f"Key {self.pat}"}
54+
request = requests.get(hits.input.data.text.url, headers=header)
55+
request.encoding = request.apparent_encoding
56+
requested_text = request.text
57+
return requested_text
58+
59+
def forward(
60+
self, query_or_queries: Union[str, List[str]], k: Optional[int] = None
61+
) -> dspy.Prediction:
62+
"""Uses clarifai-python SDK search function and retrieves top_k similar passages for given query,
63+
Args:
64+
query_or_queries : single query or list of queries
65+
k : Top K relevant documents to return
66+
67+
Returns:
68+
passages in format of dotdict
69+
70+
Examples:
71+
Below is a code snippet that shows how to use Marqo as the default retriver:
72+
```python
73+
import clarifai
74+
llm = dspy.Clarifai(model=MODEL_URL, api_key="YOUR CLARIFAI_PAT")
75+
retriever_model = ClarifaiRM(clarifai_user_id="USER_ID", clarfiai_app_id="APP_ID", clarifai_pat="YOUR CLARIFAI_PAT")
76+
dspy.settings.configure(lm=llm, rm=retriever_model)
77+
```
78+
"""
79+
queries = (
80+
[query_or_queries]
81+
if isinstance(query_or_queries, str)
82+
else query_or_queries
83+
)
84+
self.clarifai_search.top_k = k if k is not None else self.clarifai_search.top_k
85+
passages = []
86+
queries = [q for q in queries if q]
87+
88+
for query in queries:
89+
search_response = self.clarifai_search.query(ranks=[{"text_raw": query}])
90+
91+
# Retrieve hits
92+
hits = [hit for data in search_response for hit in data.hits]
93+
with ThreadPoolExecutor(max_workers=10) as executor:
94+
results = list(executor.map(self.retrieve_hits, hits))
95+
passages.extend(dotdict({"long_text": d}) for d in results)
96+
97+
return passages

0 commit comments

Comments
 (0)