Skip to content

Commit 350f0ea

Browse files
author
Ziqun Ye
committed
adding unit test
1 parent 429d4ea commit 350f0ea

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,7 +1404,7 @@ def _update_from_oci_model(self, oci_model_instance) -> "ModelDeployment":
14041404
infrastructure.CONST_WEB_CONCURRENCY,
14051405
runtime.env.get("WEB_CONCURRENCY", None),
14061406
)
1407-
if runtime.env.pop("CONTAINER_TYPE", None) == "TRITON":
1407+
if runtime.env.get("CONTAINER_TYPE", None) == "TRITON":
14081408
runtime.set_spec(
14091409
runtime.CONST_INFERENCE_SERVER, "triton"
14101410
)
@@ -1584,7 +1584,7 @@ def _build_model_deployment_configuration_details(self) -> Dict:
15841584
infrastructure.web_concurrency
15851585
)
15861586
runtime.set_spec(runtime.CONST_ENV, environment_variables)
1587-
if runtime.inference_server.lower() == "triton":
1587+
if hasattr(runtime, "inference_server") and runtime.inference_server and runtime.inference_server.lower() == "triton":
15881588
environment_variables["CONTAINER_TYPE"] = "TRITON"
15891589
runtime.set_spec(runtime.CONST_ENV, environment_variables)
15901590
environment_configuration_details = {

tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,62 @@ def initialize_model_deployment_from_spec(self):
308308
"runtime": runtime,
309309
}
310310
)
311+
312+
def initialize_model_deployment_triton_builder(self):
313+
infrastructure = ModelDeploymentInfrastructure()\
314+
.with_compartment_id("fakeid.compartment.oc1..xxx")\
315+
.with_project_id("fakeid.datascienceproject.oc1.iad.xxx")\
316+
.with_shape_name("VM.Standard.E4.Flex")\
317+
.with_replica(2)\
318+
.with_bandwidth_mbps(10)\
319+
320+
runtime = ModelDeploymentContainerRuntime()\
321+
.with_image("fake_image")\
322+
.with_server_port(5000)\
323+
.with_health_check_port(5000)\
324+
.with_model_uri("fake_model_id")\
325+
.with_env({"key":"value", "key2":"value2"})\
326+
.with_inference_server("triton")
327+
328+
deployment = ModelDeployment()\
329+
.with_display_name("triton case")\
330+
.with_infrastructure(infrastructure)\
331+
.with_runtime(runtime)
332+
return deployment
333+
334+
def initialize_model_deployment_triton_yaml(self):
335+
yaml_string = """
336+
kind: deployment
337+
spec:
338+
displayName: triton
339+
infrastructure:
340+
kind: infrastructure
341+
spec:
342+
bandwidthMbps: 10
343+
compartmentId: fake_compartment_id
344+
deploymentType: SINGLE_MODEL
345+
policyType: FIXED_SIZE
346+
replica: 2
347+
shapeConfigDetails:
348+
memoryInGBs: 16.0
349+
ocpus: 1.0
350+
shapeName: VM.Standard.E4.Flex
351+
type: datascienceModelDeployment
352+
runtime:
353+
kind: runtime
354+
spec:
355+
env:
356+
key: value
357+
key2: value2
358+
inference_server: triton
359+
healthCheckPort: 8000
360+
image: fake_image
361+
modelUri: fake_model_id
362+
serverPort: 8000
363+
type: container
364+
"""
365+
deployment_from_yaml = ModelDeployment.from_yaml(yaml_string)
366+
return deployment_from_yaml
311367

312368
def initialize_model_deployment_from_kwargs(self):
313369
infrastructure = (
@@ -435,11 +491,34 @@ def test_initialize_model_deployment_with_error(self):
435491
},
436492
)
437493

494+
438495
def test_initialize_model_deployment_with_spec_kwargs(self):
439496
model_deployment_kwargs = self.initialize_model_deployment_from_kwargs()
440497
model_deployment_builder = self.initialize_model_deployment()
441498

442499
assert model_deployment_kwargs.to_dict() == model_deployment_builder.to_dict()
500+
501+
502+
def test_initialize_model_deployment_triton_builder(self):
503+
temp_model_deployment = self.initialize_model_deployment_triton_builder()
504+
assert isinstance(
505+
temp_model_deployment.runtime, ModelDeploymentContainerRuntime
506+
)
507+
assert isinstance(
508+
temp_model_deployment.infrastructure, ModelDeploymentInfrastructure
509+
)
510+
assert temp_model_deployment.runtime.inference_server == "triton"
511+
512+
def test_initialize_model_deployment_triton_yaml(self):
513+
temp_model_deployment = self.initialize_model_deployment_triton_yaml()
514+
assert isinstance(
515+
temp_model_deployment.runtime, ModelDeploymentContainerRuntime
516+
)
517+
assert isinstance(
518+
temp_model_deployment.infrastructure, ModelDeploymentInfrastructure
519+
)
520+
assert temp_model_deployment.runtime.inference_server == "triton"
521+
443522

444523
def test_model_deployment_to_dict(self):
445524
model_deployment = self.initialize_model_deployment()

0 commit comments

Comments
 (0)