1515 build_pydantic_error_message ,
1616 get_resource_type ,
1717 load_config ,
18+ load_gpu_shapes_index ,
1819)
1920from ads .aqua .shaperecommend .constants import (
2021 SAFETENSORS ,
22+ SHAPE_MAP ,
2123 TEXT_GENERATION ,
2224 TROUBLESHOOT_MSG ,
2325)
3032 ShapeReport ,
3133)
3234from ads .model .datascience_model import DataScienceModel
35+ from ads .model .service .oci_datascience_model_deployment import (
36+ OCIDataScienceModelDeployment ,
37+ )
3338
3439
35- class AquaShapeRecommend ( BaseModel ) :
40+ class AquaShapeRecommend :
3641 """
3742 Interface for recommending GPU shapes for machine learning model deployments
3843 on Oracle Cloud Infrastructure Data Science service.
@@ -42,7 +47,7 @@ class AquaShapeRecommend(BaseModel):
4247 Must be used within a properly configured and authenticated OCI environment.
4348 """
4449
45- def which_shapes (self , ** kwargs ) -> Union [ShapeRecommendationReport , Table ]:
50+ def which_shapes (self , request : RequestRecommend ) -> Union [ShapeRecommendationReport , Table ]:
4651 """
4752 Lists valid GPU deployment shapes for the provided model and configuration.
4853
@@ -77,7 +82,8 @@ def which_shapes(self, **kwargs) -> Union[ShapeRecommendationReport, Table]:
7782 If parameters are missing or invalid, or if no valid sequence length is requested.
7883 """
7984 try :
80- request = RequestRecommend (** kwargs )
85+ shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
86+
8187 ds_model = self ._validate_model_ocid (request .model_id )
8288 data = self ._get_model_config (ds_model )
8389
@@ -86,7 +92,7 @@ def which_shapes(self, **kwargs) -> Union[ShapeRecommendationReport, Table]:
8692 model_name = ds_model .display_name if ds_model .display_name else ""
8793
8894 shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
89- llm_config , request . shapes , model_name
95+ llm_config , shapes , model_name
9096 )
9197
9298 if request .generate_table and shape_recommendation_report .recommendations :
@@ -107,10 +113,61 @@ def which_shapes(self, **kwargs) -> Union[ShapeRecommendationReport, Table]:
107113 ) from ex
108114 except AquaValueError as ex :
109115 logger .error (f"Error with LLM config: { ex } " )
110- raise
116+ raise AquaValueError ( # noqa: B904
117+ f"An error occured while producing recommendations: { ex } "
118+ )
111119
112120 return shape_recommendation_report
113121
122+ def valid_compute_shapes (self , compartment_id : str ) -> List ["ComputeShapeSummary" ]:
123+ """
124+ Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
125+
126+ Parameters
127+ ----------
128+ file : str
129+ Path to the JSON file containing shape data.
130+
131+ Returns
132+ -------
133+ List[ComputeShapeSummary]
134+ List of ComputeShapeSummary objects passing the checks.
135+
136+ Raises
137+ ------
138+ ValueError
139+ If the file cannot be opened, parsed, or the 'shapes' key is missing.
140+ """
141+ oci_shapes = OCIDataScienceModelDeployment .shapes (compartment_id = compartment_id )
142+ set_user_shapes = {shape .name : shape for shape in oci_shapes }
143+
144+ gpu_shapes_metadata = load_gpu_shapes_index ().shapes
145+
146+ valid_shapes = []
147+ # only loops through GPU shapes, update later to include CPU shapes
148+ for name , spec in gpu_shapes_metadata .items ():
149+ if name in set_user_shapes :
150+ oci_shape = set_user_shapes .get (name )
151+
152+ compute_shape = ComputeShapeSummary (
153+ available = True ,
154+ core_count = oci_shape .core_count ,
155+ memory_in_gbs = oci_shape .memory_in_gbs ,
156+ shape_series = SHAPE_MAP .get (oci_shape .shape_series , "GPU" ),
157+ name = oci_shape .name ,
158+ gpu_specs = spec ,
159+ )
160+ else :
161+ compute_shape = ComputeShapeSummary (
162+ available = False , name = name , shape_series = "GPU" , gpu_specs = spec
163+ )
164+ valid_shapes .append (compute_shape )
165+
166+ valid_shapes .sort (
167+ key = lambda shape : shape .gpu_specs .gpu_memory_in_gbs , reverse = True
168+ )
169+ return valid_shapes
170+
114171 @staticmethod
115172 def _rich_diff_table (shape_report : ShapeRecommendationReport ) -> Table :
116173 """
@@ -321,7 +378,7 @@ def _summarize_shapes_for_seq_lens(
321378 recommendations = []
322379
323380 if not shapes :
324- raise ValueError (
381+ raise AquaValueError (
325382 "No GPU shapes were passed for recommendation. Ensure shape parsing succeeded."
326383 )
327384
0 commit comments