Skip to content

Commit dcfd4ea

Browse files
feat(mm): reidentify models
Add route and model record service method to reidentify a model. This re-probes the model files and replaces the model's config with the new one if it does not error.
1 parent 093f8d6 commit dcfd4ea

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

invokeai/app/api/routers/model_manager.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
UnknownModelException,
2929
)
3030
from invokeai.app.util.suppress_output import SuppressOutput
31-
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
31+
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
3232
from invokeai.backend.model_manager.configs.main import (
3333
Main_Checkpoint_SD1_Config,
3434
Main_Checkpoint_SD2_Config,
@@ -38,6 +38,7 @@
3838
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
3939
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
4040
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
41+
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
4142
from invokeai.backend.model_manager.search import ModelSearch
4243
from invokeai.backend.model_manager.starter_models import (
4344
STARTER_BUNDLES,
@@ -191,6 +192,40 @@ async def get_model_record(
191192
raise HTTPException(status_code=404, detail=str(e))
192193

193194

195+
@model_manager_router.post(
196+
"/i/{key}/reidentify",
197+
operation_id="reidentify_model",
198+
responses={
199+
200: {
200+
"description": "The model configuration was retrieved successfully",
201+
"content": {"application/json": {"example": example_model_config}},
202+
},
203+
400: {"description": "Bad request"},
204+
404: {"description": "The model could not be found"},
205+
},
206+
)
207+
async def reidentify_model(
208+
key: Annotated[str, Path(description="Key of the model to reidentify.")],
209+
) -> AnyModelConfig:
210+
"""Attempt to reidentify a model by re-probing its weights file."""
211+
try:
212+
config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
213+
models_path = ApiDependencies.invoker.services.configuration.models_path
214+
if pathlib.Path(config.path).is_relative_to(models_path):
215+
model_path = pathlib.Path(config.path)
216+
else:
217+
model_path = models_path / config.path
218+
mod = ModelOnDisk(model_path)
219+
result = ModelConfigFactory.from_model_on_disk(mod)
220+
if result.config is None:
221+
raise InvalidModelException("Unable to identify model format")
222+
result.config.key = config.key # retain the same key
223+
new_config = ApiDependencies.invoker.services.model_manager.store.replace_model(config.key, result.config)
224+
return new_config
225+
except UnknownModelException as e:
226+
raise HTTPException(status_code=404, detail=str(e))
227+
228+
194229
class FoundModel(BaseModel):
195230
path: str = Field(description="Path to the model")
196231
is_installed: bool = Field(description="Whether or not the model is already installed")

invokeai/app/services/model_records/model_records_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,18 @@ def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change
138138
"""
139139
pass
140140

141+
@abstractmethod
142+
def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
143+
"""
144+
Replace the model record entirely, returning the new record.
145+
146+
This is used when we re-identify a model and have a new config object.
147+
148+
:param key: Unique key for the model to be updated.
149+
:param new_config: The new model config to write.
150+
"""
151+
pass
152+
141153
@abstractmethod
142154
def get_model(self, key: str) -> AnyModelConfig:
143155
"""

invokeai/app/services/model_records/model_records_sql.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,23 @@ def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change
179179

180180
return self.get_model(key)
181181

182+
def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
183+
if key != new_config.key:
184+
raise ValueError("key does not match new_config.key")
185+
with self._db.transaction() as cursor:
186+
cursor.execute(
187+
"""--sql
188+
UPDATE models
189+
SET
190+
config=?
191+
WHERE id=?;
192+
""",
193+
(new_config.model_dump_json(), key),
194+
)
195+
if cursor.rowcount == 0:
196+
raise UnknownModelException("model not found")
197+
return self.get_model(key)
198+
182199
def get_model(self, key: str) -> AnyModelConfig:
183200
"""
184201
Retrieve the ModelConfigBase instance for the indicated model.

0 commit comments

Comments
 (0)