77
88import asyncio
99import collections
10+ import functools
1011import getpass
1112import os
1213import pathlib
1314import platform
1415import re
1516import socket
17+ import ssl as ssl_module
1618import stat
1719import struct
1820import time
3234 'password' ,
3335 'database' ,
3436 'ssl' ,
37+ 'ssl_is_advisory' ,
3538 'connect_timeout' ,
3639 'server_settings' ,
3740 ])
@@ -208,6 +211,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
208211 if passfile is None :
209212 passfile = val
210213
214+ if 'sslmode' in query :
215+ val = query .pop ('sslmode' )
216+ if ssl is None :
217+ ssl = val
218+
211219 if query :
212220 if server_settings is None :
213221 server_settings = query
@@ -303,6 +311,47 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
303311 raise ValueError (
304312 'could not determine the database address to connect to' )
305313
314+ if ssl is None :
315+ ssl = os .getenv ('PGSSLMODE' )
316+
317+ # ssl_is_advisory is only allowed to come from the sslmode parameter.
318+ ssl_is_advisory = None
319+ if isinstance (ssl , str ):
320+ SSLMODES = {
321+ 'disable' : 0 ,
322+ 'allow' : 1 ,
323+ 'prefer' : 2 ,
324+ 'require' : 3 ,
325+ 'verify-ca' : 4 ,
326+ 'verify-full' : 5 ,
327+ }
328+ try :
329+ sslmode = SSLMODES [ssl ]
330+ except KeyError :
331+ modes = ', ' .join (SSLMODES .keys ())
332+ raise ValueError ('`sslmode` parameter must be one of ' + modes )
333+
334+ # sslmode 'allow' is currently handled as 'prefer' because we're
335+ # missing the "retry with SSL" behavior for 'allow', but do have the
336+ # "retry without SSL" behavior for 'prefer'.
337+ # Not changing 'allow' to 'prefer' here would be effectively the same
338+ # as changing 'allow' to 'disable'.
339+ if sslmode == SSLMODES ['allow' ]:
340+ sslmode = SSLMODES ['prefer' ]
341+
342+ # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
343+ # Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
344+ if sslmode <= SSLMODES ['allow' ]:
345+ ssl = False
346+ ssl_is_advisory = sslmode >= SSLMODES ['allow' ]
347+ else :
348+ ssl = ssl_module .create_default_context ()
349+ ssl .check_hostname = sslmode >= SSLMODES ['verify-full' ]
350+ ssl .verify_mode = ssl_module .CERT_REQUIRED
351+ if sslmode <= SSLMODES ['require' ]:
352+ ssl .verify_mode = ssl_module .CERT_NONE
353+ ssl_is_advisory = sslmode <= SSLMODES ['prefer' ]
354+
306355 if ssl :
307356 for addr in addrs :
308357 if isinstance (addr , str ):
@@ -321,7 +370,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
321370
322371 params = _ConnectionParameters (
323372 user = user , password = password , database = database , ssl = ssl ,
324- connect_timeout = connect_timeout , server_settings = server_settings )
373+ ssl_is_advisory = ssl_is_advisory , connect_timeout = connect_timeout ,
374+ server_settings = server_settings )
325375
326376 return addrs , params
327377
@@ -384,11 +434,12 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
384434
385435 if isinstance (addr , str ):
386436 # UNIX socket
387- assert params .ssl is None
437+ assert not params .ssl
388438 connector = loop .create_unix_connection (proto_factory , addr )
389439 elif params .ssl :
390440 connector = _create_ssl_connection (
391- proto_factory , * addr , loop = loop , ssl_context = params .ssl )
441+ proto_factory , * addr , loop = loop , ssl_context = params .ssl ,
442+ ssl_is_advisory = params .ssl_is_advisory )
392443 else :
393444 connector = loop .create_connection (proto_factory , * addr )
394445
@@ -435,7 +486,12 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
435486 raise last_error
436487
437488
438- async def _get_ssl_ready_socket (host , port , * , loop ):
489+ async def _negotiate_ssl_connection (host , port , conn_factory , * , loop , ssl ,
490+ server_hostname , ssl_is_advisory = False ):
491+ # Note: ssl_is_advisory only affects behavior when the server does not
492+ # accept SSLRequests. If the SSLRequest is accepted but either the SSL
493+ # negotiation fails or the PostgreSQL user isn't permitted to use SSL,
494+ # there's nothing that would attempt to reconnect with a non-SSL socket.
439495 reader , writer = await asyncio .open_connection (host , port , loop = loop )
440496
441497 tr = writer .transport
@@ -448,44 +504,55 @@ async def _get_ssl_ready_socket(host, port, *, loop):
448504 resp = await reader .readexactly (1 )
449505
450506 if resp == b'S' :
451- return sock .dup ()
507+ conn_factory = functools .partial (
508+ conn_factory , ssl = ssl , server_hostname = server_hostname )
509+ elif (ssl_is_advisory and
510+ ssl .verify_mode == ssl_module .CERT_NONE and
511+ resp == b'N' ):
512+ # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
513+ # since the only way to get ssl_is_advisory is from sslmode=prefer
514+ # (or sslmode=allow). But be extra sure to disallow insecure
515+ # connections when the ssl context asks for real security.
516+ pass
452517 else :
453518 raise ConnectionError (
454519 'PostgreSQL server at "{}:{}" rejected SSL upgrade' .format (
455520 host , port ))
521+
522+ sock = sock .dup () # Must come before tr.close()
456523 finally :
457524 tr .close ()
458525
459-
460- async def _create_ssl_connection (protocol_factory , host , port , * ,
461- loop , ssl_context ):
462- sock = await _get_ssl_ready_socket (host , port , loop = loop )
463526 try :
464- return await loop .create_connection (
465- protocol_factory , sock = sock , ssl = ssl_context ,
466- server_hostname = host )
527+ return await conn_factory (sock = sock ) # Must come after tr.close()
467528 except Exception :
468529 sock .close ()
469530 raise
470531
471532
533+ async def _create_ssl_connection (protocol_factory , host , port , * ,
534+ loop , ssl_context , ssl_is_advisory = False ):
535+ return await _negotiate_ssl_connection (
536+ host , port ,
537+ functools .partial (loop .create_connection , protocol_factory ),
538+ loop = loop ,
539+ ssl = ssl_context ,
540+ server_hostname = host ,
541+ ssl_is_advisory = ssl_is_advisory )
542+
543+
472544async def _open_connection (* , loop , addr , params : _ConnectionParameters ):
473545 if isinstance (addr , str ):
474546 r , w = await asyncio .open_unix_connection (addr , loop = loop )
475547 else :
476548 if params .ssl :
477- sock = await _get_ssl_ready_socket (* addr , loop = loop )
478-
479- try :
480- r , w = await asyncio .open_connection (
481- sock = sock ,
482- loop = loop ,
483- ssl = params .ssl ,
484- server_hostname = addr [0 ])
485- except Exception :
486- sock .close ()
487- raise
488-
549+ r , w = await _negotiate_ssl_connection (
550+ * addr ,
551+ functools .partial (asyncio .open_connection , loop = loop ),
552+ loop = loop ,
553+ ssl = params .ssl ,
554+ server_hostname = addr [0 ],
555+ ssl_is_advisory = params .ssl_is_advisory )
489556 else :
490557 r , w = await asyncio .open_connection (* addr , loop = loop )
491558 _set_nodelay (_get_socket (w .transport ))
0 commit comments