Skip to content

Commit e5ee1bb

Browse files
Addressing review comments
1 parent a80395a commit e5ee1bb

File tree

7 files changed

+30
-38
lines changed

7 files changed

+30
-38
lines changed

ads/aqua/config/container_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,13 @@ def from_service_config(
203203
{
204204
"PORT": container.workload_configuration_details_list[
205205
0
206-
].additional_configurations.get("PORT", ""),
206+
].additional_configurations.get("PORT", "")
207+
},
208+
{
207209
"HEALTH_CHECK_PORT": container.workload_configuration_details_list[
208210
0
209211
].additional_configurations.get("HEALTH_CHECK_PORT", UNKNOWN),
210212
},
211-
{},
212213
]
213214
container_spec = AquaContainerConfigSpec(
214215
cli_param=container.workload_configuration_details_list[0].cmd,

ads/aqua/extension/common_ws_msg_handler.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55

66
import json
77
from importlib import metadata
8-
from typing import List, Union
8+
from typing import List, Optional, Union
99

1010
from ads.aqua.common.decorator import handle_exceptions
1111
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
1212
from ads.aqua.extension.models.ws_models import (
1313
AdsVersionResponse,
14-
CompatibilityCheckResponse,
1514
RequestResponseType,
1615
)
1716

@@ -25,7 +24,7 @@ def __init__(self, message: Union[str, bytes]):
2524
super().__init__(message)
2625

2726
@handle_exceptions
28-
def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
27+
def process(self) -> Optional[AdsVersionResponse]:
2928
request = json.loads(self.message)
3029
if request.get("kind") == "AdsVersion":
3130
version = metadata.version("oracle_ads")
@@ -35,9 +34,3 @@ def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
3534
data=version,
3635
)
3736
return response
38-
if request.get("kind") == "CompatibilityCheck":
39-
return CompatibilityCheckResponse(
40-
message_id=request.get("message_id"),
41-
kind=RequestResponseType.CompatibilityCheck,
42-
data={"status": "ok"},
43-
)

ads/aqua/extension/model_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ads.aqua.extension.errors import Errors
1616
from ads.aqua.model import AquaModelApp
1717
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
18-
from ads.config import USER
18+
from ads.config import SERVICE
1919
from ads.model.common.utils import MetadataArtifactPathType
2020

2121

@@ -82,7 +82,7 @@ def list(self):
8282
# project_id is no needed.
8383
project_id = self.get_argument("project_id", default=None)
8484
model_type = self.get_argument("model_type", default=None)
85-
category = self.get_argument("category", default=USER)
85+
category = self.get_argument("category", default=SERVICE)
8686
return self.finish(
8787
AquaModelApp().list(
8888
compartment_id=compartment_id,

ads/aqua/model/model.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def create(
185185
target_compartment = compartment_id or COMPARTMENT_OCID
186186

187187
# Skip model copying if it is registered model
188-
if service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM) is not None:
188+
if service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None:
189189
logger.info(
190190
f"Aqua Model {model_id} already exists in the user's compartment."
191191
"Skipped copying."
@@ -919,12 +919,13 @@ def _process_model(
919919
def list(
920920
self,
921921
compartment_id: str = None,
922+
category: str = None,
922923
project_id: str = None,
923924
model_type: str = None,
924925
**kwargs,
925926
) -> List["AquaModelSummary"]:
926927
"""Lists all Aqua models within a specified compartment and/or project.
927-
If `compartment_id` is not specified, the method defaults to returning
928+
If `category` is not specified, the method defaults to returning
928929
the service models within the pre-configured default compartment. By default, the list
929930
of models in the service compartment are cached. Use clear_model_list_cache() to invalidate
930931
the cache.
@@ -933,6 +934,8 @@ def list(
933934
----------
934935
compartment_id: (str, optional). Defaults to `None`.
935936
The compartment OCID.
937+
category: (str,optional). Defaults to `SERVICE`
938+
The category of the models to fetch. Can be either `USER` or `SERVICE`
936939
project_id: (str, optional). Defaults to `None`.
937940
The project OCID.
938941
model_type: (str, optional). Defaults to `None`.
@@ -946,9 +949,9 @@ def list(
946949
The list of the `ads.aqua.model.AquaModelSummary`.
947950
"""
948951

949-
models = []
950-
category = kwargs.pop("category", USER)
951-
if compartment_id and category != SERVICE:
952+
category = category or kwargs.pop("category", SERVICE)
953+
compartment_id = compartment_id or COMPARTMENT_OCID
954+
if category == USER:
952955
# tracks number of times custom model listing was called
953956
self.telemetry.record_event_async(
954957
category="aqua/custom/model", action="list"
@@ -957,33 +960,32 @@ def list(
957960
logger.info(f"Fetching custom models from compartment_id={compartment_id}.")
958961
model_type = model_type.upper() if model_type else ModelType.FT
959962
models = self._rqs(compartment_id, model_type=model_type)
963+
logger.info(
964+
f"Fetched {len(models)} models from {compartment_id or COMPARTMENT_OCID}."
965+
)
960966
else:
961967
# tracks number of times service model listing was called
962968
self.telemetry.record_event_async(
963969
category="aqua/service/model", action="list"
964970
)
965971

966972
if AQUA_SERVICE_MODELS in self._service_models_cache:
967-
logger.info(
968-
f"Returning service models list in {AQUA_SERVICE_MODELS} from cache."
969-
)
973+
logger.info("Returning service models list from cache.")
970974
return self._service_models_cache.get(AQUA_SERVICE_MODELS)
971-
logger.info("Fetching service models.")
975+
logger.info("Fetching service models from cache.")
972976
lifecycle_state = kwargs.pop(
973977
"lifecycle_state", Model.LIFECYCLE_STATE_ACTIVE
974978
)
975979

976980
models = self.list_resource(
977981
self.ds_client.list_models,
978-
compartment_id=compartment_id or COMPARTMENT_OCID,
982+
compartment_id=compartment_id,
979983
lifecycle_state=lifecycle_state,
980984
category=category,
981985
**kwargs,
982986
)
987+
logger.info(f"Fetched {len(models)} service models.")
983988

984-
logger.info(
985-
f"Fetched {len(models)} models from {AQUA_SERVICE_MODELS if category==SERVICE else compartment_id}."
986-
)
987989
aqua_models = []
988990
inference_containers = self.get_container_config().to_dict().get("inference")
989991
for model in models:
@@ -997,7 +999,6 @@ def list(
997999
project_id=project_id or UNKNOWN,
9981000
)
9991001
)
1000-
10011002
if category == SERVICE:
10021003
self._service_models_cache.__setitem__(
10031004
key=AQUA_SERVICE_MODELS, value=aqua_models

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1483,7 +1483,7 @@ def test_list_custom_models(self, mock_get_container_config):
14831483
]
14841484
)
14851485

1486-
results = self.app.list(TestDataset.COMPARTMENT_ID)
1486+
results = self.app.list(TestDataset.COMPARTMENT_ID, category=ads.config.USER)
14871487

14881488
self.app._rqs.assert_called_with(TestDataset.COMPARTMENT_ID, model_type="FT")
14891489

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from ads.aqua.model import AquaModelApp
2929
from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary
30-
from ads.config import USER
30+
from ads.config import USER, SERVICE
3131

3232

3333
class ModelHandlerTestCase(TestCase):
@@ -134,7 +134,7 @@ def test_list(self, mock_list):
134134
mock_finish.side_effect = lambda x: x
135135
self.model_handler.list()
136136
mock_list.assert_called_with(
137-
compartment_id=None, project_id=None, model_type=None, category=USER
137+
compartment_id=None, project_id=None, model_type=None, category=SERVICE
138138
)
139139

140140
@parameterized.expand(

tests/unitary/with_extras/aqua/test_ui.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,6 @@ def test_list_containers(self, mock_list_service_containers):
541541
mock_list_service_containers.return_value = TestDataset.CONTAINERS_LIST
542542

543543
test_result = self.app.list_containers()
544-
print("test_result: ", test_result)
545544
expected_result = {
546545
"evaluate": [
547546
{
@@ -566,13 +565,11 @@ def test_list_containers(self, mock_list_service_containers):
566565
"spec": {
567566
"cli_param": "--served-model-name odsc-llm --disable-custom-all-reduce --seed 42 ",
568567
"env_vars": [
569-
{
570-
"HEALTH_CHECK_PORT": "8080",
571-
"MODEL_DEPLOY_ENABLE_STREAMING": "true",
572-
"MODEL_DEPLOY_HEALTH_ENDPOINT": "",
573-
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions",
574-
"PORT": "8080",
575-
}
568+
{"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions"},
569+
{"MODEL_DEPLOY_HEALTH_ENDPOINT": ""},
570+
{"MODEL_DEPLOY_ENABLE_STREAMING": "true"},
571+
{"PORT": "8080"},
572+
{"HEALTH_CHECK_PORT": "8080"},
576573
],
577574
"health_check_port": "8080",
578575
"restricted_params": [

0 commit comments

Comments
 (0)