11#!/usr/bin/env python
22
3- # Copyright (c) 2024 Oracle and/or its affiliates.
3+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
44# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
6- from copy import deepcopy
76from typing import Any , Dict , List , Optional
87
98from pydantic import Field
109
1110from ads .aqua .config .utils .serializer import Serializable
1211
1312
14- class ModelParamsOverrides (Serializable ):
15- """Defines overrides for model parameters, including exclusions and additional inclusions."""
16-
17- exclude : Optional [List [str ]] = Field (default_factory = list )
18- include : Optional [Dict [str , Any ]] = Field (default_factory = dict )
19-
20- class Config :
21- extra = "ignore"
22-
23-
24- class ModelParamsVersion (Serializable ):
25- """Handles version-specific model parameter overrides."""
26-
27- overrides : Optional [ModelParamsOverrides ] = Field (
28- default_factory = ModelParamsOverrides
29- )
30-
31- class Config :
32- extra = "ignore"
33-
34-
35- class ModelParamsContainer (Serializable ):
36- """Represents a container's model configuration, including tasks, defaults, and versions."""
37-
38- name : Optional [str ] = None
39- default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
40- versions : Optional [Dict [str , ModelParamsVersion ]] = Field (default_factory = dict )
41-
42- class Config :
43- extra = "ignore"
44-
45-
46- class InferenceParams (Serializable ):
47- """Contains inference-related parameters with defaults."""
48-
49- class Config :
50- extra = "allow"
51-
52-
53- class InferenceContainer (Serializable ):
54- """Represents the inference parameters specific to a container."""
55-
56- name : Optional [str ] = None
57- params : Optional [Dict [str , Any ]] = Field (default_factory = dict )
58-
59- class Config :
60- extra = "ignore"
61-
62-
63- class ReportParams (Serializable ):
64- """Handles the report-related parameters."""
65-
66- default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
67-
68- class Config :
69- extra = "ignore"
70-
71-
72- class InferenceParamsConfig (Serializable ):
73- """Combines default inference parameters with container-specific configurations."""
74-
75- default : Optional [InferenceParams ] = Field (default_factory = InferenceParams )
76- containers : Optional [List [InferenceContainer ]] = Field (default_factory = list )
77-
78- def get_merged_params (self , container_name : str ) -> InferenceParams :
79- """
80- Merges default inference params with those specific to the given container.
81-
82- Parameters
83- ----------
84- container_name (str): The name of the container.
85-
86- Returns
87- -------
88- InferenceParams: The merged inference parameters.
89- """
90- merged_params = self .default .to_dict ()
91- for containers in self .containers :
92- if containers .name .lower () == container_name .lower ():
93- merged_params .update (containers .params or {})
94- break
95- return InferenceParams (** merged_params )
96-
97- class Config :
98- extra = "ignore"
99-
100-
101- class InferenceModelParamsConfig (Serializable ):
102- """Encapsulates the model parameters for different containers."""
103-
104- default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
105- containers : Optional [List [ModelParamsContainer ]] = Field (default_factory = list )
106-
107- def get_merged_model_params (
108- self ,
109- container_name : str ,
110- version : Optional [str ] = None ,
111- ) -> Dict [str , Any ]:
112- """
113- Gets the model parameters for a given container, version,
114- merged with the defaults.
115-
116- Parameters
117- ----------
118- container_name (str): The name of the container.
119- version (Optional[str]): The specific version of the container.
120-
121- Returns
122- -------
123- Dict[str, Any]: The merged model parameters.
124- """
125- params = deepcopy (self .default )
126-
127- for container in self .containers :
128- if container .name .lower () == container_name .lower ():
129- params .update (container .default )
130-
131- if version and version in container .versions :
132- version_overrides = container .versions [version ].overrides
133- if version_overrides :
134- if version_overrides .include :
135- params .update (version_overrides .include )
136- if version_overrides .exclude :
137- for key in version_overrides .exclude :
138- params .pop (key , None )
139- break
140-
141- return params
142-
143- class Config :
144- extra = "ignore"
145-
146-
14713class ShapeFilterConfig (Serializable ):
14814 """Represents the filtering options for a specific shape."""
14915
15016 evaluation_container : Optional [List [str ]] = Field (default_factory = list )
15117 evaluation_target : Optional [List [str ]] = Field (default_factory = list )
15218
15319 class Config :
154- extra = "ignore "
20+ extra = "allow "
15521
15622
15723class ShapeConfig (Serializable ):
@@ -178,7 +44,7 @@ class MetricConfig(Serializable):
17844 tags : Optional [List [str ]] = Field (default_factory = list )
17945
18046 class Config :
181- extra = "ignore "
47+ extra = "allow "
18248
18349
18450class ModelParamsConfig (Serializable ):
@@ -223,7 +89,7 @@ def search_shapes(
22389 ]
22490
22591 class Config :
226- extra = "ignore "
92+ extra = "allow "
22793 protected_namespaces = ()
22894
22995
@@ -235,49 +101,7 @@ class EvaluationServiceConfig(Serializable):
235101
236102 version : Optional [str ] = "1.0"
237103 kind : Optional [str ] = "evaluation_service_config"
238- report_params : Optional [ReportParams ] = Field (default_factory = ReportParams )
239- inference_params : Optional [InferenceParamsConfig ] = Field (
240- default_factory = InferenceParamsConfig
241- )
242- inference_model_params : Optional [InferenceModelParamsConfig ] = Field (
243- default_factory = InferenceModelParamsConfig
244- )
245104 ui_config : Optional [UIConfig ] = Field (default_factory = UIConfig )
246105
247- def get_merged_inference_params (self , container_name : str ) -> InferenceParams :
248- """
249- Merges default inference params with those specific to the given container.
250-
251- Params
252- ------
253- container_name (str): The name of the container.
254-
255- Returns
256- -------
257- InferenceParams: The merged inference parameters.
258- """
259- return self .inference_params .get_merged_params (container_name = container_name )
260-
261- def get_merged_inference_model_params (
262- self ,
263- container_name : str ,
264- version : Optional [str ] = None ,
265- ) -> Dict [str , Any ]:
266- """
267- Gets the model parameters for a given container, version, and task, merged with the defaults.
268-
269- Parameters
270- ----------
271- container_name (str): The name of the container.
272- version (Optional[str]): The specific version of the container.
273-
274- Returns
275- -------
276- Dict[str, Any]: The merged model parameters.
277- """
278- return self .inference_model_params .get_merged_model_params (
279- container_name = container_name , version = version
280- )
281-
282106 class Config :
283- extra = "ignore "
107+ extra = "allow "
0 commit comments