Skip to content

Commit e032259

Browse files
committed
addressed linting
1 parent 7f15f8f commit e032259

File tree

2 files changed

+80
-72
lines changed

2 files changed

+80
-72
lines changed

dsp/modules/clarifai.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
"""Clarifai LM integration"""
22
from typing import Any, Optional
3+
34
from dsp.modules.lm import LM
45

56
try:
67
from clarifai.client.model import Model
78
except ImportError as err:
89
raise ImportError("ClarifaiLLM requires `pip install clarifai`.") from err
910

11+
1012
class ClarifaiLLM(LM):
1113
"""Integration to call models hosted in clarifai platform.
12-
14+
1315
Args:
1416
model (str, optional): Clarifai URL of the model. Defaults to "Mistral-7B-Instruct".
1517
api_key (Optional[str], optional): CLARIFAI_PAT token. Defaults to None.
1618
**kwargs: Additional arguments to pass to the API provider.
1719
Example:
1820
import dspy
19-
dspy.configure(lm=dspy.Clarifai(model=MODEL_URL,
20-
api_key=CLARIFAI_PAT,
21+
dspy.configure(lm=dspy.Clarifai(model=MODEL_URL,
22+
api_key=CLARIFAI_PAT,
2123
inference_params={"max_tokens":100,'temperature':0.6}))
2224
"""
2325

@@ -30,36 +32,36 @@ def __init__(
3032
super().__init__(model)
3133

3234
self.provider = "clarifai"
33-
self.pat=api_key
34-
self._model= Model(url=model, pat=api_key)
35-
self.kwargs = {
36-
"n": 1,
37-
**kwargs
38-
}
39-
self.history :list[dict[str, Any]] = []
40-
self.kwargs['temperature'] = (
41-
self.kwargs['inference_params']['temperature'] if
42-
'inference_params' in self.kwargs and
43-
'temperature' in self.kwargs['inference_params'] else 0.0
44-
)
45-
self.kwargs['max_tokens'] = (
46-
self.kwargs['inference_params']['max_tokens'] if
47-
'inference_params' in self.kwargs and
48-
'max_tokens' in self.kwargs['inference_params'] else 150
35+
self.pat = api_key
36+
self._model = Model(url=model, pat=api_key)
37+
self.kwargs = {"n": 1, **kwargs}
38+
self.history: list[dict[str, Any]] = []
39+
self.kwargs["temperature"] = (
40+
self.kwargs["inference_params"]["temperature"]
41+
if "inference_params" in self.kwargs
42+
and "temperature" in self.kwargs["inference_params"]
43+
else 0.0
44+
)
45+
self.kwargs["max_tokens"] = (
46+
self.kwargs["inference_params"]["max_tokens"]
47+
if "inference_params" in self.kwargs
48+
and "max_tokens" in self.kwargs["inference_params"]
49+
else 150
4950
)
50-
51+
5152
def basic_request(self, prompt, **kwargs):
5253
params = (
53-
self.kwargs['inference_params'] if 'inference_params' in self.kwargs
54-
else {}
54+
self.kwargs["inference_params"] if "inference_params" in self.kwargs else {}
5555
)
5656
response = (
57-
self._model.predict_by_bytes(
58-
input_bytes= prompt.encode(encoding="utf-8"),
59-
input_type= "text",
60-
inference_params= params,
61-
).outputs[0].data.text.raw
57+
self._model.predict_by_bytes(
58+
input_bytes=prompt.encode(encoding="utf-8"),
59+
input_type="text",
60+
inference_params=params,
6261
)
62+
.outputs[0]
63+
.data.text.raw
64+
)
6365
kwargs = {**self.kwargs, **kwargs}
6466
history = {
6567
"prompt": prompt,
@@ -68,21 +70,22 @@ def basic_request(self, prompt, **kwargs):
6870
}
6971
self.history.append(history)
7072
return response
71-
73+
7274
def request(self, prompt: str, **kwargs):
7375
return self.basic_request(prompt, **kwargs)
74-
75-
def __call__(self,
76+
77+
def __call__(
78+
self,
7679
prompt: str,
7780
only_completed: bool = True,
7881
return_sorted: bool = False,
79-
**kwargs
82+
**kwargs,
8083
):
8184
assert only_completed, "for now"
8285
assert return_sorted is False, "for now"
8386

8487
n = kwargs.pop("n", 1)
85-
completions=[]
88+
completions = []
8689

8790
for i in range(n):
8891
response = self.request(prompt, **kwargs)

dspy/retrieve/clarifai_rm.py

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Clarifai as retriver to retrieve hits"""
2-
from typing import List, Union
32
import os
3+
from concurrent.futures import ThreadPoolExecutor
4+
from typing import List, Optional, Union
5+
6+
import requests
7+
48
import dspy
59
from dsp.utils import dotdict
6-
import requests
7-
from typing import Optional
8-
from concurrent.futures import ThreadPoolExecutor
910

1011
try:
1112
from clarifai.client.search import Search
@@ -25,51 +26,55 @@ class ClarifaiRM(dspy.Retrieve):
2526
clarfiai_app_id (str): Clarifai App ID, where the documents are stored.
2627
clarifai_pat (str): Clarifai PAT key.
2728
k (int): Top K documents to retrieve.
28-
29+
2930
Examples:
3031
TODO
3132
"""
3233

33-
def __init__(self,
34-
clarifai_user_id: str,
35-
clarfiai_app_id: str,
36-
clarifai_pat: Optional[str] = None,
37-
k: int = 3,
38-
34+
def __init__(
35+
self,
36+
clarifai_user_id: str,
37+
clarfiai_app_id: str,
38+
clarifai_pat: Optional[str] = None,
39+
k: int = 3,
3940
):
4041
self.app_id = clarfiai_app_id
4142
self.user_id = clarifai_user_id
42-
self.pat = clarifai_pat if clarifai_pat is not None else os.environ["CLARIFAI_PAT"]
43-
self.k=k
44-
self.clarifai_search = Search(user_id=self.user_id, app_id=self.app_id, top_k=k, pat=self.pat)
43+
self.pat = (
44+
clarifai_pat if clarifai_pat is not None else os.environ["CLARIFAI_PAT"]
45+
)
46+
self.k = k
47+
self.clarifai_search = Search(
48+
user_id=self.user_id, app_id=self.app_id, top_k=k, pat=self.pat
49+
)
4550
super().__init__(k=k)
46-
51+
4752
def retrieve_hits(self, hits):
48-
header = {"Authorization": f"Key {self.pat}"}
49-
request = requests.get(hits.input.data.text.url, headers=header)
50-
request.encoding = request.apparent_encoding
51-
requested_text = request.text
52-
return requested_text
53-
54-
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None
55-
) -> dspy.Prediction:
53+
header = {"Authorization": f"Key {self.pat}"}
54+
request = requests.get(hits.input.data.text.url, headers=header)
55+
request.encoding = request.apparent_encoding
56+
requested_text = request.text
57+
return requested_text
5658

59+
def forward(
60+
self, query_or_queries: Union[str, List[str]], k: Optional[int] = None
61+
) -> dspy.Prediction:
5762
"""Uses clarifai-python SDK search function and retrieves top_k similar passages for given query,
58-
Args:
59-
query_or_queries : single query or list of queries
60-
k : Top K relevant documents to return
61-
62-
Returns:
63-
passages in format of dotdict
64-
65-
Examples:
66-
Below is a code snippet that shows how to use Marqo as the default retriver:
67-
```python
68-
import clarifai
69-
llm = dspy.Clarifai(model=MODEL_URL, api_key="YOUR CLARIFAI_PAT")
70-
retriever_model = ClarifaiRM(clarifai_user_id="USER_ID", clarfiai_app_id="APP_ID", clarifai_pat="YOUR CLARIFAI_PAT")
71-
dspy.settings.configure(lm=llm, rm=retriever_model)
72-
```
63+
Args:
64+
query_or_queries : single query or list of queries
65+
k : Top K relevant documents to return
66+
67+
Returns:
68+
passages in format of dotdict
69+
70+
Examples:
71+
Below is a code snippet that shows how to use Marqo as the default retriver:
72+
```python
73+
import clarifai
74+
llm = dspy.Clarifai(model=MODEL_URL, api_key="YOUR CLARIFAI_PAT")
75+
retriever_model = ClarifaiRM(clarifai_user_id="USER_ID", clarfiai_app_id="APP_ID", clarifai_pat="YOUR CLARIFAI_PAT")
76+
dspy.settings.configure(lm=llm, rm=retriever_model)
77+
```
7378
"""
7479
queries = (
7580
[query_or_queries]
@@ -81,10 +86,10 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No
8186
queries = [q for q in queries if q]
8287

8388
for query in queries:
84-
search_response= self.clarifai_search.query(ranks=[{"text_raw": query}])
89+
search_response = self.clarifai_search.query(ranks=[{"text_raw": query}])
8590

8691
# Retrieve hits
87-
hits=[hit for data in search_response for hit in data.hits]
92+
hits = [hit for data in search_response for hit in data.hits]
8893
with ThreadPoolExecutor(max_workers=10) as executor:
8994
results = list(executor.map(self.retrieve_hits, hits))
9095
passages.extend(dotdict({"long_text": d}) for d in results)

0 commit comments

Comments
 (0)