Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions inference/core/registries/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from inference.core.utils.roboflow import get_model_id_chunks
from inference.models.aliases import resolve_roboflow_model_alias

_MODEL_TYPE_CACHE = {}

GENERIC_MODELS = {
"clip": ("embed", "clip"),
"sam": ("embed", "sam"),
Expand Down Expand Up @@ -88,12 +90,16 @@ def get_model(
Raises:
ModelNotRecognisedError: If the model type is not supported or found.
"""
model_type = get_model_type(
model_id,
api_key,
countinference=countinference,
service_secret=service_secret,
)
cache_key = (api_key, model_id, countinference, service_secret)
model_type = _MODEL_TYPE_CACHE.get(cache_key)
if model_type is None:
model_type = get_model_type(
model_id,
api_key,
countinference=countinference,
service_secret=service_secret,
)
_MODEL_TYPE_CACHE[cache_key] = model_type
logger.debug(f"Model type: {model_type}")

if model_type not in self.registry_dict:
Expand Down Expand Up @@ -162,7 +168,6 @@ def get_model_type(
if dataset_id in GENERIC_MODELS:
logger.debug(f"Loading generic model: {dataset_id}.")
return GENERIC_MODELS[dataset_id]

if MODELS_CACHE_AUTH_ENABLED:
if not _check_if_api_key_has_access_to_model(
api_key=api_key,
Expand All @@ -180,6 +185,7 @@ def get_model_type(

if cached_metadata is not None:
return cached_metadata[0], cached_metadata[1]

if version_id == STUB_VERSION_ID:
if api_key is None:
raise MissingApiKeyError(
Expand Down Expand Up @@ -222,8 +228,6 @@ def get_model_type(
# some older projects do not have type field - hence defaulting
model_type = api_data.get("modelType")
if model_type is None or model_type == "ort":
# some very old model versions do not have modelType reported - and API respond in a generic way -
# then we shall attempt using default model for given task type
model_type = MODEL_TYPE_DEFAULTS.get(project_task_type)
if model_type is None or project_task_type is None:
raise ModelArtefactError("Error loading model artifacts from Roboflow API.")
Expand Down