66
77
88import os
9+ from copy import deepcopy
910from unittest import TestCase , mock , SkipTest
1011
1112from langchain .llms import Cohere
2526class ChainSerializationTest (TestCase ):
2627 """Contains tests for chain serialization."""
2728
29+ # LangChain is updating frequently on the module organization,
30+ # mainly affecting the id field of the serialization.
31+ # In the test, we will not check the id field of some components.
32+ # We expect users to use the same LangChain version for serialize and de-serialize
33+
2834 def setUp (self ) -> None :
2935 self .maxDiff = None
3036 return super ().setUp ()
@@ -75,7 +81,6 @@ def setUp(self) -> None:
7581 "prompt" : {
7682 "lc" : 1 ,
7783 "type" : "constructor" ,
78- "id" : ["langchain_core" , "prompts" , "prompt" , "PromptTemplate" ],
7984 "kwargs" : {
8085 "input_variables" : ["subject" ],
8186 "template" : "Tell me a joke about {subject}" ,
@@ -118,12 +123,10 @@ def setUp(self) -> None:
118123 EXPECTED_RUNNABLE_SEQUENCE = {
119124 "lc" : 1 ,
120125 "type" : "constructor" ,
121- "id" : ["langchain_core" , "runnables" , "RunnableSequence" ],
122126 "kwargs" : {
123127 "first" : {
124128 "lc" : 1 ,
125129 "type" : "constructor" ,
126- "id" : ["langchain_core" , "runnables" , "RunnableParallel" ],
127130 "kwargs" : {
128131 "steps" : {
129132 "text" : {
@@ -144,7 +147,6 @@ def setUp(self) -> None:
144147 {
145148 "lc" : 1 ,
146149 "type" : "constructor" ,
147- "id" : ["langchain_core" , "prompts" , "prompt" , "PromptTemplate" ],
148150 "kwargs" : {
149151 "input_variables" : ["subject" ],
150152 "template" : "Tell me a joke about {subject}" ,
@@ -185,7 +187,10 @@ def test_llm_chain_serialization_with_oci(self):
185187 template = PromptTemplate .from_template (self .PROMPT_TEMPLATE )
186188 llm_chain = LLMChain (prompt = template , llm = llm )
187189 serialized = dump (llm_chain )
188- self .assertEqual (serialized , self .EXPECTED_LLM_CHAIN_WITH_OCI_MD )
190+ # Do not check the ID field.
191+ expected = deepcopy (self .EXPECTED_LLM_CHAIN_WITH_OCI_MD )
192+ expected ["kwargs" ]["prompt" ]["id" ] = serialized ["kwargs" ]["prompt" ]["id" ]
193+ self .assertEqual (serialized , expected )
189194 llm_chain = load (serialized )
190195 self .assertIsInstance (llm_chain , LLMChain )
191196 self .assertIsInstance (llm_chain .prompt , PromptTemplate )
@@ -202,8 +207,8 @@ def test_oci_gen_ai_serialization(self):
202207 compartment_id = self .COMPARTMENT_ID ,
203208 client_kwargs = self .GEN_AI_KWARGS ,
204209 )
205- except ImportError :
206- raise SkipTest ("OCI SDK does not support Generative AI." )
210+ except ImportError as ex :
211+ raise SkipTest ("OCI SDK does not support Generative AI." ) from ex
207212 serialized = dump (llm )
208213 self .assertEqual (serialized , self .EXPECTED_GEN_AI_LLM )
209214 llm = load (serialized )
@@ -216,8 +221,8 @@ def test_gen_ai_embeddings_serialization(self):
216221 embeddings = GenerativeAIEmbeddings (
217222 compartment_id = self .COMPARTMENT_ID , client_kwargs = self .GEN_AI_KWARGS
218223 )
219- except ImportError :
220- raise SkipTest ("OCI SDK does not support Generative AI." )
224+ except ImportError as ex :
225+ raise SkipTest ("OCI SDK does not support Generative AI." ) from ex
221226 serialized = dump (embeddings )
222227 self .assertEqual (serialized , self .EXPECTED_GEN_AI_EMBEDDINGS )
223228 embeddings = load (serialized )
@@ -232,7 +237,15 @@ def test_runnable_sequence_serialization(self):
232237
233238 chain = map_input | template | llm
234239 serialized = dump (chain )
235- self .assertEqual (serialized , self .EXPECTED_RUNNABLE_SEQUENCE )
240+ # Do not check the ID fields.
241+ expected = deepcopy (self .EXPECTED_RUNNABLE_SEQUENCE )
242+ expected ["id" ] = serialized ["id" ]
243+ expected ["kwargs" ]["first" ]["id" ] = serialized ["kwargs" ]["first" ]["id" ]
244+ expected ["kwargs" ]["first" ]["kwargs" ]["steps" ]["text" ]["id" ] = serialized [
245+ "kwargs"
246+ ]["first" ]["kwargs" ]["steps" ]["text" ]["id" ]
247+ expected ["kwargs" ]["middle" ][0 ]["id" ] = serialized ["kwargs" ]["middle" ][0 ]["id" ]
248+ self .assertEqual (serialized , expected )
236249 chain = load (serialized )
237250 self .assertEqual (len (chain .steps ), 3 )
238251 self .assertIsInstance (chain .steps [0 ], RunnableParallel )
0 commit comments