Skip to content

Commit 3846828

Browse files
author
Andrey Zelenchuk
committed
Fix losing initial payload because of the race.
1 parent b6a9d53 commit 3846828

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

channels_graphql_ws/graphql_ws_consumer.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -706,9 +706,6 @@ async def _register_subscription(
706706
# `_sids_by_group` without any locks.
707707
self._assert_thread()
708708

709-
# The subject we will trigger on the `broadcast` message.
710-
trigger = rx.subjects.Subject()
711-
712709
# The subscription notification queue.
713710
notification_queue = asyncio.Queue(
714711
maxsize=self.subscription_notification_queue_limit
@@ -720,56 +717,41 @@ async def _register_subscription(
720717

721718
# Start an endless task which listens the `notification_queue`
722719
# and invokes subscription "resolver" on new notifications.
723-
async def notifier():
720+
async def notifier(observer: rx.Observer):
724721
"""Watch the notification queue and notify clients."""
725722

726723
# Assert we run in a proper thread.
727724
self._assert_thread()
728-
729-
# Dirty hack to partially workaround the race between:
730-
# 1) call to `result.subscribe` in `_on_gql_start`; and
731-
# 2) call to `trigger.on_next` below in this function.
732-
# The first call must be earlier. Otherwise, first one or more notifications
733-
# may be lost.
734-
await asyncio.sleep(1)
735-
736725
while True:
737726
serialized_payload = await notification_queue.get()
738727

739728
# Run a subscription's `publish` method (invoked by the
740-
# `trigger.on_next` function) within the threadpool used
729+
# `observer.on_next` function) within the threadpool used
741730
# for processing other GraphQL resolver functions.
742731
# NOTE: it is important to run the deserialization
743732
# in the worker thread as well.
744733
def workload():
745734
try:
746735
payload = Serializer.deserialize(serialized_payload)
747736
except Exception as ex: # pylint: disable=broad-except
748-
trigger.on_error(f"Cannot deserialize payload. {ex}")
737+
observer.on_error(f"Cannot deserialize payload. {ex}")
749738
else:
750-
trigger.on_next(payload)
739+
observer.on_next(payload)
751740

752741
await self._run_in_worker(workload)
753742

754743
# Message processed. This allows `Queue.join` to work.
755744
notification_queue.task_done()
756745

757-
# Enqueue the `publish` method execution. But do not notify
758-
# clients when `publish` returns `SKIP`.
759-
stream = trigger.map(publish_callback).filter( # pylint: disable=no-member
760-
lambda publish_returned: publish_returned is not self.SKIP
761-
)
762-
746+
def push_payloads(observer: rx.Observer):
763747
# Start listening for broadcasts (subscribe to the Channels
764748
# groups), spawn the notification processing task and put
765749
# subscription information into the registry.
766750
# NOTE: Update of `_sids_by_group` & `_subscriptions` must be
767751
# atomic i.e. without `awaits` in between.
768-
waitlist = []
769752
for group in groups:
770753
self._sids_by_group.setdefault(group, []).append(operation_id)
771-
waitlist.append(self._channel_layer.group_add(group, self.channel_name))
772-
notifier_task = self._spawn_background_task(notifier())
754+
notifier_task = self._spawn_background_task(notifier(observer))
773755
self._subscriptions[operation_id] = self._SubInf(
774756
groups=groups,
775757
sid=operation_id,
@@ -778,9 +760,20 @@ def workload():
778760
notifier_task=notifier_task,
779761
)
780762

781-
await asyncio.wait(waitlist)
763+
await asyncio.wait(
764+
[
765+
self._channel_layer.group_add(group, self.channel_name)
766+
for group in groups
767+
]
768+
)
782769

783-
return stream
770+
# Enqueue the `publish` method execution. But do not notify
771+
# clients when `publish` returns `SKIP`.
772+
return (
773+
rx.Observable.create(push_payloads) # pylint: disable=no-member
774+
.map(publish_callback)
775+
.filter(lambda publish_returned: publish_returned is not self.SKIP)
776+
)
784777

785778
async def _on_gql_stop(self, operation_id):
786779
"""Process the STOP message.

0 commit comments

Comments
 (0)