1313from threading import Lock
1414from typing import Any , Dict , List , Union
1515
16+ import oci
1617from cachetools import TTLCache
18+ from oci .data_science .models import (
19+ JobRun ,
20+ Metadata ,
21+ UpdateModelDetails ,
22+ UpdateModelProvenanceDetails ,
23+ )
1724
18- import oci
1925from ads .aqua import logger
2026from ads .aqua .app import AquaApp
2127from ads .aqua .common import utils
4147)
4248from ads .aqua .constants import (
4349 CONSOLE_LINK_RESOURCE_TYPE_MAPPING ,
44- EVALUATION_INFERENCE_DEFAULT_THREADS ,
4550 EVALUATION_REPORT ,
4651 EVALUATION_REPORT_JSON ,
4752 EVALUATION_REPORT_MD ,
97102)
98103from ads .model .model_version_set import ModelVersionSet
99104from ads .telemetry import telemetry
100- from oci .data_science .models import (
101- JobRun ,
102- Metadata ,
103- UpdateModelDetails ,
104- UpdateModelProvenanceDetails ,
105- )
106105
107106
108107class AquaEvaluationApp (AquaApp ):
@@ -171,6 +170,7 @@ def create(
171170 "Specify either a model or model deployment id."
172171 )
173172 evaluation_source = None
173+ eval_inference_configuration = None
174174 if (
175175 DataScienceResource .MODEL_DEPLOYMENT
176176 in create_aqua_evaluation_details .evaluation_source_id
@@ -182,29 +182,14 @@ def create(
182182 runtime = ModelDeploymentContainerRuntime .from_dict (
183183 evaluation_source .runtime .to_dict ()
184184 )
185- container_config = AquaContainerConfig .from_container_index_json (
185+ inference_config = AquaContainerConfig .from_container_index_json (
186186 enable_spec = True
187- )
188- for container in container_config . inference .values ():
187+ ). inference
188+ for container in inference_config .values ():
189189 if container .name == runtime .image .split (":" )[0 ]:
190- max_threads = container .spec .evaluation_configuration .evaluation_max_threads
191- if (
192- max_threads
193- and create_aqua_evaluation_details .inference_max_threads
194- and max_threads
195- < create_aqua_evaluation_details .inference_max_threads
196- ):
197- raise AquaValueError (
198- f"Invalid inference max threads. The maximum allowed value for { runtime .image } is { max_threads } ."
199- )
200- if not create_aqua_evaluation_details .inference_max_threads :
201- create_aqua_evaluation_details .inference_max_threads = container .spec .evaluation_configuration .evaluation_default_threads
202- break
203- if not create_aqua_evaluation_details .inference_max_threads :
204- create_aqua_evaluation_details .inference_max_threads = (
205- EVALUATION_INFERENCE_DEFAULT_THREADS
206- )
207-
190+ eval_inference_configuration = (
191+ container .spec .evaluation_configuration
192+ )
208193 elif (
209194 DataScienceResource .MODEL
210195 in create_aqua_evaluation_details .evaluation_source_id
@@ -420,7 +405,9 @@ def create(
420405 report_path = create_aqua_evaluation_details .report_path ,
421406 model_parameters = create_aqua_evaluation_details .model_parameters ,
422407 metrics = create_aqua_evaluation_details .metrics ,
423- inference_max_threads = create_aqua_evaluation_details .inference_max_threads ,
408+ inference_configuration = eval_inference_configuration .to_filtered_dict ()
409+ if eval_inference_configuration
410+ else {},
424411 )
425412 ).create (** kwargs ) ## TODO: decide what parameters will be needed
426413 logger .debug (
@@ -542,7 +529,7 @@ def _build_evaluation_runtime(
542529 report_path : str ,
543530 model_parameters : dict ,
544531 metrics : List = None ,
545- inference_max_threads : int = None ,
532+ inference_configuration : dict = None ,
546533 ) -> Runtime :
547534 """Builds evaluation runtime for Job."""
548535 # TODO the image name needs to be extracted from the mapping index.json file.
@@ -552,17 +539,19 @@ def _build_evaluation_runtime(
552539 .with_environment_variable (
553540 ** {
554541 "AIP_SMC_EVALUATION_ARGUMENTS" : json .dumps (
555- asdict (
556- self ._build_launch_cmd (
557- evaluation_id = evaluation_id ,
558- evaluation_source_id = evaluation_source_id ,
559- dataset_path = dataset_path ,
560- report_path = report_path ,
561- model_parameters = model_parameters ,
562- metrics = metrics ,
563- inference_max_threads = inference_max_threads ,
564- )
565- )
542+ {
543+ ** asdict (
544+ self ._build_launch_cmd (
545+ evaluation_id = evaluation_id ,
546+ evaluation_source_id = evaluation_source_id ,
547+ dataset_path = dataset_path ,
548+ report_path = report_path ,
549+ model_parameters = model_parameters ,
550+ metrics = metrics ,
551+ ),
552+ ),
553+ ** inference_configuration ,
554+ },
566555 ),
567556 "CONDA_BUCKET_NS" : CONDA_BUCKET_NS ,
568557 },
@@ -620,7 +609,6 @@ def _build_launch_cmd(
620609 report_path : str ,
621610 model_parameters : dict ,
622611 metrics : List = None ,
623- inference_max_threads : int = None ,
624612 ):
625613 return AquaEvaluationCommands (
626614 evaluation_id = evaluation_id ,
@@ -637,7 +625,6 @@ def _build_launch_cmd(
637625 metrics = metrics ,
638626 output_dir = report_path ,
639627 params = model_parameters ,
640- inference_max_threads = inference_max_threads ,
641628 )
642629
643630 @telemetry (entry_point = "plugin=evaluation&action=get" , name = "aqua" )
@@ -1227,7 +1214,7 @@ def _delete_job_and_model(job, model):
12271214 f"Exception message: { ex } "
12281215 )
12291216
1230- def load_evaluation_config (self , _ ):
1217+ def load_evaluation_config (self , eval_id ):
12311218 """Loads evaluation config."""
12321219 return {
12331220 "model_params" : {
0 commit comments