1414from langchain import llms
1515from langchain .llms import loading
1616from langchain .chains .loading import load_chain_from_config
17- from langchain .load .load import load as __lc_load
17+ from langchain .load .load import Reviver , load as __lc_load
1818from langchain .load .serializable import Serializable
1919
2020from ads .common .auth import default_signer
@@ -76,14 +76,32 @@ def load(
7676 Returns:
7777 Revived LangChain objects.
7878 """
79+ # Add ADS as valid namespace
7980 if not valid_namespaces :
8081 valid_namespaces = []
8182 if "ads" not in valid_namespaces :
8283 valid_namespaces .append ("ads" )
8384
85+ reviver = Reviver (secrets_map , valid_namespaces )
86+
87+ def _load (obj : Any ) -> Any :
88+ if isinstance (obj , dict ):
89+ if "_type" in obj and obj ["_type" ] in custom_deserialization :
90+ if valid_namespaces :
91+ kwargs ["valid_namespaces" ] = valid_namespaces
92+ if secrets_map :
93+ kwargs ["secret_map" ] = secrets_map
94+ return custom_deserialization [obj ["_type" ]](obj , ** kwargs )
95+ # Need to revive leaf nodes before reviving this node
96+ loaded_obj = {k : _load (v ) for k , v in obj .items ()}
97+ return reviver (loaded_obj )
98+ if isinstance (obj , list ):
99+ return [_load (o ) for o in obj ]
100+ return obj
101+
84102 if isinstance (obj , dict ) and "_type" in obj :
85103 obj_type = obj ["_type" ]
86- # Check if the object requires a custom function to load.
104+ # Check if the object has custom load function .
87105 if obj_type in custom_deserialization :
88106 if valid_namespaces :
89107 kwargs ["valid_namespaces" ] = valid_namespaces
@@ -93,7 +111,7 @@ def load(
93111 # Legacy chain
94112 return load_chain_from_config (obj , ** kwargs )
95113
96- return __lc_load (obj , secrets_map = secrets_map , valid_namespaces = valid_namespaces )
114+ return _load (obj )
97115
98116
99117def load_from_yaml (
@@ -144,6 +162,9 @@ def default(obj: Any) -> Any:
144162 TypeError
145163 If the object is not LangChain serializable.
146164 """
165+ for super_class , save_fn in custom_serialization .items ():
166+ if isinstance (obj , super_class ):
167+ return save_fn (obj )
147168 if isinstance (obj , Serializable ) and obj .is_lc_serializable ():
148169 return obj .to_json ()
149170 raise TypeError (f"Serialization of { type (obj )} is not supported." )
0 commit comments