44# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
66from copy import deepcopy
7- from typing import Any , Dict , List , Optional , Union
7+ from typing import Any , Dict , List , Optional
88
99from pydantic import Field
1010
1919INFERENCE_DELAY = 0
2020
2121
22- class ModelParamItem (Serializable ):
23- """Represents min, max, and default values for a model parameter."""
24-
25- min : Optional [Union [int , float ]] = None
26- max : Optional [Union [int , float ]] = None
27- default : Optional [Union [int , float ]] = None
28-
29- class Config :
30- extra = "ignore"
31-
32-
3322class ModelParamsOverrides (Serializable ):
3423 """Defines overrides for model parameters, including exclusions and additional inclusions."""
3524
@@ -51,28 +40,11 @@ class Config:
5140 extra = "ignore"
5241
5342
54- class ModelDefaultParams (Serializable ):
55- """Defines default parameters for a model within a specific framework."""
56-
57- model : Optional [str ] = None
58- max_tokens : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
59- temperature : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
60- top_p : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
61- top_k : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
62- presence_penalty : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
63- frequency_penalty : Optional [ModelParamItem ] = Field (default_factory = ModelParamItem )
64- stop : List [str ] = Field (default_factory = list )
43+ class ModelParamsContainer (Serializable ):
44+ """Represents a container's model configuration, including tasks, defaults, and versions."""
6545
66- class Config :
67- extra = "allow"
68-
69-
70- class ModelFramework (Serializable ):
71- """Represents a framework's model configuration, including tasks, defaults, and versions."""
72-
73- framework : Optional [str ] = None
74- task : Optional [List [str ]] = Field (default_factory = list )
75- default : Optional [ModelDefaultParams ] = Field (default_factory = ModelDefaultParams )
46+ name : Optional [str ] = None
47+ default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
7648 versions : Optional [Dict [str , ModelParamsVersion ]] = Field (default_factory = dict )
7749
7850 class Config :
@@ -93,10 +65,10 @@ class Config:
9365 extra = "allow"
9466
9567
96- class InferenceFramework (Serializable ):
97- """Represents the inference parameters specific to a framework ."""
68+ class InferenceContainer (Serializable ):
69+ """Represents the inference parameters specific to a container ."""
9870
99- framework : Optional [str ] = None
71+ name : Optional [str ] = None
10072 params : Optional [Dict [str , Any ]] = Field (default_factory = dict )
10173
10274 class Config :
@@ -113,70 +85,66 @@ class Config:
11385
11486
11587class InferenceParamsConfig (Serializable ):
116- """Combines default inference parameters with framework -specific configurations."""
88+ """Combines default inference parameters with container -specific configurations."""
11789
11890 default : Optional [InferenceParams ] = Field (default_factory = InferenceParams )
119- frameworks : Optional [List [InferenceFramework ]] = Field (default_factory = list )
91+ containers : Optional [List [InferenceContainer ]] = Field (default_factory = list )
12092
121- def get_merged_params (self , framework_name : str ) -> InferenceParams :
93+ def get_merged_params (self , container_name : str ) -> InferenceParams :
12294 """
123- Merges default inference params with those specific to the given framework .
95+ Merges default inference params with those specific to the given container .
12496
12597 Parameters
12698 ----------
127- framework_name (str): The name of the framework .
99+ container_name (str): The name of the container .
128100
129101 Returns
130102 -------
131103 InferenceParams: The merged inference parameters.
132104 """
133105 merged_params = self .default .to_dict ()
134- for framework in self .frameworks :
135- if framework . framework .lower () == framework_name .lower ():
136- merged_params .update (framework .params or {})
106+ for containers in self .containers :
107+ if containers . name .lower () == container_name .lower ():
108+ merged_params .update (containers .params or {})
137109 break
138110 return InferenceParams (** merged_params )
139111
140112 class Config :
141113 extra = "ignore"
142114
143115
144- class ModelParamsConfig (Serializable ):
145- """Encapsulates the model parameters for different frameworks ."""
116+ class InferenceModelParamsConfig (Serializable ):
117+ """Encapsulates the model parameters for different containers ."""
146118
147119 default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
148- frameworks : Optional [List [ModelFramework ]] = Field (default_factory = list )
120+ containers : Optional [List [ModelParamsContainer ]] = Field (default_factory = list )
149121
150- def get_model_params (
122+ def get_merged_model_params (
151123 self ,
152- framework_name : str ,
124+ container_name : str ,
153125 version : Optional [str ] = None ,
154- task : Optional [str ] = None ,
155126 ) -> Dict [str , Any ]:
156127 """
157- Gets the model parameters for a given framework , version, and tasks ,
128+ Gets the model parameters for a given container , version,
158129 merged with the defaults.
159130
160131 Parameters
161132 ----------
162- framework_name (str): The name of the framework.
163- version (Optional[str]): The specific version of the framework.
164- task (Optional[str]): The specific task.
133+ container_name (str): The name of the container.
134+ version (Optional[str]): The specific version of the container.
165135
166136 Returns
167137 -------
168138 Dict[str, Any]: The merged model parameters.
169139 """
170140 params = deepcopy (self .default )
171141
172- for framework in self .frameworks :
173- if framework .framework .lower () == framework_name .lower () and (
174- not task or task .lower () in framework .task
175- ):
176- params .update (framework .default .to_dict ())
142+ for container in self .containers :
143+ if container .name .lower () == container_name .lower ():
144+ params .update (container .default )
177145
178- if version and version in framework .versions :
179- version_overrides = framework .versions [version ].overrides
146+ if version and version in container .versions :
147+ version_overrides = container .versions [version ].overrides
180148 if version_overrides :
181149 if version_overrides .include :
182150 params .update (version_overrides .include )
@@ -228,59 +196,17 @@ class Config:
228196 extra = "ignore"
229197
230198
231- class EvaluationServiceConfig (Serializable ):
232- """
233- Root configuration class for evaluation setup including model,
234- inference, and shape configurations.
235- """
199+ class ModelParamsConfig (Serializable ):
200+ """Encapsulates the default model parameters."""
236201
237- version : Optional [str ] = "1.0"
238- kind : Optional [str ] = "evaluation"
239- report_params : Optional [ReportParams ] = Field (default_factory = ReportParams )
240- inference_params : Optional [InferenceParamsConfig ] = Field (
241- default_factory = InferenceParamsConfig
242- )
202+ default : Optional [Dict [str , Any ]] = Field (default_factory = dict )
203+
204+
205+ class UIConfig (Serializable ):
243206 model_params : Optional [ModelParamsConfig ] = Field (default_factory = ModelParamsConfig )
244207 shapes : List [ShapeConfig ] = Field (default_factory = list )
245208 metrics : List [MetricConfig ] = Field (default_factory = list )
246209
247- def get_merged_inference_params (self , framework_name : str ) -> InferenceParams :
248- """
249- Merges default inference params with those specific to the given framework.
250-
251- Params
252- ------
253- framework_name (str): The name of the framework.
254-
255- Returns
256- -------
257- InferenceParams: The merged inference parameters.
258- """
259- return self .inference_params .get_merged_params (framework_name = framework_name )
260-
261- def get_merged_model_params (
262- self ,
263- framework_name : str ,
264- version : Optional [str ] = None ,
265- task : Optional [str ] = None ,
266- ) -> Dict [str , Any ]:
267- """
268- Gets the model parameters for a given framework, version, and task, merged with the defaults.
269-
270- Parameters
271- ----------
272- framework_name (str): The name of the framework.
273- version (Optional[str]): The specific version of the framework.
274- task (Optional[str]): The task.
275-
276- Returns
277- -------
278- Dict[str, Any]: The merged model parameters.
279- """
280- return self .model_params .get_model_params (
281- framework_name = framework_name , version = version , task = task
282- )
283-
284210 def search_shapes (
285211 self ,
286212 evaluation_container : Optional [str ] = None ,
@@ -315,3 +241,59 @@ def search_shapes(
315241
316242 class Config :
317243 extra = "ignore"
244+
245+
246+ class EvaluationServiceConfig (Serializable ):
247+ """
248+ Root configuration class for evaluation setup including model,
249+ inference, and shape configurations.
250+ """
251+
252+ version : Optional [str ] = "1.0"
253+ kind : Optional [str ] = "evaluation"
254+ report_params : Optional [ReportParams ] = Field (default_factory = ReportParams )
255+ inference_params : Optional [InferenceParamsConfig ] = Field (
256+ default_factory = InferenceParamsConfig
257+ )
258+ inference_model_params : Optional [InferenceModelParamsConfig ] = Field (
259+ default_factory = InferenceModelParamsConfig
260+ )
261+ ui_config : Optional [UIConfig ] = Field (default_factory = UIConfig )
262+
263+ def get_merged_inference_params (self , container_name : str ) -> InferenceParams :
264+ """
265+ Merges default inference params with those specific to the given container.
266+
267+ Params
268+ ------
269+ container_name (str): The name of the container.
270+
271+ Returns
272+ -------
273+ InferenceParams: The merged inference parameters.
274+ """
275+ return self .inference_params .get_merged_params (container_name = container_name )
276+
277+ def get_merged_inference_model_params (
278+ self ,
279+ container_name : str ,
280+ version : Optional [str ] = None ,
281+ ) -> Dict [str , Any ]:
282+ """
283+ Gets the model parameters for a given container, version, and task, merged with the defaults.
284+
285+ Parameters
286+ ----------
287+ container_name (str): The name of the container.
288+ version (Optional[str]): The specific version of the container.
289+
290+ Returns
291+ -------
292+ Dict[str, Any]: The merged model parameters.
293+ """
294+ return self .inference_model_params .get_merged_model_params (
295+ container_name = container_name , version = version
296+ )
297+
298+ class Config :
299+ extra = "ignore"
0 commit comments