1313from langchain .schema .embeddings import Embeddings
1414from langchain .vectorstores import OpenSearchVectorSearch , FAISS
1515from langchain .chains import RetrievalQA
16- from langchain import llms
17- from langchain .llms import loading
16+ from langchain .llms import Cohere
1817
1918from ads .llm .serializers .retrieval_qa import (
2019 OpenSearchVectorDBSerializer ,
2120 FaissSerializer ,
2221 RetrievalQASerializer ,
2322)
24- from tests .unitary .with_extras .langchain .test_guardrails import FakeLLM
2523
2624
2725class FakeEmbeddings (Serializable , Embeddings ):
@@ -135,7 +133,7 @@ class TestRetrievalQASerializer(unittest.TestCase):
135133 @classmethod
136134 def setUpClass (cls ):
137135 # Create a sample RetrieverQA object for testing
138- cls .llm = FakeLLM ( )
136+ cls .llm = Cohere ( cohere_api_key = "api_key" )
139137 cls .embeddings = FakeEmbeddings ()
140138 text_embedding_pair = [("test" , [1 ] * 1024 )]
141139 try :
@@ -148,18 +146,6 @@ def setUpClass(cls):
148146 llm = cls .llm , chain_type = "stuff" , retriever = cls .retriever
149147 )
150148 cls .serializer = RetrievalQASerializer ()
151- from copy import deepcopy
152-
153- cls .original_type_to_cls_dict = deepcopy (llms .get_type_to_cls_dict ())
154- __lc_llm_dict = llms .get_type_to_cls_dict ()
155- __lc_llm_dict ["custom_embedding" ] = lambda : FakeEmbeddings
156- __lc_llm_dict ["custom" ] = lambda : FakeLLM
157-
158- def __new_type_to_cls_dict ():
159- return __lc_llm_dict
160-
161- llms .get_type_to_cls_dict = __new_type_to_cls_dict
162- loading .get_type_to_cls_dict = __new_type_to_cls_dict
163149
164150 def test_type (self ):
165151 self .assertEqual (self .serializer .type (), "retrieval_qa" )
@@ -176,6 +162,7 @@ def test_save(self):
176162 self .assertIn ("retriever_kwargs" , serialized )
177163 serialized ["vectordb" ]["class" ] == "FAISS"
178164
165+ @mock .patch .dict (os .environ , {"COHERE_API_KEY" : "api_key" })
179166 def test_load (self ):
180167 # Create a sample config dictionary
181168 serialized = self .serializer .save (self .qa )
@@ -186,12 +173,6 @@ def test_load(self):
186173 # Ensure that the deserialized object is an instance of RetrieverQA
187174 self .assertIsInstance (deserialized , RetrievalQA )
188175
189- @classmethod
190- def tearDownClass (cls ) -> None :
191- llms .get_type_to_cls_dict = cls .original_type_to_cls_dict
192- loading .get_type_to_cls_dict = cls .original_type_to_cls_dict
193- return super ().tearDownClass ()
194-
195176
196177if __name__ == "__main__" :
197178 unittest .main ()
0 commit comments