From 8db2f2f3fcd76482c1ad90569b47eeb6f9642915 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Tue, 6 May 2025 15:12:36 -0700 Subject: [PATCH] Merge local and remote GPU shapes index, preferring Object Storage overrides --- ads/aqua/common/utils.py | 80 +++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/ads/aqua/common/utils.py b/ads/aqua/common/utils.py index 3b5a15e86..fc9192fa8 100644 --- a/ads/aqua/common/utils.py +++ b/ads/aqua/common/utils.py @@ -1158,9 +1158,11 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]: def build_pydantic_error_message(ex: ValidationError): - """Added to handle error messages from pydantic model validator. + """ + Added to handle error messages from pydantic model validator. Combine both loc and msg for errors where loc (field) is present in error details, else only build error - message using msg field.""" + message using msg field. + """ return { ".".join(map(str, e["loc"])): e["msg"] @@ -1185,67 +1187,71 @@ def is_pydantic_model(obj: object) -> bool: @cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now)) def load_gpu_shapes_index( - auth: Optional[Dict] = None, + auth: Optional[Dict[str, Any]] = None, ) -> GPUShapesIndex: """ - Loads the GPU shapes index from Object Storage or a local resource folder. + Load the GPU shapes index, preferring the OS bucket copy over the local one. - The function first attempts to load the file from an Object Storage bucket using fsspec. - If the loading fails (due to connection issues, missing file, etc.), it falls back to - loading the index from a local file. + Attempts to read `gpu_shapes_index.json` from OCI Object Storage first; + if that succeeds, those entries will override the local defaults. Parameters ---------- - auth: (Dict, optional). Defaults to None. - The default authentication is set using `ads.set_auth` API. If you need to override the - default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate - authentication signer and kwargs required to instantiate IdentityClient object. + auth + Optional auth dict (as returned by `ads.common.auth.default_signer()`) + to pass through to `fsspec.open()`. Returns ------- - GPUShapesIndex: The parsed GPU shapes index. + GPUShapesIndex + Merged index where any shape present remotely supersedes the local entry. Raises ------ - FileNotFoundError: If the GPU shapes index cannot be found in either Object Storage or locally. - json.JSONDecodeError: If the JSON is malformed. + json.JSONDecodeError + If any of the JSON is malformed. """ file_name = "gpu_shapes_index.json" - data: Dict[str, Any] = {} - # Check if the CONDA_BUCKET_NS environment variable is set. + # Try remote load + remote_data: Dict[str, Any] = {} if CONDA_BUCKET_NS: try: auth = auth or authutil.default_signer() - # Construct the object storage path. Adjust bucket name and path as needed. storage_path = ( f"oci://{CONDA_BUCKET_NAME}@{CONDA_BUCKET_NS}/service_pack/{file_name}" ) - logger.debug("Loading GPU shapes index from Object Storage") - with fsspec.open(storage_path, mode="r", **auth) as file_obj: - data = json.load(file_obj) - logger.debug("Successfully loaded GPU shapes index.") - except Exception as ex: logger.debug( - f"Failed to load GPU shapes index from Object Storage. Details: {ex}" + "Loading GPU shapes index from Object Storage: %s", storage_path ) - - # If loading from Object Storage failed, load from the local resource folder. - if not data: - try: - local_path = os.path.join( - os.path.dirname(__file__), "../resources", file_name - ) - logger.debug(f"Loading GPU shapes index from {local_path}.") - with open(local_path) as file_obj: - data = json.load(file_obj) - logger.debug("Successfully loaded GPU shapes index.") - except Exception as e: + with fsspec.open(storage_path, mode="r", **auth) as f: + remote_data = json.load(f) logger.debug( - f"Failed to load GPU shapes index from {local_path}. Details: {e}" + "Loaded %d shapes from Object Storage", + len(remote_data.get("shapes", {})), ) + except Exception as ex: + logger.debug("Remote load failed (%s); falling back to local", ex) + + # Load local copy + local_data: Dict[str, Any] = {} + local_path = os.path.join(os.path.dirname(__file__), "../resources", file_name) + try: + logger.debug("Loading GPU shapes index from local file: %s", local_path) + with open(local_path) as f: + local_data = json.load(f) + logger.debug( + "Loaded %d shapes from local file", len(local_data.get("shapes", {})) + ) + except Exception as ex: + logger.debug("Local load GPU shapes index failed (%s)", ex) + + # Merge: remote shapes override local + local_shapes = local_data.get("shapes", {}) + remote_shapes = remote_data.get("shapes", {}) + merged_shapes = {**local_shapes, **remote_shapes} - return GPUShapesIndex(**data) + return GPUShapesIndex(shapes=merged_shapes) def get_preferred_compatible_family(selected_families: set[str]) -> str: