Skip to content

Commit a51c74d

Browse files
author
Attila Cseh
committed
SliceConfig refactored
1 parent 67a9f5d commit a51c74d

File tree

9 files changed

+109
-74
lines changed

9 files changed

+109
-74
lines changed

invokeai/frontend/web/src/app/store/store.ts

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ import { addModelSelectedListener } from 'app/store/middleware/listenerMiddlewar
1717
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
1818
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
1919
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
20-
import { deepClone } from 'common/util/deepClone';
21-
import { merge } from 'es-toolkit';
22-
import { omit, pick } from 'es-toolkit/compat';
2320
import { changeBoardModalSliceConfig } from 'features/changeBoardModal/store/slice';
2421
import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice';
2522
import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice';
@@ -117,21 +114,14 @@ const unserialize: UnserializeFunction = (data, key) => {
117114
const { getInitialState, persistConfig } = sliceConfig;
118115
let state;
119116
try {
120-
const initialState = getInitialState();
121-
const parsed = JSON.parse(data);
122-
123-
// We need to inject non-persisted values from initial state into the rehydrated state. These values always are
124-
// required to be in the state, but won't be in the persisted data. Build an object that consists of only these
125-
// values, then merge it with the rehydrated state.
126-
const nonPersistedSubsetOfState = pick(initialState, persistConfig.persistDenylist ?? []);
127-
const stateToMigrate = merge(deepClone(parsed), nonPersistedSubsetOfState);
117+
const parsedState = JSON.parse(data);
128118

129119
// Run migrations to bring old state up to date with the current version.
130-
const migrated = persistConfig.migrate(stateToMigrate);
120+
const migrated = persistConfig.migrate(parsedState);
131121

132122
log.debug(
133123
{
134-
persistedData: parsed as JsonObject,
124+
persistedData: parsedState as JsonObject,
135125
rehydratedData: migrated as JsonObject,
136126
diff: diff(data, migrated) as JsonObject,
137127
},
@@ -146,7 +136,7 @@ const unserialize: UnserializeFunction = (data, key) => {
146136
state = getInitialState();
147137
}
148138

149-
return persistConfig.wrapState ? persistConfig.wrapState(state) : state;
139+
return persistConfig.deserialize ? persistConfig.deserialize(state) : state;
150140
};
151141

152142
const serialize: SerializeFunction = (data, key) => {
@@ -155,12 +145,9 @@ const serialize: SerializeFunction = (data, key) => {
155145
throw new Error(`No persist config for slice "${key}"`);
156146
}
157147

158-
const result = omit(
159-
sliceConfig.persistConfig.unwrapState ? sliceConfig.persistConfig.unwrapState(data) : data,
160-
sliceConfig.persistConfig.persistDenylist ?? []
161-
);
148+
const state = sliceConfig.persistConfig.serialize ? sliceConfig.persistConfig.serialize(data) : data;
162149

163-
return JSON.stringify(result);
150+
return JSON.stringify(state);
164151
};
165152

166153
const PERSISTED_KEYS = Object.values(SLICE_CONFIGS)

invokeai/frontend/web/src/app/store/types.ts

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import type { Slice } from '@reduxjs/toolkit';
22
import type { ZodType } from 'zod';
33

44
type StateFromSlice<T extends Slice> = T extends Slice<infer U> ? U : never;
5+
export type SerializedStateFromDenyList<S, T extends readonly (keyof S)[]> = Omit<S, T[number]>;
56

67
export type SliceConfig<T extends Slice, TInternalState = StateFromSlice<T>, TSerializedState = StateFromSlice<T>> = {
78
/**
@@ -29,22 +30,18 @@ export type SliceConfig<T extends Slice, TInternalState = StateFromSlice<T>, TSe
2930
*/
3031
migrate: (state: unknown) => TSerializedState;
3132
/**
32-
* Keys to omit from the persisted state.
33-
*/
34-
persistDenylist?: (keyof StateFromSlice<T>)[];
35-
/**
36-
* Wraps state into state with history
33+
* Serializes the state
3734
*
38-
* @param state The state without history
39-
* @returns The state with history
35+
* @param state The internal state
36+
* @returns The serialized state
4037
*/
41-
wrapState?: (state: unknown) => TInternalState;
38+
serialize?: (state: TInternalState) => TSerializedState;
4239
/**
43-
* Unwraps state with history
40+
* Deserializes the state
4441
*
45-
* @param state The state with history
46-
* @returns The state without history
42+
* @param state The serialized state
43+
* @returns The internal state
4744
*/
48-
unwrapState?: (state: TInternalState) => TSerializedState;
45+
deserialize?: (state: unknown) => TInternalState;
4946
};
5047
};

invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,7 +1938,7 @@ export const isCanvasInstanceAction = (action: UnknownAction) =>
19381938
isTabParamsStateAction(action) ||
19391939
isCanvasSettingsStateAction(action) ||
19401940
isCanvasStagingAreaStateAction(action);
1941-
export const isCanvasEntityStateAction = isAnyOf(...Object.values(canvasEntityState.actions), canvasReset);
1941+
const isCanvasEntityStateAction = isAnyOf(...Object.values(canvasEntityState.actions), canvasReset);
19421942

19431943
export const {
19441944
// Canvas
@@ -2102,28 +2102,34 @@ export const canvasSliceConfig: SliceConfig<typeof canvasSlice, CanvasStateWithH
21022102
}
21032103
return zCanvasStateWithoutHistory.parse(state);
21042104
},
2105-
wrapState: (state) => {
2106-
const canvasState = state as CanvasState;
2107-
2105+
serialize: (state) => {
21082106
return {
2109-
_version: canvasState._version,
2110-
activeCanvasId: canvasState.activeCanvasId,
2107+
_version: state._version,
2108+
activeCanvasId: state.activeCanvasId,
21112109
canvases: Object.fromEntries(
2112-
Object.entries(canvasState.canvases).map(([canvasId, instance]) => [
2110+
Object.entries(state.canvases).map(([canvasId, instance]) => [
21132111
canvasId,
2114-
{ ...instance, canvas: newHistory([], instance.canvas, []) },
2112+
{
2113+
...instance,
2114+
canvas: instance.canvas.present,
2115+
},
21152116
])
21162117
),
21172118
};
21182119
},
2119-
unwrapState: (state) => {
2120+
deserialize: (state) => {
2121+
const canvasState = state as CanvasState;
2122+
21202123
return {
2121-
_version: state._version,
2122-
activeCanvasId: state.activeCanvasId,
2124+
_version: canvasState._version,
2125+
activeCanvasId: canvasState.activeCanvasId,
21232126
canvases: Object.fromEntries(
2124-
Object.entries(state.canvases).map(([canvasId, instance]) => [
2127+
Object.entries(canvasState.canvases).map(([canvasId, instance]) => [
21252128
canvasId,
2126-
{ ...instance, canvas: instance.canvas.present },
2129+
{
2130+
...instance,
2131+
canvas: newHistory([], instance.canvas, []),
2132+
},
21272133
])
21282134
),
21292135
};

invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ import type { PayloadAction } from '@reduxjs/toolkit';
33
import { createSelector, createSlice } from '@reduxjs/toolkit';
44
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
55
import type { RootState } from 'app/store/store';
6-
import type { SliceConfig } from 'app/store/types';
7-
import { clamp } from 'es-toolkit/compat';
6+
import type { SerializedStateFromDenyList, SliceConfig } from 'app/store/types';
7+
import { clamp, merge, omit } from 'es-toolkit/compat';
88
import { getPrefixedId } from 'features/controlLayers/konva/util';
99
import type {
1010
CroppableImageWithDims,
@@ -279,13 +279,21 @@ export const {
279279
refImagesRecalled,
280280
} = slice.actions;
281281

282-
export const refImagesSliceConfig: SliceConfig<typeof slice> = {
282+
const denyList = ['selectedEntityId', 'isPanelOpen'] as const;
283+
type SerializedRefImagesState = SerializedStateFromDenyList<RefImagesState, typeof denyList>;
284+
285+
export const refImagesSliceConfig: SliceConfig<typeof slice, RefImagesState, SerializedRefImagesState> = {
283286
slice,
284287
schema: zRefImagesState,
285288
getInitialState: getInitialRefImagesState,
286289
persistConfig: {
287290
migrate: (state) => zRefImagesState.parse(state),
288-
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
291+
serialize: (state) => omit(state, denyList),
292+
deserialize: (state) => {
293+
const refImagesState = state as SerializedRefImagesState;
294+
295+
return merge(refImagesState, getInitialRefImagesState());
296+
},
289297
},
290298
};
291299

invokeai/frontend/web/src/features/dynamicPrompts/store/dynamicPromptsSlice.ts

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
22
import { createSelector, createSlice } from '@reduxjs/toolkit';
33
import type { RootState } from 'app/store/store';
4-
import type { SliceConfig } from 'app/store/types';
4+
import type { SerializedStateFromDenyList, SliceConfig } from 'app/store/types';
55
import { buildZodTypeGuard } from 'common/util/zodUtils';
66
import { isPlainObject } from 'es-toolkit';
7+
import { merge, omit } from 'es-toolkit/compat';
78
import { assert } from 'tsafe';
89
import { z } from 'zod';
910

@@ -69,21 +70,30 @@ export const {
6970
seedBehaviourChanged,
7071
} = slice.actions;
7172

72-
export const dynamicPromptsSliceConfig: SliceConfig<typeof slice> = {
73-
slice,
74-
schema: zDynamicPromptsState,
75-
getInitialState,
76-
persistConfig: {
77-
migrate: (state) => {
78-
assert(isPlainObject(state));
79-
if (!('_version' in state)) {
80-
state._version = 1;
81-
}
82-
return zDynamicPromptsState.parse(state);
73+
const denyList = ['prompts', 'parsingError', 'isError', 'isLoading'] as const;
74+
type SerializedDynamicPromptsState = SerializedStateFromDenyList<DynamicPromptsState, typeof denyList>;
75+
76+
export const dynamicPromptsSliceConfig: SliceConfig<typeof slice, DynamicPromptsState, SerializedDynamicPromptsState> =
77+
{
78+
slice,
79+
schema: zDynamicPromptsState,
80+
getInitialState,
81+
persistConfig: {
82+
migrate: (state) => {
83+
assert(isPlainObject(state));
84+
if (!('_version' in state)) {
85+
state._version = 1;
86+
}
87+
return zDynamicPromptsState.parse(state);
88+
},
89+
serialize: (state) => omit(state, denyList),
90+
deserialize: (state) => {
91+
const dynamicPromptsState = state as SerializedDynamicPromptsState;
92+
93+
return merge(dynamicPromptsState, getInitialState());
94+
},
8395
},
84-
persistDenylist: ['prompts', 'parsingError', 'isError', 'isLoading'],
85-
},
86-
};
96+
};
8797

8898
export const selectDynamicPromptsSlice = (state: RootState) => state.dynamicPrompts;
8999
const createDynamicPromptsSelector = <T>(selector: Selector<DynamicPromptsState, T>) =>

invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import type { PayloadAction } from '@reduxjs/toolkit';
22
import { createSlice } from '@reduxjs/toolkit';
33
import type { RootState } from 'app/store/store';
4-
import type { SliceConfig } from 'app/store/types';
4+
import type { SerializedStateFromDenyList, SliceConfig } from 'app/store/types';
55
import { isPlainObject } from 'es-toolkit';
6+
import { merge, omit } from 'es-toolkit/compat';
67
import type { BoardRecordOrderBy } from 'services/api/types';
78
import { assert } from 'tsafe';
89

@@ -176,7 +177,10 @@ export const {
176177

177178
export const selectGallerySlice = (state: RootState) => state.gallery;
178179

179-
export const gallerySliceConfig: SliceConfig<typeof slice> = {
180+
const denyList = ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'] as const;
181+
type SerializedGalleryState = SerializedStateFromDenyList<GalleryState, typeof denyList>;
182+
183+
export const gallerySliceConfig: SliceConfig<typeof slice, GalleryState, SerializedGalleryState> = {
180184
slice,
181185
schema: zGalleryState,
182186
getInitialState,
@@ -188,6 +192,11 @@ export const gallerySliceConfig: SliceConfig<typeof slice> = {
188192
}
189193
return zGalleryState.parse(state);
190194
},
191-
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'imageToCompare'],
195+
serialize: (state) => omit(state, denyList),
196+
deserialize: (state) => {
197+
const galleryState = state as SerializedGalleryState;
198+
199+
return merge(galleryState, getInitialState());
200+
},
192201
},
193202
};

invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import type { PayloadAction } from '@reduxjs/toolkit';
22
import { createSelector, createSlice } from '@reduxjs/toolkit';
33
import type { RootState } from 'app/store/store';
4-
import type { SliceConfig } from 'app/store/types';
4+
import type { SerializedStateFromDenyList, SliceConfig } from 'app/store/types';
55
import { isPlainObject } from 'es-toolkit';
6+
import { merge, omit } from 'es-toolkit/compat';
67
import { zModelType } from 'features/nodes/types/common';
78
import { assert } from 'tsafe';
89
import z from 'zod';
@@ -67,7 +68,10 @@ export const {
6768
shouldInstallInPlaceChanged,
6869
} = slice.actions;
6970

70-
export const modelManagerSliceConfig: SliceConfig<typeof slice> = {
71+
const denyList = ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'] as const;
72+
type SerializedModelManagerState = SerializedStateFromDenyList<ModelManagerState, typeof denyList>;
73+
74+
export const modelManagerSliceConfig: SliceConfig<typeof slice, ModelManagerState, SerializedModelManagerState> = {
7175
slice,
7276
schema: zModelManagerState,
7377
getInitialState,
@@ -79,7 +83,12 @@ export const modelManagerSliceConfig: SliceConfig<typeof slice> = {
7983
}
8084
return zModelManagerState.parse(state);
8185
},
82-
persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'],
86+
serialize: (state) => omit(state, denyList),
87+
deserialize: (state) => {
88+
const modelManagerState = state as SerializedModelManagerState;
89+
90+
return merge(modelManagerState, getInitialState());
91+
},
8392
},
8493
};
8594

invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,12 +819,12 @@ export const nodesSliceConfig: SliceConfig<typeof slice, StateWithHistory<NodesS
819819
}
820820
return zNodesState.parse(state);
821821
},
822-
wrapState: (state) => {
822+
serialize: (state) => state.present,
823+
deserialize: (state) => {
823824
const nodesState = state as NodesState;
824825

825826
return newHistory([], nodesState, []);
826827
},
827-
unwrapState: (state) => state.present,
828828
},
829829
};
830830

invokeai/frontend/web/src/features/ui/store/uiSlice.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import type { PayloadAction } from '@reduxjs/toolkit';
22
import { createSlice } from '@reduxjs/toolkit';
33
import type { RootState } from 'app/store/store';
4-
import type { SliceConfig } from 'app/store/types';
4+
import type { SerializedStateFromDenyList, SliceConfig } from 'app/store/types';
55
import { isPlainObject } from 'es-toolkit';
6+
import { merge, omit } from 'es-toolkit/compat';
67
import { assert } from 'tsafe';
78

89
import { getInitialUIState, type UIState, zUIState } from './uiTypes';
@@ -87,7 +88,10 @@ export const {
8788

8889
export const selectUiSlice = (state: RootState) => state.ui;
8990

90-
export const uiSliceConfig: SliceConfig<typeof slice> = {
91+
const denyList = ['shouldShowItemDetails'] as const;
92+
type SerializedUIState = SerializedStateFromDenyList<UIState, typeof denyList>;
93+
94+
export const uiSliceConfig: SliceConfig<typeof slice, UIState, SerializedUIState> = {
9195
slice,
9296
schema: zUIState,
9397
getInitialState: getInitialUIState,
@@ -111,6 +115,11 @@ export const uiSliceConfig: SliceConfig<typeof slice> = {
111115
}
112116
return zUIState.parse(state);
113117
},
114-
persistDenylist: ['shouldShowItemDetails'],
118+
serialize: (state) => omit(state, denyList),
119+
deserialize: (state) => {
120+
const uiState = state as SerializedUIState;
121+
122+
return merge(uiState, getInitialUIState());
123+
},
115124
},
116125
};

0 commit comments

Comments
 (0)