Skip to content

Commit 4f1b45e

Browse files
websocket changes for 1.0.3
1 parent b7b7ae1 commit 4f1b45e

File tree

4 files changed

+82
-1
lines changed

4 files changed

+82
-1
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import List, Union
2+
3+
from ads.aqua.common.decorator import handle_exceptions
4+
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
5+
from ads.aqua.extension.models.ws_models import RequestResponseType, ListDeploymentResponse, ListDeploymentRequest
6+
from ads.aqua.modeldeployment import AquaDeploymentApp
7+
from ads.config import COMPARTMENT_OCID
8+
9+
10+
class AquaDeploymentWSMsgHandler(AquaWSMsgHandler):
11+
12+
def __init__(self, message: Union[str, bytes]):
13+
super().__init__(message)
14+
15+
@staticmethod
16+
def get_message_types() -> List[RequestResponseType]:
17+
return [RequestResponseType.ListDeployments]
18+
19+
@handle_exceptions
20+
def process(self) -> ListDeploymentResponse:
21+
list_deployment_request = ListDeploymentRequest.from_json(self.message)
22+
deployment_list = AquaDeploymentApp().list(
23+
compartment_id=list_deployment_request.compartment_id or COMPARTMENT_OCID,
24+
project_id=list_deployment_request.project_id,
25+
)
26+
response = ListDeploymentResponse(
27+
message_id=list_deployment_request.message_id,
28+
kind=RequestResponseType.ListDeployments,
29+
data=deployment_list,
30+
)
31+
return response

ads/aqua/extension/models/ws_models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99

1010
from ads.aqua.evaluation.entities import AquaEvaluationSummary
1111
from ads.aqua.model.entities import AquaModelSummary
12+
from ads.aqua.modeldeployment.entities import AquaDeployment
1213
from ads.common.extended_enum import ExtendedEnumMeta
1314
from ads.common.serializer import DataClassSerializable
1415

1516

1617
class RequestResponseType(str, metaclass=ExtendedEnumMeta):
1718
ListEvaluations = "ListEvaluations"
19+
ListDeployments = "ListDeployments"
1820
ListModels = "ListModels"
1921
Error = "Error"
2022

@@ -43,13 +45,26 @@ class ListEvaluationsRequest(BaseRequest):
4345
@dataclass
4446
class ListModelsRequest(BaseRequest):
4547
compartment_id: Optional[str] = None
48+
project_id: Optional[str] = None
49+
model_type: Optional[str] = None
50+
kind = RequestResponseType.ListDeployments
4651

4752

4853
@dataclass
4954
class ListEvaluationsResponse(BaseResponse):
5055
data: List[AquaEvaluationSummary]
5156

57+
@dataclass
58+
class ListDeploymentRequest(BaseRequest):
59+
compartment_id: str
60+
project_id: Optional[str] = None
61+
kind = RequestResponseType.ListDeployments
5262

63+
64+
@dataclass
65+
class ListDeploymentResponse(BaseResponse):
66+
data: List[AquaDeployment]
67+
5368
@dataclass
5469
class ListModelsResponse(BaseResponse):
5570
data: List[AquaModelSummary]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import List, Union
2+
3+
from ads.aqua.common.decorator import handle_exceptions
4+
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
5+
from ads.aqua.extension.models.ws_models import RequestResponseType,ListModelsResponse, ListModelsRequest
6+
from ads.aqua.model import AquaModelApp
7+
from ads.config import COMPARTMENT_OCID
8+
9+
10+
class AquaModelWSMsgHandler(AquaWSMsgHandler):
11+
12+
def __init__(self, message: Union[str, bytes]):
13+
super().__init__(message)
14+
15+
@staticmethod
16+
def get_message_types() -> List[RequestResponseType]:
17+
return [RequestResponseType.ListModels]
18+
19+
@handle_exceptions
20+
def process(self) -> ListModelsResponse:
21+
list_models_request = ListModelsRequest.from_json(self.message)
22+
print(list_models_request)
23+
models_list = AquaModelApp().list(
24+
compartment_id=list_models_request.compartment_id or COMPARTMENT_OCID,
25+
project_id=list_models_request.project_id,
26+
model_type=list_models_request.model_type
27+
)
28+
response = ListModelsResponse(
29+
message_id=list_models_request.message_id,
30+
kind=RequestResponseType.ListModels,
31+
data=models_list,
32+
)
33+
return response

ads/aqua/extension/ui_websocket_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ads.aqua import logger
1616
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
17+
from ads.aqua.extension.deployment_ws_msg_handler import AquaDeploymentWSMsgHandler
1718
from ads.aqua.extension.evaluation_ws_msg_handler import AquaEvaluationWSMsgHandler
1819
from ads.aqua.extension.models.ws_models import (
1920
AquaWsError,
@@ -22,6 +23,7 @@
2223
ErrorResponse,
2324
RequestResponseType,
2425
)
26+
from ads.aqua.extension.models_ws_msg_handler import AquaModelWSMsgHandler
2527

2628
MAX_WORKERS = 20
2729

@@ -43,7 +45,7 @@ def get_aqua_internal_error_response(message_id: str) -> ErrorResponse:
4345
class AquaUIWebSocketHandler(WebSocketHandler):
4446
"""Handler for Aqua Websocket."""
4547

46-
_handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler]
48+
_handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler,AquaDeploymentWSMsgHandler,AquaModelWSMsgHandler]
4749

4850
thread_pool: ThreadPoolExecutor
4951

0 commit comments

Comments
 (0)