1- import unittest
2- from langchain .load .serializable import Serializable
3- from langchain .schema .embeddings import Embeddings
1+ #!/usr/bin/env python
2+ # -*- coding: utf-8 -*--
43
5- from langchain .vectorstores import OpenSearchVectorSearch , FAISS
4+ # Copyright (c) 2023 Oracle and/or its affiliates.
5+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
77
8- import unittest
9- from ads .llm .serialize import OpenSearchVectorDBSerializer , FaissSerializer , RetrievalQASerializer
10- from tests .unitary .with_extras .langchain .test_guardrails import FakeLLM
118import os
9+ import unittest
1210from unittest import mock
13- from typing import Any , Dict , List , Mapping , Optional
11+ from typing import List
12+ from langchain .load .serializable import Serializable
13+ from langchain .schema .embeddings import Embeddings
14+ from langchain .vectorstores import OpenSearchVectorSearch , FAISS
1415from langchain .chains import RetrievalQA
1516from langchain import llms
1617from langchain .llms import loading
1718
18-
19+ from ads .llm .serializers .retrieval_qa import (
20+ OpenSearchVectorDBSerializer ,
21+ FaissSerializer ,
22+ RetrievalQASerializer ,
23+ )
24+ from tests .unitary .with_extras .langchain .test_guardrails import FakeLLM
1925
2026
2127class FakeEmbeddings (Serializable , Embeddings ):
@@ -35,27 +41,38 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
3541
3642 def embed_query (self , text : str ) -> List [float ]:
3743 return [1 ] * 1024
38-
39-
44+
45+
4046class TestOpensearchSearchVectorSerializers (unittest .TestCase ):
4147 @classmethod
4248 def setUpClass (cls ):
43- cls .env_patcher = mock .patch .dict (os .environ , {"OCI_OPENSEARCH_USERNAME" : "username" ,
44- "OCI_OPENSEARCH_PASSWORD" : "password" ,
45- "OCI_OPENSEARCH_VERIFY_CERTS" : "True" ,
46- "OCI_OPENSEARCH_CA_CERTS" : "/path/to/cert.pem" })
49+ cls .env_patcher = mock .patch .dict (
50+ os .environ ,
51+ {
52+ "OCI_OPENSEARCH_USERNAME" : "username" ,
53+ "OCI_OPENSEARCH_PASSWORD" : "password" ,
54+ "OCI_OPENSEARCH_VERIFY_CERTS" : "True" ,
55+ "OCI_OPENSEARCH_CA_CERTS" : "/path/to/cert.pem" ,
56+ },
57+ )
4758 cls .env_patcher .start ()
4859 cls .index_name = "test_index"
4960 cls .embeddings = FakeEmbeddings ()
50- cls .opensearch = OpenSearchVectorSearch (
51- "https://localhost:8888" ,
52- embedding_function = cls .embeddings ,
53- index_name = cls .index_name ,
54- engine = "lucene" ,
55- http_auth = (os .environ ["OCI_OPENSEARCH_USERNAME" ], os .environ ["OCI_OPENSEARCH_PASSWORD" ]),
56- verify_certs = os .environ ["OCI_OPENSEARCH_VERIFY_CERTS" ],
57- ca_certs = os .environ ["OCI_OPENSEARCH_CA_CERTS" ],
58- )
61+ try :
62+ cls .opensearch = OpenSearchVectorSearch (
63+ "https://localhost:8888" ,
64+ embedding_function = cls .embeddings ,
65+ index_name = cls .index_name ,
66+ engine = "lucene" ,
67+ http_auth = (
68+ os .environ ["OCI_OPENSEARCH_USERNAME" ],
69+ os .environ ["OCI_OPENSEARCH_PASSWORD" ],
70+ ),
71+ verify_certs = os .environ ["OCI_OPENSEARCH_VERIFY_CERTS" ],
72+ ca_certs = os .environ ["OCI_OPENSEARCH_CA_CERTS" ],
73+ )
74+ except ImportError as ex :
75+ raise unittest .SkipTest ("opensearch-py is not installed." ) from ex
5976 cls .serializer = OpenSearchVectorDBSerializer ()
6077 super ().setUpClass ()
6178
@@ -65,7 +82,12 @@ def test_type(self):
6582
6683 def test_save (self ):
6784 serialized = self .serializer .save (self .opensearch )
68- assert serialized ["id" ] == ['langchain' , 'vectorstores' , 'opensearch_vector_search' , 'OpenSearchVectorSearch' ]
85+ assert serialized ["id" ] == [
86+ "langchain" ,
87+ "vectorstores" ,
88+ "opensearch_vector_search" ,
89+ "OpenSearchVectorSearch" ,
90+ ]
6991 assert serialized ["kwargs" ]["opensearch_url" ] == "https://localhost:8888"
7092 assert serialized ["kwargs" ]["engine" ] == "lucene"
7193 assert serialized ["_type" ] == "OpenSearchVectorSearch"
@@ -81,7 +103,10 @@ class TestFAISSSerializers(unittest.TestCase):
81103 def setUpClass (cls ):
82104 cls .embeddings = FakeEmbeddings ()
83105 text_embedding_pair = [("test" , [1 ] * 1024 )]
84- cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
106+ try :
107+ cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
108+ except ImportError as ex :
109+ raise unittest .SkipTest (ex .msg ) from ex
85110 cls .serializer = FaissSerializer ()
86111 super ().setUpClass ()
87112
@@ -90,7 +115,14 @@ def test_type(self):
90115
91116 def test_save (self ):
92117 serialized = self .serializer .save (self .db )
93- assert serialized ["embedding_function" ]["id" ] == ["tests" , "unitary" , "with_extras" , "langchain" , "test_serializers" , "FakeEmbeddings" ]
118+ assert serialized ["embedding_function" ]["id" ] == [
119+ "tests" ,
120+ "unitary" ,
121+ "with_extras" ,
122+ "langchain" ,
123+ "test_serializers" ,
124+ "FakeEmbeddings" ,
125+ ]
94126 assert isinstance (serialized ["vectordb" ], str )
95127
96128 def test_load (self ):
@@ -106,14 +138,18 @@ def setUpClass(cls):
106138 cls .llm = FakeLLM ()
107139 cls .embeddings = FakeEmbeddings ()
108140 text_embedding_pair = [("test" , [1 ] * 1024 )]
109- cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
141+ try :
142+ cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
143+ except ImportError as ex :
144+ raise unittest .SkipTest (ex .msg ) from ex
110145 cls .serializer = FaissSerializer ()
111146 cls .retriever = cls .db .as_retriever ()
112- cls .qa = RetrievalQA .from_chain_type (llm = cls . llm ,
113- chain_type = "stuff" ,
114- retriever = cls . retriever )
147+ cls .qa = RetrievalQA .from_chain_type (
148+ llm = cls . llm , chain_type = "stuff" , retriever = cls . retriever
149+ )
115150 cls .serializer = RetrievalQASerializer ()
116151 from copy import deepcopy
152+
117153 cls .original_type_to_cls_dict = deepcopy (llms .get_type_to_cls_dict ())
118154 __lc_llm_dict = llms .get_type_to_cls_dict ()
119155 __lc_llm_dict ["custom_embedding" ] = lambda : FakeEmbeddings
@@ -158,4 +194,4 @@ def tearDownClass(cls) -> None:
158194
159195
160196if __name__ == "__main__" :
161- unittest .main ()
197+ unittest .main ()
0 commit comments