2121
2222from __future__ import division
2323
24+ from base64 import b64encode
2425from collections import deque
2526from io import BytesIO
2627import logging
27- from os import environ
28+ from os import makedirs , open as os_open , write as os_write , close as os_close , O_CREAT , O_APPEND , O_WRONLY
29+ from os .path import dirname , isfile
2830from select import select
2931from socket import create_connection , SHUT_RDWR
32+ from ssl import HAS_SNI , SSLError
3033from struct import pack as struct_pack , unpack as struct_unpack , unpack_from as struct_unpack_from
3134
32- from ..meta import version
33- from .compat import hex2 , secure_socket
35+ from .constants import DEFAULT_PORT , DEFAULT_USER_AGENT , KNOWN_HOSTS , MAGIC_PREAMBLE , \
36+ SECURITY_DEFAULT , SECURITY_TRUST_ON_FIRST_USE
37+ from .compat import hex2
3438from .exceptions import ProtocolError
3539from .packstream import Packer , Unpacker
3640
3741
38- DEFAULT_PORT = 7687
39- DEFAULT_USER_AGENT = "neo4j-python/%s" % version
40-
41- MAGIC_PREAMBLE = 0x6060B017
42-
4342# Signature bytes for each message type
4443INIT = b"\x01 " # 0000 0001 // INIT <user_agent>
4544RESET = b"\x0F " # 0000 1111 // RESET
@@ -211,14 +210,18 @@ def __init__(self, sock, **config):
211210 user_agent = config .get ("user_agent" , DEFAULT_USER_AGENT )
212211 if isinstance (user_agent , bytes ):
213212 user_agent = user_agent .decode ("UTF-8" )
213+ self .user_agent = user_agent
214+
215+ # Pick up the server certificate, if any
216+ self .der_encoded_server_certificate = config .get ("der_encoded_server_certificate" )
214217
215218 def on_failure (metadata ):
216219 raise ProtocolError ("Initialisation failed" )
217220
218221 response = Response (self )
219222 response .on_failure = on_failure
220223
221- self .append (INIT , (user_agent ,), response = response )
224+ self .append (INIT , (self . user_agent ,), response = response )
222225 self .send ()
223226 while not response .complete :
224227 self .fetch ()
@@ -313,7 +316,53 @@ def close(self):
313316 self .closed = True
314317
315318
316- def connect (host , port = None , ** config ):
319+ class CertificateStore (object ):
320+
321+ def match_or_trust (self , host , der_encoded_certificate ):
322+ """ Check whether the supplied certificate matches that stored for the
323+ specified host. If it does, return ``True``, if it doesn't, return
324+ ``False``. If no entry for that host is found, add it to the store
325+ and return ``True``.
326+
327+ :arg host:
328+ :arg der_encoded_certificate:
329+ :return:
330+ """
331+ raise NotImplementedError ()
332+
333+
334+ class PersonalCertificateStore (CertificateStore ):
335+
336+ def __init__ (self , path = None ):
337+ self .path = path or KNOWN_HOSTS
338+
339+ def match_or_trust (self , host , der_encoded_certificate ):
340+ base64_encoded_certificate = b64encode (der_encoded_certificate )
341+ if isfile (self .path ):
342+ with open (self .path ) as f_in :
343+ for line in f_in :
344+ known_host , _ , known_cert = line .strip ().partition (":" )
345+ known_cert = known_cert .encode ("utf-8" )
346+ if host == known_host :
347+ return base64_encoded_certificate == known_cert
348+ # First use (no hosts match)
349+ try :
350+ makedirs (dirname (self .path ))
351+ except OSError :
352+ pass
353+ f_out = os_open (self .path , O_CREAT | O_APPEND | O_WRONLY , 0o600 ) # TODO: Windows
354+ if isinstance (host , bytes ):
355+ os_write (f_out , host )
356+ else :
357+ os_write (f_out , host .encode ("utf-8" ))
358+ os_write (f_out , b":" )
359+ os_write (f_out , base64_encoded_certificate )
360+ os_write (f_out , b"\n " )
361+ os_close (f_out )
362+ return True
363+
364+
365+ def connect (host , port = None , ssl_context = None , ** config ):
317366 """ Connect and perform a handshake and return a valid Connection object, assuming
318367 a protocol version can be agreed.
319368 """
@@ -323,14 +372,28 @@ def connect(host, port=None, **config):
323372 if __debug__ : log_info ("~~ [CONNECT] %s %d" , host , port )
324373 s = create_connection ((host , port ))
325374
326- # Secure the connection if so requested
327- try :
328- secure = environ ["NEO4J_SECURE" ]
329- except KeyError :
330- secure = config .get ("secure" , False )
331- if secure :
375+ # Secure the connection if an SSL context has been provided
376+ if ssl_context :
332377 if __debug__ : log_info ("~~ [SECURE] %s" , host )
333- s = secure_socket (s , host )
378+ try :
379+ s = ssl_context .wrap_socket (s , server_hostname = host if HAS_SNI else None )
380+ except SSLError as cause :
381+ error = ProtocolError ("Cannot establish secure connection; %s" % cause .args [1 ])
382+ error .__cause__ = cause
383+ raise error
384+ else :
385+ # Check that the server provides a certificate
386+ der_encoded_server_certificate = s .getpeercert (binary_form = True )
387+ if der_encoded_server_certificate is None :
388+ raise ProtocolError ("When using a secure socket, the server should always provide a certificate" )
389+ security = config .get ("security" , SECURITY_DEFAULT )
390+ if security == SECURITY_TRUST_ON_FIRST_USE :
391+ store = PersonalCertificateStore ()
392+ if not store .match_or_trust (host , der_encoded_server_certificate ):
393+ raise ProtocolError ("Server certificate does not match known certificate for %r; check "
394+ "details in file %r" % (host , KNOWN_HOSTS ))
395+ else :
396+ der_encoded_server_certificate = None
334397
335398 # Send details of the protocol versions supported
336399 supported_versions = [1 , 0 , 0 , 0 ]
@@ -364,4 +427,4 @@ def connect(host, port=None, **config):
364427 s .shutdown (SHUT_RDWR )
365428 s .close ()
366429 else :
367- return Connection (s , ** config )
430+ return Connection (s , der_encoded_server_certificate = der_encoded_server_certificate , ** config )
0 commit comments