66import sys
77from collections import deque
88from signal import Signals
9- from typing import Any , Callable , Deque , List , Optional
9+ from typing import Any , Callable , Deque , List , Optional , cast
1010
11- from pynvim .msgpack_rpc .event_loop .base import BaseEventLoop
11+ if sys .version_info >= (3 , 12 ):
12+ from typing import Final , override
13+ else :
14+ from typing_extensions import Final , override
15+
16+ from pynvim .msgpack_rpc .event_loop .base import BaseEventLoop , TTransportType
1217
1318logger = logging .getLogger (__name__ )
1419debug , info , warn = (logger .debug , logger .info , logger .warning ,)
2732
2833# pylint: disable=logging-fstring-interpolation
2934
30- class AsyncioEventLoop (BaseEventLoop , asyncio .Protocol ,
31- asyncio .SubprocessProtocol ):
32- """`BaseEventLoop` subclass that uses `asyncio` as a backend."""
35+ class Protocol (asyncio .Protocol , asyncio .SubprocessProtocol ):
36+ """The protocol class used for asyncio-based RPC communication."""
3337
34- _queued_data : Deque [bytes ]
35- if os .name != 'nt' :
36- _child_watcher : Optional ['asyncio.AbstractChildWatcher' ]
38+ def __init__ (self , on_data , on_error ):
39+ """Initialize the Protocol object."""
40+ assert on_data is not None
41+ assert on_error is not None
42+ self ._on_data = on_data
43+ self ._on_error = on_error
3744
45+ @override
3846 def connection_made (self , transport ):
3947 """Used to signal `asyncio.Protocol` of a successful connection."""
40- self ._transport = transport
41- self ._raw_transport = transport
42- if isinstance (transport , asyncio .SubprocessTransport ):
43- self ._transport = transport .get_pipe_transport (0 )
48+ del transport # no-op
4449
45- def connection_lost (self , exc ):
50+ @override
51+ def connection_lost (self , exc : Optional [Exception ]) -> None :
4652 """Used to signal `asyncio.Protocol` of a lost connection."""
4753 debug (f"connection_lost: exc = { exc } " )
48- self ._on_error (exc . args [ 0 ] if exc else 'EOF' )
54+ self ._on_error (exc if exc else EOFError () )
4955
56+ @override
5057 def data_received (self , data : bytes ) -> None :
5158 """Used to signal `asyncio.Protocol` of incoming data."""
52- if self ._on_data :
53- self ._on_data (data )
54- return
55- self ._queued_data .append (data )
59+ self ._on_data (data )
5660
57- def pipe_connection_lost (self , fd , exc ):
61+ @override
62+ def pipe_connection_lost (self , fd : int , exc : Optional [Exception ]) -> None :
5863 """Used to signal `asyncio.SubprocessProtocol` of a lost connection."""
5964 debug ("pipe_connection_lost: fd = %s, exc = %s" , fd , exc )
6065 if os .name == 'nt' and fd == 2 : # stderr
6166 # On windows, ignore piped stderr being closed immediately (#505)
6267 return
63- self ._on_error (exc . args [ 0 ] if exc else 'EOF' )
68+ self ._on_error (exc if exc else EOFError () )
6469
70+ @override
6571 def pipe_data_received (self , fd , data ):
6672 """Used to signal `asyncio.SubprocessProtocol` of incoming data."""
6773 if fd == 2 : # stderr fd number
6874 # Ignore stderr message, log only for debugging
6975 debug ("stderr: %s" , str (data ))
70- elif self ._on_data :
71- self ._on_data (data )
72- else :
73- self ._queued_data .append (data )
76+ elif fd == 1 : # stdout
77+ self .data_received (data )
7478
79+ @override
7580 def process_exited (self ) -> None :
7681 """Used to signal `asyncio.SubprocessProtocol` when the child exits."""
7782 debug ("process_exited" )
78- self ._on_error ('EOF' )
83+ self ._on_error (EOFError ())
84+
85+
86+ class AsyncioEventLoop (BaseEventLoop ):
87+ """`BaseEventLoop` subclass that uses core `asyncio` as a backend."""
88+
89+ _protocol : Optional [Protocol ]
90+ _transport : Optional [asyncio .WriteTransport ]
91+ _signals : List [Signals ]
92+ _data_buffer : Deque [bytes ]
93+ if os .name != 'nt' :
94+ _child_watcher : Optional ['asyncio.AbstractChildWatcher' ]
7995
80- def _init (self ) -> None :
81- self ._loop = loop_cls ()
82- self ._queued_data = deque ()
83- self ._fact = lambda : self
96+ def __init__ (self ,
97+ transport_type : TTransportType ,
98+ * args : Any , ** kwargs : Any ):
99+ """asyncio-specific initialization. see BaseEventLoop.__init__."""
100+
101+ # The underlying asyncio event loop.
102+ self ._loop : Final [asyncio .AbstractEventLoop ] = loop_cls ()
103+
104+ # Handle messages from nvim that may arrive before run() starts.
105+ self ._data_buffer = deque ()
106+
107+ def _on_data (data : bytes ) -> None :
108+ if self ._on_data is None :
109+ self ._data_buffer .append (data )
110+ return
111+ self ._on_data (data )
112+
113+ # pylint: disable-next=unnecessary-lambda
114+ self ._protocol_factory = lambda : Protocol (
115+ on_data = _on_data ,
116+ on_error = self ._on_error ,
117+ )
118+ self ._protocol = None
119+
120+ # The communication channel (endpoint) created by _connect_*() method.
121+ self ._transport = None
84122 self ._raw_transport = None
85123 self ._child_watcher = None
86124
125+ super ().__init__ (transport_type , * args , ** kwargs )
126+
127+ @override
87128 def _connect_tcp (self , address : str , port : int ) -> None :
88129 async def connect_tcp ():
89- await self ._loop .create_connection (self ._fact , address , port )
130+ transport , protocol = await self ._loop .create_connection (
131+ self ._protocol_factory , address , port )
90132 debug (f"tcp connection successful: { address } :{ port } " )
133+ self ._transport = transport
134+ self ._protocol = protocol
91135
92136 self ._loop .run_until_complete (connect_tcp ())
93137
138+ @override
94139 def _connect_socket (self , path : str ) -> None :
95140 async def connect_socket ():
96141 if os .name == 'nt' :
97- transport , _ = await self ._loop .create_pipe_connection ( self . _fact , path )
142+ _create_connection = self ._loop .create_pipe_connection
98143 else :
99- transport , _ = await self ._loop .create_unix_connection (self ._fact , path )
100- debug ("socket connection successful: %s" , transport )
144+ _create_connection = self ._loop .create_unix_connection
145+
146+ transport , protocol = await _create_connection (
147+ self ._protocol_factory , path )
148+ debug ("socket connection successful: %s" , self ._transport )
149+ self ._transport = transport
150+ self ._protocol = protocol
101151
102152 self ._loop .run_until_complete (connect_socket ())
103153
154+ @override
104155 def _connect_stdio (self ) -> None :
105156 async def connect_stdin ():
106157 if os .name == 'nt' :
107158 pipe = PipeHandle (msvcrt .get_osfhandle (sys .stdin .fileno ()))
108159 else :
109160 pipe = sys .stdin
110- await self ._loop .connect_read_pipe (self ._fact , pipe )
161+ transport , protocol = await self ._loop .connect_read_pipe (
162+ self ._protocol_factory , pipe )
111163 debug ("native stdin connection successful" )
164+ del transport , protocol
112165 self ._loop .run_until_complete (connect_stdin ())
113166
114167 # Make sure subprocesses don't clobber stdout,
@@ -122,52 +175,74 @@ async def connect_stdout():
122175 else :
123176 pipe = os .fdopen (rename_stdout , 'wb' )
124177
125- await self ._loop .connect_write_pipe (self ._fact , pipe )
178+ transport , protocol = await self ._loop .connect_write_pipe (
179+ self ._protocol_factory , pipe )
126180 debug ("native stdout connection successful" )
127-
181+ self ._transport = transport
182+ self ._protocol = protocol
128183 self ._loop .run_until_complete (connect_stdout ())
129184
185+ @override
130186 def _connect_child (self , argv : List [str ]) -> None :
131187 if os .name != 'nt' :
132188 # see #238, #241
133- _child_watcher = asyncio .get_child_watcher ()
134- _child_watcher .attach_loop (self ._loop )
189+ self . _child_watcher = asyncio .get_child_watcher ()
190+ self . _child_watcher .attach_loop (self ._loop )
135191
136192 async def create_subprocess ():
137- transport : asyncio .SubprocessTransport
138- transport , protocol = await self ._loop .subprocess_exec (self ._fact , * argv )
193+ transport : asyncio .SubprocessTransport # type: ignore
194+ transport , protocol = await self ._loop .subprocess_exec (
195+ self ._protocol_factory , * argv )
139196 pid = transport .get_pid ()
140197 debug ("child subprocess_exec successful, PID = %s" , pid )
141198
199+ self ._transport = cast (asyncio .WriteTransport ,
200+ transport .get_pipe_transport (0 )) # stdin
201+ self ._protocol = protocol
202+
203+ # await until child process have been launched and the transport has
204+ # been established
142205 self ._loop .run_until_complete (create_subprocess ())
143206
207+ @override
144208 def _start_reading (self ) -> None :
145209 pass
146210
211+ @override
147212 def _send (self , data : bytes ) -> None :
213+ assert self ._transport , "connection has not been established."
148214 self ._transport .write (data )
149215
216+ @override
150217 def _run (self ) -> None :
151- while self ._queued_data :
152- data = self ._queued_data .popleft ()
218+ # process the early messages that arrived as soon as the transport
219+ # channels are open and on_data is fully ready to receive messages.
220+ while self ._data_buffer :
221+ data : bytes = self ._data_buffer .popleft ()
153222 if self ._on_data is not None :
154223 self ._on_data (data )
224+
155225 self ._loop .run_forever ()
156226
227+ @override
157228 def _stop (self ) -> None :
158229 self ._loop .stop ()
159230
231+ @override
160232 def _close (self ) -> None :
233+ # TODO close all the transports
161234 if self ._raw_transport is not None :
162- self ._raw_transport .close ()
235+ self ._raw_transport .close () # type: ignore[unreachable]
163236 self ._loop .close ()
164237 if self ._child_watcher is not None :
165238 self ._child_watcher .close ()
166239 self ._child_watcher = None
167240
241+ @override
168242 def _threadsafe_call (self , fn : Callable [[], Any ]) -> None :
169243 self ._loop .call_soon_threadsafe (fn )
170244
245+ @override
171246 def _setup_signals (self , signals : List [Signals ]) -> None :
172247 if os .name == 'nt' :
173248 # add_signal_handler is not supported in win32
@@ -178,6 +253,7 @@ def _setup_signals(self, signals: List[Signals]) -> None:
178253 for signum in self ._signals :
179254 self ._loop .add_signal_handler (signum , self ._on_signal , signum )
180255
256+ @override
181257 def _teardown_signals (self ) -> None :
182258 for signum in self ._signals :
183259 self ._loop .remove_signal_handler (signum )
0 commit comments