Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 43 additions & 37 deletions ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,9 +1158,11 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:


def build_pydantic_error_message(ex: ValidationError):
"""Added to handle error messages from pydantic model validator.
"""
Added to handle error messages from pydantic model validator.
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
message using msg field."""
message using msg field.
"""

return {
".".join(map(str, e["loc"])): e["msg"]
Expand All @@ -1185,67 +1187,71 @@ def is_pydantic_model(obj: object) -> bool:

@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
def load_gpu_shapes_index(
auth: Optional[Dict] = None,
auth: Optional[Dict[str, Any]] = None,
) -> GPUShapesIndex:
"""
Loads the GPU shapes index from Object Storage or a local resource folder.
Load the GPU shapes index, preferring the OS bucket copy over the local one.

The function first attempts to load the file from an Object Storage bucket using fsspec.
If the loading fails (due to connection issues, missing file, etc.), it falls back to
loading the index from a local file.
Attempts to read `gpu_shapes_index.json` from OCI Object Storage first;
if that succeeds, those entries will override the local defaults.

Parameters
----------
auth: (Dict, optional). Defaults to None.
The default authentication is set using `ads.set_auth` API. If you need to override the
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
authentication signer and kwargs required to instantiate IdentityClient object.
auth
Optional auth dict (as returned by `ads.common.auth.default_signer()`)
to pass through to `fsspec.open()`.

Returns
-------
GPUShapesIndex: The parsed GPU shapes index.
GPUShapesIndex
Merged index where any shape present remotely supersedes the local entry.

Raises
------
FileNotFoundError: If the GPU shapes index cannot be found in either Object Storage or locally.
json.JSONDecodeError: If the JSON is malformed.
json.JSONDecodeError
If any of the JSON is malformed.
"""
file_name = "gpu_shapes_index.json"
data: Dict[str, Any] = {}

# Check if the CONDA_BUCKET_NS environment variable is set.
# Try remote load
remote_data: Dict[str, Any] = {}
if CONDA_BUCKET_NS:
try:
auth = auth or authutil.default_signer()
# Construct the object storage path. Adjust bucket name and path as needed.
storage_path = (
f"oci://{CONDA_BUCKET_NAME}@{CONDA_BUCKET_NS}/service_pack/{file_name}"
)
logger.debug("Loading GPU shapes index from Object Storage")
with fsspec.open(storage_path, mode="r", **auth) as file_obj:
data = json.load(file_obj)
logger.debug("Successfully loaded GPU shapes index.")
except Exception as ex:
logger.debug(
f"Failed to load GPU shapes index from Object Storage. Details: {ex}"
"Loading GPU shapes index from Object Storage: %s", storage_path
)

# If loading from Object Storage failed, load from the local resource folder.
if not data:
try:
local_path = os.path.join(
os.path.dirname(__file__), "../resources", file_name
)
logger.debug(f"Loading GPU shapes index from {local_path}.")
with open(local_path) as file_obj:
data = json.load(file_obj)
logger.debug("Successfully loaded GPU shapes index.")
except Exception as e:
with fsspec.open(storage_path, mode="r", **auth) as f:
remote_data = json.load(f)
logger.debug(
f"Failed to load GPU shapes index from {local_path}. Details: {e}"
"Loaded %d shapes from Object Storage",
len(remote_data.get("shapes", {})),
)
except Exception as ex:
logger.debug("Remote load failed (%s); falling back to local", ex)

# Load local copy
local_data: Dict[str, Any] = {}
local_path = os.path.join(os.path.dirname(__file__), "../resources", file_name)
try:
logger.debug("Loading GPU shapes index from local file: %s", local_path)
with open(local_path) as f:
local_data = json.load(f)
logger.debug(
"Loaded %d shapes from local file", len(local_data.get("shapes", {}))
)
except Exception as ex:
logger.debug("Local load GPU shapes index failed (%s)", ex)

# Merge: remote shapes override local
local_shapes = local_data.get("shapes", {})
remote_shapes = remote_data.get("shapes", {})
merged_shapes = {**local_shapes, **remote_shapes}

return GPUShapesIndex(**data)
return GPUShapesIndex(shapes=merged_shapes)


def get_preferred_compatible_family(selected_families: set[str]) -> str:
Expand Down