@@ -40,39 +40,6 @@ def setUp(self) -> None:
4040 GEN_AI_KWARGS = {"service_endpoint" : "https://endpoint.oraclecloud.com" }
4141 ENDPOINT = "https://modeldeployment.customer-oci.com/ocid/predict"
4242
43- EXPECTED_LLM_CHAIN_WITH_COHERE = {
44- "memory" : None ,
45- "verbose" : True ,
46- "tags" : None ,
47- "metadata" : None ,
48- "prompt" : {
49- "input_variables" : ["subject" ],
50- "input_types" : {},
51- "output_parser" : None ,
52- "partial_variables" : {},
53- "template" : "Tell me a joke about {subject}" ,
54- "template_format" : "f-string" ,
55- "validate_template" : False ,
56- "_type" : "prompt" ,
57- },
58- "llm" : {
59- "model" : None ,
60- "max_tokens" : 256 ,
61- "temperature" : 0.75 ,
62- "k" : 0 ,
63- "p" : 1 ,
64- "frequency_penalty" : 0.0 ,
65- "presence_penalty" : 0.0 ,
66- "truncate" : None ,
67- "_type" : "cohere" ,
68- },
69- "output_key" : "text" ,
70- "output_parser" : {"_type" : "default" },
71- "return_final_only" : True ,
72- "llm_kwargs" : {},
73- "_type" : "llm_chain" ,
74- }
75-
7643 EXPECTED_LLM_CHAIN_WITH_OCI_MD = {
7744 "lc" : 1 ,
7845 "type" : "constructor" ,
@@ -173,7 +140,23 @@ def test_llm_chain_serialization_with_cohere(self):
173140 template = PromptTemplate .from_template (self .PROMPT_TEMPLATE )
174141 llm_chain = LLMChain (prompt = template , llm = llm , verbose = True )
175142 serialized = dump (llm_chain )
176- self .assertEqual (serialized , self .EXPECTED_LLM_CHAIN_WITH_COHERE )
143+
144+ # Check the serialized chain
145+ self .assertTrue (serialized .get ("verbose" ))
146+ self .assertEqual (serialized .get ("_type" ), "llm_chain" )
147+
148+ # Check the serialized prompt template
149+ serialized_prompt = serialized .get ("prompt" )
150+ self .assertIsInstance (serialized_prompt , dict )
151+ self .assertEqual (serialized_prompt .get ("_type" ), "prompt" )
152+ self .assertEqual (set (serialized_prompt .get ("input_variables" )), {"subject" })
153+ self .assertEqual (serialized_prompt .get ("template" ), self .PROMPT_TEMPLATE )
154+
155+ # Check the serialized LLM
156+ serialized_llm = serialized .get ("llm" )
157+ self .assertIsInstance (serialized_llm , dict )
158+ self .assertEqual (serialized_llm .get ("_type" ), "cohere" )
159+
177160 llm_chain = load (serialized )
178161 self .assertIsInstance (llm_chain , LLMChain )
179162 self .assertIsInstance (llm_chain .prompt , PromptTemplate )
@@ -237,21 +220,37 @@ def test_runnable_sequence_serialization(self):
237220
238221 chain = map_input | template | llm
239222 serialized = dump (chain )
240- # Do not check the ID fields.
241- expected = deepcopy (self .EXPECTED_RUNNABLE_SEQUENCE )
242- expected ["id" ] = serialized ["id" ]
243- expected ["kwargs" ]["first" ]["id" ] = serialized ["kwargs" ]["first" ]["id" ]
244- expected ["kwargs" ]["first" ]["kwargs" ]["steps" ]["text" ]["id" ] = serialized [
245- "kwargs"
246- ]["first" ]["kwargs" ]["steps" ]["text" ]["id" ]
247- expected ["kwargs" ]["middle" ][0 ]["id" ] = serialized ["kwargs" ]["middle" ][0 ]["id" ]
248- self .assertEqual (serialized , expected )
223+
224+ self .assertEqual (serialized .get ("type" ), "constructor" )
225+ self .assertNotIn ("_type" , serialized )
226+
227+ kwargs = serialized .get ("kwargs" )
228+ self .assertIsInstance (kwargs , dict )
229+
230+ element_1 = kwargs .get ("first" )
231+ self .assertEqual (element_1 .get ("_type" ), "RunnableParallel" )
232+ step = element_1 .get ("kwargs" ).get ("steps" ).get ("text" )
233+ self .assertEqual (step .get ("id" )[- 1 ], "RunnablePassthrough" )
234+
235+ element_2 = kwargs .get ("middle" )[0 ]
236+ self .assertNotIn ("_type" , element_2 )
237+ self .assertEqual (element_2 .get ("kwargs" ).get ("template" ), self .PROMPT_TEMPLATE )
238+ self .assertEqual (element_2 .get ("kwargs" ).get ("input_variables" ), ["subject" ])
239+
240+ element_3 = kwargs .get ("last" )
241+ self .assertNotIn ("_type" , element_3 )
242+ self .assertEqual (element_3 .get ("id" ), ["ads" , "llm" , "ModelDeploymentTGI" ])
243+ self .assertEqual (
244+ element_3 .get ("kwargs" ),
245+ {"endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" },
246+ )
247+
249248 chain = load (serialized )
250249 self .assertEqual (len (chain .steps ), 3 )
251250 self .assertIsInstance (chain .steps [0 ], RunnableParallel )
252251 self .assertEqual (
253- chain .steps [0 ].dict (),
254- { "steps" : { " text": { "input_type" : None , "func" : None , "afunc" : None }}} ,
252+ list ( chain .steps [0 ].dict (). get ( "steps" ). keys () ),
253+ [ " text"] ,
255254 )
256255 self .assertIsInstance (chain .steps [1 ], PromptTemplate )
257256 self .assertIsInstance (chain .steps [2 ], ModelDeploymentTGI )
0 commit comments