diff --git a/packages/toolkit/src/createDraftSafeSelector.ts b/packages/toolkit/src/createDraftSafeSelector.ts index 9e1bb0a6da..57348786e8 100644 --- a/packages/toolkit/src/createDraftSafeSelector.ts +++ b/packages/toolkit/src/createDraftSafeSelector.ts @@ -1,18 +1,30 @@ -import { current, isDraft } from 'immer' import { createSelectorCreator, defaultMemoize } from 'reselect' +import type { ImmutableHelpers } from './tsHelpers' +import { immutableHelpers } from './immer' -export const createDraftSafeSelectorCreator: typeof createSelectorCreator = ( - ...args: unknown[] -) => { - const createSelector = (createSelectorCreator as any)(...args) - return (...args: unknown[]) => { - const selector = createSelector(...args) - const wrappedSelector = (value: unknown, ...rest: unknown[]) => - selector(isDraft(value) ? current(value) : value, ...rest) - return wrappedSelector as any +export type BuildCreateDraftSafeSelectorConfiguration = Pick< + ImmutableHelpers, + 'isDraft' | 'current' +> + +export function buildCreateDraftSafeSelectorCreator({ + isDraft, + current, +}: BuildCreateDraftSafeSelectorConfiguration): typeof createSelectorCreator { + return function createDraftSafeSelectorCreator(...args: unknown[]) { + const createSelector = (createSelectorCreator as any)(...args) + return function createDraftSafeSelector(...args: unknown[]) { + const selector = (createSelector as any)(...args) + const wrappedSelector = (value: unknown, ...rest: unknown[]) => + selector(isDraft(value) ? current(value) : value, ...rest) + return wrappedSelector as any + } } } +export const createDraftSafeSelectorCreator = + buildCreateDraftSafeSelectorCreator(immutableHelpers) + /** * "Draft-Safe" version of `reselect`'s `createSelector`: * If an `immer`-drafted object is passed into the resulting selector's first argument, diff --git a/packages/toolkit/src/createReducer.ts b/packages/toolkit/src/createReducer.ts index f0f2ac50be..3a1f8b9733 100644 --- a/packages/toolkit/src/createReducer.ts +++ b/packages/toolkit/src/createReducer.ts @@ -1,10 +1,9 @@ import type { Draft } from 'immer' -import { produce as createNextState, isDraft, isDraftable } from 'immer' -import type { Action, Reducer, UnknownAction } from 'redux' +import type { UnknownAction, Action, Reducer } from 'redux' import type { ActionReducerMapBuilder } from './mapBuilders' import { executeReducerBuilderCallback } from './mapBuilders' -import type { NoInfer, TypeGuard } from './tsHelpers' -import { freezeDraftable } from './utils' +import type { ImmutableHelpers, NoInfer, TypeGuard } from './tsHelpers' +import { immutableHelpers } from './immer' /** * Defines a mapping from action types to corresponding action object shapes. @@ -73,7 +72,8 @@ export type ReducerWithInitialState> = Reducer & { getInitialState: () => S } -/** +export type CreateReducer = { + /** * A utility function that allows defining a reducer as a mapping from action * type to *case reducer* functions that handle these action types. The * reducer's initial state is passed as the first argument. @@ -137,85 +137,105 @@ const reducer = createReducer( ``` * @public */ -export function createReducer>( - initialState: S | (() => S), - mapOrBuilderCallback: (builder: ActionReducerMapBuilder) => void -): ReducerWithInitialState { - if (process.env.NODE_ENV !== 'production') { - if (typeof mapOrBuilderCallback === 'object') { - throw new Error( - "The object notation for `createReducer` has been removed. Please use the 'builder callback' notation instead: https://redux-toolkit.js.org/api/createReducer" - ) - } - } + >( + initialState: S | (() => S), + builderCallback: (builder: ActionReducerMapBuilder) => void + ): ReducerWithInitialState +} + +export type BuildCreateReducerConfiguration = Pick< + ImmutableHelpers, + 'createNextState' | 'isDraft' | 'isDraftable' | 'freeze' +> - let [actionsMap, finalActionMatchers, finalDefaultCaseReducer] = - executeReducerBuilderCallback(mapOrBuilderCallback) +export function buildCreateReducer({ + createNextState, + isDraft, + isDraftable, + freeze, +}: BuildCreateReducerConfiguration): CreateReducer { + return function createReducer>( + initialState: S | (() => S), + mapOrBuilderCallback: (builder: ActionReducerMapBuilder) => void + ): ReducerWithInitialState { + if (process.env.NODE_ENV !== 'production') { + if (typeof mapOrBuilderCallback === 'object') { + throw new Error( + "The object notation for `createReducer` has been removed. Please use the 'builder callback' notation instead: https://redux-toolkit.js.org/api/createReducer" + ) + } + } - // Ensure the initial state gets frozen either way (if draftable) - let getInitialState: () => S - if (isStateFunction(initialState)) { - getInitialState = () => freezeDraftable(initialState()) - } else { - const frozenInitialState = freezeDraftable(initialState) - getInitialState = () => frozenInitialState - } + let [actionsMap, finalActionMatchers, finalDefaultCaseReducer] = + executeReducerBuilderCallback(mapOrBuilderCallback) - function reducer(state = getInitialState(), action: any): S { - let caseReducers = [ - actionsMap[action.type], - ...finalActionMatchers - .filter(({ matcher }) => matcher(action)) - .map(({ reducer }) => reducer), - ] - if (caseReducers.filter((cr) => !!cr).length === 0) { - caseReducers = [finalDefaultCaseReducer] + // Ensure the initial state gets frozen either way (if draftable) + let getInitialState: () => S + if (isStateFunction(initialState)) { + getInitialState = () => freeze(initialState(), true) + } else { + const frozenInitialState = freeze(initialState, true) + getInitialState = () => frozenInitialState } - return caseReducers.reduce((previousState, caseReducer): S => { - if (caseReducer) { - if (isDraft(previousState)) { - // If it's already a draft, we must already be inside a `createNextState` call, - // likely because this is being wrapped in `createReducer`, `createSlice`, or nested - // inside an existing draft. It's safe to just pass the draft to the mutator. - const draft = previousState as Draft // We can assume this is already a draft - const result = caseReducer(draft, action) - - if (result === undefined) { - return previousState - } + function reducer(state = getInitialState(), action: any): S { + let caseReducers = [ + actionsMap[action.type], + ...finalActionMatchers + .filter(({ matcher }) => matcher(action)) + .map(({ reducer }) => reducer), + ] + if (caseReducers.filter((cr) => !!cr).length === 0) { + caseReducers = [finalDefaultCaseReducer] + } - return result as S - } else if (!isDraftable(previousState)) { - // If state is not draftable (ex: a primitive, such as 0), we want to directly - // return the caseReducer func and not wrap it with produce. - const result = caseReducer(previousState as any, action) + return caseReducers.reduce((previousState, caseReducer): S => { + if (caseReducer) { + if (isDraft(previousState)) { + // If it's already a draft, we must already be inside a `createNextState` call, + // likely because this is being wrapped in `createReducer`, `createSlice`, or nested + // inside an existing draft. It's safe to just pass the draft to the mutator. + const draft = previousState as Draft // We can assume this is already a draft + const result = caseReducer(draft, action) - if (result === undefined) { - if (previousState === null) { + if (result === undefined) { return previousState } - throw Error( - 'A case reducer on a non-draftable value must not return undefined' - ) - } - return result as S - } else { - // @ts-ignore createNextState() produces an Immutable> rather - // than an Immutable, and TypeScript cannot find out how to reconcile - // these two types. - return createNextState(previousState, (draft: Draft) => { - return caseReducer(draft, action) - }) + return result as S + } else if (!isDraftable(previousState)) { + // If state is not draftable (ex: a primitive, such as 0), we want to directly + // return the caseReducer func and not wrap it with produce. + const result = caseReducer(previousState as any, action) + + if (result === undefined) { + if (previousState === null) { + return previousState + } + throw Error( + 'A case reducer on a non-draftable value must not return undefined' + ) + } + + return result as S + } else { + // @ts-ignore createNextState() produces an Immutable> rather + // than an Immutable, and TypeScript cannot find out how to reconcile + // these two types. + return createNextState(previousState, (draft: Draft) => { + return caseReducer(draft, action) + }) + } } - } - return previousState - }, state) - } + return previousState + }, state) + } - reducer.getInitialState = getInitialState + reducer.getInitialState = getInitialState - return reducer as ReducerWithInitialState + return reducer as ReducerWithInitialState + } } + +export const createReducer = buildCreateReducer(immutableHelpers) diff --git a/packages/toolkit/src/createSlice.ts b/packages/toolkit/src/createSlice.ts index e102e3dfbd..dcdbdc12ca 100644 --- a/packages/toolkit/src/createSlice.ts +++ b/packages/toolkit/src/createSlice.ts @@ -7,10 +7,15 @@ import type { _ActionCreatorWithPreparedPayload, } from './createAction' import { createAction } from './createAction' -import type { CaseReducer, ReducerWithInitialState } from './createReducer' -import { createReducer } from './createReducer' +import type { + BuildCreateReducerConfiguration, + CaseReducer, + ReducerWithInitialState, +} from './createReducer' +import { buildCreateReducer } from './createReducer' import type { ActionReducerMapBuilder } from './mapBuilders' import { executeReducerBuilderCallback } from './mapBuilders' +import { immutableHelpers } from './immer' import type { Id, Tail } from './tsHelpers' import type { InjectConfig } from './combineSlices' import type { @@ -538,188 +543,222 @@ function getType(slice: string, actionKey: string): string { return `${slice}/${actionKey}` } -/** - * A function that accepts an initial state, an object full of reducer - * functions, and a "slice name", and automatically generates - * action creators and action types that correspond to the - * reducers and state. - * - * @public - */ -export function createSlice< - State, - CaseReducers extends SliceCaseReducers, - Name extends string, - Selectors extends SliceSelectors, - ReducerPath extends string = Name ->( - options: CreateSliceOptions -): Slice { - const { name, reducerPath = name as unknown as ReducerPath } = options - if (!name) { - throw new Error('`name` is a required option for createSlice') - } - - if ( - typeof process !== 'undefined' && - process.env.NODE_ENV === 'development' - ) { - if (options.initialState === undefined) { - console.error( - 'You must provide an `initialState` value that is not `undefined`. You may have misspelled `initialState`' - ) - } - } - - const reducers = - (typeof options.reducers === 'function' - ? options.reducers(buildReducerCreators()) - : options.reducers) || {} +export type CreateSlice = { + /** + * A function that accepts an initial state, an object full of reducer + * functions, and a "slice name", and automatically generates + * action creators and action types that correspond to the + * reducers and state. + * + * The `reducer` argument is passed to `createReducer()`. + * + * @public + */ + < + State, + CaseReducers extends SliceCaseReducers, + Name extends string, + Selectors extends SliceSelectors, + ReducerPath extends string = Name + >( + options: CreateSliceOptions< + State, + CaseReducers, + Name, + ReducerPath, + Selectors + > + ): Slice +} - const reducerNames = Object.keys(reducers) +export interface BuildCreateSliceConfiguration + extends BuildCreateReducerConfiguration {} - const context: ReducerHandlingContext = { - sliceCaseReducersByName: {}, - sliceCaseReducersByType: {}, - actionCreators: {}, - } +export function buildCreateSlice( + configuration: BuildCreateSliceConfiguration +): CreateSlice { + const createReducer = buildCreateReducer(configuration) - reducerNames.forEach((reducerName) => { - const reducerDefinition = reducers[reducerName] - const reducerDetails: ReducerDetails = { - reducerName, - type: getType(name, reducerName), - createNotation: typeof options.reducers === 'function', - } - if (isAsyncThunkSliceReducerDefinition(reducerDefinition)) { - handleThunkCaseReducerDefinition( - reducerDetails, - reducerDefinition, - context - ) - } else { - handleNormalReducerDefinition( - reducerDetails, - reducerDefinition, - context - ) + return function createSlice< + State, + CaseReducers extends SliceCaseReducers, + Name extends string, + Selectors extends SliceSelectors, + ReducerPath extends string = Name + >( + options: CreateSliceOptions< + State, + CaseReducers, + Name, + ReducerPath, + Selectors + > + ): Slice { + const { name, reducerPath = name as unknown as ReducerPath } = options + if (!name) { + throw new Error('`name` is a required option for createSlice') } - }) - function buildReducer() { - if (process.env.NODE_ENV !== 'production') { - if (typeof options.extraReducers === 'object') { - throw new Error( - "The object notation for `createSlice.extraReducers` has been removed. Please use the 'builder callback' notation instead: https://redux-toolkit.js.org/api/createSlice" + if ( + typeof process !== 'undefined' && + process.env.NODE_ENV === 'development' + ) { + if (options.initialState === undefined) { + console.error( + 'You must provide an `initialState` value that is not `undefined`. You may have misspelled `initialState`' ) } } - const [ - extraReducers = {}, - actionMatchers = [], - defaultCaseReducer = undefined, - ] = - typeof options.extraReducers === 'function' - ? executeReducerBuilderCallback(options.extraReducers) - : [options.extraReducers] - - const finalCaseReducers = { - ...extraReducers, - ...context.sliceCaseReducersByType, + const reducers = + (typeof options.reducers === 'function' + ? options.reducers(buildReducerCreators()) + : options.reducers) || {} + + const reducerNames = Object.keys(reducers) + + const context: ReducerHandlingContext = { + sliceCaseReducersByName: {}, + sliceCaseReducersByType: {}, + actionCreators: {}, } - return createReducer(options.initialState, (builder) => { - for (let key in finalCaseReducers) { - builder.addCase(key, finalCaseReducers[key] as CaseReducer) - } - for (let m of actionMatchers) { - builder.addMatcher(m.matcher, m.reducer) + reducerNames.forEach((reducerName) => { + const reducerDefinition = reducers[reducerName] + const reducerDetails: ReducerDetails = { + reducerName, + type: getType(name, reducerName), + createNotation: typeof options.reducers === 'function', } - if (defaultCaseReducer) { - builder.addDefaultCase(defaultCaseReducer) + if (isAsyncThunkSliceReducerDefinition(reducerDefinition)) { + handleThunkCaseReducerDefinition( + reducerDetails, + reducerDefinition, + context + ) + } else { + handleNormalReducerDefinition( + reducerDetails, + reducerDefinition, + context + ) } }) - } - - const defaultSelectSlice = ( - rootState: { [K in ReducerPath]: State } - ): State => rootState[reducerPath] - const selectSelf = (state: State) => state - - const injectedSelectorCache = new WeakMap< - Slice, - WeakMap< - (rootState: any) => State | undefined, - Record any> - > - >() + function buildReducer() { + if (process.env.NODE_ENV !== 'production') { + if (typeof options.extraReducers === 'object') { + throw new Error( + "The object notation for `createSlice.extraReducers` has been removed. Please use the 'builder callback' notation instead: https://redux-toolkit.js.org/api/createSlice" + ) + } + } + const [ + extraReducers = {}, + actionMatchers = [], + defaultCaseReducer = undefined, + ] = + typeof options.extraReducers === 'function' + ? executeReducerBuilderCallback(options.extraReducers) + : [options.extraReducers] + + const finalCaseReducers = { + ...extraReducers, + ...context.sliceCaseReducersByType, + } - let _reducer: ReducerWithInitialState + return createReducer(options.initialState, (builder) => { + for (let key in finalCaseReducers) { + builder.addCase(key, finalCaseReducers[key] as CaseReducer) + } + for (let m of actionMatchers) { + builder.addMatcher(m.matcher, m.reducer) + } + if (defaultCaseReducer) { + builder.addDefaultCase(defaultCaseReducer) + } + }) + } - const slice: Slice = { - name, - reducerPath, - reducer(state, action) { - if (!_reducer) _reducer = buildReducer() + const defaultSelectSlice = ( + rootState: { [K in ReducerPath]: State } + ): State => rootState[reducerPath] - return _reducer(state, action) - }, - actions: context.actionCreators as any, - caseReducers: context.sliceCaseReducersByName as any, - getInitialState() { - if (!_reducer) _reducer = buildReducer() + const selectSelf = (state: State) => state - return _reducer.getInitialState() - }, - getSelectors(selectState: (rootState: any) => State = selectSelf) { - let selectorCache = injectedSelectorCache.get(this) - if (!selectorCache) { - selectorCache = new WeakMap() - injectedSelectorCache.set(this, selectorCache) - } - let cached = selectorCache.get(selectState) - if (!cached) { - cached = {} - for (const [name, selector] of Object.entries( - options.selectors ?? {} - )) { - cached[name] = (rootState: any, ...args: any[]) => { - let sliceState = selectState(rootState) - if (typeof sliceState === 'undefined') { - // check if injectInto has been called - if (this !== slice) { - sliceState = this.getInitialState() - } else if (process.env.NODE_ENV !== 'production') { - throw new Error( - 'selectState returned undefined for an uninjected slice reducer' - ) + const injectedSelectorCache = new WeakMap< + Slice, + WeakMap< + (rootState: any) => State | undefined, + Record any> + > + >() + + let _reducer: ReducerWithInitialState + + const slice: Slice = { + name, + reducerPath, + reducer(state, action) { + if (!_reducer) _reducer = buildReducer() + + return _reducer(state, action) + }, + actions: context.actionCreators as any, + caseReducers: context.sliceCaseReducersByName as any, + getInitialState() { + if (!_reducer) _reducer = buildReducer() + + return _reducer.getInitialState() + }, + getSelectors(selectState: (rootState: any) => State = selectSelf) { + let selectorCache = injectedSelectorCache.get(this) + if (!selectorCache) { + selectorCache = new WeakMap() + injectedSelectorCache.set(this, selectorCache) + } + let cached = selectorCache.get(selectState) + if (!cached) { + cached = {} + for (const [name, selector] of Object.entries( + options.selectors ?? {} + )) { + cached[name] = (rootState: any, ...args: any[]) => { + let sliceState = selectState(rootState) + if (typeof sliceState === 'undefined') { + // check if injectInto has been called + if (this !== slice) { + sliceState = this.getInitialState() + } else if (process.env.NODE_ENV !== 'production') { + throw new Error( + 'selectState returned undefined for an uninjected slice reducer' + ) + } } + return selector(sliceState, ...args) } - return selector(sliceState, ...args) } + selectorCache.set(selectState, cached) } - selectorCache.set(selectState, cached) - } - return cached as any - }, - get selectors() { - return this.getSelectors(defaultSelectSlice) - }, - injectInto(injectable, { reducerPath: pathOpt, ...config } = {}) { - const reducerPath = pathOpt ?? this.reducerPath - injectable.inject({ reducerPath, reducer: this.reducer }, config) - const selectSlice = (state: any) => state[reducerPath] - return { - ...this, - reducerPath, - get selectors() { - return this.getSelectors(selectSlice) - }, - } as any - }, + return cached as any + }, + get selectors() { + return this.getSelectors(defaultSelectSlice) + }, + injectInto(injectable, { reducerPath: pathOpt, ...config } = {}) { + const reducerPath = pathOpt ?? this.reducerPath + injectable.inject({ reducerPath, reducer: this.reducer }, config) + const selectSlice = (state: any) => state[reducerPath] + return { + ...this, + reducerPath, + get selectors() { + return this.getSelectors(selectSlice) + }, + } as any + }, + } + return slice } - return slice } interface ReducerHandlingContext { @@ -851,3 +890,5 @@ function handleThunkCaseReducerDefinition( } function noop() {} + +export const createSlice = buildCreateSlice(immutableHelpers) diff --git a/packages/toolkit/src/entities/create_adapter.ts b/packages/toolkit/src/entities/create_adapter.ts index 83ad6c6515..c94e3e1c50 100644 --- a/packages/toolkit/src/entities/create_adapter.ts +++ b/packages/toolkit/src/entities/create_adapter.ts @@ -6,53 +6,59 @@ import type { EntityId, } from './models' import { createInitialStateFactory } from './entity_state' -import { createSelectorsFactory } from './state_selectors' -import { createSortedStateAdapter } from './sorted_state_adapter' -import { createUnsortedStateAdapter } from './unsorted_state_adapter' +import { buildCreateSelectorsFactory } from './state_selectors' +import { buildCreateSortedStateAdapter } from './sorted_state_adapter' +import { buildCreateUnsortedStateAdapter } from './unsorted_state_adapter' +import type { BuildCreateDraftSafeSelectorConfiguration } from '..' +import type { BuildStateOperatorConfiguration } from './state_adapter' +import { immutableHelpers } from '../immer' -export interface EntityAdapterOptions { - selectId?: IdSelector - sortComparer?: false | Comparer -} - -export function createEntityAdapter(options: { - selectId: IdSelector - sortComparer?: false | Comparer -}): EntityAdapter +export interface BuildCreateEntityAdapterConfiguration + extends BuildCreateDraftSafeSelectorConfiguration, + BuildStateOperatorConfiguration {} -export function createEntityAdapter(options?: { - sortComparer?: false | Comparer -}): EntityAdapter - -/** - * - * @param options - * - * @public - */ -export function createEntityAdapter( - options: { - selectId?: IdSelector +export type CreateEntityAdapter = { + (options?: { + selectId: IdSelector sortComparer?: false | Comparer - } = {} -): EntityAdapter { - const { selectId, sortComparer }: EntityDefinition = { - sortComparer: false, - selectId: (instance: any) => instance.id, - ...options, - } + }): EntityAdapter + (options?: { + sortComparer?: false | Comparer + }): EntityAdapter +} - const stateFactory = createInitialStateFactory() - const selectorsFactory = createSelectorsFactory() - const stateAdapter = sortComparer - ? createSortedStateAdapter(selectId, sortComparer) - : createUnsortedStateAdapter(selectId) +export function buildCreateEntityAdapter( + config: BuildCreateEntityAdapterConfiguration +): CreateEntityAdapter { + const createSelectorsFactory = buildCreateSelectorsFactory(config) + const createUnsortedStateAdapter = buildCreateUnsortedStateAdapter(config) + const createSortedStateAdapter = buildCreateSortedStateAdapter(config) + return function createEntityAdapter( + options: { + selectId?: IdSelector + sortComparer?: false | Comparer + } = {} + ): EntityAdapter { + const { selectId, sortComparer }: EntityDefinition = { + sortComparer: false, + selectId: (instance: any) => instance.id, + ...options, + } - return { - selectId, - sortComparer, - ...stateFactory, - ...selectorsFactory, - ...stateAdapter, + const stateFactory = createInitialStateFactory() + const selectorsFactory = createSelectorsFactory() + const stateAdapter = sortComparer + ? createSortedStateAdapter(selectId, sortComparer) + : createUnsortedStateAdapter(selectId) + + return { + selectId, + sortComparer, + ...stateFactory, + ...selectorsFactory, + ...stateAdapter, + } } } + +export const createEntityAdapter = buildCreateEntityAdapter(immutableHelpers) diff --git a/packages/toolkit/src/entities/sorted_state_adapter.ts b/packages/toolkit/src/entities/sorted_state_adapter.ts index 91645d1af0..0e3b758986 100644 --- a/packages/toolkit/src/entities/sorted_state_adapter.ts +++ b/packages/toolkit/src/entities/sorted_state_adapter.ts @@ -4,165 +4,174 @@ import type { EntityStateAdapter, Update, EntityId, - DraftableEntityState + DraftableEntityState, } from './models' -import { createStateOperator } from './state_adapter' -import { createUnsortedStateAdapter } from './unsorted_state_adapter' +import type { BuildStateOperatorConfiguration } from './state_adapter' +import { buildCreateStateOperator } from './state_adapter' +import { buildCreateUnsortedStateAdapter } from './unsorted_state_adapter' import { selectIdValue, ensureEntitiesArray, splitAddedUpdatedEntities, } from './utils' -export function createSortedStateAdapter( - selectId: IdSelector, - sort: Comparer -): EntityStateAdapter { - type R = DraftableEntityState - - const { removeOne, removeMany, removeAll } = - createUnsortedStateAdapter(selectId) - - function addOneMutably(entity: T, state: R): void { - return addManyMutably([entity], state) - } +export function buildCreateSortedStateAdapter( + config: BuildStateOperatorConfiguration +) { + const createStateOperator = buildCreateStateOperator(config) + const createUnsortedStateAdapter = buildCreateUnsortedStateAdapter(config) + return function createSortedStateAdapter( + selectId: IdSelector, + sort: Comparer + ): EntityStateAdapter { + type R = DraftableEntityState + + const { removeOne, removeMany, removeAll } = + createUnsortedStateAdapter(selectId) + + function addOneMutably(entity: T, state: R): void { + return addManyMutably([entity], state) + } - function addManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) + function addManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) - const models = newEntities.filter( - (model) => !(selectIdValue(model, selectId) in state.entities) - ) + const models = newEntities.filter( + (model) => !(selectIdValue(model, selectId) in state.entities) + ) - if (models.length !== 0) { - merge(models, state) + if (models.length !== 0) { + merge(models, state) + } } - } - - function setOneMutably(entity: T, state: R): void { - return setManyMutably([entity], state) - } - function setManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) - if (newEntities.length !== 0) { - merge(newEntities, state) + function setOneMutably(entity: T, state: R): void { + return setManyMutably([entity], state) } - } - function setAllMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) - state.entities = {} as Record - state.ids = [] + function setManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) + if (newEntities.length !== 0) { + merge(newEntities, state) + } + } - addManyMutably(newEntities, state) - } + function setAllMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) + state.entities = {} as Record + state.ids = [] - function updateOneMutably(update: Update, state: R): void { - return updateManyMutably([update], state) - } + addManyMutably(newEntities, state) + } - function updateManyMutably( - updates: ReadonlyArray>, - state: R - ): void { - let appliedUpdates = false + function updateOneMutably(update: Update, state: R): void { + return updateManyMutably([update], state) + } - for (let update of updates) { - const entity: T | undefined = (state.entities as Record)[update.id] - if (!entity) { - continue + function updateManyMutably( + updates: ReadonlyArray>, + state: R + ): void { + let appliedUpdates = false + + for (let update of updates) { + const entity: T | undefined = (state.entities as Record)[ + update.id + ] + if (!entity) { + continue + } + + appliedUpdates = true + + Object.assign(entity, update.changes) + const newId = selectId(entity) + if (update.id !== newId) { + delete (state.entities as Record)[update.id] + ;(state.entities as Record)[newId] = entity + } } - appliedUpdates = true - - Object.assign(entity, update.changes) - const newId = selectId(entity) - if (update.id !== newId) { - delete (state.entities as Record)[update.id]; - (state.entities as Record)[newId] = entity + if (appliedUpdates) { + resortEntities(state) } } - if (appliedUpdates) { - resortEntities(state) + function upsertOneMutably(entity: T, state: R): void { + return upsertManyMutably([entity], state) } - } - - function upsertOneMutably(entity: T, state: R): void { - return upsertManyMutably([entity], state) - } - function upsertManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - const [added, updated] = splitAddedUpdatedEntities( - newEntities, - selectId, - state - ) - - updateManyMutably(updated, state) - addManyMutably(added, state) - } - - function areArraysEqual(a: readonly unknown[], b: readonly unknown[]) { - if (a.length !== b.length) { - return false + function upsertManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + const [added, updated] = splitAddedUpdatedEntities( + newEntities, + selectId, + state + ) + + updateManyMutably(updated, state) + addManyMutably(added, state) } - for (let i = 0; i < a.length && i < b.length; i++) { - if (a[i] === b[i]) { - continue + function areArraysEqual(a: readonly unknown[], b: readonly unknown[]) { + if (a.length !== b.length) { + return false + } + + for (let i = 0; i < a.length && i < b.length; i++) { + if (a[i] === b[i]) { + continue + } + return false } - return false + return true } - return true - } - function merge(models: readonly T[], state: R): void { - // Insert/overwrite all new/updated - models.forEach((model) => { - (state.entities as Record)[selectId(model)] = model - }) + function merge(models: readonly T[], state: R): void { + // Insert/overwrite all new/updated + models.forEach((model) => { + ;(state.entities as Record)[selectId(model)] = model + }) - resortEntities(state) - } + resortEntities(state) + } - function resortEntities(state: R) { - const allEntities = Object.values(state.entities) as T[] - allEntities.sort(sort) + function resortEntities(state: R) { + const allEntities = Object.values(state.entities) as T[] + allEntities.sort(sort) - const newSortedIds = allEntities.map(selectId) - const { ids } = state + const newSortedIds = allEntities.map(selectId) + const { ids } = state - if (!areArraysEqual(ids, newSortedIds)) { - state.ids = newSortedIds + if (!areArraysEqual(ids, newSortedIds)) { + state.ids = newSortedIds + } } - } - return { - removeOne, - removeMany, - removeAll, - addOne: createStateOperator(addOneMutably), - updateOne: createStateOperator(updateOneMutably), - upsertOne: createStateOperator(upsertOneMutably), - setOne: createStateOperator(setOneMutably), - setMany: createStateOperator(setManyMutably), - setAll: createStateOperator(setAllMutably), - addMany: createStateOperator(addManyMutably), - updateMany: createStateOperator(updateManyMutably), - upsertMany: createStateOperator(upsertManyMutably), + return { + removeOne, + removeMany, + removeAll, + addOne: createStateOperator(addOneMutably), + updateOne: createStateOperator(updateOneMutably), + upsertOne: createStateOperator(upsertOneMutably), + setOne: createStateOperator(setOneMutably), + setMany: createStateOperator(setManyMutably), + setAll: createStateOperator(setAllMutably), + addMany: createStateOperator(addManyMutably), + updateMany: createStateOperator(updateManyMutably), + upsertMany: createStateOperator(upsertManyMutably), + } } } diff --git a/packages/toolkit/src/entities/state_adapter.ts b/packages/toolkit/src/entities/state_adapter.ts index 93738d21c2..99ce3965ba 100644 --- a/packages/toolkit/src/entities/state_adapter.ts +++ b/packages/toolkit/src/entities/state_adapter.ts @@ -1,56 +1,70 @@ -import { produce as createNextState, isDraft } from 'immer' -import type { Draft } from 'immer' +import type { Draft } from 'immer' import type { EntityId, DraftableEntityState, PreventAny } from './models' import type { PayloadAction } from '../createAction' import { isFSA } from '../createAction' +import type { ImmutableHelpers } from '../tsHelpers' -export const isDraftTyped = isDraft as (value: T | Draft) => value is Draft +export type BuildStateOperatorConfiguration = Pick< + ImmutableHelpers, + 'isDraft' | 'createNextState' +> -export function createSingleArgumentStateOperator( - mutator: (state: DraftableEntityState) => void +export function buildCreateSingleArgumentStateOperator( + config: BuildStateOperatorConfiguration ) { - const operator = createStateOperator( - (_: undefined, state: DraftableEntityState) => mutator(state) - ) + const createStateOperator = buildCreateStateOperator(config) + return function createSingleArgumentStateOperator( + mutator: (state: DraftableEntityState) => void + ) { + const operator = createStateOperator( + (_: undefined, state: DraftableEntityState) => mutator(state) + ) - return function operation>( - state: PreventAny - ): S { - return operator(state as S, undefined) + return function operation>( + state: PreventAny + ): S { + return operator(state as S, undefined) + } } } -export function createStateOperator( - mutator: (arg: R, state: DraftableEntityState) => void -) { - return function operation>( - state: S, - arg: R | PayloadAction - ): S { - function isPayloadActionArgument( +export function buildCreateStateOperator({ + createNextState, + isDraft, +}: BuildStateOperatorConfiguration) { + const isDraftTyped = isDraft as (value: T | Draft) => value is Draft + return function createStateOperator( + mutator: (arg: R, state: DraftableEntityState) => void + ) { + return function operation>( + state: S, arg: R | PayloadAction - ): arg is PayloadAction { - return isFSA(arg) - } + ): S { + function isPayloadActionArgument( + arg: R | PayloadAction + ): arg is PayloadAction { + return isFSA(arg) + } - const runMutator = (draft: DraftableEntityState) => { - if (isPayloadActionArgument(arg)) { - mutator(arg.payload, draft) - } else { - mutator(arg, draft) + const runMutator = (draft: DraftableEntityState) => { + if (isPayloadActionArgument(arg)) { + mutator(arg.payload, draft) + } else { + mutator(arg, draft) + } } - } - if (isDraftTyped>(state)) { - // we must already be inside a `createNextState` call, likely because - // this is being wrapped in `createReducer` or `createSlice`. - // It's safe to just pass the draft to the mutator. - runMutator(state) + if (isDraftTyped(state)) { + // we must already be inside a `createNextState` call, likely because + // this is being wrapped in `createReducer` or `createSlice`. + // It's safe to just pass the draft to the mutator. + runMutator(state) - // since it's a draft, we'll just return it - return state + // since it's a draft, we'll just return it + return state + } + // @ts-ignore Type 'Draft' is not assignable to type 'DraftableEntityState'. + return createNextState(state, runMutator) } - - return createNextState(state, runMutator) } } diff --git a/packages/toolkit/src/entities/state_selectors.ts b/packages/toolkit/src/entities/state_selectors.ts index a3743d148a..d96e14d3cb 100644 --- a/packages/toolkit/src/entities/state_selectors.ts +++ b/packages/toolkit/src/entities/state_selectors.ts @@ -1,5 +1,7 @@ import type { CreateSelectorFunction, Selector } from 'reselect' -import { createDraftSafeSelector } from '../createDraftSafeSelector' +import { defaultMemoize } from 'reselect' +import type { BuildCreateDraftSafeSelectorConfiguration } from '../createDraftSafeSelector' +import { buildCreateDraftSafeSelectorCreator } from '../createDraftSafeSelector' import type { EntityState, EntitySelectors, EntityId } from './models' export type AnyCreateSelectorFunction = CreateSelectorFunction< @@ -10,64 +12,67 @@ export type AnyCreateSelectorFunction = CreateSelectorFunction< export interface GetSelectorsOptions { createSelector?: AnyCreateSelectorFunction } +export function buildCreateSelectorsFactory( + config: BuildCreateDraftSafeSelectorConfiguration +) { + const createDraftSafeSelector = + buildCreateDraftSafeSelectorCreator(config)(defaultMemoize) + return function createSelectorsFactory() { + function getSelectors( + selectState?: undefined, + options?: GetSelectorsOptions + ): EntitySelectors, Id> + function getSelectors( + selectState: (state: V) => EntityState, + options?: GetSelectorsOptions + ): EntitySelectors + function getSelectors( + selectState?: (state: V) => EntityState, + options: GetSelectorsOptions = {} + ): EntitySelectors { + const { createSelector = createDraftSafeSelector } = options + const selectIds = (state: EntityState) => state.ids + const selectEntities = (state: EntityState) => state.entities -export function createSelectorsFactory() { - function getSelectors( - selectState?: undefined, - options?: GetSelectorsOptions - ): EntitySelectors, Id> - function getSelectors( - selectState: (state: V) => EntityState, - options?: GetSelectorsOptions - ): EntitySelectors - function getSelectors( - selectState?: (state: V) => EntityState, - options: GetSelectorsOptions = {} - ): EntitySelectors { - const { createSelector = createDraftSafeSelector } = options - const selectIds = (state: EntityState) => state.ids + const selectAll = createSelector( + selectIds, + selectEntities, + (ids, entities): T[] => ids.map((id) => entities[id]!) + ) - const selectEntities = (state: EntityState) => state.entities + const selectId = (_: unknown, id: Id) => id - const selectAll = createSelector( - selectIds, - selectEntities, - (ids, entities): T[] => ids.map((id) => entities[id]!) - ) + const selectById = (entities: Record, id: Id) => entities[id] - const selectId = (_: unknown, id: Id) => id + const selectTotal = createSelector(selectIds, (ids) => ids.length) - const selectById = (entities: Record, id: Id) => entities[id] + if (!selectState) { + return { + selectIds, + selectEntities, + selectAll, + selectTotal, + selectById: createSelector(selectEntities, selectId, selectById), + } + } - const selectTotal = createSelector(selectIds, (ids) => ids.length) + const selectGlobalizedEntities = createSelector( + selectState as Selector>, + selectEntities + ) - if (!selectState) { return { - selectIds, - selectEntities, - selectAll, - selectTotal, - selectById: createSelector(selectEntities, selectId, selectById), + selectIds: createSelector(selectState, selectIds), + selectEntities: selectGlobalizedEntities, + selectAll: createSelector(selectState, selectAll), + selectTotal: createSelector(selectState, selectTotal), + selectById: createSelector( + selectGlobalizedEntities, + selectId, + selectById + ), } } - - const selectGlobalizedEntities = createSelector( - selectState as Selector>, - selectEntities - ) - - return { - selectIds: createSelector(selectState, selectIds), - selectEntities: selectGlobalizedEntities, - selectAll: createSelector(selectState, selectAll), - selectTotal: createSelector(selectState, selectTotal), - selectById: createSelector( - selectGlobalizedEntities, - selectId, - selectById - ), - } + return { getSelectors } } - - return { getSelectors } } diff --git a/packages/toolkit/src/entities/unsorted_state_adapter.ts b/packages/toolkit/src/entities/unsorted_state_adapter.ts index 1b74a01479..096cf611c3 100644 --- a/packages/toolkit/src/entities/unsorted_state_adapter.ts +++ b/packages/toolkit/src/entities/unsorted_state_adapter.ts @@ -4,11 +4,12 @@ import type { IdSelector, Update, EntityId, - DraftableEntityState + DraftableEntityState, } from './models' +import type { BuildStateOperatorConfiguration } from './state_adapter' import { - createStateOperator, - createSingleArgumentStateOperator, + buildCreateStateOperator, + buildCreateSingleArgumentStateOperator, } from './state_adapter' import { selectIdValue, @@ -16,189 +17,200 @@ import { splitAddedUpdatedEntities, } from './utils' -export function createUnsortedStateAdapter( - selectId: IdSelector -): EntityStateAdapter { - type R = DraftableEntityState +export function buildCreateUnsortedStateAdapter( + config: BuildStateOperatorConfiguration +) { + const createSingleArgumentStateOperator = + buildCreateSingleArgumentStateOperator(config) + const createStateOperator = buildCreateStateOperator(config) + return function createUnsortedStateAdapter( + selectId: IdSelector + ): EntityStateAdapter { + type R = DraftableEntityState - function addOneMutably(entity: T, state: R): void { - const key = selectIdValue(entity, selectId) + function addOneMutably(entity: T, state: R): void { + const key = selectIdValue(entity, selectId) - if (key in state.entities) { - return - } + if (key in state.entities) { + return + } - state.ids.push(key as Id & Draft); - (state.entities as Record)[key] = entity - } + state.ids.push(key as Id & Draft) + ;(state.entities as Record)[key] = entity + } - function addManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) + function addManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) - for (const entity of newEntities) { - addOneMutably(entity, state) + for (const entity of newEntities) { + addOneMutably(entity, state) + } } - } - function setOneMutably(entity: T, state: R): void { - const key = selectIdValue(entity, selectId) - if (!(key in state.entities)) { - state.ids.push(key as Id & Draft); + function setOneMutably(entity: T, state: R): void { + const key = selectIdValue(entity, selectId) + if (!(key in state.entities)) { + state.ids.push(key as Id & Draft) + } + ;(state.entities as Record)[key] = entity } - (state.entities as Record)[key] = entity - } - function setManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) - for (const entity of newEntities) { - setOneMutably(entity, state) + function setManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) + for (const entity of newEntities) { + setOneMutably(entity, state) + } } - } - function setAllMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) + function setAllMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) - state.ids = [] - state.entities = {} as Record + state.ids = [] + state.entities = {} as Record - addManyMutably(newEntities, state) - } + addManyMutably(newEntities, state) + } - function removeOneMutably(key: Id, state: R): void { - return removeManyMutably([key], state) - } + function removeOneMutably(key: Id, state: R): void { + return removeManyMutably([key], state) + } - function removeManyMutably(keys: readonly Id[], state: R): void { - let didMutate = false + function removeManyMutably(keys: readonly Id[], state: R): void { + let didMutate = false - keys.forEach((key) => { - if (key in state.entities) { - delete (state.entities as Record)[key] - didMutate = true - } - }) + keys.forEach((key) => { + if (key in state.entities) { + delete (state.entities as Record)[key] + didMutate = true + } + }) - if (didMutate) { - state.ids = (state.ids as Id[]).filter((id) => id in state.entities) as Id[] | Draft + if (didMutate) { + state.ids = (state.ids as Id[]).filter((id) => id in state.entities) as + | Id[] + | Draft + } } - } - - function removeAllMutably(state: R): void { - Object.assign(state, { - ids: [], - entities: {}, - }) - } - function takeNewKey( - keys: { [id: string]: Id }, - update: Update, - state: R - ): boolean { - const original: T | undefined = (state.entities as Record)[update.id] - if (original === undefined) { - return false + function removeAllMutably(state: R): void { + Object.assign(state, { + ids: [], + entities: {}, + }) } - const updated: T = Object.assign({}, original, update.changes) - const newKey = selectIdValue(updated, selectId) - const hasNewKey = newKey !== update.id - if (hasNewKey) { - keys[update.id] = newKey - delete (state.entities as Record)[update.id] - } + function takeNewKey( + keys: { [id: string]: Id }, + update: Update, + state: R + ): boolean { + const original: T | undefined = (state.entities as Record)[ + update.id + ] + if (original === undefined) { + return false + } + const updated: T = Object.assign({}, original, update.changes) + const newKey = selectIdValue(updated, selectId) + const hasNewKey = newKey !== update.id - (state.entities as Record)[newKey] = updated + if (hasNewKey) { + keys[update.id] = newKey + delete (state.entities as Record)[update.id] + } - return hasNewKey - } + ;(state.entities as Record)[newKey] = updated - function updateOneMutably(update: Update, state: R): void { - return updateManyMutably([update], state) - } + return hasNewKey + } - function updateManyMutably( - updates: ReadonlyArray>, - state: R - ): void { - const newKeys: { [id: string]: Id } = {} - - const updatesPerEntity: { [id: string]: Update } = {} - - updates.forEach((update) => { - // Only apply updates to entities that currently exist - if (update.id in state.entities) { - // If there are multiple updates to one entity, merge them together - updatesPerEntity[update.id] = { - id: update.id, - // Spreads ignore falsy values, so this works even if there isn't - // an existing update already at this key - changes: { - ...(updatesPerEntity[update.id] - ? updatesPerEntity[update.id].changes - : null), - ...update.changes, - }, + function updateOneMutably(update: Update, state: R): void { + return updateManyMutably([update], state) + } + + function updateManyMutably( + updates: ReadonlyArray>, + state: R + ): void { + const newKeys: { [id: string]: Id } = {} + + const updatesPerEntity: { [id: string]: Update } = {} + + updates.forEach((update) => { + // Only apply updates to entities that currently exist + if (update.id in state.entities) { + // If there are multiple updates to one entity, merge them together + updatesPerEntity[update.id] = { + id: update.id, + // Spreads ignore falsy values, so this works even if there isn't + // an existing update already at this key + changes: { + ...(updatesPerEntity[update.id] + ? updatesPerEntity[update.id].changes + : null), + ...update.changes, + }, + } } - } - }) + }) - updates = Object.values(updatesPerEntity) + updates = Object.values(updatesPerEntity) - const didMutateEntities = updates.length > 0 + const didMutateEntities = updates.length > 0 - if (didMutateEntities) { - const didMutateIds = - updates.filter((update) => takeNewKey(newKeys, update, state)).length > - 0 + if (didMutateEntities) { + const didMutateIds = + updates.filter((update) => takeNewKey(newKeys, update, state)) + .length > 0 - if (didMutateIds) { - state.ids = Object.values(state.entities).map((e) => - selectIdValue(e as T, selectId) - ) + if (didMutateIds) { + state.ids = Object.values(state.entities).map((e) => + selectIdValue(e as T, selectId) + ) + } } } - } - function upsertOneMutably(entity: T, state: R): void { - return upsertManyMutably([entity], state) - } + function upsertOneMutably(entity: T, state: R): void { + return upsertManyMutably([entity], state) + } - function upsertManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - const [added, updated] = splitAddedUpdatedEntities( - newEntities, - selectId, - state - ) - - updateManyMutably(updated, state) - addManyMutably(added, state) - } + function upsertManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + const [added, updated] = splitAddedUpdatedEntities( + newEntities, + selectId, + state + ) + + updateManyMutably(updated, state) + addManyMutably(added, state) + } - return { - removeAll: createSingleArgumentStateOperator(removeAllMutably), - addOne: createStateOperator(addOneMutably), - addMany: createStateOperator(addManyMutably), - setOne: createStateOperator(setOneMutably), - setMany: createStateOperator(setManyMutably), - setAll: createStateOperator(setAllMutably), - updateOne: createStateOperator(updateOneMutably), - updateMany: createStateOperator(updateManyMutably), - upsertOne: createStateOperator(upsertOneMutably), - upsertMany: createStateOperator(upsertManyMutably), - removeOne: createStateOperator(removeOneMutably), - removeMany: createStateOperator(removeManyMutably), + return { + removeAll: createSingleArgumentStateOperator(removeAllMutably), + addOne: createStateOperator(addOneMutably), + addMany: createStateOperator(addManyMutably), + setOne: createStateOperator(setOneMutably), + setMany: createStateOperator(setManyMutably), + setAll: createStateOperator(setAllMutably), + updateOne: createStateOperator(updateOneMutably), + updateMany: createStateOperator(updateManyMutably), + upsertOne: createStateOperator(upsertOneMutably), + upsertMany: createStateOperator(upsertManyMutably), + removeOne: createStateOperator(removeOneMutably), + removeMany: createStateOperator(removeManyMutably), + } } } diff --git a/packages/toolkit/src/immer.ts b/packages/toolkit/src/immer.ts new file mode 100644 index 0000000000..81a87da99a --- /dev/null +++ b/packages/toolkit/src/immer.ts @@ -0,0 +1,22 @@ +import { + applyPatches, + current, + freeze, + isDraft, + isDraftable, + original, + produce, + produceWithPatches, +} from 'immer' +import { defineImmutableHelpers } from './tsHelpers' + +export const immutableHelpers = defineImmutableHelpers({ + createNextState: produce, + createWithPatches: produceWithPatches, + applyPatches, + isDraft, + isDraftable, + original, + current, + freeze, +}) diff --git a/packages/toolkit/src/index.ts b/packages/toolkit/src/index.ts index 8ee73148de..58525383e2 100644 --- a/packages/toolkit/src/index.ts +++ b/packages/toolkit/src/index.ts @@ -24,9 +24,11 @@ export type { OutputSelector, ParametricSelector, } from 'reselect' +export type { BuildCreateDraftSafeSelectorConfiguration } from './createDraftSafeSelector' export { - createDraftSafeSelector, + buildCreateDraftSafeSelectorCreator, createDraftSafeSelectorCreator, + createDraftSafeSelector, } from './createDraftSafeSelector' export type { ThunkAction, ThunkDispatch, ThunkMiddleware } from 'redux-thunk' @@ -61,21 +63,27 @@ export type { export { // js createReducer, + buildCreateReducer, } from './createReducer' export type { // types Actions, CaseReducer, CaseReducers, + CreateReducer, + BuildCreateReducerConfiguration, } from './createReducer' export { // js createSlice, + buildCreateSlice, ReducerType, } from './createSlice' export type { // types + BuildCreateSliceConfiguration, + CreateSlice, CreateSliceOptions, Slice, CaseReducerActions, @@ -112,7 +120,10 @@ export type { } from './mapBuilders' export { Tuple } from './utils' -export { createEntityAdapter } from './entities/create_adapter' +export { + buildCreateEntityAdapter, + createEntityAdapter, +} from './entities/create_adapter' export type { EntityState, EntityAdapter, @@ -206,8 +217,11 @@ export { } from './autoBatchEnhancer' export type { AutoBatchOptions } from './autoBatchEnhancer' -export { combineSlices } from './combineSlices' +export type { ImmutableHelpers } from './tsHelpers' +export { defineImmutableHelpers } from './tsHelpers' +export { immutableHelpers as immerImmutableHelpers } from './immer' +export { combineSlices } from './combineSlices' export type { WithSlice } from './combineSlices' export type { ExtractDispatchExtensions as TSHelpersExtractDispatchExtensions } from './tsHelpers' diff --git a/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts b/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts index 0bd2a73515..29477dbcfb 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts @@ -1,11 +1,15 @@ import type { InternalHandlerBuilder } from './types' import type { SubscriptionState } from '../apiState' -import { produceWithPatches } from 'immer' import type { Action } from '@reduxjs/toolkit' export const buildBatchedActionsHandler: InternalHandlerBuilder< [actionShouldContinue: boolean, subscriptionExists: boolean] -> = ({ api, queryThunk, internalState }) => { +> = ({ + api, + queryThunk, + internalState, + immutableHelpers: { createWithPatches }, +}) => { const subscriptionsPrefix = `${api.reducerPath}/subscriptions` let previousSubscriptions: SubscriptionState = @@ -108,7 +112,7 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< JSON.stringify(internalState.currentSubscriptions) ) // Figure out a smaller diff between original and current - const [, patches] = produceWithPatches( + const [, patches] = createWithPatches( previousSubscriptions, () => newSubscriptions ) diff --git a/packages/toolkit/src/query/core/buildMiddleware/index.ts b/packages/toolkit/src/query/core/buildMiddleware/index.ts index aa58617d6f..3ac6b8822b 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/index.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/index.ts @@ -17,6 +17,7 @@ import { buildInvalidationByTagsHandler } from './invalidationByTags' import { buildPollingHandler } from './polling' import type { BuildMiddlewareInput, + BuildSubMiddlewareInput, InternalHandlerBuilder, InternalMiddlewareState, } from './types' @@ -63,7 +64,7 @@ export function buildMiddleware< currentSubscriptions: {}, } - const builderArgs = { + const builderArgs: BuildSubMiddlewareInput = { ...(input as any as BuildMiddlewareInput< EndpointDefinitions, string, diff --git a/packages/toolkit/src/query/core/buildMiddleware/types.ts b/packages/toolkit/src/query/core/buildMiddleware/types.ts index c7e4e52e02..e6488659a3 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/types.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/types.ts @@ -4,6 +4,7 @@ import type { Middleware, MiddlewareAPI, ThunkDispatch, + ImmutableHelpers, UnknownAction, } from '@reduxjs/toolkit' @@ -43,6 +44,7 @@ export interface BuildMiddlewareInput< mutationThunk: MutationThunk api: Api assertTagType: AssertTagTypes + immutableHelpers: Pick } export type SubMiddlewareApi = MiddlewareAPI< diff --git a/packages/toolkit/src/query/core/buildSelectors.ts b/packages/toolkit/src/query/core/buildSelectors.ts index efd48b8170..6ef41356e5 100644 --- a/packages/toolkit/src/query/core/buildSelectors.ts +++ b/packages/toolkit/src/query/core/buildSelectors.ts @@ -1,4 +1,5 @@ -import { createNextState, createSelector } from './rtkImports' +import type { ImmutableHelpers } from '@reduxjs/toolkit' +import { createSelector } from './rtkImports' import type { MutationSubState, QuerySubState, @@ -103,30 +104,28 @@ export type MutationResultSelectorResult< Definition extends MutationDefinition > = MutationSubState & RequestStatusFlags -const initialSubState: QuerySubState = { - status: QueryStatus.uninitialized as const, -} - -// abuse immer to freeze default states -const defaultQuerySubState = /* @__PURE__ */ createNextState( - initialSubState, - () => {} -) -const defaultMutationSubState = /* @__PURE__ */ createNextState( - initialSubState as MutationSubState, - () => {} -) - export function buildSelectors< Definitions extends EndpointDefinitions, ReducerPath extends string >({ serializeQueryArgs, reducerPath, + immutableHelpers: { freeze }, }: { serializeQueryArgs: InternalSerializeQueryArgs reducerPath: ReducerPath + immutableHelpers: Pick }) { + const initialSubState: QuerySubState = { + status: QueryStatus.uninitialized as const, + } + + const defaultQuerySubState = freeze(initialSubState, true) + const defaultMutationSubState = freeze( + initialSubState as MutationSubState, + true + ) + type RootState = _RootState const selectSkippedQuery = (state: RootState) => defaultQuerySubState diff --git a/packages/toolkit/src/query/core/buildSlice.ts b/packages/toolkit/src/query/core/buildSlice.ts index dc10799940..047fe59fcc 100644 --- a/packages/toolkit/src/query/core/buildSlice.ts +++ b/packages/toolkit/src/query/core/buildSlice.ts @@ -1,13 +1,16 @@ -import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit' +import type { + PayloadAction, + BuildCreateSliceConfiguration, + ImmutableHelpers, +} from '@reduxjs/toolkit' import { combineReducers, createAction, - createSlice, isAnyOf, isFulfilled, isRejectedWithValue, - createNextState, prepareAutoBatched, + buildCreateSlice, } from './rtkImports' import type { QuerySubstateIdentifier, @@ -32,8 +35,6 @@ import type { QueryDefinition, } from '../endpointDefinitions' import type { Patch } from 'immer' -import { isDraft } from 'immer' -import { applyPatches, original } from 'immer' import { onFocus, onFocusLost, onOffline, onOnline } from './setupListeners' import { isDocumentVisible, @@ -100,6 +101,8 @@ export function buildSlice({ }, assertTagType, config, + immutableHelpers, + immutableHelpers: { applyPatches, createNextState, isDraft, original }, }: { reducerPath: string queryThunk: QueryThunk @@ -110,7 +113,11 @@ export function buildSlice({ ConfigState, 'online' | 'focused' | 'middlewareRegistered' > + immutableHelpers: Pick & + BuildCreateSliceConfiguration }) { + const createSlice = buildCreateSlice(immutableHelpers) + const resetApiState = createAction(`${reducerPath}/resetApiState`) const querySlice = createSlice({ name: `${reducerPath}/queries`, diff --git a/packages/toolkit/src/query/core/buildThunks.ts b/packages/toolkit/src/query/core/buildThunks.ts index 70627228c4..8879937bcc 100644 --- a/packages/toolkit/src/query/core/buildThunks.ts +++ b/packages/toolkit/src/query/core/buildThunks.ts @@ -27,7 +27,11 @@ import { calculateProvidedBy } from '../endpointDefinitions' import type { AsyncThunkPayloadCreator, Draft, + ImmutableHelpers, UnknownAction, + ThunkAction, + ThunkDispatch, + AsyncThunk, } from '@reduxjs/toolkit' import { isAllOf, @@ -39,8 +43,6 @@ import { SHOULD_AUTOBATCH, } from './rtkImports' import type { Patch } from 'immer' -import { isDraftable, produceWithPatches } from 'immer' -import type { ThunkAction, ThunkDispatch, AsyncThunk } from '@reduxjs/toolkit' import { HandledError } from '../HandledError' @@ -225,6 +227,7 @@ export function buildThunks< context: { endpointDefinitions }, serializeQueryArgs, api, + immutableHelpers: { isDraftable, createWithPatches }, assertTagType, }: { baseQuery: BaseQuery @@ -232,6 +235,7 @@ export function buildThunks< context: ApiContext serializeQueryArgs: InternalSerializeQueryArgs api: Api + immutableHelpers: Pick assertTagType: AssertTagTypes }) { type State = RootState @@ -302,7 +306,7 @@ export function buildThunks< let newValue if ('data' in currentState) { if (isDraftable(currentState.data)) { - const [value, patches, inversePatches] = produceWithPatches( + const [value, patches, inversePatches] = createWithPatches( currentState.data, updateRecipe ) diff --git a/packages/toolkit/src/query/core/module.ts b/packages/toolkit/src/query/core/module.ts index 59b2514070..50bf08dc88 100644 --- a/packages/toolkit/src/query/core/module.ts +++ b/packages/toolkit/src/query/core/module.ts @@ -9,12 +9,15 @@ import type { import { buildThunks } from './buildThunks' import type { ActionCreatorWithPayload, + BuildCreateSliceConfiguration, Middleware, Reducer, ThunkAction, ThunkDispatch, + ImmutableHelpers, UnknownAction, } from '@reduxjs/toolkit' +import { immerImmutableHelpers } from './rtkImports' import type { EndpointDefinitions, QueryArgFrom, @@ -431,6 +434,18 @@ export type ListenerActions = { export type InternalActions = SliceActions & ListenerActions +interface CoreModuleOptions { + immutableHelpers?: BuildCreateSliceConfiguration & + Pick< + ImmutableHelpers, + | 'createWithPatches' + | 'applyPatches' + | 'isDraftable' + | 'freeze' + | 'original' + > +} + /** * Creates a module containing the basic redux logic for use with `buildCreateApi`. * @@ -439,7 +454,9 @@ export type InternalActions = SliceActions & ListenerActions * const createBaseApi = buildCreateApi(coreModule()); * ``` */ -export const coreModule = (): Module => ({ +export const coreModule = ({ + immutableHelpers = immerImmutableHelpers, +}: CoreModuleOptions = {}): Module => ({ name: coreModuleName, init( api, @@ -499,6 +516,7 @@ export const coreModule = (): Module => ({ context, api, serializeQueryArgs, + immutableHelpers, assertTagType, }) @@ -515,6 +533,7 @@ export const coreModule = (): Module => ({ keepUnusedDataFor, reducerPath, }, + immutableHelpers, }) safeAssign(api.util, { @@ -533,6 +552,7 @@ export const coreModule = (): Module => ({ mutationThunk, api, assertTagType, + immutableHelpers, }) safeAssign(api.util, middlewareActions) @@ -546,6 +566,7 @@ export const coreModule = (): Module => ({ } = buildSelectors({ serializeQueryArgs: serializeQueryArgs as any, reducerPath, + immutableHelpers, }) safeAssign(api.util, { selectInvalidatedBy, selectCachedArgsForQuery }) diff --git a/packages/toolkit/src/query/core/rtkImports.ts b/packages/toolkit/src/query/core/rtkImports.ts index 4ba180bab4..74f4bcdb11 100644 --- a/packages/toolkit/src/query/core/rtkImports.ts +++ b/packages/toolkit/src/query/core/rtkImports.ts @@ -4,7 +4,6 @@ export { createAction, - createSlice, createSelector, createAsyncThunk, combineReducers, @@ -21,4 +20,6 @@ export { SHOULD_AUTOBATCH, isPlainObject, nanoid, + buildCreateSlice, + immerImmutableHelpers, } from '@reduxjs/toolkit' diff --git a/packages/toolkit/src/tsHelpers.ts b/packages/toolkit/src/tsHelpers.ts index f0ed92a6e1..eae917b5b9 100644 --- a/packages/toolkit/src/tsHelpers.ts +++ b/packages/toolkit/src/tsHelpers.ts @@ -1,4 +1,5 @@ import type { Middleware, StoreEnhancer } from 'redux' +import type { Draft, Patch, applyPatches } from 'immer' import type { Tuple } from './utils' export function safeAssign( @@ -8,6 +9,57 @@ export function safeAssign( Object.assign(target, ...args) } +export interface ImmutableHelpers { + /** + * Function that receives a base object, and a recipe which is called with a draft that the recipe is allowed to mutate. + * The recipe can return a new state which will replace the existing state, or it can not return (in which case the existing draft is used) + * Returns an immutably modified version of the input object. + */ + createNextState: ( + base: Base, + recipe: (draft: Draft) => void | Base | Draft + ) => Base + /** + * Function that receives a base object, and a recipe which is called with a draft that the recipe is allowed to mutate. + * The recipe can return a new state which will replace the existing state, or it can not return (in which case the existing draft is used) + * Returns a tuple of an immutably modified version of the input object, an array of patches describing the changes made, and an array of inverse patches. + */ + createWithPatches: ( + base: Base, + recipe: (draft: Draft) => void | Base | Draft + ) => readonly [Base, Patch[], Patch[]] + /** + * Receives a base object and an array of patches describing changes to apply. + * Returns an immutably modified version of the base object with changes applied. + */ + applyPatches: typeof applyPatches + /** + * Indicates whether the value passed is a draft, meaning it's safe to mutate. + */ + isDraft(value: any): boolean + /** + * Indicates whether the value passed is possible to turn into a mutable draft. + */ + isDraftable(value: any): boolean + /** + * Receives a draft and returns its base object. + */ + original(value: T): T | undefined + /** + * Receives a draft and returns an object with any changes to date immutably applied. + */ + current(value: T): T + /** + * Receives an object and freezes it, causing runtime errors if mutation is attempted after. + */ + freeze(obj: T, deep?: boolean): T +} + +/** + * Define a config object indicating utilities for RTK packages to use for immutable operations. + */ +export const defineImmutableHelpers = (helpers: ImmutableHelpers) => helpers + /** * return True if T is `any`, otherwise return False * taken from https://github.com/joonhocho/tsdef diff --git a/packages/toolkit/src/utils.ts b/packages/toolkit/src/utils.ts index 2e29b2bc9e..e8933f7a59 100644 --- a/packages/toolkit/src/utils.ts +++ b/packages/toolkit/src/utils.ts @@ -1,4 +1,3 @@ -import { produce as createNextState, isDraftable } from 'immer' import type { Middleware, StoreEnhancer } from 'redux' export function getTimeMeasureUtils(maxDelay: number, fnName: string) { @@ -83,7 +82,3 @@ export class Tuple = []> extends Array< return new Tuple(...arr.concat(this)) } } - -export function freezeDraftable(val: T) { - return isDraftable(val) ? createNextState(val, () => {}) : val -}