1+ import unittest
2+ from langchain .load .serializable import Serializable
3+ from langchain .schema .embeddings import Embeddings
4+
5+ from langchain .vectorstores import OpenSearchVectorSearch , FAISS
6+
7+
8+ import unittest
9+ from ads .llm .serialize import OpenSearchVectorDBSerializer , FaissSerializer , RetrievalQASerializer
10+ from tests .unitary .with_extras .langchain .test_guardrails import FakeLLM
11+ import os
12+ from unittest import mock
13+ from typing import Any , Dict , List , Mapping , Optional
14+ from langchain .chains import RetrievalQA
15+ from langchain import llms
16+ from langchain .llms import loading
17+
18+
19+
20+
21+ class FakeEmbeddings (Serializable , Embeddings ):
22+ """Fake LLM for testing purpose."""
23+
24+ @property
25+ def _llm_type (self ) -> str :
26+ return "custom_embeddings"
27+
28+ @classmethod
29+ def is_lc_serializable (cls ) -> bool :
30+ """This class can be serialized with default LangChain serialization."""
31+ return True
32+
33+ def embed_documents (self , texts : List [str ]) -> List [List [float ]]:
34+ return [[1 ] * 1024 for text in texts ]
35+
36+ def embed_query (self , text : str ) -> List [float ]:
37+ return [1 ] * 1024
38+
39+
40+ class TestOpensearchSearchVectorSerializers (unittest .TestCase ):
41+ @classmethod
42+ def setUpClass (cls ):
43+ cls .env_patcher = mock .patch .dict (os .environ , {"oci_opensearch_username" : "username" ,
44+ "oci_opensearch_password" : "password" ,
45+ "oci_opensearch_verify_certs" : "True" ,
46+ "oci_opensearch_ca_certs" : "/path/to/cert.pem" })
47+ cls .env_patcher .start ()
48+ cls .index_name = "test_index"
49+ cls .embeddings = FakeEmbeddings ()
50+ cls .opensearch = OpenSearchVectorSearch (
51+ "https://localhost:8888" ,
52+ embedding_function = cls .embeddings ,
53+ index_name = cls .index_name ,
54+ engine = "lucene" ,
55+ http_auth = (os .environ ["oci_opensearch_username" ], os .environ ["oci_opensearch_password" ]),
56+ verify_certs = os .environ ["oci_opensearch_verify_certs" ],
57+ ca_certs = os .environ ["oci_opensearch_ca_certs" ],
58+ )
59+ cls .serializer = OpenSearchVectorDBSerializer ()
60+ super ().setUpClass ()
61+
62+ def test_type (self ):
63+ # Test type()
64+ self .assertEqual (self .serializer .type (), "OpenSearchVectorSearch" )
65+
66+ def test_save (self ):
67+ serialized = self .serializer .save (self .opensearch )
68+ assert serialized ["id" ] == ['langchain' , 'vectorstores' , 'opensearch_vector_search' , 'OpenSearchVectorSearch' ]
69+ assert serialized ["kwargs" ]["opensearch_url" ] == "https://localhost:8888"
70+ assert serialized ["kwargs" ]["engine" ] == "lucene"
71+ assert serialized ["_type" ] == "OpenSearchVectorSearch"
72+
73+ def test_load (self ):
74+ serialized = self .serializer .save (self .opensearch )
75+ new_opensearch = self .serializer .load (serialized , valid_namespaces = ["tests" ])
76+ assert isinstance (new_opensearch , OpenSearchVectorSearch )
77+
78+
79+ class TestFAISSSerializers (unittest .TestCase ):
80+ @classmethod
81+ def setUpClass (cls ):
82+ cls .embeddings = FakeEmbeddings ()
83+ text_embedding_pair = [("test" , [1 ] * 1024 )]
84+ cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
85+ cls .serializer = FaissSerializer ()
86+ super ().setUpClass ()
87+
88+ def test_type (self ):
89+ self .assertEqual (self .serializer .type (), "FAISS" )
90+
91+ def test_save (self ):
92+ serialized = self .serializer .save (self .db )
93+ assert serialized ["embedding_function" ]["id" ] == ["tests" , "unitary" , "with_extras" , "langchain" , "test_serializers" , "FakeEmbeddings" ]
94+ assert isinstance (serialized ["vectordb" ], str )
95+
96+ def test_load (self ):
97+ serialized = self .serializer .save (self .db )
98+ new_db = self .serializer .load (serialized , valid_namespaces = ["tests" ])
99+ assert isinstance (new_db , FAISS )
100+
101+
102+ class TestRetrievalQASerializer (unittest .TestCase ):
103+ @classmethod
104+ def setUpClass (cls ):
105+ # Create a sample RetrieverQA object for testing
106+ cls .llm = FakeLLM ()
107+ cls .embeddings = FakeEmbeddings ()
108+ text_embedding_pair = [("test" , [1 ] * 1024 )]
109+ cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
110+ cls .serializer = FaissSerializer ()
111+ cls .retriever = cls .db .as_retriever ()
112+ cls .qa = RetrievalQA .from_chain_type (llm = cls .llm ,
113+ chain_type = "stuff" ,
114+ retriever = cls .retriever )
115+ cls .serializer = RetrievalQASerializer ()
116+ from copy import deepcopy
117+ cls .original_type_to_cls_dict = deepcopy (llms .get_type_to_cls_dict ())
118+ __lc_llm_dict = llms .get_type_to_cls_dict ()
119+ __lc_llm_dict ["custom_embedding" ] = lambda : FakeEmbeddings
120+ __lc_llm_dict ["custom" ] = lambda : FakeLLM
121+
122+ def __new_type_to_cls_dict ():
123+ return __lc_llm_dict
124+
125+ llms .get_type_to_cls_dict = __new_type_to_cls_dict
126+ loading .get_type_to_cls_dict = __new_type_to_cls_dict
127+
128+ def test_type (self ):
129+ self .assertEqual (self .serializer .type (), "retrieval_qa" )
130+
131+ def test_save (self ):
132+ # Serialize the RetrieverQA object
133+ serialized = self .serializer .save (self .qa )
134+
135+ # Ensure that the serialized object is a dictionary
136+ self .assertIsInstance (serialized , dict )
137+
138+ # Ensure that the serialized object contains the necessary keys
139+ self .assertIn ("combine_documents_chain" , serialized )
140+ self .assertIn ("retriever_kwargs" , serialized )
141+ serialized ["vectordb" ]["class" ] == "FAISS"
142+
143+ def test_load (self ):
144+ # Create a sample config dictionary
145+ serialized = self .serializer .save (self .qa )
146+
147+ # Deserialize the serialized object
148+ deserialized = self .serializer .load (serialized , valid_namespaces = ["tests" ])
149+
150+ # Ensure that the deserialized object is an instance of RetrieverQA
151+ self .assertIsInstance (deserialized , RetrievalQA )
152+
153+ @classmethod
154+ def tearDownClass (cls ) -> None :
155+ llms .get_type_to_cls_dict = cls .original_type_to_cls_dict
156+ loading .get_type_to_cls_dict = cls .original_type_to_cls_dict
157+ return super ().tearDownClass ()
158+
159+
160+ if __name__ == "__main__" :
161+ unittest .main ()
0 commit comments