1010import logging
1111
1212from distutils .version import LooseVersion
13- from functools import partial
14- from typing import Union
1513
1614import ipykernel
1715import ipykernel .kernelbase
@@ -45,28 +43,15 @@ class SessionWebsocket(session.Session):
4543
4644 parent : BokehKernel
4745
48- def send (self , stream , msg_type , content = None , parent = None , ident = None , buffers = None , track = False , header = None , metadata = None ):
49- if not isinstance (stream , WebsocketStream ):
50- self .parent .log .warn (f"skipping { msg_type } ${ content } " )
51- return
52-
53- msg = self .msg (msg_type , content = content , parent = parent , header = header , metadata = metadata )
54- msg ['channel' ] = stream .channel
55-
56- doc = self .document
57-
58- # Ensure document message handler is only added once
59- try :
60- doc .remove_on_message ("ipywidgets_bokeh" , self .receive )
61- except Exception :
62- pass
63- finally :
64- doc .on_message ("ipywidgets_bokeh" , self .receive )
46+ def __init__ (self , * args , ** kwargs ):
47+ super ().__init__ (* args , ** kwargs )
48+ self .document .on_message ("ipywidgets_bokeh" , self .receive )
6549
50+ def _encode_msg (self , msg : dict [str , Any ], buffers : list [bytes ]) -> bytes | str :
6651 packed = self .pack (msg )
6752
68- data : Union [ bytes , str ]
69- if buffers is not None and len ( buffers ) != 0 :
53+ data : bytes | str
54+ if buffers :
7055 buffers = [packed ] + buffers
7156 nbufs = len (buffers )
7257
@@ -82,23 +67,28 @@ def send(self, stream, msg_type, content=None, parent=None, ident=None, buffers=
8267 data = b"" .join (items )
8368 else :
8469 data = packed .decode ("utf-8" )
85- event = MessageSentEvent (doc , "ipywidgets_bokeh" , data )
70+
71+ return data
72+
73+ def send (self , stream , msg_type , content = None , parent = None , ident = None , buffers : list [bytes ] | None = None , track = False , header = None , metadata = None ):
74+ msg = self .msg (msg_type , content = content , parent = parent , header = header , metadata = metadata )
75+ msg ["channel" ] = getattr (stream , "channel" , "shell" )
76+ data = self ._encode_msg (msg , buffers or [])
77+ event = MessageSentEvent (self .document , "ipywidgets_bokeh" , data )
8678 self ._trigger_change (event )
8779
8880 def receive (self , data : str ) -> None :
8981 msg = json .loads (data )
90- msg_serialized = self .serialize (msg )
9182 if msg ['channel' ] == 'shell' :
92- stream = StreamWrapper (msg [ 'channel' ] )
83+ msg_serialized = self . serialize (msg )
9384 msg_list = [BytesWrap (k ) for k in msg_serialized ]
94- if kernel_version > '6' :
95- cb = partial (self .parent .dispatch_shell , msg_list )
96- if self .document .session_context : # Bokeh Server
97- self .document .add_next_tick_callback (cb )
98- else : # Other Tornado based server
99- self .parent .io_loop .add_callback (cb )
100- else :
101- self .parent .dispatch_shell (stream , msg_list )
85+ async def dispatch_shell ():
86+ parent = self .parent
87+ await parent .dispatch_shell (msg_list )
88+ if self .document .session_context : # Bokeh Server
89+ self .document .add_next_tick_callback (dispatch_shell )
90+ else : # Other Tornado based server
91+ self .parent .io_loop .add_callback (dispatch_shell )
10292
10393 @property
10494 def document (self ):
@@ -131,7 +121,7 @@ def __init__(self):
131121
132122 self .iopub_socket .channel = 'iopub'
133123 self .session .stream = self .iopub_socket
134- self .comm_manager = CommManager ()
124+ self .comm_manager = CommManager (parent = self , kernel = self )
135125 self .shell = None
136126 self .log = logging .getLogger ("ipywidgets_bokeh" )
137127
0 commit comments