Skip to content

Commit d75208d

Browse files
committed
Refactor EventBroadcaster
1 parent 671c189 commit d75208d

File tree

2 files changed

+98
-194
lines changed

2 files changed

+98
-194
lines changed
Lines changed: 86 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
from typing import Any
34

45
from broadcaster import Broadcast
@@ -30,84 +31,6 @@ class BroadcasterAlreadyStarted(EventBroadcasterException):
3031
pass
3132

3233

33-
class EventBroadcasterContextManager:
34-
"""
35-
Manages the context for the EventBroadcaster
36-
Friend-like class of EventBroadcaster (accessing "protected" members )
37-
"""
38-
39-
def __init__(
40-
self,
41-
event_broadcaster: "EventBroadcaster",
42-
listen: bool = True,
43-
share: bool = True,
44-
) -> None:
45-
"""
46-
Provide a context manager for an EventBroadcaster, managing if it listens to events coming from the broadcaster
47-
and if it subscribes to the internal notifier to share its events with the broadcaster
48-
49-
Args:
50-
event_broadcaster (EventBroadcaster): the broadcaster we manage the context for.
51-
share (bool, optional): Should we share events with the broadcaster. Defaults to True.
52-
listen (bool, optional): Should we listen for incoming events from the broadcaster. Defaults to True.
53-
"""
54-
self._event_broadcaster = event_broadcaster
55-
self._share: bool = share
56-
self._listen: bool = listen
57-
58-
async def __aenter__(self):
59-
async with self._event_broadcaster._context_manager_lock:
60-
if self._listen:
61-
self._event_broadcaster._listen_count += 1
62-
if self._event_broadcaster._listen_count == 1:
63-
# We have our first listener start the read-task for it (And all those who'd follow)
64-
logger.info(
65-
"Listening for incoming events from broadcast channel (first listener started)"
66-
)
67-
# Start task listening on incoming broadcasts
68-
await self._event_broadcaster.start_reader_task()
69-
70-
if self._share:
71-
self._event_broadcaster._share_count += 1
72-
if self._event_broadcaster._share_count == 1:
73-
# We have our first publisher
74-
# Init the broadcast used for sharing (reading has its own)
75-
logger.debug(
76-
"Subscribing to ALL_TOPICS, and sharing messages with broadcast channel"
77-
)
78-
# Subscribe to internal events form our own event notifier and broadcast them
79-
await self._event_broadcaster._subscribe_to_all_topics()
80-
else:
81-
logger.debug(
82-
f"Did not subscribe to ALL_TOPICS: share count == {self._event_broadcaster._share_count}"
83-
)
84-
return self
85-
86-
async def __aexit__(self, exc_type, exc, tb):
87-
async with self._event_broadcaster._context_manager_lock:
88-
try:
89-
if self._listen:
90-
self._event_broadcaster._listen_count -= 1
91-
# if this was last listener - we can stop the reading task
92-
if self._event_broadcaster._listen_count == 0:
93-
# Cancel task reading broadcast subscriptions
94-
if self._event_broadcaster._subscription_task is not None:
95-
logger.info("Cancelling broadcast listen task")
96-
self._event_broadcaster._subscription_task.cancel()
97-
self._event_broadcaster._subscription_task = None
98-
99-
if self._share:
100-
self._event_broadcaster._share_count -= 1
101-
# if this was last sharer - we can stop subscribing to internal events - we aren't sharing anymore
102-
if self._event_broadcaster._share_count == 0:
103-
# Unsubscribe from internal events
104-
logger.debug("Unsubscribing from ALL TOPICS")
105-
await self._event_broadcaster._unsubscribe_from_topics()
106-
107-
except:
108-
logger.exception("Failed to exit EventBroadcaster context")
109-
110-
11134
class EventBroadcaster:
11235
"""
11336
Bridge EventNotifier to work across processes and machines by sharing their events through a broadcasting channel
@@ -135,31 +58,57 @@ def __init__(
13558
notifier (EventNotifier): the event notifier managing our internal events - which will be bridge via the broadcaster
13659
channel (str, optional): Channel name. Defaults to "EventNotifier".
13760
broadcast_type (Broadcast, optional): Broadcast class to use. None - Defaults to Broadcast.
138-
is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False
61+
DEPRECATED is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False
62+
# TODO: Like that?
13963
"""
140-
# Broadcast init params
14164
self._broadcast_url = broadcast_url
14265
self._broadcast_type = broadcast_type or Broadcast
143-
# Publish broadcast (initialized within async with statement)
144-
self._sharing_broadcast_channel = None
145-
# channel to operate on
14666
self._channel = channel
147-
# Async-io task for reading broadcasts (initialized within async with statement)
14867
self._subscription_task = None
149-
# Uniqueue instance id (used to avoid reading own notifications sent in broadcast)
15068
self._id = gen_uid()
151-
# The internal events notifier
15269
self._notifier = notifier
153-
self._is_publish_only = is_publish_only
154-
self._publish_lock = None
155-
# used to track creation / removal of resources needed per type (reader task->listen, and subscription to internal events->share)
156-
self._listen_count: int = 0
157-
self._share_count: int = 0
158-
# If we opt to manage the context directly (i.e. call async with on the event broadcaster itself)
159-
self._context_manager = None
160-
self._context_manager_lock = asyncio.Lock()
161-
self._tasks = set()
162-
self.listening_broadcast_channel = None
70+
self._broadcast_channel = None
71+
self._connect_lock = asyncio.Lock()
72+
self._refcount = 0
73+
is_publish_only = is_publish_only # Depracated
74+
75+
async def connect(self):
76+
async with self._connect_lock:
77+
if self._refcount == 0: # TODO: Is that needed?
78+
try:
79+
self._broadcast_channel = self._broadcast_type(self._broadcast_url)
80+
await self._broadcast_channel.connect()
81+
except Exception as e:
82+
logger.error(
83+
f"Failed to connect to broadcast channel for reading incoming events: {e}"
84+
)
85+
raise e
86+
await self._subscribe_notifier()
87+
self._subscription_task = asyncio.create_task(
88+
self.__read_notifications__()
89+
)
90+
self._refcount += 1
91+
92+
async def _close(self):
93+
if self._broadcast_channel is not None:
94+
await self._unsubscribe_notifier()
95+
await self._broadcast_channel.disconnect()
96+
await self.wait_until_done()
97+
self._broadcast_channel = None
98+
99+
async def close(self):
100+
async with self._connect_lock:
101+
if self._refcount == 0:
102+
return
103+
self._refcount -= 1
104+
if self._refcount == 0:
105+
await self._close()
106+
107+
async def __aenter__(self):
108+
await self.connect()
109+
110+
async def __aexit__(self, exc_type, exc, tb):
111+
await self.close()
163112

164113
async def __broadcast_notifications__(self, subscription: Subscription, data):
165114
"""
@@ -174,136 +123,89 @@ async def __broadcast_notifications__(self, subscription: Subscription, data):
174123
{"topic": subscription.topic, "notifier_id": self._id}
175124
)
176125
)
126+
177127
note = BroadcastNotification(
178128
notifier_id=self._id, topics=[subscription.topic], data=data
179129
)
180130

181-
# Publish event to broadcast
131+
# Publish event to broadcast using a new connection from connection pool
182132
async with self._broadcast_type(
183133
self._broadcast_url
184134
) as sharing_broadcast_channel:
185135
await sharing_broadcast_channel.publish(
186136
self._channel, pydantic_serialize(note)
187137
)
188138

189-
async def _subscribe_to_all_topics(self):
139+
async def _subscribe_notifier(self):
190140
return await self._notifier.subscribe(
191141
self._id, ALL_TOPICS, self.__broadcast_notifications__
192142
)
193143

194-
async def _unsubscribe_from_topics(self):
144+
async def _unsubscribe_notifier(self):
195145
return await self._notifier.unsubscribe(self._id)
196146

197147
def get_context(self, listen=True, share=True):
198-
"""
199-
Create a new context manager you can call 'async with' on, configuring the broadcaster for listening, sharing, or both.
200-
201-
Args:
202-
listen (bool, optional): Should we listen for events incoming from the broadcast channel. Defaults to True.
203-
share (bool, optional): Should we share events with the broadcast channel. Defaults to True.
204-
205-
Returns:
206-
EventBroadcasterContextManager: the context
207-
"""
208-
return EventBroadcasterContextManager(self, listen=listen, share=share)
148+
"""Backward compatibility for the old interface"""
149+
return self
209150

210151
def get_listening_context(self):
211-
return EventBroadcasterContextManager(self, listen=True, share=False)
152+
"""Backward compatibility for the old interface"""
153+
return self
212154

213155
def get_sharing_context(self):
214-
return EventBroadcasterContextManager(self, listen=False, share=True)
215-
216-
async def __aenter__(self):
217-
"""
218-
Convince caller (also backward compaltability)
219-
"""
220-
if self._context_manager is None:
221-
self._context_manager = self.get_context(listen=not self._is_publish_only)
222-
return await self._context_manager.__aenter__()
223-
224-
async def __aexit__(self, exc_type, exc, tb):
225-
await self._context_manager.__aexit__(exc_type, exc, tb)
226-
227-
async def start_reader_task(self):
228-
"""Spawn a task reading incoming broadcasts and posting them to the intreal notifier
229-
Raises:
230-
BroadcasterAlreadyStarted: if called more than once per context
231-
Returns:
232-
the spawned task
233-
"""
234-
# Make sure a task wasn't started already
235-
if self._subscription_task is not None:
236-
# we already started a task for this worker process
237-
logger.debug(
238-
"No need for listen task, already started broadcast listen task for this notifier"
239-
)
240-
return
241-
242-
# Init new broadcast channel for reading
243-
try:
244-
if self.listening_broadcast_channel is None:
245-
self.listening_broadcast_channel = self._broadcast_type(
246-
self._broadcast_url
247-
)
248-
await self.listening_broadcast_channel.connect()
249-
except Exception as e:
250-
logger.error(
251-
f"Failed to connect to broadcast channel for reading incoming events: {e}"
252-
)
253-
raise e
254-
255-
# Trigger the task
256-
logger.debug("Spawning broadcast listen task")
257-
self._subscription_task = asyncio.create_task(self.__read_notifications__())
258-
return self._subscription_task
156+
"""Backward compatibility for the old interface"""
157+
return self
259158

260159
def get_reader_task(self):
261160
return self._subscription_task
262161

162+
async def wait_until_done(self):
163+
if self._subscription_task is not None:
164+
await self._subscription_task
165+
self._subscription_task = None
166+
263167
async def __read_notifications__(self):
264168
"""
265169
read incoming broadcasts and posting them to the intreal notifier
266170
"""
267171
logger.debug("Starting broadcaster listener")
172+
173+
notify_tasks = set()
268174
try:
269175
# Subscribe to our channel
270-
async with self.listening_broadcast_channel.subscribe(
176+
async with self._broadcast_channel.subscribe(
271177
channel=self._channel
272178
) as subscriber:
273179
async for event in subscriber:
274-
try:
275-
notification = BroadcastNotification.parse_raw(event.message)
276-
# Avoid re-publishing our own broadcasts
277-
if notification.notifier_id != self._id:
278-
logger.debug(
279-
"Handling incoming broadcast event: {}".format(
280-
{
281-
"topics": notification.topics,
282-
"src": notification.notifier_id,
283-
}
284-
)
180+
notification = BroadcastNotification.parse_raw(event.message)
181+
# Avoid re-publishing our own broadcasts
182+
if notification.notifier_id != self._id:
183+
logger.debug(
184+
"Handling incoming broadcast event: {}".format(
185+
{
186+
"topics": notification.topics,
187+
"src": notification.notifier_id,
188+
}
285189
)
286-
# Notify subscribers of message received from broadcast
287-
task = asyncio.create_task(
288-
self._notifier.notify(
289-
notification.topics,
290-
notification.data,
291-
notifier_id=self._id,
292-
)
190+
)
191+
# Notify subscribers of message received from broadcast
192+
task = asyncio.create_task(
193+
self._notifier.notify(
194+
notification.topics,
195+
notification.data,
196+
notifier_id=self._id,
293197
)
198+
)
294199

295-
self._tasks.add(task)
200+
notify_tasks.add(task)
296201

297-
def cleanup(task):
298-
self._tasks.remove(task)
202+
def cleanup(t):
203+
notify_tasks.remove(t)
299204

300-
task.add_done_callback(cleanup)
301-
except:
302-
logger.exception("Failed handling incoming broadcast")
205+
task.add_done_callback(cleanup)
303206
logger.info(
304207
"No more events to read from subscriber (underlying connection closed)"
305208
)
306209
finally:
307-
if self.listening_broadcast_channel is not None:
308-
await self.listening_broadcast_channel.disconnect()
309-
self.listening_broadcast_channel = None
210+
# TODO: return_exceptions?
211+
await asyncio.gather(*notify_tasks, return_exceptions=True)

fastapi_websocket_pubsub/pub_sub_server.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
on_connect: List[Coroutine] = None,
3535
on_disconnect: List[Coroutine] = None,
3636
rpc_channel_get_remote_id: bool = False,
37-
ignore_broadcaster_disconnected = True,
37+
ignore_broadcaster_disconnected=True,
3838
):
3939
"""
4040
The PubSub endpoint recives subscriptions from clients and publishes data back to them upon receiving relevant publications.
@@ -65,7 +65,7 @@ def __init__(
6565
broadcaster
6666
if isinstance(broadcaster, EventBroadcaster) or broadcaster is None
6767
else EventBroadcaster(broadcaster, self.notifier)
68-
)
68+
) # TODO: Connect broadcaster if needed
6969
self.methods = (
7070
methods_class(self.notifier)
7171
if methods_class is not None
@@ -132,21 +132,23 @@ async def main_loop(self, websocket: WebSocket, client_id: str = None, **kwargs)
132132
async with self.broadcaster:
133133
logger.debug("Entering endpoint's main loop with broadcaster")
134134
if self._ignore_broadcaster_disconnected:
135-
await self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
135+
await self.endpoint.main_loop(
136+
websocket, client_id=client_id, **kwargs
137+
)
136138
else:
137139
main_loop_task = asyncio.create_task(
138-
self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
140+
self.endpoint.main_loop(
141+
websocket, client_id=client_id, **kwargs
142+
)
143+
)
144+
done, pending = await asyncio.wait(
145+
[main_loop_task, self.broadcaster.get_reader_task()],
146+
return_when=asyncio.FIRST_COMPLETED,
139147
)
140-
done, pending = await asyncio.wait([main_loop_task,
141-
self.broadcaster.get_reader_task()],
142-
return_when=asyncio.FIRST_COMPLETED)
143148
logger.debug(f"task is done: {done}")
144149
# broadcaster's reader task is used by other endpoints and shouldn't be cancelled
145150
if main_loop_task in pending:
146151
main_loop_task.cancel()
147-
else:
148-
logger.debug("Entering endpoint's main loop without broadcaster")
149-
await self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
150152

151153
logger.debug("Leaving endpoint's main loop")
152154

0 commit comments

Comments
 (0)