1414from langchain import llms
1515from langchain .llms import loading
1616from langchain .chains .loading import load_chain_from_config
17- from langchain .load .load import load as __lc_load
17+ from langchain .load .load import Reviver , load as __lc_load
1818from langchain .load .serializable import Serializable
1919
2020from ads .common .auth import default_signer
@@ -76,14 +76,32 @@ def load(
7676 Returns:
7777 Revived LangChain objects.
7878 """
79+ # Add ADS as valid namespace
7980 if not valid_namespaces :
8081 valid_namespaces = []
8182 if "ads" not in valid_namespaces :
8283 valid_namespaces .append ("ads" )
8384
85+ reviver = Reviver (secrets_map , valid_namespaces )
86+
87+ def _load (obj : Any ) -> Any :
88+ if isinstance (obj , dict ):
89+ if "_type" in obj and obj ["_type" ] in custom_deserialization :
90+ if valid_namespaces :
91+ kwargs ["valid_namespaces" ] = valid_namespaces
92+ if secrets_map :
93+ kwargs ["secret_map" ] = secrets_map
94+ return custom_deserialization [obj ["_type" ]](obj , ** kwargs )
95+ # Need to revive leaf nodes before reviving this node
96+ loaded_obj = {k : _load (v ) for k , v in obj .items ()}
97+ return reviver (loaded_obj )
98+ if isinstance (obj , list ):
99+ return [_load (o ) for o in obj ]
100+ return obj
101+
84102 if isinstance (obj , dict ) and "_type" in obj :
85103 obj_type = obj ["_type" ]
86- # Check if the object requires a custom function to load.
104+ # Check if the object has custom load function .
87105 if obj_type in custom_deserialization :
88106 if valid_namespaces :
89107 kwargs ["valid_namespaces" ] = valid_namespaces
@@ -93,7 +111,7 @@ def load(
93111 # Legacy chain
94112 return load_chain_from_config (obj , ** kwargs )
95113
96- return __lc_load (obj , secrets_map = secrets_map , valid_namespaces = valid_namespaces )
114+ return _load (obj )
97115
98116
99117def load_from_yaml (
@@ -144,11 +162,30 @@ def default(obj: Any) -> Any:
144162 TypeError
145163 If the object is not LangChain serializable.
146164 """
165+ for super_class , save_fn in custom_serialization .items ():
166+ if isinstance (obj , super_class ):
167+ return save_fn (obj )
147168 if isinstance (obj , Serializable ) and obj .is_lc_serializable ():
148169 return obj .to_json ()
149170 raise TypeError (f"Serialization of { type (obj )} is not supported." )
150171
151172
173+ def __save (obj ):
174+ """Calls the legacy save method to save the object to temp json
175+ then load it into a dictionary.
176+ """
177+ try :
178+ temp_file = tempfile .NamedTemporaryFile (
179+ mode = "w" , encoding = "utf-8" , suffix = ".json" , delete = False
180+ )
181+ temp_file .close ()
182+ obj .save (temp_file .name )
183+ with open (temp_file .name , "r" , encoding = "utf-8" ) as f :
184+ return json .load (f )
185+ finally :
186+ os .unlink (temp_file .name )
187+
188+
152189def dump (obj : Any ) -> Dict [str , Any ]:
153190 """Return a json dict representation of an object.
154191
@@ -167,14 +204,14 @@ def dump(obj: Any) -> Dict[str, Any]:
167204 ):
168205 # The object is not is_lc_serializable.
169206 # However, it supports the legacy save() method.
170- try :
171- temp_file = tempfile . NamedTemporaryFile (
172- mode = "w" , encoding = "utf-8" , suffix = ".json" , delete = False
173- )
174- temp_file . close ()
175- obj . save ( temp_file . name )
176- with open ( temp_file . name , "r" , encoding = "utf-8" ) as f :
177- return json . load ( f )
178- finally :
179- os . unlink ( temp_file . name )
180- return json . loads ( json . dumps ( obj , default = default ))
207+ return __save ( obj )
208+ # The object is is_lc_serializable.
209+ # However, some properties may not be serializable
210+ # Here we try to dump the object and fallback to the save() method
211+ # if there is an error.
212+ try :
213+ return json . loads ( json . dumps ( obj , default = default ))
214+ except TypeError as ex :
215+ if isinstance ( obj , Serializable ) and hasattr ( obj , "save" ) :
216+ return __save ( obj )
217+ raise ex
0 commit comments