11import asyncio
2+ import contextlib
23from typing import Any
34
45from broadcaster import Broadcast
@@ -22,14 +23,6 @@ class BroadcastNotification(BaseModel):
2223 data : Any = None
2324
2425
25- class EventBroadcasterException (Exception ):
26- pass
27-
28-
29- class BroadcasterAlreadyStarted (EventBroadcasterException ):
30- pass
31-
32-
3326class EventBroadcasterContextManager :
3427 """
3528 Manages the context for the EventBroadcaster
@@ -56,56 +49,18 @@ def __init__(
5649 self ._listen : bool = listen
5750
5851 async def __aenter__ (self ):
59- async with self ._event_broadcaster ._context_manager_lock :
60- if self ._listen :
61- self ._event_broadcaster ._listen_count += 1
62- if self ._event_broadcaster ._listen_count == 1 :
63- # We have our first listener start the read-task for it (And all those who'd follow)
64- logger .info (
65- "Listening for incoming events from broadcast channel (first listener started)"
66- )
67- # Start task listening on incoming broadcasts
68- await self ._event_broadcaster .start_reader_task ()
69-
70- if self ._share :
71- self ._event_broadcaster ._share_count += 1
72- if self ._event_broadcaster ._share_count == 1 :
73- # We have our first publisher
74- # Init the broadcast used for sharing (reading has its own)
75- logger .debug (
76- "Subscribing to ALL_TOPICS, and sharing messages with broadcast channel"
77- )
78- # Subscribe to internal events form our own event notifier and broadcast them
79- await self ._event_broadcaster ._subscribe_to_all_topics ()
80- else :
81- logger .debug (
82- f"Did not subscribe to ALL_TOPICS: share count == { self ._event_broadcaster ._share_count } "
83- )
84- return self
52+ await self ._event_broadcaster .connect (self ._listen , self ._share )
8553
8654 async def __aexit__ (self , exc_type , exc , tb ):
87- async with self ._event_broadcaster ._context_manager_lock :
88- try :
89- if self ._listen :
90- self ._event_broadcaster ._listen_count -= 1
91- # if this was last listener - we can stop the reading task
92- if self ._event_broadcaster ._listen_count == 0 :
93- # Cancel task reading broadcast subscriptions
94- if self ._event_broadcaster ._subscription_task is not None :
95- logger .info ("Cancelling broadcast listen task" )
96- self ._event_broadcaster ._subscription_task .cancel ()
97- self ._event_broadcaster ._subscription_task = None
98-
99- if self ._share :
100- self ._event_broadcaster ._share_count -= 1
101- # if this was last sharer - we can stop subscribing to internal events - we aren't sharing anymore
102- if self ._event_broadcaster ._share_count == 0 :
103- # Unsubscribe from internal events
104- logger .debug ("Unsubscribing from ALL TOPICS" )
105- await self ._event_broadcaster ._unsubscribe_from_topics ()
106-
107- except :
108- logger .exception ("Failed to exit EventBroadcaster context" )
55+ await self ._event_broadcaster .close (self ._listen , self ._share )
56+
57+
58+ class EventBroadcasterException (Exception ):
59+ pass
60+
61+
62+ class BroadcasterAlreadyStarted (EventBroadcasterException ):
63+ pass
10964
11065
11166class EventBroadcaster :
@@ -137,62 +92,46 @@ def __init__(
13792 broadcast_type (Broadcast, optional): Broadcast class to use. None - Defaults to Broadcast.
13893 is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False
13994 """
140- # Broadcast init params
14195 self ._broadcast_url = broadcast_url
14296 self ._broadcast_type = broadcast_type or Broadcast
143- # Publish broadcast (initialized within async with statement)
144- self ._sharing_broadcast_channel = None
145- # channel to operate on
14697 self ._channel = channel
147- # Async-io task for reading broadcasts (initialized within async with statement)
14898 self ._subscription_task = None
149- # Uniqueue instance id (used to avoid reading own notifications sent in broadcast)
15099 self ._id = gen_uid ()
151- # The internal events notifier
152100 self ._notifier = notifier
101+ self ._broadcast_channel = None
102+ self ._connect_lock = asyncio .Lock ()
103+ self ._listen_refcount = 0
104+ self ._share_refcount = 0
153105 self ._is_publish_only = is_publish_only
154- self ._publish_lock = None
155- # used to track creation / removal of resources needed per type (reader task->listen, and subscription to internal events->share)
156- self ._listen_count : int = 0
157- self ._share_count : int = 0
158- # If we opt to manage the context directly (i.e. call async with on the event broadcaster itself)
159- self ._context_manager = None
160- self ._context_manager_lock = asyncio .Lock ()
161- self ._tasks = set ()
162- self .listening_broadcast_channel = None
163106
164- async def __broadcast_notifications__ (self , subscription : Subscription , data ):
107+ async def connect (self , listen = True , share = True ):
165108 """
166- Share incoming internal notifications with the entire broadcast channel
167-
168- Args:
169- subscription (Subscription): the subscription that got triggered
170- data: the event data
109+ This connects the listening channel
171110 """
172- logger .info (
173- "Broadcasting incoming event: {}" .format (
174- {"topic" : subscription .topic , "notifier_id" : self ._id }
175- )
176- )
177- note = BroadcastNotification (
178- notifier_id = self ._id , topics = [subscription .topic ], data = data
179- )
111+ async with self ._connect_lock :
112+ if listen :
113+ await self ._connect_listen ()
114+ self ._listen_refcount += 1
180115
181- # Publish event to broadcast
182- async with self ._broadcast_type (
183- self ._broadcast_url
184- ) as sharing_broadcast_channel :
185- await sharing_broadcast_channel .publish (
186- self ._channel , pydantic_serialize (note )
187- )
116+ if share :
117+ await self ._connect_share ()
118+ self ._share_refcount += 1
188119
189- async def _subscribe_to_all_topics (self ):
190- return await self ._notifier .subscribe (
191- self ._id , ALL_TOPICS , self .__broadcast_notifications__
192- )
120+ async def close (self , listen = True , share = True ):
121+ async with self ._connect_lock :
122+ if listen :
123+ await self ._close_listen ()
124+ self ._listen_refcount -= 1
125+
126+ if share :
127+ await self ._close_share ()
128+ self ._share_refcount -= 1
129+
130+ async def __aenter__ (self ):
131+ await self .connect (listen = not self ._is_publish_only )
193132
194- async def _unsubscribe_from_topics (self ):
195- return await self ._notifier . unsubscribe ( self ._id )
133+ async def __aexit__ (self , exc_type , exc , tb ):
134+ await self .close ( listen = not self ._is_publish_only )
196135
197136 def get_context (self , listen = True , share = True ):
198137 """
@@ -213,97 +152,115 @@ def get_listening_context(self):
213152 def get_sharing_context (self ):
214153 return EventBroadcasterContextManager (self , listen = False , share = True )
215154
216- async def __aenter__ (self ):
155+ async def __broadcast_notifications__ (self , subscription : Subscription , data ):
217156 """
218- Convince caller (also backward compaltability)
157+ Share incoming internal notifications with the entire broadcast channel
158+
159+ Args:
160+ subscription (Subscription): the subscription that got triggered
161+ data: the event data
219162 """
220- if self ._context_manager is None :
221- self ._context_manager = self .get_context (listen = not self ._is_publish_only )
222- return await self ._context_manager .__aenter__ ()
163+ logger .info (
164+ "Broadcasting incoming event: {}" .format (
165+ {"topic" : subscription .topic , "notifier_id" : self ._id }
166+ )
167+ )
223168
224- async def __aexit__ (self , exc_type , exc , tb ):
225- await self ._context_manager .__aexit__ (exc_type , exc , tb )
169+ note = BroadcastNotification (
170+ notifier_id = self ._id , topics = [subscription .topic ], data = data
171+ )
226172
227- async def start_reader_task (self ):
228- """Spawn a task reading incoming broadcasts and posting them to the intreal notifier
229- Raises:
230- BroadcasterAlreadyStarted: if called more than once per context
231- Returns:
232- the spawned task
233- """
234- # Make sure a task wasn't started already
235- if self ._subscription_task is not None :
236- # we already started a task for this worker process
237- logger .debug (
238- "No need for listen task, already started broadcast listen task for this notifier"
173+ # Publish event to broadcast using a new connection from connection pool
174+ async with self ._broadcast_type (
175+ self ._broadcast_url
176+ ) as sharing_broadcast_channel :
177+ await sharing_broadcast_channel .publish (
178+ self ._channel , pydantic_serialize (note )
239179 )
240- return
241180
242- # Init new broadcast channel for reading
243- try :
244- if self .listening_broadcast_channel is None :
245- self .listening_broadcast_channel = self ._broadcast_type (
246- self ._broadcast_url
247- )
248- await self .listening_broadcast_channel .connect ()
249- except Exception as e :
250- logger .error (
251- f"Failed to connect to broadcast channel for reading incoming events: { e } "
181+ async def _connect_share (self ):
182+ if self ._share_refcount == 0 :
183+ return await self ._notifier .subscribe (
184+ self ._id , ALL_TOPICS , self .__broadcast_notifications__
252185 )
253- raise e
254186
255- # Trigger the task
256- logger .debug ("Spawning broadcast listen task" )
257- self ._subscription_task = asyncio .create_task (self .__read_notifications__ ())
258- return self ._subscription_task
187+ async def _close_share (self ):
188+ if self ._share_refcount == 1 :
189+ return await self ._notifier .unsubscribe (self ._id )
190+
191+ async def _connect_listen (self ):
192+ if self ._listen_refcount == 0 :
193+ if self ._listen_refcount == 0 :
194+ try :
195+ self ._broadcast_channel = self ._broadcast_type (self ._broadcast_url )
196+ await self ._broadcast_channel .connect ()
197+ except Exception as e :
198+ logger .error (
199+ f"Failed to connect to broadcast channel for reading incoming events: { e } "
200+ )
201+ raise e
202+ self ._subscription_task = asyncio .create_task (
203+ self .__read_notifications__ ()
204+ )
205+ return await self ._notifier .subscribe (
206+ self ._id , ALL_TOPICS , self .__broadcast_notifications__
207+ )
208+
209+ async def _close_listen (self ):
210+ if self ._listen_refcount == 1 and self ._broadcast_channel is not None :
211+ await self ._broadcast_channel .disconnect ()
212+ await self .wait_until_done ()
213+ self ._broadcast_channel = None
259214
260215 def get_reader_task (self ):
261216 return self ._subscription_task
262217
218+ async def wait_until_done (self ):
219+ if self ._subscription_task is not None :
220+ await self ._subscription_task
221+ self ._subscription_task = None
222+
263223 async def __read_notifications__ (self ):
264224 """
265225 read incoming broadcasts and posting them to the intreal notifier
266226 """
267227 logger .debug ("Starting broadcaster listener" )
228+
229+ notify_tasks = set ()
268230 try :
269231 # Subscribe to our channel
270- async with self .listening_broadcast_channel .subscribe (
232+ async with self ._broadcast_channel .subscribe (
271233 channel = self ._channel
272234 ) as subscriber :
273235 async for event in subscriber :
274- try :
275- notification = BroadcastNotification .parse_raw (event .message )
276- # Avoid re-publishing our own broadcasts
277- if notification .notifier_id != self ._id :
278- logger .debug (
279- "Handling incoming broadcast event: {}" .format (
280- {
281- "topics" : notification .topics ,
282- "src" : notification .notifier_id ,
283- }
284- )
236+ notification = BroadcastNotification .parse_raw (event .message )
237+ # Avoid re-publishing our own broadcasts
238+ if notification .notifier_id != self ._id :
239+ logger .debug (
240+ "Handling incoming broadcast event: {}" .format (
241+ {
242+ "topics" : notification .topics ,
243+ "src" : notification .notifier_id ,
244+ }
285245 )
286- # Notify subscribers of message received from broadcast
287- task = asyncio . create_task (
288- self . _notifier . notify (
289- notification . topics ,
290- notification .data ,
291- notifier_id = self . _id ,
292- )
246+ )
247+ # Notify subscribers of message received from broadcast
248+ task = asyncio . create_task (
249+ self . _notifier . notify (
250+ notification .topics ,
251+ notification . data ,
252+ notifier_id = self . _id ,
293253 )
254+ )
294255
295- self . _tasks .add (task )
256+ notify_tasks .add (task )
296257
297- def cleanup (task ):
298- self . _tasks . remove (task )
258+ def cleanup (t ):
259+ notify_tasks . remove (t )
299260
300- task .add_done_callback (cleanup )
301- except :
302- logger .exception ("Failed handling incoming broadcast" )
261+ task .add_done_callback (cleanup )
303262 logger .info (
304263 "No more events to read from subscriber (underlying connection closed)"
305264 )
306265 finally :
307- if self .listening_broadcast_channel is not None :
308- await self .listening_broadcast_channel .disconnect ()
309- self .listening_broadcast_channel = None
266+ await asyncio .gather (* notify_tasks , return_exceptions = True )
0 commit comments