Skip to content

Commit 8285006

Browse files
version sniffing fix (#392)
Version sniffin are returning the Bolt Protocol Version for Bolt Protocol 4.0+ the older versions are still using the ServerInfo.agen string to determine the server version.
1 parent 9894561 commit 8285006

File tree

6 files changed

+62
-22
lines changed

6 files changed

+62
-22
lines changed

neo4j/__init__.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -348,18 +348,11 @@ def supports_multi_db(self):
348348
:return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false.
349349
:rtype: bool
350350
"""
351-
from neo4j.io._bolt4x0 import Bolt4x0
352-
353-
multi_database = False
354351
cx = self._pool.acquire(access_mode=READ_ACCESS, timeout=self._pool.workspace_config.connection_acquisition_timeout, database=self._pool.workspace_config.database)
355-
356-
# TODO: This logic should be inside the Bolt subclasses, because it can change depending on Bolt Protocol Version.
357-
if cx.PROTOCOL_VERSION >= Bolt4x0.PROTOCOL_VERSION and cx.server_info.version_info() >= Version(4, 0, 0):
358-
multi_database = True
359-
352+
support = cx.supports_multiple_databases
360353
self._pool.release(cx)
361354

362-
return multi_database
355+
return support
363356

364357

365358
class BoltDriver(Direct, Driver):

neo4j/api.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
parse_qs,
2424
)
2525
from.exceptions import (
26+
DriverError,
2627
ConfigurationError,
2728
)
2829

@@ -149,12 +150,38 @@ def __init__(self, address, protocol_version):
149150

150151
@property
151152
def agent(self):
153+
"""The server agent string the server responded with.
154+
155+
:return: Server agent string
156+
:rtype: string
157+
"""
158+
# Example "Neo4j/4.0.5"
159+
# Example "Neo4j/4"
152160
return self.metadata.get("server")
153161

154162
def version_info(self):
163+
"""Return the server version if available.
164+
165+
:return: Server Version or None
166+
:rtype: tuple
167+
"""
155168
if not self.agent:
156169
return None
157-
_, _, value = self.agent.partition("/")
170+
# Note: Confirm that the server agent string begins with "Neo4j/" and fail gracefully if not.
171+
# This is intended to help prevent drivers working for non-genuine Neo4j instances.
172+
173+
neo4j, _, value = self.agent.partition("/")
174+
try:
175+
assert neo4j == "Neo4j"
176+
except AssertionError:
177+
raise DriverError("Server name does not start with Neo4j/")
178+
179+
try:
180+
if self.protocol_version >= (4, 0):
181+
return self.protocol_version
182+
except TypeError:
183+
pass
184+
158185
value = value.replace("-", ".").split(".")
159186
for i, v in enumerate(value):
160187
try:

neo4j/io/_bolt3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No
8888
self._max_connection_lifetime = max_connection_lifetime
8989
self._creation_timestamp = perf_counter()
9090
self.supports_multiple_results = False
91+
self.supports_multiple_databases = False
9192
self._is_reset = True
9293

9394
# Determine the user agent

neo4j/io/_bolt4x0.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No
8787
self._max_connection_lifetime = max_connection_lifetime # self.pool_config.max_connection_lifetime
8888
self._creation_timestamp = perf_counter()
8989
self.supports_multiple_results = True
90+
self.supports_multiple_databases = True
9091
self._is_reset = True
9192

9293
# Determine the user agent

tests/integration/test_bolt_driver.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
unit_of_work,
3232
Transaction,
3333
Result,
34+
ServerInfo,
3435
)
3536
from neo4j.exceptions import (
3637
ServiceUnavailable,
@@ -39,6 +40,7 @@
3940
ClientError,
4041
)
4142
from neo4j._exceptions import BoltHandshakeError
43+
from neo4j.io._bolt3 import Bolt3
4244

4345
# python -m pytest tests/integration/test_bolt_driver.py -s -v
4446

@@ -139,21 +141,28 @@ def test_supports_multi_db(bolt_uri, auth):
139141

140142
with driver.session() as session:
141143
result = session.run("RETURN 1")
142-
value = result.single().value() # Consumes the result
144+
_ = result.single().value() # Consumes the result
143145
summary = result.consume()
144146
server_info = summary.server
145147

148+
assert isinstance(summary, ResultSummary)
149+
assert isinstance(server_info, ServerInfo)
150+
assert server_info.version_info() is not None
151+
assert isinstance(server_info.protocol_version, Version)
152+
146153
result = driver.supports_multi_db()
147154
driver.close()
148155

149-
if server_info.version_info() >= Version(4, 0, 0) and server_info.protocol_version >= Version(4, 0):
150-
assert result is True
151-
assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server
152-
assert summary.query_type == "r"
153-
else:
156+
if server_info.protocol_version == Bolt3.PROTOCOL_VERSION:
154157
assert result is False
155158
assert summary.database is None
156159
assert summary.query_type == "r"
160+
else:
161+
assert result is True
162+
assert server_info.version_info() >= Version(4, 0)
163+
assert server_info.protocol_version >= Version(4, 0)
164+
assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server
165+
assert summary.query_type == "r"
157166

158167

159168
def test_test_multi_db_specify_database(bolt_uri, auth):

tests/integration/test_neo4j_driver.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Version,
2828
READ_ACCESS,
2929
ResultSummary,
30+
ServerInfo,
3031
)
3132
from neo4j.exceptions import (
3233
ServiceUnavailable,
@@ -39,6 +40,7 @@
3940
from neo4j.conf import (
4041
RoutingConfig,
4142
)
43+
from neo4j.io._bolt3 import Bolt3
4244

4345
# python -m pytest tests/integration/test_neo4j_driver.py -s -v
4446

@@ -72,21 +74,28 @@ def test_supports_multi_db(neo4j_uri, auth, target):
7274

7375
with driver.session() as session:
7476
result = session.run("RETURN 1")
75-
value = result.single().value() # Consumes the result
77+
_ = result.single().value() # Consumes the result
7678
summary = result.consume()
7779
server_info = summary.server
7880

81+
assert isinstance(summary, ResultSummary)
82+
assert isinstance(server_info, ServerInfo)
83+
assert server_info.version_info() is not None
84+
assert isinstance(server_info.protocol_version, Version)
85+
7986
result = driver.supports_multi_db()
8087
driver.close()
8188

82-
if server_info.version_info() >= Version(4, 0, 0) and server_info.protocol_version >= Version(4, 0):
83-
assert result is True
84-
assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server
85-
assert summary.query_type == "r"
86-
else:
89+
if server_info.protocol_version == Bolt3.PROTOCOL_VERSION:
8790
assert result is False
8891
assert summary.database is None
8992
assert summary.query_type == "r"
93+
else:
94+
assert result is True
95+
assert server_info.version_info() >= Version(4, 0)
96+
assert server_info.protocol_version >= Version(4, 0)
97+
assert summary.database == "neo4j" # This is the default database name if not set explicitly on the Neo4j Server
98+
assert summary.query_type == "r"
9099

91100

92101
def test_test_multi_db_specify_database(neo4j_uri, auth, target):

0 commit comments

Comments
 (0)