diff --git a/inference/core/registries/roboflow.py b/inference/core/registries/roboflow.py index fddb1c0099..c706a5e834 100644 --- a/inference/core/registries/roboflow.py +++ b/inference/core/registries/roboflow.py @@ -43,6 +43,8 @@ from inference.core.utils.roboflow import get_model_id_chunks from inference.models.aliases import resolve_roboflow_model_alias +_MODEL_TYPE_CACHE = {} + GENERIC_MODELS = { "clip": ("embed", "clip"), "sam": ("embed", "sam"), @@ -88,12 +90,16 @@ def get_model( Raises: ModelNotRecognisedError: If the model type is not supported or found. """ - model_type = get_model_type( - model_id, - api_key, - countinference=countinference, - service_secret=service_secret, - ) + cache_key = (api_key, model_id, countinference, service_secret) + model_type = _MODEL_TYPE_CACHE.get(cache_key) + if model_type is None: + model_type = get_model_type( + model_id, + api_key, + countinference=countinference, + service_secret=service_secret, + ) + _MODEL_TYPE_CACHE[cache_key] = model_type logger.debug(f"Model type: {model_type}") if model_type not in self.registry_dict: @@ -162,7 +168,6 @@ def get_model_type( if dataset_id in GENERIC_MODELS: logger.debug(f"Loading generic model: {dataset_id}.") return GENERIC_MODELS[dataset_id] - if MODELS_CACHE_AUTH_ENABLED: if not _check_if_api_key_has_access_to_model( api_key=api_key, @@ -180,6 +185,7 @@ def get_model_type( if cached_metadata is not None: return cached_metadata[0], cached_metadata[1] + if version_id == STUB_VERSION_ID: if api_key is None: raise MissingApiKeyError( @@ -222,8 +228,6 @@ def get_model_type( # some older projects do not have type field - hence defaulting model_type = api_data.get("modelType") if model_type is None or model_type == "ort": - # some very old model versions do not have modelType reported - and API respond in a generic way - - # then we shall attempt using default model for given task type model_type = MODEL_TYPE_DEFAULTS.get(project_task_type) if model_type is None or project_task_type is None: raise ModelArtefactError("Error loading model artifacts from Roboflow API.")