44# Copyright (c) 2023 Oracle and/or its affiliates.
55# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
7- import base64
87import json
98import os
109import tempfile
11- from copy import deepcopy
1210from typing import Any , Dict , List , Optional
1311
1412import fsspec
2018from langchain .load import dumpd
2119from langchain .load .load import Reviver
2220from langchain .load .serializable import Serializable
23- from langchain .vectorstores import FAISS , OpenSearchVectorSearch
24- from opensearchpy .client import OpenSearch
21+ from langchain .schema .runnable import RunnableParallel
2522
2623from ads .common .auth import default_signer
2724from ads .common .object_storage_details import ObjectStorageDetails
2825from ads .llm import GenerativeAI , ModelDeploymentTGI , ModelDeploymentVLLM
2926from ads .llm .chain import GuardrailSequence
3027from ads .llm .guardrails .base import CustomGuardrailBase
31- from ads .llm .patch import RunnableParallel , RunnableParallelSerializer
28+ from ads .llm .serializers .runnable_parallel import RunnableParallelSerializer
29+ from ads .llm .serializers .retrieval_qa import RetrievalQASerializer
3230
3331# This is a temp solution for supporting custom LLM in legacy load_chain
3432__lc_llm_dict = llms .get_type_to_cls_dict ()
@@ -45,122 +43,6 @@ def __new_type_to_cls_dict():
4543loading .get_type_to_cls_dict = __new_type_to_cls_dict
4644
4745
48- class OpenSearchVectorDBSerializer :
49- """
50- Serializer for OpenSearchVectorSearch class
51- """
52- @staticmethod
53- def type ():
54- return OpenSearchVectorSearch .__name__
55-
56- @staticmethod
57- def load (config : dict , ** kwargs ):
58- config ["kwargs" ]["embedding_function" ] = load (
59- config ["kwargs" ]["embedding_function" ], ** kwargs
60- )
61- return OpenSearchVectorSearch (
62- ** config ["kwargs" ],
63- http_auth = (
64- os .environ .get ("OCI_OPENSEARCH_USERNAME" , None ),
65- os .environ .get ("OCI_OPENSEARCH_PASSWORD" , None ),
66- ),
67- verify_certs = True if os .environ .get ("OCI_OPENSEARCH_VERIFY_CERTS" , None ).lower () == "true" else False ,
68- ca_certs = os .environ .get ("OCI_OPENSEARCH_CA_CERTS" , None ),
69- )
70-
71- @staticmethod
72- def save (obj ):
73- serialized = dumpd (obj )
74- serialized ["type" ] = "constructor"
75- serialized ["_type" ] = OpenSearchVectorDBSerializer .type ()
76- kwargs = {}
77- for key , val in obj .__dict__ .items ():
78- if key == "client" :
79- if isinstance (val , OpenSearch ):
80- client_info = val .transport .hosts [0 ]
81- opensearch_url = (
82- f"https://{ client_info ['host' ]} :{ client_info ['port' ]} "
83- )
84- kwargs .update ({"opensearch_url" : opensearch_url })
85- else :
86- raise NotImplementedError ("Only support OpenSearch client." )
87- continue
88- kwargs [key ] = dump (val )
89- serialized ["kwargs" ] = kwargs
90- return serialized
91-
92-
93- class FaissSerializer :
94- """
95- Serializer for OpenSearchVectorSearch class
96- """
97- @staticmethod
98- def type ():
99- return FAISS .__name__
100-
101- @staticmethod
102- def load (config : dict , ** kwargs ):
103- embedding_function = load (config ["embedding_function" ], ** kwargs )
104- decoded_pkl = base64 .b64decode (json .loads (config ["vectordb" ]))
105- return FAISS .deserialize_from_bytes (
106- embeddings = embedding_function , serialized = decoded_pkl
107- ) # Load the index
108-
109- @staticmethod
110- def save (obj ):
111- serialized = {}
112- serialized ["_type" ] = FaissSerializer .type ()
113- pkl = obj .serialize_to_bytes ()
114- # Encoding bytes to a base64 string
115- encoded_pkl = base64 .b64encode (pkl ).decode ('utf-8' )
116- # Serializing the base64 string
117- serialized ["vectordb" ] = json .dumps (encoded_pkl )
118- serialized ["embedding_function" ] = dump (obj .__dict__ ["embedding_function" ])
119- return serialized
120-
121- # Mapping class to vector store serialization functions
122- vectordb_serialization = {"OpenSearchVectorSearch" : OpenSearchVectorDBSerializer , "FAISS" : FaissSerializer }
123-
124-
125- class RetrievalQASerializer :
126- """
127- Serializer for RetrieverQA class
128- """
129- @staticmethod
130- def type ():
131- return "retrieval_qa"
132-
133- @staticmethod
134- def load (config : dict , ** kwargs ):
135- config_param = deepcopy (config )
136- retriever_kwargs = config_param .pop ("retriever_kwargs" )
137- vectordb_serializer = vectordb_serialization [config_param ["vectordb" ]["class" ]]
138- vectordb = vectordb_serializer .load (config_param .pop ("vectordb" ), ** kwargs )
139- retriever = vectordb .as_retriever (** retriever_kwargs )
140- return load_chain_from_config (config = config_param , retriever = retriever )
141-
142- @staticmethod
143- def save (obj ):
144- serialized = obj .dict ()
145- retriever_kwargs = {}
146- for key , val in obj .retriever .__dict__ .items ():
147- if key not in ["tags" , "metadata" , "vectorstore" ]:
148- retriever_kwargs [key ] = val
149- serialized ["retriever_kwargs" ] = retriever_kwargs
150- serialized ["vectordb" ] = {"class" : obj .retriever .vectorstore .__class__ .__name__ }
151-
152- vectordb_serializer = vectordb_serialization [serialized ["vectordb" ]["class" ]]
153- serialized ["vectordb" ].update (
154- vectordb_serializer .save (obj .retriever .vectorstore )
155- )
156-
157- if serialized ["vectordb" ]["class" ] not in vectordb_serialization :
158- raise NotImplementedError (
159- f"VectorDBSerializer for { serialized ['vectordb' ]['class' ]} is not implemented."
160- )
161- return serialized
162-
163-
16446# Mapping class to custom serialization functions
16547custom_serialization = {
16648 GuardrailSequence : GuardrailSequence .save ,
0 commit comments