Skip to content

Commit 343b168

Browse files
committed
Remove duplication between Bolt4xN classes
1 parent 4a77f28 commit 343b168

File tree

5 files changed

+47
-458
lines changed

5 files changed

+47
-458
lines changed

neo4j/io/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,7 @@ def protocol_handlers(cls, protocol_version=None):
134134

135135
# Carry out Bolt subclass imports locally to avoid circular dependency issues.
136136
from neo4j.io._bolt3 import Bolt3
137-
from neo4j.io._bolt4x0 import Bolt4x0
138-
from neo4j.io._bolt4x1 import Bolt4x1
137+
from neo4j.io._bolt4 import Bolt4x0, Bolt4x1
139138

140139
handlers = {
141140
Bolt3.PROTOCOL_VERSION: Bolt3,
@@ -204,11 +203,11 @@ def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_
204203
connection = Bolt3(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent, routing_context=routing_context)
205204
elif pool_config.protocol_version == (4, 0):
206205
# Carry out Bolt subclass imports locally to avoid circular dependency issues.
207-
from neo4j.io._bolt4x0 import Bolt4x0
206+
from neo4j.io._bolt4 import Bolt4x0
208207
connection = Bolt4x0(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent, routing_context=routing_context)
209208
elif pool_config.protocol_version == (4, 1):
210209
# Carry out Bolt subclass imports locally to avoid circular dependency issues.
211-
from neo4j.io._bolt4x1 import Bolt4x1
210+
from neo4j.io._bolt4 import Bolt4x1
212211
connection = Bolt4x1(address, s, pool_config.max_connection_lifetime, auth=auth, user_agent=pool_config.user_agent, routing_context=routing_context)
213212
else:
214213
log.debug("[#%04X] S: <CLOSE>", s.getpeername()[1])
@@ -673,8 +672,7 @@ def fetch_routing_info(self, *, address, timeout, database):
673672

674673
# Carry out Bolt subclass imports locally to avoid circular dependency issues.
675674
from neo4j.io._bolt3 import Bolt3
676-
from neo4j.io._bolt4x0 import Bolt4x0
677-
from neo4j.io._bolt4x1 import Bolt4x1
675+
from neo4j.io._bolt4 import Bolt4x0, Bolt4x1
678676

679677
from neo4j.api import (
680678
SYSTEM_DATABASE,

neo4j/io/_bolt3.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@
6363

6464

6565
class Bolt3(Bolt):
66+
""" Protocol handler for Bolt 3.
67+
68+
This is supported by Neo4j versions 3.5, 4.0 and 4.1.
69+
"""
6670

6771
PROTOCOL_VERSION = Version(3, 0)
6872

@@ -81,7 +85,7 @@ class Bolt3(Bolt):
8185
def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None):
8286
self.unresolved_address = unresolved_address
8387
self.socket = sock
84-
self.server_info = ServerInfo(Address(sock.getpeername()), Bolt3.PROTOCOL_VERSION)
88+
self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION)
8589
self.outbox = Outbox()
8690
self.inbox = Inbox(self.socket, on_error=self._set_defunct)
8791
self.packer = Packer(self.outbox)
@@ -136,8 +140,13 @@ def local_port(self):
136140
except IOError:
137141
return 0
138142

143+
def get_base_headers(self):
144+
return {
145+
"user_agent": self.user_agent,
146+
}
147+
139148
def hello(self):
140-
headers = {"user_agent": self.user_agent}
149+
headers = self.get_base_headers()
141150
headers.update(self.auth_dict)
142151
logged_headers = dict(headers)
143152
if "credentials" in logged_headers:

neo4j/io/_bolt4x0.py renamed to neo4j/io/_bolt4.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@
6363

6464

6565
class Bolt4x0(Bolt):
66+
""" Protocol handler for Bolt 4.0.
67+
68+
This is supported by Neo4j versions 4.0 and 4.1.
69+
"""
6670

6771
PROTOCOL_VERSION = Version(4, 0)
6872

@@ -81,7 +85,7 @@ class Bolt4x0(Bolt):
8185
def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None):
8286
self.unresolved_address = unresolved_address
8387
self.socket = sock
84-
self.server_info = ServerInfo(Address(sock.getpeername()), Bolt4x0.PROTOCOL_VERSION)
88+
self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION)
8589
self.outbox = Outbox()
8690
self.inbox = Inbox(self.socket, on_error=self._set_defunct)
8791
self.packer = Packer(self.outbox)
@@ -136,8 +140,13 @@ def local_port(self):
136140
except IOError:
137141
return 0
138142

143+
def get_base_headers(self):
144+
return {
145+
"user_agent": self.user_agent,
146+
}
147+
139148
def hello(self):
140-
headers = {"user_agent": self.user_agent}
149+
headers = self.get_base_headers()
141150
headers.update(self.auth_dict)
142151
logged_headers = dict(headers)
143152
if "credentials" in logged_headers:
@@ -445,3 +454,23 @@ def closed(self):
445454

446455
def defunct(self):
447456
return self._defunct
457+
458+
459+
class Bolt4x1(Bolt4x0):
460+
""" Protocol handler for Bolt 4.1.
461+
462+
This is supported by Neo4j version 4.1.
463+
"""
464+
465+
PROTOCOL_VERSION = Version(4, 1)
466+
467+
def get_base_headers(self):
468+
""" Bolt 4.1 passes the routing context, originally taken from
469+
the URI, into the connection initialisation message. This
470+
enables server-side routing to propagate the same behaviour
471+
through its driver.
472+
"""
473+
return {
474+
"user_agent": self.user_agent,
475+
"routing": self.routing_context,
476+
}

0 commit comments

Comments
 (0)