1313import os
1414import pathlib
1515import platform
16+ import random
1617import re
1718import socket
1819import ssl as ssl_module
@@ -56,6 +57,7 @@ def parse(cls, sslmode):
5657 'direct_tls' ,
5758 'connect_timeout' ,
5859 'server_settings' ,
60+ 'target_session_attrs' ,
5961 ])
6062
6163
@@ -260,7 +262,8 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
260262
261263def _parse_connect_dsn_and_args (* , dsn , host , port , user ,
262264 password , passfile , database , ssl ,
263- direct_tls , connect_timeout , server_settings ):
265+ direct_tls , connect_timeout , server_settings ,
266+ target_session_attrs ):
264267 # `auth_hosts` is the version of host information for the purposes
265268 # of reading the pgpass file.
266269 auth_hosts = None
@@ -607,10 +610,28 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
607610 'server_settings is expected to be None or '
608611 'a Dict[str, str]' )
609612
613+ if target_session_attrs is None :
614+
615+ target_session_attrs = os .getenv (
616+ "PGTARGETSESSIONATTRS" , SessionAttribute .any
617+ )
618+ try :
619+
620+ target_session_attrs = SessionAttribute (target_session_attrs )
621+ except ValueError as exc :
622+ raise exceptions .InterfaceError (
623+ "target_session_attrs is expected to be one of "
624+ "{!r}"
625+ ", got {!r}" .format (
626+ SessionAttribute .__members__ .values , target_session_attrs
627+ )
628+ ) from exc
629+
610630 params = _ConnectionParameters (
611631 user = user , password = password , database = database , ssl = ssl ,
612632 sslmode = sslmode , direct_tls = direct_tls ,
613- connect_timeout = connect_timeout , server_settings = server_settings )
633+ connect_timeout = connect_timeout , server_settings = server_settings ,
634+ target_session_attrs = target_session_attrs )
614635
615636 return addrs , params
616637
@@ -620,8 +641,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
620641 statement_cache_size ,
621642 max_cached_statement_lifetime ,
622643 max_cacheable_statement_size ,
623- ssl , direct_tls , server_settings ):
624-
644+ ssl , direct_tls , server_settings ,
645+ target_session_attrs ):
625646 local_vars = locals ()
626647 for var_name in {'max_cacheable_statement_size' ,
627648 'max_cached_statement_lifetime' ,
@@ -649,7 +670,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
649670 dsn = dsn , host = host , port = port , user = user ,
650671 password = password , passfile = passfile , ssl = ssl ,
651672 direct_tls = direct_tls , database = database ,
652- connect_timeout = timeout , server_settings = server_settings )
673+ connect_timeout = timeout , server_settings = server_settings ,
674+ target_session_attrs = target_session_attrs )
653675
654676 config = _ClientConfiguration (
655677 command_timeout = command_timeout ,
@@ -882,18 +904,84 @@ async def __connect_addr(
882904 return con
883905
884906
907+ class SessionAttribute (str , enum .Enum ):
908+ any = 'any'
909+ primary = 'primary'
910+ standby = 'standby'
911+ prefer_standby = 'prefer-standby'
912+ read_write = "read-write"
913+ read_only = "read-only"
914+
915+
916+ def _accept_in_hot_standby (should_be_in_hot_standby : bool ):
917+ """
918+ If the server didn't report "in_hot_standby" at startup, we must determine
919+ the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
920+ If the server allows a connection and states it is in recovery it must
921+ be a replica/standby server.
922+ """
923+ async def can_be_used (connection ):
924+ settings = connection .get_settings ()
925+ hot_standby_status = getattr (settings , 'in_hot_standby' , None )
926+ if hot_standby_status is not None :
927+ is_in_hot_standby = hot_standby_status == 'on'
928+ else :
929+ is_in_hot_standby = await connection .fetchval (
930+ "SELECT pg_catalog.pg_is_in_recovery()"
931+ )
932+ return is_in_hot_standby == should_be_in_hot_standby
933+
934+ return can_be_used
935+
936+
937+ def _accept_read_only (should_be_read_only : bool ):
938+ """
939+ Verify the server has not set default_transaction_read_only=True
940+ """
941+ async def can_be_used (connection ):
942+ settings = connection .get_settings ()
943+ is_readonly = getattr (settings , 'default_transaction_read_only' , 'off' )
944+
945+ if is_readonly == "on" :
946+ return should_be_read_only
947+
948+ return await _accept_in_hot_standby (should_be_read_only )(connection )
949+ return can_be_used
950+
951+
952+ async def _accept_any (_ ):
953+ return True
954+
955+
956+ target_attrs_check = {
957+ SessionAttribute .any : _accept_any ,
958+ SessionAttribute .primary : _accept_in_hot_standby (False ),
959+ SessionAttribute .standby : _accept_in_hot_standby (True ),
960+ SessionAttribute .prefer_standby : _accept_in_hot_standby (True ),
961+ SessionAttribute .read_write : _accept_read_only (False ),
962+ SessionAttribute .read_only : _accept_read_only (True ),
963+ }
964+
965+
966+ async def _can_use_connection (connection , attr : SessionAttribute ):
967+ can_use = target_attrs_check [attr ]
968+ return await can_use (connection )
969+
970+
885971async def _connect (* , loop , timeout , connection_class , record_class , ** kwargs ):
886972 if loop is None :
887973 loop = asyncio .get_event_loop ()
888974
889975 addrs , params , config = _parse_connect_arguments (timeout = timeout , ** kwargs )
976+ target_attr = params .target_session_attrs
890977
978+ candidates = []
979+ chosen_connection = None
891980 last_error = None
892- addr = None
893981 for addr in addrs :
894982 before = time .monotonic ()
895983 try :
896- return await _connect_addr (
984+ conn = await _connect_addr (
897985 addr = addr ,
898986 loop = loop ,
899987 timeout = timeout ,
@@ -902,12 +990,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
902990 connection_class = connection_class ,
903991 record_class = record_class ,
904992 )
993+ candidates .append (conn )
994+ if await _can_use_connection (conn , target_attr ):
995+ chosen_connection = conn
996+ break
905997 except (OSError , asyncio .TimeoutError , ConnectionError ) as ex :
906998 last_error = ex
907999 finally :
9081000 timeout -= time .monotonic () - before
1001+ else :
1002+ if target_attr == SessionAttribute .prefer_standby and candidates :
1003+ chosen_connection = random .choice (candidates )
1004+
1005+ await asyncio .gather (
1006+ (c .close () for c in candidates if c is not chosen_connection ),
1007+ return_exceptions = True
1008+ )
1009+
1010+ if chosen_connection :
1011+ return chosen_connection
9091012
910- raise last_error
1013+ raise last_error or exceptions .TargetServerAttributeNotMatched (
1014+ 'None of the hosts match the target attribute requirement '
1015+ '{!r}' .format (target_attr )
1016+ )
9111017
9121018
9131019async def _cancel (* , loop , addr , params : _ConnectionParameters ,
0 commit comments