|
1 | 1 | #!/usr/bin/env python |
2 | 2 | # -*- coding: utf-8 -*-- |
3 | | - |
| 3 | +import json |
4 | 4 | # Copyright (c) 2024 Oracle and/or its affiliates. |
5 | 5 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
6 | 6 |
|
@@ -37,6 +37,21 @@ def mock_auth(): |
37 | 37 | yield mock_default_signer |
38 | 38 |
|
39 | 39 |
|
| 40 | +@pytest.fixture(autouse=True, scope="class") |
| 41 | +def mock_get_container_config(): |
| 42 | + with patch("ads.aqua.ui.get_container_config") as mock_config: |
| 43 | + with open( |
| 44 | + os.path.join( |
| 45 | + os.path.dirname(os.path.abspath(__file__)), |
| 46 | + "test_data/ui/container_index.json", |
| 47 | + ), |
| 48 | + "r", |
| 49 | + ) as _file: |
| 50 | + container_index_json = json.load(_file) |
| 51 | + mock_config.return_value = container_index_json |
| 52 | + yield mock_config |
| 53 | + |
| 54 | + |
40 | 55 | @pytest.fixture(autouse=True, scope="class") |
41 | 56 | def mock_init_client(): |
42 | 57 | with patch( |
@@ -256,6 +271,7 @@ def test_get_foundation_models( |
256 | 271 | mock_from_id, |
257 | 272 | mock_read_file, |
258 | 273 | foundation_model_type, |
| 274 | + mock_get_container_config, |
259 | 275 | mock_auth, |
260 | 276 | ): |
261 | 277 | ds_model = MagicMock() |
@@ -334,7 +350,7 @@ def test_get_foundation_models( |
334 | 350 | "model_card": f"{mock_read_file.return_value}", |
335 | 351 | "model_format": ModelFormat.SAFETENSORS, |
336 | 352 | "name": f"{ds_model.display_name}", |
337 | | - "nvidia_gpu_supported": False, |
| 353 | + "nvidia_gpu_supported": True, |
338 | 354 | "organization": f'{ds_model.freeform_tags["organization"]}', |
339 | 355 | "project_id": f"{ds_model.project_id}", |
340 | 356 | "ready_to_deploy": False if foundation_model_type == "verified" else True, |
@@ -366,6 +382,7 @@ def test_get_model_fine_tuned( |
366 | 382 | mock_from_id, |
367 | 383 | mock_read_file, |
368 | 384 | mock_query_resource, |
| 385 | + mock_get_container_config, |
369 | 386 | mock_auth, |
370 | 387 | ): |
371 | 388 | ds_model = MagicMock() |
@@ -507,7 +524,7 @@ def test_get_model_fine_tuned( |
507 | 524 | "model_card": f"{mock_read_file.return_value}", |
508 | 525 | "model_format": ModelFormat.SAFETENSORS, |
509 | 526 | "name": f"{ds_model.display_name}", |
510 | | - "nvidia_gpu_supported": False, |
| 527 | + "nvidia_gpu_supported": True, |
511 | 528 | "organization": "test_organization", |
512 | 529 | "project_id": f"{ds_model.project_id}", |
513 | 530 | "ready_to_deploy": True, |
@@ -709,12 +726,16 @@ def test_import_model_with_project_compartment_override(self, mock_load_config): |
709 | 726 | assert model.project_id == project_override |
710 | 727 |
|
711 | 728 | @patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError) |
712 | | - def test_import_model_with_missing_config(self, mock_load_config): |
| 729 | + def test_import_model_with_missing_config( |
| 730 | + self, mock_load_config, mock_get_container_config |
| 731 | + ): |
713 | 732 | """Test for validating if error is returned when model artifacts are incomplete or not available.""" |
| 733 | + |
714 | 734 | os_path = "oci://aqua-bkt@aqua-ns/prefix/path" |
715 | 735 | model_name = "oracle/aqua-1t-mega-model" |
716 | 736 | reload(ads.aqua.model.model) |
717 | 737 | app = AquaModelApp() |
| 738 | + app.list_resource = MagicMock(return_value=[]) |
718 | 739 | with pytest.raises(AquaRuntimeError): |
719 | 740 | model: AquaModel = app.register( |
720 | 741 | model=model_name, |
|
0 commit comments