Skip to content

Commit 7452a25

Browse files
whitelist of prefix when doing hello messsage (#442)
1 parent 43f571e commit 7452a25

File tree

8 files changed

+69
-22
lines changed

8 files changed

+69
-22
lines changed

neo4j/api.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def values(self):
149149
"""
150150
return self._values
151151

152+
152153
class ServerInfo:
153154

154155
def __init__(self, address, protocol_version):
@@ -178,9 +179,9 @@ def version_info(self):
178179
# Note: Confirm that the server agent string begins with "Neo4j/" and fail gracefully if not.
179180
# This is intended to help prevent drivers working for non-genuine Neo4j instances.
180181

181-
neo4j, _, value = self.agent.partition("/")
182+
prefix, _, value = self.agent.partition("/")
182183
try:
183-
assert neo4j == "Neo4j"
184+
assert prefix in ["Neo4j"]
184185
except AssertionError:
185186
raise DriverError("Server name does not start with Neo4j/")
186187

@@ -198,6 +199,15 @@ def version_info(self):
198199
pass
199200
return tuple(value)
200201

202+
def _update_metadata(self, metadata):
203+
"""Internal, update the metadata and perform check that the prefix is whitelisted by calling self.version()
204+
205+
:param metadata: metadata from the server
206+
:type metadata: dict
207+
"""
208+
self.metadata.update(metadata)
209+
_ = self.version_info()
210+
201211

202212
class Version(tuple):
203213

neo4j/io/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,14 @@ def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_
218218
supported_versions = Bolt.protocol_handlers().keys()
219219
raise BoltHandshakeError("The Neo4J server does not support communication with this driver. This driver have support for Bolt Protocols {}".format(supported_versions), address=address, request_data=handshake, response_data=data)
220220

221-
connection.hello()
221+
try:
222+
connection.hello()
223+
except Exception as error:
224+
log.debug("[#%04X] C: <CLOSE> %s", s.getsockname()[1], str(error))
225+
s.shutdown(SHUT_RDWR)
226+
s.close()
227+
raise error
228+
222229
return connection
223230

224231
@property

neo4j/io/_bolt3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def hello(self):
143143
logged_headers["credentials"] = "*******"
144144
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
145145
self._append(b"\x01", (headers,),
146-
response=InitResponse(self, on_success=self.server_info.metadata.update))
146+
response=InitResponse(self, on_success=self.server_info._update_metadata))
147147
self.send_all()
148148
self.fetch_all()
149149

neo4j/io/_bolt4x0.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def hello(self):
143143
logged_headers["credentials"] = "*******"
144144
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
145145
self._append(b"\x01", (headers,),
146-
response=InitResponse(self, on_success=self.server_info.metadata.update))
146+
response=InitResponse(self, on_success=self.server_info._update_metadata))
147147
self.send_all()
148148
self.fetch_all()
149149

neo4j/io/_bolt4x1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,7 @@ def hello(self):
143143
if "credentials" in logged_headers:
144144
logged_headers["credentials"] = "*******"
145145
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
146-
self._append(b"\x01", (headers,),
147-
response=InitResponse(self, on_success=self.server_info.metadata.update))
146+
self._append(b"\x01", (headers,), response=InitResponse(self, on_success=self.server_info._update_metadata))
148147
self.send_all()
149148
self.fetch_all()
150149

neo4j/work/simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def run(self, query, parameters=None, **kwparameters):
192192
:type parameters: dict
193193
:param kwparameters: additional keyword parameters
194194
:returns: a new :class:`neo4j.Result` object
195-
:type: :class:`neo4j.Result`
195+
:rtype: :class:`neo4j.Result`
196196
"""
197197
if not query:
198198
raise ValueError("Cannot run an empty query")
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
!: BOLT 4.1
2+
!: PORT 9001
3+
4+
C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test", "routing": {"address": "localhost:9001"}}
5+
S: SUCCESS {"server": "Bogus/4.1.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"}
6+
C: RUN "RETURN 1 AS x" {} {"mode": "r"}
7+
PULL {"n": -1}
8+
S: SUCCESS {"fields": ["x"]}
9+
RECORD [1]
10+
SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 5, "db": "neo4j"}

tests/stub/test_directdriver.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from neo4j.exceptions import (
2525
ServiceUnavailable,
2626
ConfigurationError,
27+
DriverError,
2728
)
2829
from neo4j._exceptions import (
2930
BoltHandshakeError,
@@ -88,27 +89,47 @@ def test_bolt_uri_constructs_bolt_driver(driver_info, test_script):
8889

8990

9091
@pytest.mark.parametrize(
91-
"test_script, test_expected",
92+
"test_script",
9293
[
93-
# ("v1/empty_explicit_hello_goodbye.script", ServiceUnavailable), # skip: cant close stub server gracefully
94-
# ("v2/empty_explicit_hello_goodbye.script", ServiceUnavailable), # skip: cant close stub server gracefully
95-
("v3/empty_explicit_hello_goodbye.script", None),
96-
("v4x0/empty_explicit_hello_goodbye.script", None),
97-
("v4x1/empty_explicit_hello_goodbye.script", None),
94+
"v3/empty_explicit_hello_goodbye.script",
95+
"v4x0/empty_explicit_hello_goodbye.script",
96+
"v4x1/empty_explicit_hello_goodbye.script",
9897
]
9998
)
100-
def test_direct_driver_handshake_negotiation(driver_info, test_script, test_expected):
99+
def test_direct_driver_handshake_negotiation(driver_info, test_script):
101100
# python -m pytest tests/stub/test_directdriver.py -s -v -k test_direct_driver_handshake_negotiation
102101
with StubCluster(test_script):
103102
uri = "bolt://localhost:9001"
104-
if test_expected:
105-
with pytest.raises(test_expected) as error:
106-
driver = GraphDatabase.driver(uri, auth=driver_info["auth_token"], **driver_config)
107-
assert isinstance(error.value.__cause__, BoltHandshakeError)
108-
else:
109-
driver = GraphDatabase.driver(uri, auth=driver_info["auth_token"], **driver_config)
103+
driver = GraphDatabase.driver(uri, auth=driver_info["auth_token"], **driver_config)
104+
assert isinstance(driver, BoltDriver)
105+
driver.close()
106+
107+
108+
@pytest.mark.parametrize(
109+
"test_script, test_expected",
110+
[
111+
("v3/return_1_port_9001.script", "Neo4j/3.0.0"),
112+
("v4x0/return_1_port_9001.script", "Neo4j/4.0.0"),
113+
("v4x1/return_1_port_9001_bogus_server.script", DriverError),
114+
]
115+
)
116+
def test_return_1_as_x(driver_info, test_script, test_expected):
117+
# python -m pytest tests/stub/test_directdriver.py -s -v -k test_return_1_as_x
118+
with StubCluster(test_script):
119+
uri = "bolt://localhost:9001"
120+
try:
121+
driver = GraphDatabase.driver(uri, auth=driver_info["auth_token"], user_agent="test")
110122
assert isinstance(driver, BoltDriver)
123+
with driver.session(default_access_mode=READ_ACCESS, fetch_size=-1) as session:
124+
result = session.run("RETURN 1 AS x")
125+
value = result.single().value()
126+
assert value == 1
127+
summary = result.consume()
128+
assert summary.server.agent == test_expected
129+
assert summary.server.agent.startswith("Neo4j")
111130
driver.close()
131+
except DriverError as error:
132+
assert isinstance(error, test_expected)
112133

113134

114135
def test_direct_driver_with_wrong_port(driver_info):
@@ -128,7 +149,7 @@ def test_direct_driver_with_wrong_port(driver_info):
128149
def test_direct_verify_connectivity(driver_info, test_script, test_expected):
129150
# python -m pytest tests/stub/test_directdriver.py -s -v -k test_direct_verify_connectivity
130151
with StubCluster(test_script):
131-
uri = "bolt://127.0.0.1:9001"
152+
uri = "bolt://localhost:9001"
132153
with GraphDatabase.driver(uri, auth=driver_info["auth_token"], **driver_config) as driver:
133154
assert isinstance(driver, BoltDriver)
134155
assert driver.verify_connectivity(default_access_mode=READ_ACCESS) == test_expected

0 commit comments

Comments
 (0)