2121
2222from __future__ import division
2323
24+ from base64 import b64encode
2425from collections import deque
2526from io import BytesIO
2627import logging
27- from os import environ
28+ from os import makedirs , open as os_open , write as os_write , close as os_close , O_CREAT , O_APPEND , O_WRONLY
29+ from os .path import dirname , isfile
2830from select import select
2931from socket import create_connection , SHUT_RDWR
32+ from ssl import HAS_SNI , SSLError
3033from struct import pack as struct_pack , unpack as struct_unpack , unpack_from as struct_unpack_from
3134
32- from ..meta import version
33- from .compat import hex2 , secure_socket
35+ from .constants import DEFAULT_PORT , DEFAULT_USER_AGENT , KNOWN_HOSTS , MAGIC_PREAMBLE , \
36+ SECURITY_NONE , SECURITY_TRUST_ON_FIRST_USE
37+ from .compat import hex2
3438from .exceptions import ProtocolError
3539from .packstream import Packer , Unpacker
3640
3741
38- DEFAULT_PORT = 7687
39- DEFAULT_USER_AGENT = "neo4j-python/%s" % version
40-
41- MAGIC_PREAMBLE = 0x6060B017
42-
4342# Signature bytes for each message type
4443INIT = b"\x01 " # 0000 0001 // INIT <user_agent>
4544RESET = b"\x0F " # 0000 1111 // RESET
@@ -211,14 +210,18 @@ def __init__(self, sock, **config):
211210 user_agent = config .get ("user_agent" , DEFAULT_USER_AGENT )
212211 if isinstance (user_agent , bytes ):
213212 user_agent = user_agent .decode ("UTF-8" )
213+ self .user_agent = user_agent
214+
215+ # Pick up the server certificate, if any
216+ self .der_encoded_server_certificate = config .get ("der_encoded_server_certificate" )
214217
215218 def on_failure (metadata ):
216219 raise ProtocolError ("Initialisation failed" )
217220
218221 response = Response (self )
219222 response .on_failure = on_failure
220223
221- self .append (INIT , (user_agent ,), response = response )
224+ self .append (INIT , (self . user_agent ,), response = response )
222225 self .send ()
223226 while not response .complete :
224227 self .fetch ()
@@ -313,7 +316,39 @@ def close(self):
313316 self .closed = True
314317
315318
316- def connect (host , port = None , ** config ):
319+ def verify_certificate (host , der_encoded_certificate ):
320+ base64_encoded_certificate = b64encode (der_encoded_certificate )
321+ if isfile (KNOWN_HOSTS ):
322+ with open (KNOWN_HOSTS ) as f_in :
323+ for line in f_in :
324+ known_host , _ , known_cert = line .strip ().partition (":" )
325+ if host == known_host :
326+ if base64_encoded_certificate == known_cert :
327+ # Certificate match
328+ return
329+ else :
330+ # Certificate mismatch
331+ print (base64_encoded_certificate )
332+ print (known_cert )
333+ raise ProtocolError ("Server certificate does not match known certificate for %r; check "
334+ "details in file %r" % (host , KNOWN_HOSTS ))
335+ # First use (no hosts match)
336+ try :
337+ makedirs (dirname (KNOWN_HOSTS ))
338+ except OSError :
339+ pass
340+ f_out = os_open (KNOWN_HOSTS , O_CREAT | O_APPEND | O_WRONLY , 0o600 ) # TODO: Windows
341+ if isinstance (host , bytes ):
342+ os_write (f_out , host )
343+ else :
344+ os_write (f_out , host .encode ("utf-8" ))
345+ os_write (f_out , b":" )
346+ os_write (f_out , base64_encoded_certificate )
347+ os_write (f_out , b"\n " )
348+ os_close (f_out )
349+
350+
351+ def connect (host , port = None , ssl_context = None , ** config ):
317352 """ Connect and perform a handshake and return a valid Connection object, assuming
318353 a protocol version can be agreed.
319354 """
@@ -323,10 +358,25 @@ def connect(host, port=None, **config):
323358 if __debug__ : log_info ("~~ [CONNECT] %s %d" , host , port )
324359 s = create_connection ((host , port ))
325360
326- # Secure the connection if so requested
327- if config . get ( "secure" , False ) :
361+ # Secure the connection if an SSL context has been provided
362+ if ssl_context :
328363 if __debug__ : log_info ("~~ [SECURE] %s" , host )
329- s = secure_socket (s , host )
364+ try :
365+ s = ssl_context .wrap_socket (s , server_hostname = host if HAS_SNI else None )
366+ except SSLError as cause :
367+ error = ProtocolError ("Cannot establish secure connection; %s" % cause .args [1 ])
368+ error .__cause__ = cause
369+ raise error
370+ else :
371+ # Check that the server provides a certificate
372+ der_encoded_server_certificate = s .getpeercert (binary_form = True )
373+ if der_encoded_server_certificate is None :
374+ raise ProtocolError ("When using a secure socket, the server should always provide a certificate" )
375+ security = config .get ("security" , SECURITY_NONE )
376+ if security == SECURITY_TRUST_ON_FIRST_USE :
377+ verify_certificate (host , der_encoded_server_certificate )
378+ else :
379+ der_encoded_server_certificate = None
330380
331381 # Send details of the protocol versions supported
332382 supported_versions = [1 , 0 , 0 , 0 ]
@@ -360,4 +410,4 @@ def connect(host, port=None, **config):
360410 s .shutdown (SHUT_RDWR )
361411 s .close ()
362412 else :
363- return Connection (s , ** config )
413+ return Connection (s , der_encoded_server_certificate = der_encoded_server_certificate , ** config )
0 commit comments