@@ -1158,9 +1158,11 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
11581158
11591159
11601160def build_pydantic_error_message (ex : ValidationError ):
1161- """Added to handle error messages from pydantic model validator.
1161+ """
1162+ Added to handle error messages from pydantic model validator.
11621163 Combine both loc and msg for errors where loc (field) is present in error details, else only build error
1163- message using msg field."""
1164+ message using msg field.
1165+ """
11641166
11651167 return {
11661168 "." .join (map (str , e ["loc" ])): e ["msg" ]
@@ -1185,67 +1187,71 @@ def is_pydantic_model(obj: object) -> bool:
11851187
11861188@cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 5 ), timer = datetime .now ))
11871189def load_gpu_shapes_index (
1188- auth : Optional [Dict ] = None ,
1190+ auth : Optional [Dict [ str , Any ] ] = None ,
11891191) -> GPUShapesIndex :
11901192 """
1191- Loads the GPU shapes index from Object Storage or a local resource folder .
1193+ Load the GPU shapes index, preferring the OS bucket copy over the local one .
11921194
1193- The function first attempts to load the file from an Object Storage bucket using fsspec.
1194- If the loading fails (due to connection issues, missing file, etc.), it falls back to
1195- loading the index from a local file.
1195+ Attempts to read `gpu_shapes_index.json` from OCI Object Storage first;
1196+ if that succeeds, those entries will override the local defaults.
11961197
11971198 Parameters
11981199 ----------
1199- auth: (Dict, optional). Defaults to None.
1200- The default authentication is set using `ads.set_auth` API. If you need to override the
1201- default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
1202- authentication signer and kwargs required to instantiate IdentityClient object.
1200+ auth
1201+ Optional auth dict (as returned by `ads.common.auth.default_signer()`)
1202+ to pass through to `fsspec.open()`.
12031203
12041204 Returns
12051205 -------
1206- GPUShapesIndex: The parsed GPU shapes index.
1206+ GPUShapesIndex
1207+ Merged index where any shape present remotely supersedes the local entry.
12071208
12081209 Raises
12091210 ------
1210- FileNotFoundError: If the GPU shapes index cannot be found in either Object Storage or locally.
1211- json.JSONDecodeError: If the JSON is malformed.
1211+ json.JSONDecodeError
1212+ If any of the JSON is malformed.
12121213 """
12131214 file_name = "gpu_shapes_index.json"
1214- data : Dict [str , Any ] = {}
12151215
1216- # Check if the CONDA_BUCKET_NS environment variable is set.
1216+ # Try remote load
1217+ remote_data : Dict [str , Any ] = {}
12171218 if CONDA_BUCKET_NS :
12181219 try :
12191220 auth = auth or authutil .default_signer ()
1220- # Construct the object storage path. Adjust bucket name and path as needed.
12211221 storage_path = (
12221222 f"oci://{ CONDA_BUCKET_NAME } @{ CONDA_BUCKET_NS } /service_pack/{ file_name } "
12231223 )
1224- logger .debug ("Loading GPU shapes index from Object Storage" )
1225- with fsspec .open (storage_path , mode = "r" , ** auth ) as file_obj :
1226- data = json .load (file_obj )
1227- logger .debug ("Successfully loaded GPU shapes index." )
1228- except Exception as ex :
12291224 logger .debug (
1230- f"Failed to load GPU shapes index from Object Storage. Details: { ex } "
1225+ "Loading GPU shapes index from Object Storage: %s" , storage_path
12311226 )
1232-
1233- # If loading from Object Storage failed, load from the local resource folder.
1234- if not data :
1235- try :
1236- local_path = os .path .join (
1237- os .path .dirname (__file__ ), "../resources" , file_name
1238- )
1239- logger .debug (f"Loading GPU shapes index from { local_path } ." )
1240- with open (local_path ) as file_obj :
1241- data = json .load (file_obj )
1242- logger .debug ("Successfully loaded GPU shapes index." )
1243- except Exception as e :
1227+ with fsspec .open (storage_path , mode = "r" , ** auth ) as f :
1228+ remote_data = json .load (f )
12441229 logger .debug (
1245- f"Failed to load GPU shapes index from { local_path } . Details: { e } "
1230+ "Loaded %d shapes from Object Storage" ,
1231+ len (remote_data .get ("shapes" , {})),
12461232 )
1233+ except Exception as ex :
1234+ logger .debug ("Remote load failed (%s); falling back to local" , ex )
1235+
1236+ # Load local copy
1237+ local_data : Dict [str , Any ] = {}
1238+ local_path = os .path .join (os .path .dirname (__file__ ), "../resources" , file_name )
1239+ try :
1240+ logger .debug ("Loading GPU shapes index from local file: %s" , local_path )
1241+ with open (local_path ) as f :
1242+ local_data = json .load (f )
1243+ logger .debug (
1244+ "Loaded %d shapes from local file" , len (local_data .get ("shapes" , {}))
1245+ )
1246+ except Exception as ex :
1247+ logger .debug ("Local load GPU shapes index failed (%s)" , ex )
1248+
1249+ # Merge: remote shapes override local
1250+ local_shapes = local_data .get ("shapes" , {})
1251+ remote_shapes = remote_data .get ("shapes" , {})
1252+ merged_shapes = {** local_shapes , ** remote_shapes }
12471253
1248- return GPUShapesIndex (** data )
1254+ return GPUShapesIndex (shapes = merged_shapes )
12491255
12501256
12511257def get_preferred_compatible_family (selected_families : set [str ]) -> str :
0 commit comments