2222from .h3 import H3Protocol
2323from ..config import Config
2424from ..events import Closed , Event , RawData
25- from ..typing import AppWrapper , TaskGroup , WorkerContext
25+ from ..typing import AppWrapper , TaskGroup , WorkerContext , Timer
2626
2727
2828class QuicProtocol :
@@ -40,6 +40,7 @@ def __init__(
4040 self .context = context
4141 self .connections : Dict [bytes , QuicConnection ] = {}
4242 self .http_connections : Dict [QuicConnection , H3Protocol ] = {}
43+ self .timers : Dict [QuicConnection , Timer ] = {}
4344 self .send = send
4445 self .server = server
4546 self .task_group = task_group
@@ -82,10 +83,12 @@ async def handle(self, event: Event) -> None:
8283 )
8384 self .connections [header .destination_cid ] = connection
8485 self .connections [connection .host_cid ] = connection
86+ # This partial() needs python >= 3.8
87+ self .timers [connection ] = self .task_group .create_timer (partial (self ._timeout , connection ))
8588
8689 if connection is not None :
8790 connection .receive_datagram (event .data , event .address , now = self .context .time ())
88- await self ._handle_events (connection , event . address )
91+ await self ._wake_up_timer (connection )
8992 elif isinstance (event , Closed ):
9093 pass
9194
@@ -99,7 +102,16 @@ async def _handle_events(
99102 event = connection .next_event ()
100103 while event is not None :
101104 if isinstance (event , ConnectionTerminated ):
102- pass
105+ await self .timers [connection ].stop ()
106+ del self .timers [connection ]
107+ # XXXRTH This is not the speediest! Better would be tracking
108+ # assigned ids in a set.
109+ prune = []
110+ for tcid , tconn in self .connections .items ():
111+ if tconn == connection :
112+ prune .append (tcid )
113+ for tcid in prune :
114+ del self .connections [tcid ]
103115 elif isinstance (event , ProtocolNegotiated ):
104116 self .http_connections [connection ] = H3Protocol (
105117 self .app ,
@@ -109,7 +121,7 @@ async def _handle_events(
109121 client ,
110122 self .server ,
111123 connection ,
112- partial (self .send_all , connection ),
124+ partial (self ._wake_up_timer , connection ),
113125 )
114126 elif isinstance (event , ConnectionIdIssued ):
115127 self .connections [event .connection_id ] = connection
@@ -121,15 +133,20 @@ async def _handle_events(
121133
122134 event = connection .next_event ()
123135
136+ async def _wake_up_timer (self , connection : QuicConnection ):
137+ # When new output is send, or new input is received, we
138+ # fire the timer right away so we update our state.
139+ timer = self .timers .get (connection )
140+ if timer is not None :
141+ await timer .schedule (0.0 )
142+
143+ async def _timeout (self , connection : QuicConnection ):
144+ now = self .context .time ()
145+ when = connection .get_timer ()
146+ if when is not None and now > when :
147+ connection .handle_timer (now )
148+ await self ._handle_events (connection , None )
124149 await self .send_all (connection )
125-
126- timer = connection .get_timer ()
150+ timer = self .timers .get (connection )
127151 if timer is not None :
128- self .task_group .spawn (self ._handle_timer , timer , connection )
129-
130- async def _handle_timer (self , timer : float , connection : QuicConnection ) -> None :
131- wait = max (0 , timer - self .context .time ())
132- await self .context .sleep (wait )
133- if connection ._close_at is not None :
134- connection .handle_timer (now = self .context .time ())
135- await self ._handle_events (connection , None )
152+ await timer .schedule (connection .get_timer ())
0 commit comments