Skip to content

Commit 2e3049d

Browse files
model list and eval details changes
1 parent 4f1b45e commit 2e3049d

File tree

7 files changed

+196
-49
lines changed

7 files changed

+196
-49
lines changed

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,8 @@ clean:
1111
@find ./ -name 'Thumbs.db' -exec rm -f {} \;
1212
@find ./ -name '*~' -exec rm -f {} \;
1313
@find ./ -name '.DS_Store' -exec rm -f {} \;
14+
15+
aqua.test:
16+
pip install -e .
17+
jupyter server extension enable --py ads.aqua.extension
18+
jupyter lab --NotebookApp.disable_check_xsrf=True --no-browser
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import json
2+
from typing import List, Union
3+
4+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment
5+
from ads.aqua.common.decorator import handle_exceptions
6+
from ads.aqua.common.errors import AquaResourceAccessError
7+
from ads.aqua.common.utils import known_realm
8+
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
9+
from ads.aqua.extension.models.ws_models import RequestResponseType, AdsVersionResponse, AdsVersionRequest, \
10+
CompatibilityCheckResponse
11+
from importlib import metadata
12+
13+
14+
class AquaCommonWsMsgHandler(AquaWSMsgHandler):
15+
16+
@staticmethod
17+
def get_message_types() -> List[RequestResponseType]:
18+
return [RequestResponseType.AdsVersion, RequestResponseType.CompatibilityCheck]
19+
20+
def __init__(self, message: Union[str, bytes]):
21+
super().__init__(message)
22+
23+
@handle_exceptions
24+
def process(self) -> AdsVersionResponse | CompatibilityCheckResponse:
25+
request = json.loads(self.message)
26+
print("request: {}".format(request))
27+
if request.get('kind') == 'AdsVersion':
28+
version = metadata.version("oracle_ads")
29+
response = AdsVersionResponse(
30+
message_id=request.get("message_id"),
31+
kind=RequestResponseType.AdsVersion,
32+
data=version)
33+
return response
34+
if request.get('kind') == 'CompatibilityCheck':
35+
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
36+
return CompatibilityCheckResponse(message_id=request.get("message_id"),
37+
kind=RequestResponseType.CompatibilityCheck,
38+
data={'status': 'ok'})
39+
elif known_realm():
40+
return CompatibilityCheckResponse(message_id=request.get("message_id"),
41+
kind=RequestResponseType.CompatibilityCheck,
42+
data={'status': 'compatible'})
43+
else:
44+
raise AquaResourceAccessError(
45+
f"The AI Quick actions extension is not compatible in the given region."
46+
)
Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import json
12
from typing import List, Union
23

34
from ads.aqua.common.decorator import handle_exceptions
45
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
5-
from ads.aqua.extension.models.ws_models import RequestResponseType, ListDeploymentResponse, ListDeploymentRequest
6+
from ads.aqua.extension.models.ws_models import RequestResponseType, ListDeploymentResponse, ListDeploymentRequest, \
7+
ModelDeploymentDetailsResponse
68
from ads.aqua.modeldeployment import AquaDeploymentApp
79
from ads.config import COMPARTMENT_OCID
810

@@ -14,18 +16,25 @@ def __init__(self, message: Union[str, bytes]):
1416

1517
@staticmethod
1618
def get_message_types() -> List[RequestResponseType]:
17-
return [RequestResponseType.ListDeployments]
19+
return [RequestResponseType.ListDeployments, RequestResponseType.DeploymentDetails]
1820

1921
@handle_exceptions
20-
def process(self) -> ListDeploymentResponse:
21-
list_deployment_request = ListDeploymentRequest.from_json(self.message)
22-
deployment_list = AquaDeploymentApp().list(
23-
compartment_id=list_deployment_request.compartment_id or COMPARTMENT_OCID,
24-
project_id=list_deployment_request.project_id,
25-
)
26-
response = ListDeploymentResponse(
27-
message_id=list_deployment_request.message_id,
28-
kind=RequestResponseType.ListDeployments,
29-
data=deployment_list,
30-
)
31-
return response
22+
def process(self) -> ListDeploymentResponse | ModelDeploymentDetailsResponse:
23+
request = json.loads(self.message)
24+
if request.get("kind") == "ListDeployments":
25+
deployment_list = AquaDeploymentApp().list(
26+
compartment_id=request.get("compartment_id") or COMPARTMENT_OCID,
27+
project_id=request.get("project_id"),
28+
)
29+
response = ListDeploymentResponse(
30+
message_id=request.get("message_id"),
31+
kind=RequestResponseType.ListDeployments,
32+
data=deployment_list,
33+
)
34+
return response
35+
elif request.get("kind") == "DeploymentDetails":
36+
deployment_details = AquaDeploymentApp().get(request.get("model_deployment_id"))
37+
response = ModelDeploymentDetailsResponse(message_id=request.get("message_id"),
38+
kind=RequestResponseType.DeploymentDetails,
39+
data=deployment_details)
40+
return response
Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,57 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
3-
3+
import json
44
# Copyright (c) 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from typing import List, Union
88

9-
from tornado.web import HTTPError
10-
119
from ads.aqua.common.decorator import handle_exceptions
1210
from ads.aqua.evaluation import AquaEvaluationApp
1311
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
1412
from ads.aqua.extension.models.ws_models import (
15-
ListEvaluationsRequest,
1613
ListEvaluationsResponse,
17-
RequestResponseType,
14+
RequestResponseType, EvaluationDetailsResponse,
1815
)
1916
from ads.config import COMPARTMENT_OCID
2017

2118

2219
class AquaEvaluationWSMsgHandler(AquaWSMsgHandler):
2320
@staticmethod
2421
def get_message_types() -> List[RequestResponseType]:
25-
return [RequestResponseType.ListEvaluations]
22+
return [RequestResponseType.ListEvaluations, RequestResponseType.EvaluationDetails]
2623

2724
def __init__(self, message: Union[str, bytes]):
2825
super().__init__(message)
2926

3027
@handle_exceptions
31-
def process(self) -> ListEvaluationsResponse:
32-
list_eval_request = ListEvaluationsRequest.from_json(self.message)
28+
def process(self) -> ListEvaluationsResponse | EvaluationDetailsResponse:
29+
request = json.loads(self.message)
30+
if request['kind'] == "ListEvaluations":
31+
return self.list_evaluations(request)
32+
if request["kind"] == "EvaluationDetails":
33+
return self.evaluation_details(request)
34+
35+
@staticmethod
36+
def list_evaluations(request) -> ListEvaluationsResponse:
3337

3438
eval_list = AquaEvaluationApp().list(
35-
list_eval_request.compartment_id or COMPARTMENT_OCID,
36-
list_eval_request.project_id,
39+
request.get("compartment_id") or COMPARTMENT_OCID,
40+
request.get("project_id"),
3741
)
3842
response = ListEvaluationsResponse(
39-
message_id=list_eval_request.message_id,
43+
message_id=request["message_id"],
4044
kind=RequestResponseType.ListEvaluations,
4145
data=eval_list,
4246
)
4347
return response
48+
49+
@staticmethod
50+
def evaluation_details(request) -> EvaluationDetailsResponse:
51+
evaluation_details = AquaEvaluationApp().get(eval_id=request.get("evaluation_id"))
52+
response = EvaluationDetailsResponse(
53+
message_id=request.get('message_id'),
54+
kind=RequestResponseType.EvaluationDetails,
55+
data=evaluation_details,
56+
)
57+
return response

ads/aqua/extension/models/ws_models.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,30 @@
77
from dataclasses import dataclass
88
from typing import List, Optional
99

10-
from ads.aqua.evaluation.entities import AquaEvaluationSummary
11-
from ads.aqua.model.entities import AquaModelSummary
12-
from ads.aqua.modeldeployment.entities import AquaDeployment
10+
from ads.aqua.evaluation.entities import AquaEvaluationSummary, AquaEvaluationDetail
11+
from ads.aqua.model.entities import AquaModelSummary, AquaModel
12+
from ads.aqua.modeldeployment.entities import AquaDeployment, AquaDeploymentDetail
1313
from ads.common.extended_enum import ExtendedEnumMeta
1414
from ads.common.serializer import DataClassSerializable
1515

1616

1717
class RequestResponseType(str, metaclass=ExtendedEnumMeta):
1818
ListEvaluations = "ListEvaluations"
19+
EvaluationDetails = "EvaluationDetails"
1920
ListDeployments = "ListDeployments"
21+
DeploymentDetails = "DeploymentDetails"
2022
ListModels = "ListModels"
23+
ModelDetails = "ModelDetails"
24+
AdsVersion = "AdsVersion"
25+
CompatibilityCheck = "CompatibilityCheck"
2126
Error = "Error"
2227

2328

2429
@dataclass
2530
class BaseResponse(DataClassSerializable):
2631
message_id: str
2732
kind: RequestResponseType
28-
data: object
33+
data: Optional[object]
2934

3035

3136
@dataclass
@@ -42,6 +47,12 @@ class ListEvaluationsRequest(BaseRequest):
4247
kind = RequestResponseType.ListEvaluations
4348

4449

50+
@dataclass
51+
class EvaluationDetailsRequest(BaseRequest):
52+
kind = RequestResponseType.EvaluationDetails
53+
evaluation_id: str
54+
55+
4556
@dataclass
4657
class ListModelsRequest(BaseRequest):
4758
compartment_id: Optional[str] = None
@@ -51,8 +62,10 @@ class ListModelsRequest(BaseRequest):
5162

5263

5364
@dataclass
54-
class ListEvaluationsResponse(BaseResponse):
55-
data: List[AquaEvaluationSummary]
65+
class ModelDetailsRequest(BaseRequest):
66+
kind = RequestResponseType.ModelDetails
67+
model_id: str
68+
5669

5770
@dataclass
5871
class ListDeploymentRequest(BaseRequest):
@@ -61,15 +74,62 @@ class ListDeploymentRequest(BaseRequest):
6174
kind = RequestResponseType.ListDeployments
6275

6376

77+
@dataclass
78+
class DeploymentDetailsRequest(BaseRequest):
79+
model_deployment_id: str
80+
kind = RequestResponseType.DeploymentDetails
81+
82+
83+
@dataclass
84+
class ListEvaluationsResponse(BaseResponse):
85+
data: List[AquaEvaluationSummary]
86+
87+
88+
@dataclass
89+
class EvaluationDetailsResponse(BaseResponse):
90+
data: AquaEvaluationDetail
91+
92+
6493
@dataclass
6594
class ListDeploymentResponse(BaseResponse):
6695
data: List[AquaDeployment]
67-
96+
97+
98+
@dataclass
99+
class ModelDeploymentDetailsResponse(BaseResponse):
100+
data: AquaDeploymentDetail
101+
102+
68103
@dataclass
69104
class ListModelsResponse(BaseResponse):
70105
data: List[AquaModelSummary]
71106

72107

108+
@dataclass
109+
class ModelDetailsResponse(BaseResponse):
110+
data: AquaModel
111+
112+
113+
@dataclass
114+
class AdsVersionRequest(BaseRequest):
115+
kind: RequestResponseType.AdsVersion
116+
117+
118+
@dataclass
119+
class AdsVersionResponse(BaseResponse):
120+
data: str
121+
122+
123+
@dataclass
124+
class CompatibilityCheckRequest(BaseRequest):
125+
kind: RequestResponseType.CompatibilityCheck
126+
127+
128+
@dataclass
129+
class CompatibilityCheckResponse(BaseResponse):
130+
data: object
131+
132+
73133
@dataclass
74134
class AquaWsError(DataClassSerializable):
75135
status: str
Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import json
12
from typing import List, Union
23

34
from ads.aqua.common.decorator import handle_exceptions
45
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
5-
from ads.aqua.extension.models.ws_models import RequestResponseType,ListModelsResponse, ListModelsRequest
6+
from ads.aqua.extension.models.ws_models import RequestResponseType, ListModelsResponse, ListModelsRequest, \
7+
ModelDetailsResponse
68
from ads.aqua.model import AquaModelApp
79
from ads.config import COMPARTMENT_OCID
810

@@ -14,20 +16,27 @@ def __init__(self, message: Union[str, bytes]):
1416

1517
@staticmethod
1618
def get_message_types() -> List[RequestResponseType]:
17-
return [RequestResponseType.ListModels]
19+
return [RequestResponseType.ListModels,RequestResponseType.ModelDetails]
1820

1921
@handle_exceptions
20-
def process(self) -> ListModelsResponse:
21-
list_models_request = ListModelsRequest.from_json(self.message)
22-
print(list_models_request)
23-
models_list = AquaModelApp().list(
24-
compartment_id=list_models_request.compartment_id or COMPARTMENT_OCID,
25-
project_id=list_models_request.project_id,
26-
model_type=list_models_request.model_type
27-
)
28-
response = ListModelsResponse(
29-
message_id=list_models_request.message_id,
30-
kind=RequestResponseType.ListModels,
31-
data=models_list,
32-
)
33-
return response
22+
def process(self) -> ListModelsResponse | ModelDetailsResponse:
23+
request = json.loads(self.message)
24+
if request.get('kind') == 'ListModels':
25+
models_list = AquaModelApp().list(
26+
compartment_id=request.get("compartment_id") or COMPARTMENT_OCID,
27+
project_id=request.get("project_id"),
28+
model_type=request.get("model_type")
29+
)
30+
response = ListModelsResponse(
31+
message_id=request.get("message_id"),
32+
kind=RequestResponseType.ListModels,
33+
data=models_list,
34+
)
35+
return response
36+
elif request.get('kind') == 'ModelDetails':
37+
model_id=request.get("model_id")
38+
response=AquaModelApp().get(model_id)
39+
return ModelDetailsResponse(message_id=request.get("message_id"),
40+
kind=RequestResponseType.ModelDetails,
41+
data=response)
42+

ads/aqua/extension/ui_websocket_handler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ads.aqua import logger
1616
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
17+
from ads.aqua.extension.common_ws_msg_handler import AquaCommonWsMsgHandler
1718
from ads.aqua.extension.deployment_ws_msg_handler import AquaDeploymentWSMsgHandler
1819
from ads.aqua.extension.evaluation_ws_msg_handler import AquaEvaluationWSMsgHandler
1920
from ads.aqua.extension.models.ws_models import (
@@ -45,7 +46,10 @@ def get_aqua_internal_error_response(message_id: str) -> ErrorResponse:
4546
class AquaUIWebSocketHandler(WebSocketHandler):
4647
"""Handler for Aqua Websocket."""
4748

48-
_handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler,AquaDeploymentWSMsgHandler,AquaModelWSMsgHandler]
49+
_handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler,
50+
AquaDeploymentWSMsgHandler,
51+
AquaModelWSMsgHandler,
52+
AquaCommonWsMsgHandler]
4953

5054
thread_pool: ThreadPoolExecutor
5155

0 commit comments

Comments
 (0)