|
| 1 | +"use client"; |
| 2 | + |
| 3 | +import type { StreamQuery, StreamQueryArgs } from "./types.js"; |
| 4 | +import type { FunctionArgs } from "convex/server"; |
| 5 | +import type { StreamArgs, StreamDelta, StreamMessage } from "../validators.js"; |
| 6 | +import type { SyncStreamsReturnValue } from "@convex-dev/agent"; |
| 7 | +import { useQuery } from "convex/react"; |
| 8 | +import { useState } from "react"; |
| 9 | +import { assert } from "convex-helpers"; |
| 10 | + |
| 11 | +export function useDeltaStreams< |
| 12 | + // eslint-disable-next-line @typescript-eslint/no-explicit-any |
| 13 | + Query extends StreamQuery<any> = StreamQuery<object>, |
| 14 | +>( |
| 15 | + query: Query, |
| 16 | + args: StreamQueryArgs<Query> | "skip", |
| 17 | + options?: { |
| 18 | + startOrder?: number; |
| 19 | + skipStreamIds?: string[]; |
| 20 | + }, |
| 21 | +): { streamMessage: StreamMessage; deltas: StreamDelta[] }[] | undefined { |
| 22 | + // We hold onto and modify state directly to avoid re-running unnecessarily. |
| 23 | + const [state] = useState<{ |
| 24 | + startOrder: number; |
| 25 | + threadId: string | undefined; |
| 26 | + deltaStreams: |
| 27 | + | Array<{ |
| 28 | + streamMessage: StreamMessage; |
| 29 | + deltas: StreamDelta[]; |
| 30 | + }> |
| 31 | + | undefined; |
| 32 | + }>({ |
| 33 | + startOrder: options?.startOrder ?? 0, |
| 34 | + deltaStreams: undefined, |
| 35 | + threadId: args === "skip" ? undefined : args.threadId, |
| 36 | + }); |
| 37 | + const [cursors, setCursors] = useState<Record<string, number>>({}); |
| 38 | + if (args !== "skip" && state.threadId !== args.threadId) { |
| 39 | + state.threadId = args.threadId; |
| 40 | + state.deltaStreams = undefined; |
| 41 | + state.startOrder = options?.startOrder ?? 0; |
| 42 | + setCursors({}); |
| 43 | + } |
| 44 | + if ( |
| 45 | + state.deltaStreams?.length || |
| 46 | + (options?.startOrder && options.startOrder < state.startOrder) |
| 47 | + ) { |
| 48 | + const cacheFriendlyStartOrder = options?.startOrder |
| 49 | + ? // round down to the nearest 10 for some cache benefits |
| 50 | + options.startOrder - (options.startOrder % 10) |
| 51 | + : 0; |
| 52 | + if (cacheFriendlyStartOrder !== state.startOrder) { |
| 53 | + state.startOrder = cacheFriendlyStartOrder; |
| 54 | + } |
| 55 | + } |
| 56 | + |
| 57 | + // Get all the active streams |
| 58 | + const streamList = useQuery( |
| 59 | + query, |
| 60 | + args === "skip" |
| 61 | + ? args |
| 62 | + : ({ |
| 63 | + ...args, |
| 64 | + streamArgs: { |
| 65 | + kind: "list", |
| 66 | + startOrder: state.startOrder, |
| 67 | + } as StreamArgs, |
| 68 | + } as FunctionArgs<Query>), |
| 69 | + ) as |
| 70 | + | { streams: Extract<SyncStreamsReturnValue, { kind: "list" }> } |
| 71 | + | undefined; |
| 72 | + |
| 73 | + const streamMessages = |
| 74 | + args === "skip" |
| 75 | + ? undefined |
| 76 | + : !streamList |
| 77 | + ? state.deltaStreams?.map(({ streamMessage }) => streamMessage) |
| 78 | + : streamList.streams.messages.filter( |
| 79 | + ({ streamId, order }) => |
| 80 | + !options?.skipStreamIds?.includes(streamId) && |
| 81 | + (!options?.startOrder || order >= options.startOrder), |
| 82 | + ); |
| 83 | + |
| 84 | + // Get the deltas for all the active streams, if any. |
| 85 | + const cursorQuery = useQuery( |
| 86 | + query, |
| 87 | + args === "skip" || !streamMessages?.length |
| 88 | + ? ("skip" as const) |
| 89 | + : ({ |
| 90 | + ...args, |
| 91 | + streamArgs: { |
| 92 | + kind: "deltas", |
| 93 | + cursors: streamMessages.map(({ streamId }) => ({ |
| 94 | + streamId, |
| 95 | + cursor: cursors[streamId] ?? 0, |
| 96 | + })), |
| 97 | + } as StreamArgs, |
| 98 | + } as FunctionArgs<Query>), |
| 99 | + ) as |
| 100 | + | { streams: Extract<SyncStreamsReturnValue, { kind: "deltas" }> } |
| 101 | + | undefined; |
| 102 | + |
| 103 | + const newDeltas = cursorQuery?.streams.deltas; |
| 104 | + if (newDeltas?.length && streamMessages) { |
| 105 | + const newDeltasByStreamId = new Map<string, StreamDelta[]>(); |
| 106 | + for (const delta of newDeltas) { |
| 107 | + const oldCursor = cursors[delta.streamId]; |
| 108 | + if (oldCursor && delta.start < oldCursor) continue; |
| 109 | + const existing = newDeltasByStreamId.get(delta.streamId); |
| 110 | + if (existing) { |
| 111 | + const previousEnd = existing.at(-1)!.end; |
| 112 | + assert( |
| 113 | + previousEnd === delta.start, |
| 114 | + `Gap found in deltas for ${delta.streamId} jumping to ${delta.start} from ${previousEnd}`, |
| 115 | + ); |
| 116 | + existing.push(delta); |
| 117 | + } else { |
| 118 | + assert( |
| 119 | + !oldCursor || oldCursor === delta.start, |
| 120 | + `Gap found - first delta after ${oldCursor} is ${delta.start} for stream ${delta.streamId}`, |
| 121 | + ); |
| 122 | + newDeltasByStreamId.set(delta.streamId, [delta]); |
| 123 | + } |
| 124 | + } |
| 125 | + const newCursors: Record<string, number> = {}; |
| 126 | + for (const { streamId } of streamMessages) { |
| 127 | + const cursor = |
| 128 | + newDeltasByStreamId.get(streamId)?.at(-1)?.end ?? cursors[streamId]; |
| 129 | + if (cursor !== undefined) { |
| 130 | + newCursors[streamId] = cursor; |
| 131 | + } |
| 132 | + } |
| 133 | + setCursors(newCursors); |
| 134 | + |
| 135 | + // we defensively create a new object so object identity matches contents |
| 136 | + state.deltaStreams = streamMessages.map((streamMessage) => { |
| 137 | + const streamId = streamMessage.streamId; |
| 138 | + const old = state.deltaStreams?.find( |
| 139 | + (ds) => ds.streamMessage.streamId === streamId, |
| 140 | + ); |
| 141 | + const newDeltas = newDeltasByStreamId.get(streamId); |
| 142 | + if (!newDeltas && streamMessage === old?.streamMessage) { |
| 143 | + return old; |
| 144 | + } |
| 145 | + return { |
| 146 | + streamMessage, |
| 147 | + deltas: [...(old?.deltas ?? []), ...(newDeltas ?? [])], |
| 148 | + }; |
| 149 | + }); |
| 150 | + } |
| 151 | + return state.deltaStreams; |
| 152 | +} |
0 commit comments