Skip to content

Commit ae4e38c

Browse files
authored
Merge branch 'main' into bugfix/heathen711/rocm-docker
2 parents 1cdd4b5 + a9f3f1a commit ae4e38c

File tree

369 files changed

+9997
-5230
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

369 files changed

+9997
-5230
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,5 @@ installer/update.bat
190190
installer/update.sh
191191
installer/InvokeAI-Installer/
192192
.aider*
193+
194+
.claude/

docs/contributing/frontend/workflows.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ Migration logic is in [migrations.ts].
297297
<!-- links -->
298298

299299
[pydantic]: https://github.com/pydantic/pydantic 'pydantic'
300-
[zod]: https://github.com/colinhacks/zod 'zod/v4'
300+
[zod]: https://github.com/colinhacks/zod 'zod'
301301
[openapi-types]: https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types 'openapi-types'
302302
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
303303
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions

invokeai/app/api/dependencies.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
1111
from invokeai.app.services.boards.boards_default import BoardService
1212
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
13+
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
1314
from invokeai.app.services.config.config_default import InvokeAIAppConfig
1415
from invokeai.app.services.download.download_default import DownloadQueueService
1516
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
@@ -151,6 +152,7 @@ def initialize(
151152
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
152153
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
153154
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
155+
client_state_persistence = ClientStatePersistenceSqlite(db=db)
154156

155157
services = InvocationServices(
156158
board_image_records=board_image_records,
@@ -181,6 +183,7 @@ def initialize(
181183
style_preset_records=style_preset_records,
182184
style_preset_image_files=style_preset_image_files,
183185
workflow_thumbnails=workflow_thumbnails,
186+
client_state_persistence=client_state_persistence,
184187
)
185188

186189
ApiDependencies.invoker = Invoker(services)

invokeai/app/api/routers/app_info.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from typing import Optional
66

77
import torch
8-
from fastapi import Body
8+
from fastapi import Body, HTTPException, Query
99
from fastapi.routing import APIRouter
10-
from pydantic import BaseModel, Field
10+
from pydantic import BaseModel, Field, JsonValue
1111

1212
from invokeai.app.api.dependencies import ApiDependencies
1313
from invokeai.app.invocations.upscale import ESRGAN_MODELS
@@ -173,3 +173,50 @@ async def disable_invocation_cache() -> None:
173173
async def get_invocation_cache_status() -> InvocationCacheStatus:
174174
"""Clears the invocation cache"""
175175
return ApiDependencies.invoker.services.invocation_cache.get_status()
176+
177+
178+
@app_router.get(
179+
"/client_state",
180+
operation_id="get_client_state_by_key",
181+
response_model=JsonValue | None,
182+
)
183+
async def get_client_state_by_key(
184+
key: str = Query(..., description="Key to get"),
185+
) -> JsonValue | None:
186+
"""Gets the client state"""
187+
try:
188+
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(key)
189+
except Exception as e:
190+
logging.error(f"Error getting client state: {e}")
191+
raise HTTPException(status_code=500, detail="Error setting client state")
192+
193+
194+
@app_router.post(
195+
"/client_state",
196+
operation_id="set_client_state",
197+
response_model=None,
198+
)
199+
async def set_client_state(
200+
key: str = Query(..., description="Key to set"),
201+
value: JsonValue = Body(..., description="Value of the key"),
202+
) -> None:
203+
"""Sets the client state"""
204+
try:
205+
ApiDependencies.invoker.services.client_state_persistence.set_by_key(key, value)
206+
except Exception as e:
207+
logging.error(f"Error setting client state: {e}")
208+
raise HTTPException(status_code=500, detail="Error setting client state")
209+
210+
211+
@app_router.delete(
212+
"/client_state",
213+
operation_id="delete_client_state",
214+
responses={204: {"description": "Client state deleted"}},
215+
)
216+
async def delete_client_state() -> None:
217+
"""Deletes the client state"""
218+
try:
219+
ApiDependencies.invoker.services.client_state_persistence.delete()
220+
except Exception as e:
221+
logging.error(f"Error deleting client state: {e}")
222+
raise HTTPException(status_code=500, detail="Error deleting client state")
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from abc import ABC, abstractmethod
2+
3+
from pydantic import JsonValue
4+
5+
6+
class ClientStatePersistenceABC(ABC):
7+
"""
8+
Base class for client persistence implementations.
9+
This class defines the interface for persisting client data.
10+
"""
11+
12+
@abstractmethod
13+
def set_by_key(self, key: str, value: JsonValue) -> None:
14+
"""
15+
Store the data for the client.
16+
17+
:param data: The client data to be stored.
18+
"""
19+
pass
20+
21+
@abstractmethod
22+
def get_by_key(self, key: str) -> JsonValue | None:
23+
"""
24+
Get the data for the client.
25+
26+
:return: The client data.
27+
"""
28+
pass
29+
30+
@abstractmethod
31+
def delete(self) -> None:
32+
"""
33+
Delete the data for the client.
34+
"""
35+
pass
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import json
2+
3+
from pydantic import JsonValue
4+
5+
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
6+
from invokeai.app.services.invoker import Invoker
7+
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
8+
9+
10+
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
11+
"""
12+
Base class for client persistence implementations.
13+
This class defines the interface for persisting client data.
14+
"""
15+
16+
def __init__(self, db: SqliteDatabase) -> None:
17+
super().__init__()
18+
self._db = db
19+
self._default_row_id = 1
20+
21+
def start(self, invoker: Invoker) -> None:
22+
self._invoker = invoker
23+
24+
def set_by_key(self, key: str, value: JsonValue) -> None:
25+
state = self.get() or {}
26+
state.update({key: value})
27+
28+
with self._db.transaction() as cursor:
29+
cursor.execute(
30+
f"""
31+
INSERT INTO client_state (id, data)
32+
VALUES ({self._default_row_id}, ?)
33+
ON CONFLICT(id) DO UPDATE
34+
SET data = excluded.data;
35+
""",
36+
(json.dumps(state),),
37+
)
38+
39+
def get(self) -> dict[str, JsonValue] | None:
40+
with self._db.transaction() as cursor:
41+
cursor.execute(
42+
f"""
43+
SELECT data FROM client_state
44+
WHERE id = {self._default_row_id}
45+
"""
46+
)
47+
row = cursor.fetchone()
48+
if row is None:
49+
return None
50+
return json.loads(row[0])
51+
52+
def get_by_key(self, key: str) -> JsonValue | None:
53+
state = self.get()
54+
if state is None:
55+
return None
56+
return state.get(key, None)
57+
58+
def delete(self) -> None:
59+
with self._db.transaction() as cursor:
60+
cursor.execute(
61+
f"""
62+
DELETE FROM client_state
63+
WHERE id = {self._default_row_id}
64+
"""
65+
)

invokeai/app/services/invocation_services.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
1818
from invokeai.app.services.boards.boards_base import BoardServiceABC
1919
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
20+
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
2021
from invokeai.app.services.config import InvokeAIAppConfig
2122
from invokeai.app.services.download import DownloadQueueServiceBase
2223
from invokeai.app.services.events.events_base import EventServiceBase
@@ -73,6 +74,7 @@ def __init__(
7374
style_preset_records: "StylePresetRecordsStorageBase",
7475
style_preset_image_files: "StylePresetImageFileStorageBase",
7576
workflow_thumbnails: "WorkflowThumbnailServiceBase",
77+
client_state_persistence: "ClientStatePersistenceABC",
7678
):
7779
self.board_images = board_images
7880
self.board_image_records = board_image_records
@@ -102,3 +104,4 @@ def __init__(
102104
self.style_preset_records = style_preset_records
103105
self.style_preset_image_files = style_preset_image_files
104106
self.workflow_thumbnails = workflow_thumbnails
107+
self.client_state_persistence = client_state_persistence

invokeai/app/services/model_install/model_install_default.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
5252
from invokeai.backend.model_manager.search import ModelSearch
5353
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
54+
from invokeai.backend.model_manager.util.lora_metadata_extractor import apply_lora_metadata
5455
from invokeai.backend.util import InvokeAILogger
5556
from invokeai.backend.util.catch_sigint import catch_sigint
5657
from invokeai.backend.util.devices import TorchDevice
@@ -185,7 +186,8 @@ def install_path(
185186
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
186187

187188
if preferred_name := config.name:
188-
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
189+
# Careful! Don't use pathlib.Path(...).with_suffix - it can will strip everything after the first dot.
190+
preferred_name = f"{preferred_name}{model_path.suffix}"
189191

190192
dest_path = (
191193
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
@@ -667,6 +669,10 @@ def _register(
667669

668670
info = info or self._probe(model_path, config)
669671

672+
# Apply LoRA metadata if applicable
673+
model_images_path = self.app_config.models_path / "model_images"
674+
apply_lora_metadata(info, model_path.resolve(), model_images_path)
675+
670676
model_path = model_path.resolve()
671677

672678
# Models in the Invoke-managed models dir should use relative paths.

invokeai/app/services/shared/sqlite/sqlite_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
2424
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
2525
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
26+
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
2627
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
2728

2829

@@ -63,6 +64,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
6364
migrator.register_migration(build_migration_18())
6465
migrator.register_migration(build_migration_19(app_config=config))
6566
migrator.register_migration(build_migration_20())
67+
migrator.register_migration(build_migration_21())
6668
migrator.run_migrations()
6769

6870
return db
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import sqlite3
2+
3+
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
4+
5+
6+
class Migration21Callback:
7+
def __call__(self, cursor: sqlite3.Cursor) -> None:
8+
cursor.execute(
9+
"""
10+
CREATE TABLE client_state (
11+
id INTEGER PRIMARY KEY CHECK(id = 1),
12+
data TEXT NOT NULL, -- Frontend will handle the shape of this data
13+
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP)
14+
);
15+
"""
16+
)
17+
cursor.execute(
18+
"""
19+
CREATE TRIGGER tg_client_state_updated_at
20+
AFTER UPDATE ON client_state
21+
FOR EACH ROW
22+
BEGIN
23+
UPDATE client_state
24+
SET updated_at = CURRENT_TIMESTAMP
25+
WHERE id = OLD.id;
26+
END;
27+
"""
28+
)
29+
30+
31+
def build_migration_21() -> Migration:
32+
"""Builds the migration object for migrating from version 20 to version 21. This includes:
33+
- Creating the `client_state` table.
34+
- Adding a trigger to update the `updated_at` field on updates.
35+
"""
36+
return Migration(
37+
from_version=20,
38+
to_version=21,
39+
callback=Migration21Callback(),
40+
)

0 commit comments

Comments
 (0)