11import os
2- from unittest import TestCase , mock
2+ from unittest import TestCase , mock , SkipTest
33
44from langchain .llms import Cohere
55from langchain .chains import LLMChain
66from langchain .prompts import PromptTemplate
77from langchain .schema .runnable import RunnableParallel , RunnablePassthrough
88
99from ads .llm .serialize import load , dump
10- from ads .llm import GenerativeAI , ModelDeploymentTGI , GenerativeAIEmbeddings
10+ from ads .llm import (
11+ GenerativeAI ,
12+ GenerativeAIEmbeddings ,
13+ ModelDeploymentTGI ,
14+ ModelDeploymentVLLM ,
15+ )
1116
1217
1318class ChainSerializationTest (TestCase ):
@@ -51,7 +56,7 @@ class ChainSerializationTest(TestCase):
5156 "_type" : "llm_chain" ,
5257 }
5358
54- EXPECTED_LLM_CHAIN_WITH_OCI_GEN_AI = {
59+ EXPECTED_LLM_CHAIN_WITH_OCI_MD = {
5560 "lc" : 1 ,
5661 "type" : "constructor" ,
5762 "id" : ["langchain" , "chains" , "llm" , "LLMChain" ],
@@ -70,12 +75,10 @@ class ChainSerializationTest(TestCase):
7075 "llm" : {
7176 "lc" : 1 ,
7277 "type" : "constructor" ,
73- "id" : ["ads" , "llm" , "GenerativeAI " ],
78+ "id" : ["ads" , "llm" , "ModelDeploymentVLLM " ],
7479 "kwargs" : {
75- "compartment_id" : "<ocid>" ,
76- "client_kwargs" : {
77- "service_endpoint" : "https://endpoint.oraclecloud.com"
78- },
80+ "endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" ,
81+ "model" : "my_model" ,
7982 },
8083 },
8184 },
@@ -166,31 +169,31 @@ def test_llm_chain_serialization_with_cohere(self):
166169 self .assertIsInstance (llm_chain .llm , Cohere )
167170 self .assertEqual (llm_chain .input_keys , ["subject" ])
168171
169- def test_llm_chain_serialization_with_oci_gen_ai (self ):
172+ def test_llm_chain_serialization_with_oci (self ):
170173 """Tests serialization of LLMChain with OCI Gen AI."""
171- llm = GenerativeAI (
172- compartment_id = self .COMPARTMENT_ID ,
173- client_kwargs = self .GEN_AI_KWARGS ,
174- )
174+ llm = ModelDeploymentVLLM (endpoint = self .ENDPOINT , model = "my_model" )
175175 template = PromptTemplate .from_template (self .PROMPT_TEMPLATE )
176176 llm_chain = LLMChain (prompt = template , llm = llm )
177177 serialized = dump (llm_chain )
178- self .assertEqual (serialized , self .EXPECTED_LLM_CHAIN_WITH_OCI_GEN_AI )
178+ self .assertEqual (serialized , self .EXPECTED_LLM_CHAIN_WITH_OCI_MD )
179179 llm_chain = load (serialized )
180180 self .assertIsInstance (llm_chain , LLMChain )
181181 self .assertIsInstance (llm_chain .prompt , PromptTemplate )
182182 self .assertEqual (llm_chain .prompt .template , self .PROMPT_TEMPLATE )
183- self .assertIsInstance (llm_chain .llm , GenerativeAI )
184- self .assertEqual (llm_chain .llm .compartment_id , self .COMPARTMENT_ID )
185- self .assertEqual (llm_chain .llm .client_kwargs , self . GEN_AI_KWARGS )
183+ self .assertIsInstance (llm_chain .llm , ModelDeploymentVLLM )
184+ self .assertEqual (llm_chain .llm .endpoint , self .ENDPOINT )
185+ self .assertEqual (llm_chain .llm .model , "my_model" )
186186 self .assertEqual (llm_chain .input_keys , ["subject" ])
187187
188188 def test_oci_gen_ai_serialization (self ):
189189 """Tests serialization of OCI Gen AI LLM."""
190- llm = GenerativeAI (
191- compartment_id = self .COMPARTMENT_ID ,
192- client_kwargs = self .GEN_AI_KWARGS ,
193- )
190+ try :
191+ llm = GenerativeAI (
192+ compartment_id = self .COMPARTMENT_ID ,
193+ client_kwargs = self .GEN_AI_KWARGS ,
194+ )
195+ except ImportError :
196+ raise SkipTest ("OCI SDK does not support Generative AI." )
194197 serialized = dump (llm )
195198 self .assertEqual (serialized , self .EXPECTED_GEN_AI_LLM )
196199 llm = load (serialized )
@@ -199,9 +202,12 @@ def test_oci_gen_ai_serialization(self):
199202
200203 def test_gen_ai_embeddings_serialization (self ):
201204 """Tests serialization of OCI Gen AI embeddings."""
202- embeddings = GenerativeAIEmbeddings (
203- compartment_id = self .COMPARTMENT_ID , client_kwargs = self .GEN_AI_KWARGS
204- )
205+ try :
206+ embeddings = GenerativeAIEmbeddings (
207+ compartment_id = self .COMPARTMENT_ID , client_kwargs = self .GEN_AI_KWARGS
208+ )
209+ except ImportError :
210+ raise SkipTest ("OCI SDK does not support Generative AI." )
205211 serialized = dump (embeddings )
206212 self .assertEqual (serialized , self .EXPECTED_GEN_AI_EMBEDDINGS )
207213 embeddings = load (serialized )
0 commit comments