44# Copyright (c) 2023 Oracle and/or its affiliates.
55# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
7+ import base64
78import json
89import os
910import tempfile
11+ from copy import deepcopy
1012from typing import Any , Dict , List , Optional
1113
1214import fsspec
1315import yaml
1416from langchain import llms
15- from langchain .llms import loading
17+ from langchain .chains import RetrievalQA
1618from langchain .chains .loading import load_chain_from_config
19+ from langchain .llms import loading
20+ from langchain .load import dumpd
1721from langchain .load .load import Reviver
1822from langchain .load .serializable import Serializable
23+ from langchain .vectorstores import FAISS , OpenSearchVectorSearch
24+ from opensearchpy .client import OpenSearch
1925
2026from ads .common .auth import default_signer
2127from ads .common .object_storage_details import ObjectStorageDetails
22- from ads .llm import GenerativeAI , ModelDeploymentVLLM , ModelDeploymentTGI
28+ from ads .llm import GenerativeAI , ModelDeploymentTGI , ModelDeploymentVLLM
2329from ads .llm .chain import GuardrailSequence
2430from ads .llm .guardrails .base import CustomGuardrailBase
2531from ads .llm .patch import RunnableParallel , RunnableParallelSerializer
2632
27-
2833# This is a temp solution for supporting custom LLM in legacy load_chain
2934__lc_llm_dict = llms .get_type_to_cls_dict ()
3035__lc_llm_dict [GenerativeAI .__name__ ] = lambda : GenerativeAI
@@ -39,11 +44,129 @@ def __new_type_to_cls_dict():
3944llms .get_type_to_cls_dict = __new_type_to_cls_dict
4045loading .get_type_to_cls_dict = __new_type_to_cls_dict
4146
47+
48+ class OpenSearchVectorDBSerializer :
49+ """
50+ Serializer for OpenSearchVectorSearch class
51+ """
52+ @staticmethod
53+ def type ():
54+ return OpenSearchVectorSearch .__name__
55+
56+ @staticmethod
57+ def load (config : dict , ** kwargs ):
58+ config ["kwargs" ]["embedding_function" ] = load (
59+ config ["kwargs" ]["embedding_function" ], ** kwargs
60+ )
61+ return OpenSearchVectorSearch (
62+ ** config ["kwargs" ],
63+ http_auth = (
64+ os .environ .get ("OCI_OPENSEARCH_USERNAME" , None ),
65+ os .environ .get ("OCI_OPENSEARCH_PASSWORD" , None ),
66+ ),
67+ verify_certs = True if os .environ .get ("OCI_OPENSEARCH_VERIFY_CERTS" , None ).lower () == "true" else False ,
68+ ca_certs = os .environ .get ("OCI_OPENSEARCH_CA_CERTS" , None ),
69+ )
70+
71+ @staticmethod
72+ def save (obj ):
73+ serialized = dumpd (obj )
74+ serialized ["type" ] = "constructor"
75+ serialized ["_type" ] = OpenSearchVectorDBSerializer .type ()
76+ kwargs = {}
77+ for key , val in obj .__dict__ .items ():
78+ if key == "client" :
79+ if isinstance (val , OpenSearch ):
80+ client_info = val .transport .hosts [0 ]
81+ opensearch_url = (
82+ f"https://{ client_info ['host' ]} :{ client_info ['port' ]} "
83+ )
84+ kwargs .update ({"opensearch_url" : opensearch_url })
85+ else :
86+ raise NotImplementedError ("Only support OpenSearch client." )
87+ continue
88+ kwargs [key ] = dump (val )
89+ serialized ["kwargs" ] = kwargs
90+ return serialized
91+
92+
93+ class FaissSerializer :
94+ """
95+ Serializer for OpenSearchVectorSearch class
96+ """
97+ @staticmethod
98+ def type ():
99+ return FAISS .__name__
100+
101+ @staticmethod
102+ def load (config : dict , ** kwargs ):
103+ embedding_function = load (config ["embedding_function" ], ** kwargs )
104+ decoded_pkl = base64 .b64decode (json .loads (config ["vectordb" ]))
105+ return FAISS .deserialize_from_bytes (
106+ embeddings = embedding_function , serialized = decoded_pkl
107+ ) # Load the index
108+
109+ @staticmethod
110+ def save (obj ):
111+ serialized = {}
112+ serialized ["_type" ] = FaissSerializer .type ()
113+ pkl = obj .serialize_to_bytes ()
114+ # Encoding bytes to a base64 string
115+ encoded_pkl = base64 .b64encode (pkl ).decode ('utf-8' )
116+ # Serializing the base64 string
117+ serialized ["vectordb" ] = json .dumps (encoded_pkl )
118+ serialized ["embedding_function" ] = dump (obj .__dict__ ["embedding_function" ])
119+ return serialized
120+
121+ # Mapping class to vector store serialization functions
122+ vectordb_serialization = {"OpenSearchVectorSearch" : OpenSearchVectorDBSerializer , "FAISS" : FaissSerializer }
123+
124+
125+ class RetrievalQASerializer :
126+ """
127+ Serializer for RetrieverQA class
128+ """
129+ @staticmethod
130+ def type ():
131+ return "retrieval_qa"
132+
133+ @staticmethod
134+ def load (config : dict , ** kwargs ):
135+ config_param = deepcopy (config )
136+ retriever_kwargs = config_param .pop ("retriever_kwargs" )
137+ vectordb_serializer = vectordb_serialization [config_param ["vectordb" ]["class" ]]
138+ vectordb = vectordb_serializer .load (config_param .pop ("vectordb" ), ** kwargs )
139+ retriever = vectordb .as_retriever (** retriever_kwargs )
140+ return load_chain_from_config (config = config_param , retriever = retriever )
141+
142+ @staticmethod
143+ def save (obj ):
144+ serialized = obj .dict ()
145+ retriever_kwargs = {}
146+ for key , val in obj .retriever .__dict__ .items ():
147+ if key not in ["tags" , "metadata" , "vectorstore" ]:
148+ retriever_kwargs [key ] = val
149+ serialized ["retriever_kwargs" ] = retriever_kwargs
150+ serialized ["vectordb" ] = {"class" : obj .retriever .vectorstore .__class__ .__name__ }
151+
152+ vectordb_serializer = vectordb_serialization [serialized ["vectordb" ]["class" ]]
153+ serialized ["vectordb" ].update (
154+ vectordb_serializer .save (obj .retriever .vectorstore )
155+ )
156+
157+ if serialized ["vectordb" ]["class" ] not in vectordb_serialization :
158+ raise NotImplementedError (
159+ f"VectorDBSerializer for { serialized ['vectordb' ]['class' ]} is not implemented."
160+ )
161+ return serialized
162+
163+
42164# Mapping class to custom serialization functions
43165custom_serialization = {
44166 GuardrailSequence : GuardrailSequence .save ,
45167 CustomGuardrailBase : CustomGuardrailBase .save ,
46168 RunnableParallel : RunnableParallelSerializer .save ,
169+ RetrievalQA : RetrievalQASerializer .save ,
47170}
48171
49172# Mapping _type to custom deserialization functions
@@ -52,6 +175,7 @@ def __new_type_to_cls_dict():
52175 GuardrailSequence .type (): GuardrailSequence .load ,
53176 CustomGuardrailBase .type (): CustomGuardrailBase .load ,
54177 RunnableParallelSerializer .type (): RunnableParallelSerializer .load ,
178+ RetrievalQASerializer .type (): RetrievalQASerializer .load ,
55179}
56180
57181
0 commit comments