1010from importlib import reload
1111from unittest .mock import MagicMock
1212
13+ from mock import patch
1314import oci
1415from parameterized import parameterized
1516
1617import ads .aqua .model
1718import ads .config
1819from ads .aqua .model import AquaModelApp , AquaModelSummary
20+ from ads .model .datascience_model import DataScienceModel
21+ from ads .model .model_metadata import ModelCustomMetadata , ModelProvenanceMetadata , ModelTaxonomyMetadata
1922
2023
2124class TestDataset :
@@ -71,6 +74,8 @@ class TestAquaModel(unittest.TestCase):
7174 """Contains unittests for AquaModelApp."""
7275
7376 def setUp (self ):
77+ import ads
78+ ads .set_auth ("security_token" )
7479 self .app = AquaModelApp ()
7580
7681 @classmethod
@@ -89,11 +94,279 @@ def tearDownClass(cls):
8994 reload (ads .aqua )
9095 reload (ads .aqua .model )
9196
92- def test_create_model (self ):
93- pass
97+ @patch .object (DataScienceModel , "create" )
98+ @patch ("ads.model.datascience_model.validate" )
99+ @patch .object (DataScienceModel , "from_id" )
100+ def test_create_model (self , mock_from_id , mock_validate , mock_create ):
101+ service_model = MagicMock ()
102+ service_model .model_file_description = {"test_key" :"test_value" }
103+ service_model .display_name = "test_display_name"
104+ service_model .description = "test_description"
105+ service_model .freeform_tags = {"test_key" :"test_value" }
106+ custom_metadata_list = ModelCustomMetadata ()
107+ custom_metadata_list .add (
108+ key = "test_metadata_item_key" ,
109+ value = "test_metadata_item_value"
110+ )
111+ service_model .custom_metadata_list = custom_metadata_list
112+ service_model .provenance_metadata = ModelProvenanceMetadata (
113+ training_id = "test_training_id"
114+ )
115+ mock_from_id .return_value = service_model
116+
117+ # will not copy service model
118+ self .app .create (
119+ model_id = "test_model_id" ,
120+ project_id = "test_project_id" ,
121+ compartment_id = "test_compartment_id" ,
122+ )
123+
124+ mock_from_id .assert_called_with ("test_model_id" )
125+ mock_validate .assert_not_called ()
126+ mock_create .assert_not_called ()
127+
128+ service_model .compartment_id = TestDataset .SERVICE_COMPARTMENT_ID
129+ mock_from_id .return_value = service_model
130+
131+ # will copy service model
132+ self .app .create (
133+ model_id = "test_model_id" ,
134+ project_id = "test_project_id" ,
135+ compartment_id = "test_compartment_id"
136+ )
137+
138+ mock_from_id .assert_called_with ("test_model_id" )
139+ mock_validate .assert_called ()
140+ mock_create .assert_called_with (
141+ model_by_reference = True
142+ )
143+
144+ @patch ("ads.aqua.model.read_file" )
145+ @patch .object (DataScienceModel , "from_id" )
146+ def test_get_model_not_fine_tuned (self , mock_from_id , mock_read_file ):
147+ ds_model = MagicMock ()
148+ ds_model .id = "test_id"
149+ ds_model .compartment_id = "test_compartment_id"
150+ ds_model .project_id = "test_project_id"
151+ ds_model .display_name = "test_display_name"
152+ ds_model .description = "test_description"
153+ ds_model .freeform_tags = {
154+ "OCI_AQUA" :"ACTIVE" ,
155+ "license" :"test_license" ,
156+ "organization" :"test_organization" ,
157+ "task" :"test_task"
158+ }
159+ ds_model .time_created = "2024-01-19T17:57:39.158000+00:00"
160+ custom_metadata_list = ModelCustomMetadata ()
161+ custom_metadata_list .add (
162+ key = "artifact_location" ,
163+ value = "oci://bucket@namespace/prefix/"
164+ )
165+ ds_model .custom_metadata_list = custom_metadata_list
166+
167+ mock_from_id .return_value = ds_model
168+ mock_read_file .return_value = "test_model_card"
169+
170+ aqua_model = self .app .get (model_id = "test_model_id" )
171+
172+ mock_from_id .assert_called_with ("test_model_id" )
173+ mock_read_file .assert_called_with (
174+ file_path = "oci://bucket@namespace/prefix/README.md" ,
175+ auth = self .app ._auth ,
176+ )
177+
178+ assert asdict (aqua_model ) == {
179+ 'compartment_id' : f'{ ds_model .compartment_id } ' ,
180+ 'console_link' : (
181+ f'https://cloud.oracle.com/data-science/models/{ ds_model .id } ?region={ self .app .region } ' ,
182+ ),
183+ 'icon' : '' ,
184+ 'id' : f'{ ds_model .id } ' ,
185+ 'is_fine_tuned_model' : False ,
186+ 'license' : f'{ ds_model .freeform_tags ["license" ]} ' ,
187+ 'model_card' : f'{ mock_read_file .return_value } ' ,
188+ 'name' : f'{ ds_model .display_name } ' ,
189+ 'organization' : f'{ ds_model .freeform_tags ["organization" ]} ' ,
190+ 'project_id' : f'{ ds_model .project_id } ' ,
191+ 'ready_to_deploy' : True ,
192+ 'ready_to_finetune' : False ,
193+ 'search_text' : 'ACTIVE,test_license,test_organization,test_task' ,
194+ 'tags' : ds_model .freeform_tags ,
195+ 'task' : f'{ ds_model .freeform_tags ["task" ]} ' ,
196+ 'time_created' : f'{ ds_model .time_created } '
197+ }
198+
199+ @patch ("ads.aqua.utils.query_resource" )
200+ @patch ("ads.aqua.model.read_file" )
201+ @patch .object (DataScienceModel , "from_id" )
202+ def test_get_model_fine_tuned (self , mock_from_id , mock_read_file , mock_query_resource ):
203+ ds_model = MagicMock ()
204+ ds_model .id = "test_id"
205+ ds_model .compartment_id = "test_model_compartment_id"
206+ ds_model .project_id = "test_project_id"
207+ ds_model .display_name = "test_display_name"
208+ ds_model .description = "test_description"
209+ ds_model .model_version_set_id = "test_model_version_set_id"
210+ ds_model .model_version_set_name = "test_model_version_set_name"
211+ ds_model .freeform_tags = {
212+ "OCI_AQUA" :"ACTIVE" ,
213+ "license" :"test_license" ,
214+ "organization" :"test_organization" ,
215+ "task" :"test_task" ,
216+ "aqua_fine_tuned_model" :"test_finetuned_model"
217+ }
218+ ds_model .time_created = "2024-01-19T17:57:39.158000+00:00"
219+ ds_model .lifecycle_state = "ACTIVE"
220+ custom_metadata_list = ModelCustomMetadata ()
221+ custom_metadata_list .add (
222+ key = "artifact_location" ,
223+ value = "oci://bucket@namespace/prefix/"
224+ )
225+ custom_metadata_list .add (
226+ key = "fine_tune_source" ,
227+ value = "test_fine_tuned_source_id"
228+ )
229+ custom_metadata_list .add (
230+ key = "fine_tune_source_name" ,
231+ value = "test_fine_tuned_source_name"
232+ )
233+ ds_model .custom_metadata_list = custom_metadata_list
234+ defined_metadata_list = ModelTaxonomyMetadata ()
235+ defined_metadata_list ["Hyperparameters" ].value = {
236+ "training_data" : "test_training_data" ,
237+ "val_set_size" : "test_val_set_size"
238+ }
239+ ds_model .defined_metadata_list = defined_metadata_list
240+ ds_model .provenance_metadata = ModelProvenanceMetadata (
241+ training_id = "test_training_job_run_id"
242+ )
243+
244+ mock_from_id .return_value = ds_model
245+ mock_read_file .return_value = "test_model_card"
246+
247+ response = MagicMock ()
248+ job_run = MagicMock ()
249+ job_run .id = "test_job_run_id"
250+ job_run .lifecycle_state = "SUCCEEDED"
251+ job_run .lifecycle_details = "test lifecycle details"
252+ job_run .identifier = "test_job_id" ,
253+ job_run .display_name = "test_job_name"
254+ job_run .compartment_id = "test_job_run_compartment_id"
255+ job_infrastructure_configuration_details = MagicMock ()
256+ job_infrastructure_configuration_details .shape_name = "test_shape_name"
257+
258+ job_configuration_override_details = MagicMock ()
259+ job_configuration_override_details .environment_variables = {
260+ "NODE_COUNT" : 1
261+ }
262+ job_run .job_infrastructure_configuration_details = job_infrastructure_configuration_details
263+ job_run .job_configuration_override_details = job_configuration_override_details
264+ log_details = MagicMock ()
265+ log_details .log_id = "test_log_id"
266+ log_details .log_group_id = "test_log_group_id"
267+ job_run .log_details = log_details
268+ response .data = job_run
269+ self .app .ds_client .get_job_run = MagicMock (
270+ return_value = response
271+ )
272+
273+ query_resource = MagicMock ()
274+ query_resource .display_name = "test_display_name"
275+ mock_query_resource .return_value = query_resource
276+
277+ model = self .app .get (model_id = "test_model_id" )
278+
279+ mock_from_id .assert_called_with ("test_model_id" )
280+ mock_read_file .assert_called_with (
281+ file_path = "oci://bucket@namespace/prefix/README.md" ,
282+ auth = self .app ._auth ,
283+ )
284+ mock_query_resource .assert_called ()
285+
286+ assert asdict (model ) == {
287+ 'compartment_id' : f'{ ds_model .compartment_id } ' ,
288+ 'console_link' : (
289+ f'https://cloud.oracle.com/data-science/models/{ ds_model .id } ?region={ self .app .region } ' ,
290+ ),
291+ 'dataset' : 'test_training_data' ,
292+ 'experiment' : {'id' : '' , 'name' : '' , 'url' : '' },
293+ 'icon' : '' ,
294+ 'id' : f'{ ds_model .id } ' ,
295+ 'is_fine_tuned_model' : True ,
296+ 'job' : {'id' : '' , 'name' : '' , 'url' : '' },
297+ 'license' : 'test_license' ,
298+ 'lifecycle_details' : f'{ job_run .lifecycle_details } ' ,
299+ 'lifecycle_state' : f'{ ds_model .lifecycle_state } ' ,
300+ 'log' : {
301+ 'id' : f'{ log_details .log_id } ' ,
302+ 'name' : f'{ query_resource .display_name } ' ,
303+ 'url' : 'https://cloud.oracle.com/logging/search?searchQuery=search '
304+ f'"{ job_run .compartment_id } /{ log_details .log_group_id } /{ log_details .log_id } " | '
305+ f"source='{ job_run .id } ' | sort by datetime desc®ions={ self .app .region } "
306+ },
307+ 'log_group' : {
308+ 'id' : f'{ log_details .log_group_id } ' ,
309+ 'name' : f'{ query_resource .display_name } ' ,
310+ 'url' : f'https://cloud.oracle.com/logging/log-groups/{ log_details .log_group_id } ?region={ self .app .region } '
311+ },
312+ 'metrics' : [
313+ {
314+ 'category' : 'validation' ,
315+ 'name' : 'validation_metrics' ,
316+ 'scores' : []
317+ },
318+ {
319+ 'category' : 'training' ,
320+ 'name' : 'training_metrics' ,
321+ 'scores' : []
322+ },
323+ {
324+ 'category' : 'validation' ,
325+ 'name' : 'validation_metrics_final' ,
326+ 'scores' : []
327+ },
328+ {
329+ 'category' : 'training' ,
330+ 'name' : 'training_metrics_final' ,
331+ 'scores' : []
332+ }
333+ ],
334+ 'model_card' : f'{ mock_read_file .return_value } ' ,
335+ 'name' : f'{ ds_model .display_name } ' ,
336+ 'organization' : 'test_organization' ,
337+ 'project_id' : f'{ ds_model .project_id } ' ,
338+ 'ready_to_deploy' : True ,
339+ 'ready_to_finetune' : False ,
340+ 'search_text' : 'ACTIVE,test_license,test_organization,test_task,test_finetuned_model' ,
341+ 'shape_info' : {
342+ 'instance_shape' : f'{ job_infrastructure_configuration_details .shape_name } ' ,
343+ 'replica' : 1 ,
344+ },
345+ 'source' : {'id' : '' , 'name' : '' , 'url' : '' },
346+ 'tags' : ds_model .freeform_tags ,
347+ 'task' : 'test_task' ,
348+ 'time_created' : f'{ ds_model .time_created } ' ,
349+ 'validation' : {
350+ 'type' : 'Automatic split' ,
351+ 'value' : 'test_val_set_size'
352+ }
353+ }
354+
355+ @patch ("ads.aqua.model.read_file" )
356+ @patch ("ads.aqua.model.get_artifact_path" )
357+ def test_load_license (self , mock_get_artifact_path , mock_read_file ):
358+ self .app .ds_client .get_model = MagicMock ()
359+ mock_get_artifact_path .return_value = "oci://bucket@namespace/prefix/config/LICENSE.txt"
360+ mock_read_file .return_value = "test_license"
361+
362+ license = self .app .load_license (model_id = "test_model_id" )
363+
364+ mock_get_artifact_path .assert_called ()
365+ mock_read_file .assert_called ()
94366
95- def test_get_model (self ):
96- pass
367+ assert asdict (license ) == {
368+ 'id' : 'test_model_id' , 'license' : 'test_license'
369+ }
97370
98371 def test_list_service_models (self ):
99372 """Tests listing service models succesfully."""
0 commit comments