Skip to content

Commit 1cf92d7

Browse files
author
Zhen
committed
Support routing context from bolt routing uri
1 parent 64cbfa5 commit 1cf92d7

File tree

6 files changed

+122
-54
lines changed

6 files changed

+122
-54
lines changed

neo4j/addressing.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,25 @@ def parse(cls, string, default_port=0):
7979
"""
8080
return cls.from_uri("//{}".format(string), default_port)
8181

82+
@classmethod
83+
def parse_routing_context(cls, uri):
84+
query = urlparse(uri).query
85+
if not query:
86+
return {}
87+
88+
context = {}
89+
parameters = [x for x in query.split('&') if x]
90+
for keyValue in parameters:
91+
pair = keyValue.split('=')
92+
if len(pair) != 2 or not pair[0] or not pair[1]:
93+
raise ValueError("Invalid parameters: '" + keyValue + "' in URI '" + uri + "'.")
94+
key = pair[0]
95+
value = pair[1]
96+
if key in context:
97+
raise ValueError("Duplicated query parameters with key '" + key + "' found in URL '" + uri + "'")
98+
context[key] = value
99+
return context
100+
82101

83102
def resolve(socket_address):
84103
try:

neo4j/util.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,25 @@
2525
from sys import stdout
2626

2727

28+
class ServerVersion(object):
29+
def __init__(self, product, version_tuple, tags_tuple):
30+
self.product = product
31+
self.version_tuple = version_tuple
32+
self.tags_tuple = tags_tuple
33+
34+
def at_least_version(self, major, minor):
35+
return self.version_tuple >= (major, minor)
36+
37+
@classmethod
38+
def from_str(cls, full_version):
39+
if full_version is None:
40+
return ServerVersion("Neo4j", (3, 0), ())
41+
product, _, tagged_version = full_version.partition("/")
42+
tags = tagged_version.split("-")
43+
version = map(int, tags[0].split("."))
44+
return ServerVersion(product, tuple(version), tuple(tags[1:]))
45+
46+
2847
class ColourFormatter(logging.Formatter):
2948
""" Colour formatter for pretty log output.
3049
"""

neo4j/v1/routing.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from neo4j.v1.exceptions import SessionExpired
3131
from neo4j.v1.security import SecurityPlan
3232
from neo4j.v1.session import BoltSession
33+
from neo4j.util import ServerVersion
3334

3435

3536
class RoundRobinSet(MutableSet):
@@ -152,14 +153,23 @@ class RoutingConnectionPool(ConnectionPool):
152153
""" Connection pool with routing table.
153154
"""
154155

155-
routing_info_procedure = "dbms.cluster.routing.getServers"
156+
call_get_servers = "CALL dbms.cluster.routing.getServers"
157+
get_routing_table_param = "context"
158+
call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({" + get_routing_table_param + "})"
156159

157-
def __init__(self, connector, initial_address, *routers):
160+
def __init__(self, connector, initial_address, routing_context, *routers):
158161
super(RoutingConnectionPool, self).__init__(connector)
159162
self.initial_address = initial_address
163+
self.routing_context = routing_context
160164
self.routing_table = RoutingTable(routers)
161165
self.refresh_lock = Lock()
162166

167+
def routing_info_procedure(self, connection):
168+
if ServerVersion.from_str(connection.server.version).at_least_version(3, 2):
169+
return self.call_get_routing_table, {self.get_routing_table_param: self.routing_context}
170+
else:
171+
return self.call_get_servers, None
172+
163173
def fetch_routing_info(self, address):
164174
""" Fetch raw routing info from a given router address.
165175
@@ -170,8 +180,9 @@ def fetch_routing_info(self, address):
170180
if routing support is broken
171181
"""
172182
try:
173-
with BoltSession(lambda _: self.acquire_direct(address)) as session:
174-
return list(session.run("CALL %s" % self.routing_info_procedure))
183+
connection = self.acquire_direct(address)
184+
with BoltSession(lambda _: connection) as session:
185+
return list(session.run(*self.routing_info_procedure(connection)))
175186
except CypherError as error:
176187
if error.code == "Neo.ClientError.Procedure.ProcedureNotFound":
177188
raise ServiceUnavailable("Server {!r} does not support routing".format(address))
@@ -313,6 +324,7 @@ def __init__(self, uri, **config):
313324
self.initial_address = initial_address = SocketAddress.from_uri(uri, DEFAULT_PORT)
314325
self.security_plan = security_plan = SecurityPlan.build(**config)
315326
self.encrypted = security_plan.encrypted
327+
routing_context = SocketAddress.parse_routing_context(uri)
316328
if not security_plan.routing_compatible:
317329
# this error message is case-specific as there is only one incompatible
318330
# scenario right now
@@ -321,7 +333,7 @@ def __init__(self, uri, **config):
321333
def connector(a):
322334
return connect(a, security_plan.ssl_context, **config)
323335

324-
pool = RoutingConnectionPool(connector, initial_address, *resolve(initial_address))
336+
pool = RoutingConnectionPool(connector, initial_address, routing_context *resolve(initial_address))
325337
try:
326338
pool.update_routing_table()
327339
except:

test/integration/tools.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from boltkit.controller import WindowsController, UnixController
3434

3535
from neo4j.v1 import GraphDatabase, AuthError
36+
from neo4j.util import ServerVersion
3637

3738
from test.env import NEO4J_SERVER_PACKAGE, NEO4J_USER, NEO4J_PASSWORD
3839

@@ -89,17 +90,11 @@ def server_version_info(cls):
8990
with GraphDatabase.driver(cls.bolt_uri, auth=cls.auth_token) as driver:
9091
with driver.session() as session:
9192
full_version = session.run("RETURN 1").summary().server.version
92-
if full_version is None:
93-
return "Neo4j", (3, 0), ()
94-
product, _, tagged_version = full_version.partition("/")
95-
tags = tagged_version.split("-")
96-
version = map(int, tags[0].split("."))
97-
return product, tuple(version), tuple(tags[1:])
93+
return ServerVersion.from_str(full_version)
9894

9995
@classmethod
10096
def at_least_version(cls, major, minor):
101-
_, server_version, _ = cls.server_version_info()
102-
return server_version >= (major, minor)
97+
return cls.server_version_info().at_least_version(major, minor);
10398

10499
@classmethod
105100
def delete_known_hosts_file(cls):

0 commit comments

Comments
 (0)