88import React , { createContext , useCallback , useContext , useEffect , useState } from "react" ;
99import useWebSocket , { ReadyState } from "react-use-websocket" ;
1010import { useLocation } from "react-router-dom" ;
11- import { BotMessage , DebugMessage , AiModel , StateMessage } from "../../types/api/entities/bot" ;
11+ import { BotMessage , DebugMessage , AiModel , StateMessage , StreamMessage } from "../../types/api/entities/bot" ;
1212import { getChatsWithWholeThreads } from "../../api/bot/getChatsWithWholeThreads" ;
1313import { getChats } from "api/bot/getChats" ;
1414import { useAlertSnackbar } from "@postgres.ai/shared/components/AlertSnackbar/useAlertSnackbar" ;
@@ -58,6 +58,8 @@ type UseAiBotReturnType = {
5858 aiModelsLoading : UseAiModelsList [ 'loading' ] ;
5959 debugMessagesLoading : boolean ;
6060 stateMessage : StateMessage | null ;
61+ isStreamingInProcess : boolean
62+ currentStreamMessage : StreamMessage | null
6163}
6264
6365type UseAiBotArgs = {
@@ -90,6 +92,8 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
9092 const [ wsLoading , setWsLoading ] = useState < boolean > ( false ) ;
9193 const [ chatVisibility , setChatVisibility ] = useState < UseAiBotReturnType [ 'chatVisibility' ] > ( 'public' ) ;
9294 const [ stateMessage , setStateMessage ] = useState < StateMessage | null > ( null )
95+ const [ currentStreamMessage , setCurrentStreamMessage ] = useState < StreamMessage | null > ( null )
96+ const [ isStreamingInProcess , setStreamingInProcess ] = useState < boolean > ( false )
9397
9498 const [ isChangeVisibilityLoading , setIsChangeVisibilityLoading ] = useState < boolean > ( false ) ;
9599
@@ -102,51 +106,35 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
102106
103107 const onWebSocketMessage = ( event : WebSocketEventMap [ 'message' ] ) => {
104108 if ( event . data ) {
105- const messageData : BotMessage | DebugMessage | StateMessage = JSON . parse ( event . data ) ;
109+ const messageData : BotMessage | DebugMessage | StateMessage | StreamMessage = JSON . parse ( event . data ) ;
106110 if ( messageData ) {
107111 const isThreadMatching = threadId && threadId === messageData . thread_id ;
108112 const isParentMatching = ! threadId && 'parent_id' in messageData && messageData . parent_id && messages ;
109113 const isDebugMessage = messageData . type === 'debug' ;
110114 const isStateMessage = messageData . type === 'state' ;
111- if ( isThreadMatching || isParentMatching || isDebugMessage || isStateMessage ) {
112- if ( isDebugMessage ) {
113- let currentDebugMessages = [ ...( debugMessages || [ ] ) ] ;
114- currentDebugMessages . push ( messageData )
115- setDebugMessages ( currentDebugMessages )
116- } else if ( isStateMessage ) {
117- if ( isThreadMatching || ! threadId ) {
118- if ( messageData . state ) {
119- setStateMessage ( messageData )
120- } else {
121- setStateMessage ( null )
122- }
123- }
124- } else {
125- // Check if the last message needs its data updated
126- let currentMessages = [ ...( messages || [ ] ) ] ;
127- const lastMessage = currentMessages [ currentMessages . length - 1 ] ;
128- if ( lastMessage && ! lastMessage . id && messageData . parent_id ) {
129- lastMessage . id = messageData . parent_id ;
130- lastMessage . created_at = messageData . created_at ;
131- lastMessage . is_public = messageData . is_public ;
132- }
133-
134- currentMessages . push ( messageData ) ;
135- setMessages ( currentMessages ) ;
136- setWsLoading ( false ) ;
137- if ( document . visibilityState === "hidden" ) {
138- if ( Notification . permission === "granted" ) {
139- new Notification ( "New message" , {
140- body : 'New message from Postgres.AI Bot' ,
141- icon : '/images/bot_avatar.png'
142- } ) ;
143- }
144- }
115+ const isStreamMessage = messageData . type === 'stream' ;
116+
117+ if ( isThreadMatching || isParentMatching || isDebugMessage || isStateMessage || isStreamMessage ) {
118+ switch ( messageData . type ) {
119+ case 'debug' :
120+ handleDebugMessage ( messageData )
121+ break ;
122+ case 'state' :
123+ handleStateMessage ( messageData , Boolean ( isThreadMatching ) )
124+ break ;
125+ case 'stream' :
126+ handleStreamMessage ( messageData , Boolean ( isThreadMatching ) )
127+ break ;
128+ case 'message' :
129+ handleBotMessage ( messageData )
130+ break ;
145131 }
146132 } else if ( threadId !== messageData . thread_id ) {
147133 const threadInList = chatsList ?. find ( ( item ) => item . thread_id === messageData . thread_id )
148134 if ( ! threadInList ) getChatsList ( )
149- setWsLoading ( false ) ;
135+ if ( currentStreamMessage ) setCurrentStreamMessage ( null )
136+ if ( wsLoading ) setWsLoading ( false ) ;
137+ if ( isStreamingInProcess ) setStreamingInProcess ( false )
150138 }
151139 } else {
152140 showMessage ( 'An error occurred. Please try again' )
@@ -158,6 +146,56 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
158146 setLoading ( false ) ;
159147 }
160148
149+ const handleDebugMessage = ( message : DebugMessage ) => {
150+ let currentDebugMessages = [ ...( debugMessages || [ ] ) ] ;
151+ currentDebugMessages . push ( message )
152+ setDebugMessages ( currentDebugMessages )
153+ }
154+
155+ const handleStateMessage = ( message : StateMessage , isThreadMatching ?: boolean ) => {
156+ if ( isThreadMatching || ! threadId ) {
157+ if ( message . state ) {
158+ setStateMessage ( message )
159+ } else {
160+ setStateMessage ( null )
161+ }
162+ }
163+ }
164+
165+ const handleStreamMessage = ( message : StreamMessage , isThreadMatching ?: boolean ) => {
166+ if ( isThreadMatching || ! threadId ) {
167+ if ( ! isStreamingInProcess ) setStreamingInProcess ( true )
168+ setCurrentStreamMessage ( message )
169+ setWsLoading ( false ) ;
170+ }
171+ }
172+
173+ const handleBotMessage = ( message : BotMessage ) => {
174+ if ( messages && messages . length > 0 ) {
175+ let currentMessages = [ ...messages ] ;
176+ const lastMessage = currentMessages [ currentMessages . length - 1 ] ;
177+ if ( lastMessage && ! lastMessage . id && message . parent_id ) {
178+ lastMessage . id = message . parent_id ;
179+ lastMessage . created_at = message . created_at ;
180+ lastMessage . is_public = message . is_public ;
181+ }
182+
183+ currentMessages . push ( message ) ;
184+ if ( currentStreamMessage ) setCurrentStreamMessage ( null )
185+ setMessages ( currentMessages ) ;
186+ setWsLoading ( false ) ;
187+ setStreamingInProcess ( false ) ;
188+ if ( document . visibilityState === "hidden" ) {
189+ if ( Notification . permission === "granted" ) {
190+ new Notification ( "New message" , {
191+ body : 'New message from Postgres.AI Bot' ,
192+ icon : '/images/bot_avatar.png'
193+ } ) ;
194+ }
195+ }
196+ }
197+ }
198+
161199 const onWebSocketOpen = ( ) => {
162200 console . log ( 'WebSocket connection established' ) ;
163201 if ( threadId ) {
@@ -381,6 +419,8 @@ export const useAiBotProviderValue = (args: UseAiBotArgs): UseAiBotReturnType =>
381419 debugMessages,
382420 debugMessagesLoading,
383421 stateMessage,
422+ isStreamingInProcess,
423+ currentStreamMessage
384424 }
385425}
386426
0 commit comments