11"""Clarifai as retriver to retrieve hits"""
2- from typing import List , Union
32import os
3+ from concurrent .futures import ThreadPoolExecutor
4+ from typing import List , Optional , Union
5+
6+ import requests
7+
48import dspy
59from dsp .utils import dotdict
6- import requests
7- from typing import Optional
8- from concurrent .futures import ThreadPoolExecutor
910
1011try :
1112 from clarifai .client .search import Search
@@ -25,51 +26,55 @@ class ClarifaiRM(dspy.Retrieve):
2526 clarfiai_app_id (str): Clarifai App ID, where the documents are stored.
2627 clarifai_pat (str): Clarifai PAT key.
2728 k (int): Top K documents to retrieve.
28-
29+
2930 Examples:
3031 TODO
3132 """
3233
33- def __init__ (self ,
34- clarifai_user_id : str ,
35- clarfiai_app_id : str ,
36- clarifai_pat : Optional [ str ] = None ,
37- k : int = 3 ,
38-
34+ def __init__ (
35+ self ,
36+ clarifai_user_id : str ,
37+ clarfiai_app_id : str ,
38+ clarifai_pat : Optional [ str ] = None ,
39+ k : int = 3 ,
3940 ):
4041 self .app_id = clarfiai_app_id
4142 self .user_id = clarifai_user_id
42- self .pat = clarifai_pat if clarifai_pat is not None else os .environ ["CLARIFAI_PAT" ]
43- self .k = k
44- self .clarifai_search = Search (user_id = self .user_id , app_id = self .app_id , top_k = k , pat = self .pat )
43+ self .pat = (
44+ clarifai_pat if clarifai_pat is not None else os .environ ["CLARIFAI_PAT" ]
45+ )
46+ self .k = k
47+ self .clarifai_search = Search (
48+ user_id = self .user_id , app_id = self .app_id , top_k = k , pat = self .pat
49+ )
4550 super ().__init__ (k = k )
46-
51+
4752 def retrieve_hits (self , hits ):
48- header = {"Authorization" : f"Key { self .pat } " }
49- request = requests .get (hits .input .data .text .url , headers = header )
50- request .encoding = request .apparent_encoding
51- requested_text = request .text
52- return requested_text
53-
54- def forward (self , query_or_queries : Union [str , List [str ]], k : Optional [int ] = None
55- ) -> dspy .Prediction :
53+ header = {"Authorization" : f"Key { self .pat } " }
54+ request = requests .get (hits .input .data .text .url , headers = header )
55+ request .encoding = request .apparent_encoding
56+ requested_text = request .text
57+ return requested_text
5658
59+ def forward (
60+ self , query_or_queries : Union [str , List [str ]], k : Optional [int ] = None
61+ ) -> dspy .Prediction :
5762 """Uses clarifai-python SDK search function and retrieves top_k similar passages for given query,
58- Args:
59- query_or_queries : single query or list of queries
60- k : Top K relevant documents to return
61-
62- Returns:
63- passages in format of dotdict
64-
65- Examples:
66- Below is a code snippet that shows how to use Marqo as the default retriver:
67- ```python
68- import clarifai
69- llm = dspy.Clarifai(model=MODEL_URL, api_key="YOUR CLARIFAI_PAT")
70- retriever_model = ClarifaiRM(clarifai_user_id="USER_ID", clarfiai_app_id="APP_ID", clarifai_pat="YOUR CLARIFAI_PAT")
71- dspy.settings.configure(lm=llm, rm=retriever_model)
72- ```
63+ Args:
64+ query_or_queries : single query or list of queries
65+ k : Top K relevant documents to return
66+
67+ Returns:
68+ passages in format of dotdict
69+
70+ Examples:
71+ Below is a code snippet that shows how to use Marqo as the default retriver:
72+ ```python
73+ import clarifai
74+ llm = dspy.Clarifai(model=MODEL_URL, api_key="YOUR CLARIFAI_PAT")
75+ retriever_model = ClarifaiRM(clarifai_user_id="USER_ID", clarfiai_app_id="APP_ID", clarifai_pat="YOUR CLARIFAI_PAT")
76+ dspy.settings.configure(lm=llm, rm=retriever_model)
77+ ```
7378 """
7479 queries = (
7580 [query_or_queries ]
@@ -81,10 +86,10 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No
8186 queries = [q for q in queries if q ]
8287
8388 for query in queries :
84- search_response = self .clarifai_search .query (ranks = [{"text_raw" : query }])
89+ search_response = self .clarifai_search .query (ranks = [{"text_raw" : query }])
8590
8691 # Retrieve hits
87- hits = [hit for data in search_response for hit in data .hits ]
92+ hits = [hit for data in search_response for hit in data .hits ]
8893 with ThreadPoolExecutor (max_workers = 10 ) as executor :
8994 results = list (executor .map (self .retrieve_hits , hits ))
9095 passages .extend (dotdict ({"long_text" : d }) for d in results )
0 commit comments