@@ -29,7 +29,7 @@ class Connection:
2929 __slots__ = ('_protocol' , '_transport' , '_loop' , '_types_stmt' ,
3030 '_type_by_name_stmt' , '_top_xact' , '_uid' , '_aborted' ,
3131 '_stmt_cache_max_size' , '_stmt_cache' , '_stmts_to_close' ,
32- '_addr' , '_opts' , '_command_timeout' )
32+ '_addr' , '_opts' , '_command_timeout' , '_listeners' )
3333
3434 def __init__ (self , protocol , transport , loop , addr , opts , * ,
3535 statement_cache_size , command_timeout ):
@@ -51,7 +51,44 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
5151
5252 self ._command_timeout = command_timeout
5353
54+ self ._listeners = {}
55+
56+ async def add_listener (self , channel , callback ):
57+ """Add a listener for Postgres notifications.
58+
59+ :param str channel: Channel to listen on.
60+ :param callable callback:
61+ A callable receiving the following arguments:
62+ **connection**: a Connection the callback is registered with;
63+ **pid**: PID of the Postgres server that sent the notification;
64+ **channel**: name of the channel the notification was sent to;
65+ **payload**: the payload.
66+ """
67+ if channel not in self ._listeners :
68+ await self .fetch ('LISTEN {}' .format (channel ))
69+ self ._listeners [channel ] = set ()
70+ self ._listeners [channel ].add (callback )
71+
72+ async def remove_listener (self , channel , callback ):
73+ """Remove a listening callback on the specified channel."""
74+ if channel not in self ._listeners :
75+ return
76+ if callback not in self ._listeners [channel ]:
77+ return
78+ self ._listeners [channel ].remove (callback )
79+ if not self ._listeners [channel ]:
80+ del self ._listeners [channel ]
81+ await self .fetch ('UNLISTEN {}' .format (channel ))
82+
83+ def get_server_pid (self ):
84+ """Return the PID of the Postgres server the connection is bound to."""
85+ return self ._protocol .get_server_pid ()
86+
5487 def get_settings (self ):
88+ """Return connection settings.
89+
90+ :return: :class:`~asyncpg.ConnectionSettings`.
91+ """
5592 return self ._protocol .get_settings ()
5693
5794 def transaction (self , * , isolation = 'read_committed' , readonly = False ,
@@ -269,17 +306,20 @@ async def close(self):
269306 if self .is_closed ():
270307 return
271308 self ._close_stmts ()
309+ self ._listeners = {}
272310 self ._aborted = True
273311 protocol = self ._protocol
274312 await protocol .close ()
275313
276314 def terminate (self ):
277315 """Terminate the connection without waiting for pending data."""
278316 self ._close_stmts ()
317+ self ._listeners = {}
279318 self ._aborted = True
280319 self ._protocol .abort ()
281320
282321 async def reset (self ):
322+ self ._listeners = {}
283323 await self .execute ('''
284324 SET SESSION AUTHORIZATION DEFAULT;
285325 RESET ALL;
@@ -351,6 +391,20 @@ async def cancel():
351391
352392 self ._loop .create_task (cancel ())
353393
394+ def _notify (self , pid , channel , payload ):
395+ if channel not in self ._listeners :
396+ return
397+
398+ for cb in self ._listeners [channel ]:
399+ try :
400+ cb (self , pid , channel , payload )
401+ except Exception as ex :
402+ self ._loop .call_exception_handler ({
403+ 'message' : 'Unhandled exception in asyncpg notification '
404+ 'listener callback {!r}' .format (cb ),
405+ 'exception' : ex
406+ })
407+
354408
355409async def connect (dsn = None , * ,
356410 host = None , port = None ,
0 commit comments