Skip to content

Commit 8a02a90

Browse files
committed
TLS settings and tests
1 parent 63ca7a0 commit 8a02a90

File tree

6 files changed

+169
-39
lines changed

6 files changed

+169
-39
lines changed

neo4j/v1/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
2020

21+
from .constants import *
2122
from .session import *
2223
from .typesystem import *

neo4j/v1/compat.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,3 @@ def perf_counter():
9090
from urllib.parse import urlparse
9191
except ImportError:
9292
from urlparse import urlparse
93-
94-
95-
try:
96-
from ssl import SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, HAS_SNI
97-
except ImportError:
98-
from ssl import wrap_socket, PROTOCOL_SSLv23
99-
100-
def secure_socket(s, host):
101-
return wrap_socket(s, ssl_version=PROTOCOL_SSLv23)
102-
103-
else:
104-
105-
def secure_socket(s, host):
106-
ssl_context = SSLContext(PROTOCOL_SSLv23)
107-
ssl_context.options |= OP_NO_SSLv2
108-
return ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None)

neo4j/v1/connection.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,24 @@
2121

2222
from __future__ import division
2323

24+
from base64 import b64encode
2425
from collections import deque
2526
from io import BytesIO
2627
import logging
27-
from os import environ
28+
from os import makedirs, open as os_open, write as os_write, close as os_close, O_CREAT, O_APPEND, O_WRONLY
29+
from os.path import dirname, isfile
2830
from select import select
2931
from socket import create_connection, SHUT_RDWR
32+
from ssl import HAS_SNI, SSLError
3033
from struct import pack as struct_pack, unpack as struct_unpack, unpack_from as struct_unpack_from
3134

32-
from ..meta import version
33-
from .compat import hex2, secure_socket
35+
from .constants import DEFAULT_PORT, DEFAULT_USER_AGENT, KNOWN_HOSTS, MAGIC_PREAMBLE, \
36+
SECURITY_NONE, SECURITY_TRUST_ON_FIRST_USE
37+
from .compat import hex2
3438
from .exceptions import ProtocolError
3539
from .packstream import Packer, Unpacker
3640

3741

38-
DEFAULT_PORT = 7687
39-
DEFAULT_USER_AGENT = "neo4j-python/%s" % version
40-
41-
MAGIC_PREAMBLE = 0x6060B017
42-
4342
# Signature bytes for each message type
4443
INIT = b"\x01" # 0000 0001 // INIT <user_agent>
4544
RESET = b"\x0F" # 0000 1111 // RESET
@@ -211,14 +210,18 @@ def __init__(self, sock, **config):
211210
user_agent = config.get("user_agent", DEFAULT_USER_AGENT)
212211
if isinstance(user_agent, bytes):
213212
user_agent = user_agent.decode("UTF-8")
213+
self.user_agent = user_agent
214+
215+
# Pick up the server certificate, if any
216+
self.der_encoded_server_certificate = config.get("der_encoded_server_certificate")
214217

215218
def on_failure(metadata):
216219
raise ProtocolError("Initialisation failed")
217220

218221
response = Response(self)
219222
response.on_failure = on_failure
220223

221-
self.append(INIT, (user_agent,), response=response)
224+
self.append(INIT, (self.user_agent,), response=response)
222225
self.send()
223226
while not response.complete:
224227
self.fetch()
@@ -313,7 +316,39 @@ def close(self):
313316
self.closed = True
314317

315318

316-
def connect(host, port=None, **config):
319+
def verify_certificate(host, der_encoded_certificate):
320+
base64_encoded_certificate = b64encode(der_encoded_certificate)
321+
if isfile(KNOWN_HOSTS):
322+
with open(KNOWN_HOSTS) as f_in:
323+
for line in f_in:
324+
known_host, _, known_cert = line.strip().partition(":")
325+
if host == known_host:
326+
if base64_encoded_certificate == known_cert:
327+
# Certificate match
328+
return
329+
else:
330+
# Certificate mismatch
331+
print(base64_encoded_certificate)
332+
print(known_cert)
333+
raise ProtocolError("Server certificate does not match known certificate for %r; check "
334+
"details in file %r" % (host, KNOWN_HOSTS))
335+
# First use (no hosts match)
336+
try:
337+
makedirs(dirname(KNOWN_HOSTS))
338+
except OSError:
339+
pass
340+
f_out = os_open(KNOWN_HOSTS, O_CREAT | O_APPEND | O_WRONLY, 0o600) # TODO: Windows
341+
if isinstance(host, bytes):
342+
os_write(f_out, host)
343+
else:
344+
os_write(f_out, host.encode("utf-8"))
345+
os_write(f_out, b":")
346+
os_write(f_out, base64_encoded_certificate)
347+
os_write(f_out, b"\n")
348+
os_close(f_out)
349+
350+
351+
def connect(host, port=None, ssl_context=None, **config):
317352
""" Connect and perform a handshake and return a valid Connection object, assuming
318353
a protocol version can be agreed.
319354
"""
@@ -323,10 +358,25 @@ def connect(host, port=None, **config):
323358
if __debug__: log_info("~~ [CONNECT] %s %d", host, port)
324359
s = create_connection((host, port))
325360

326-
# Secure the connection if so requested
327-
if config.get("secure", False):
361+
# Secure the connection if an SSL context has been provided
362+
if ssl_context:
328363
if __debug__: log_info("~~ [SECURE] %s", host)
329-
s = secure_socket(s, host)
364+
try:
365+
s = ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None)
366+
except SSLError as cause:
367+
error = ProtocolError("Cannot establish secure connection; %s" % cause.args[1])
368+
error.__cause__ = cause
369+
raise error
370+
else:
371+
# Check that the server provides a certificate
372+
der_encoded_server_certificate = s.getpeercert(binary_form=True)
373+
if der_encoded_server_certificate is None:
374+
raise ProtocolError("When using a secure socket, the server should always provide a certificate")
375+
security = config.get("security", SECURITY_NONE)
376+
if security == SECURITY_TRUST_ON_FIRST_USE:
377+
verify_certificate(host, der_encoded_server_certificate)
378+
else:
379+
der_encoded_server_certificate = None
330380

331381
# Send details of the protocol versions supported
332382
supported_versions = [1, 0, 0, 0]
@@ -360,4 +410,4 @@ def connect(host, port=None, **config):
360410
s.shutdown(SHUT_RDWR)
361411
s.close()
362412
else:
363-
return Connection(s, **config)
413+
return Connection(s, der_encoded_server_certificate=der_encoded_server_certificate, **config)

neo4j/v1/constants.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
# Copyright (c) 2002-2016 "Neo Technology,"
5+
# Network Engine for Objects in Lund AB [http://neotechnology.com]
6+
#
7+
# This file is part of Neo4j.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
22+
from os.path import expanduser, join
23+
24+
from ..meta import version
25+
26+
27+
DEFAULT_PORT = 7687
28+
DEFAULT_USER_AGENT = "neo4j-python/%s" % version
29+
30+
KNOWN_HOSTS = join(expanduser("~"), ".neo4j", "known_hosts")
31+
32+
MAGIC_PREAMBLE = 0x6060B017
33+
34+
SECURITY_NONE = 0
35+
SECURITY_TRUST_ON_FIRST_USE = 1
36+
SECURITY_VERIFIED = 2

neo4j/v1/session.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ class which can be used to obtain `Driver` instances that are used for
2929
from __future__ import division
3030

3131
from collections import deque, namedtuple
32+
from ssl import SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED, Purpose
3233

3334
from .compat import integer, string, urlparse
3435
from .connection import connect, Response, RUN, PULL_ALL
36+
from .constants import SECURITY_NONE, SECURITY_VERIFIED
3537
from .exceptions import CypherError, ResultError
3638
from .typesystem import hydrated
3739

@@ -77,6 +79,16 @@ def __init__(self, url, **config):
7779
self.config = config
7880
self.max_pool_size = config.get("max_pool_size", DEFAULT_MAX_POOL_SIZE)
7981
self.session_pool = deque()
82+
self.security = security = config.get("security", SECURITY_NONE)
83+
if security > SECURITY_NONE:
84+
ssl_context = SSLContext(PROTOCOL_SSLv23)
85+
ssl_context.options |= OP_NO_SSLv2
86+
if security >= SECURITY_VERIFIED:
87+
ssl_context.verify_mode = CERT_REQUIRED
88+
ssl_context.load_default_certs(Purpose.SERVER_AUTH)
89+
self.ssl_context = ssl_context
90+
else:
91+
self.ssl_context = None
8092

8193
def session(self):
8294
""" Create a new session based on the graph database details
@@ -425,7 +437,7 @@ class Session(object):
425437

426438
def __init__(self, driver):
427439
self.driver = driver
428-
self.connection = connect(driver.host, driver.port, **driver.config)
440+
self.connection = connect(driver.host, driver.port, driver.ssl_context, **driver.config)
429441
self.transaction = None
430442
self.last_cursor = None
431443

@@ -654,6 +666,7 @@ def __eq__(self, other):
654666
def __ne__(self, other):
655667
return not self.__eq__(other)
656668

669+
657670
def record(obj):
658671
""" Obtain an immutable record for the given object
659672
(either by calling obj.__record__() or by copying out the record data)

test/test_session.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,22 @@
1919
# limitations under the License.
2020

2121

22+
from os import remove, rename
23+
from os.path import isfile
2224
from socket import socket
2325
from ssl import SSLSocket
2426
from unittest import TestCase
2527

2628
from mock import patch
27-
from neo4j.v1.exceptions import ResultError
28-
from neo4j.v1.session import GraphDatabase, CypherError, Record, record
29+
from neo4j.v1.constants import KNOWN_HOSTS, SECURITY_NONE, SECURITY_TRUST_ON_FIRST_USE, SECURITY_VERIFIED
30+
from neo4j.v1.exceptions import CypherError, ResultError
31+
from neo4j.v1.session import GraphDatabase, Record, record
2932
from neo4j.v1.typesystem import Node, Relationship, Path
3033

3134

35+
KNOWN_HOSTS_BACKUP = KNOWN_HOSTS + ".backup"
36+
37+
3238
class DriverTestCase(TestCase):
3339

3440
def test_healthy_session_will_be_returned_to_the_pool_on_close(self):
@@ -82,17 +88,57 @@ def test_sessions_are_not_reused_if_still_in_use(self):
8288
session_1.close()
8389
assert session_1 is not session_2
8490

85-
def test_insecure_session_uses_insecure_socket(self):
86-
driver = GraphDatabase.driver("bolt://localhost", secure=False)
91+
92+
class SecurityTestCase(TestCase):
93+
94+
def setUp(self):
95+
if isfile(KNOWN_HOSTS):
96+
rename(KNOWN_HOSTS, KNOWN_HOSTS_BACKUP)
97+
98+
def tearDown(self):
99+
if isfile(KNOWN_HOSTS_BACKUP):
100+
rename(KNOWN_HOSTS_BACKUP, KNOWN_HOSTS)
101+
102+
def test_default_session_uses_security_none(self):
103+
# TODO: verify this is the correct default (maybe TOFU?)
104+
driver = GraphDatabase.driver("bolt://localhost")
105+
assert driver.security == SECURITY_NONE
106+
107+
def test_insecure_session_uses_normal_socket(self):
108+
driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_NONE)
109+
session = driver.session()
110+
connection = session.connection
111+
assert isinstance(connection.channel.socket, socket)
112+
assert connection.der_encoded_server_certificate is None
113+
session.close()
114+
115+
def test_tofu_session_uses_secure_socket(self):
116+
driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_TRUST_ON_FIRST_USE)
87117
session = driver.session()
88-
assert isinstance(session.connection.channel.socket, socket)
118+
connection = session.connection
119+
assert isinstance(connection.channel.socket, SSLSocket)
120+
assert connection.der_encoded_server_certificate is not None
89121
session.close()
90122

91-
def test_secure_session_uses_secure_socket(self):
92-
driver = GraphDatabase.driver("bolt://localhost", secure=True)
123+
def test_tofu_session_trusts_certificate_after_first_use(self):
124+
driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_TRUST_ON_FIRST_USE)
93125
session = driver.session()
94-
assert isinstance(session.connection.channel.socket, SSLSocket)
126+
connection = session.connection
127+
certificate = connection.der_encoded_server_certificate
95128
session.close()
129+
session = driver.session()
130+
connection = session.connection
131+
assert connection.der_encoded_server_certificate == certificate
132+
session.close()
133+
134+
# TODO: Find a way to run this test
135+
# def test_verified_session_uses_secure_socket(self):
136+
# driver = GraphDatabase.driver("bolt://localhost", security=SECURITY_VERIFIED)
137+
# session = driver.session()
138+
# connection = session.connection
139+
# assert isinstance(connection.channel.socket, SSLSocket)
140+
# assert connection.der_encoded_server_certificate is not None
141+
# session.close()
96142

97143

98144
class RunTestCase(TestCase):

0 commit comments

Comments
 (0)