Skip to content

Commit aa2c1fa

Browse files
committed
Decoupled connection and session
1 parent 2e88154 commit aa2c1fa

File tree

2 files changed

+83
-67
lines changed

2 files changed

+83
-67
lines changed

neo4j/v1/bolt.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -400,26 +400,26 @@ def match_or_trust(self, host, der_encoded_certificate):
400400
return True
401401

402402

403-
def connect(host, port=None, ssl_context=None, **config):
403+
def connect(host_port, ssl_context=None, **config):
404404
""" Connect and perform a handshake and return a valid Connection object, assuming
405405
a protocol version can be agreed.
406406
"""
407407

408408
# Establish a connection to the host and port specified
409409
# Catches refused connections see:
410410
# https://docs.python.org/2/library/errno.html
411-
port = port or DEFAULT_PORT
412-
if __debug__: log_info("~~ [CONNECT] %s %d", host, port)
411+
if __debug__: log_info("~~ [CONNECT] %s", host_port)
413412
try:
414-
s = create_connection((host, port))
413+
s = create_connection(host_port)
415414
except SocketError as error:
416415
if error.errno == 111 or error.errno == 61:
417-
raise ProtocolError("Unable to connect to %s on port %d - is the server running?" % (host, port))
416+
raise ProtocolError("Unable to connect to %s on port %d - is the server running?" % host_port)
418417
else:
419418
raise
420419

421420
# Secure the connection if an SSL context has been provided
422421
if ssl_context and SSL_AVAILABLE:
422+
host, port = host_port
423423
if __debug__: log_info("~~ [SECURE] %s", host)
424424
try:
425425
s = ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None)

neo4j/v1/session.py

Lines changed: 78 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class which can be used to obtain `Driver` instances that are used for
3232

3333
from .compat import integer, string, urlparse
3434
from .bolt import connect, Response, RUN, PULL_ALL
35-
from .constants import ENCRYPTED_DEFAULT, TRUST_DEFAULT, TRUST_SIGNED_CERTIFICATES
35+
from .constants import DEFAULT_PORT, ENCRYPTED_DEFAULT, TRUST_DEFAULT, TRUST_SIGNED_CERTIFICATES
3636
from .exceptions import CypherError, ProtocolError, ResultError
3737
from .ssl_compat import SSL_AVAILABLE, SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED
3838
from .types import hydrated
@@ -100,15 +100,21 @@ class Driver(object):
100100
""" Accessor for a specific graph database resource.
101101
"""
102102

103-
def __init__(self, url, **config):
104-
self.url = url
105-
parsed = urlparse(self.url)
106-
if parsed.scheme == "bolt":
107-
self.host = parsed.hostname
108-
self.port = parsed.port
103+
def __init__(self, address, **config):
104+
if "://" in address:
105+
parsed = urlparse(address)
106+
if parsed.scheme == "bolt":
107+
host = parsed.hostname
108+
port = parsed.port or DEFAULT_PORT
109+
else:
110+
raise ProtocolError("Only the 'bolt' URI scheme is supported [%s]" % address)
111+
elif ":" in address:
112+
host, port = address.split(":")
113+
port = int(port)
109114
else:
110-
raise ProtocolError("Unsupported URI scheme: '%s' in url: '%s'. Currently only supported 'bolt'." %
111-
(parsed.scheme, url))
115+
host = address
116+
port = DEFAULT_PORT
117+
self.address = (host, port)
112118
self.config = config
113119
self.max_pool_size = config.get("max_pool_size", DEFAULT_MAX_POOL_SIZE)
114120
self.session_pool = deque()
@@ -137,20 +143,20 @@ def session(self):
137143
>>> from neo4j.v1 import GraphDatabase
138144
>>> driver = GraphDatabase.driver("bolt://localhost")
139145
>>> session = driver.session()
140-
141146
"""
142147
session = None
143-
done = False
144-
while not done:
148+
connected = False
149+
while not connected:
145150
try:
146151
session = self.session_pool.pop()
147152
except IndexError:
148-
session = Session(self)
149-
done = True
153+
connection = connect(self.address, self.ssl_context, **self.config)
154+
session = Session(self, connection)
155+
connected = True
150156
else:
151157
if session.healthy:
152158
session.connection.reset()
153-
done = session.healthy
159+
connected = session.healthy
154160
return session
155161

156162
def recycle(self, session):
@@ -450,17 +456,51 @@ def make_plan(plan_dict):
450456
return Plan(operator_type, identifiers, arguments, children)
451457

452458

459+
def run(connection, statement, parameters=None):
460+
""" Run a Cypher statement on a given connection.
461+
462+
:param connection: connection to carry the request and response
463+
:param statement: Cypher statement
464+
:param parameters: optional dictionary of parameters
465+
:return: statement result
466+
"""
467+
# Ensure the statement is a Unicode value
468+
if isinstance(statement, bytes):
469+
statement = statement.decode("UTF-8")
470+
471+
params = {}
472+
for key, value in (parameters or {}).items():
473+
if isinstance(key, bytes):
474+
key = key.decode("UTF-8")
475+
if isinstance(value, bytes):
476+
params[key] = value.decode("UTF-8")
477+
else:
478+
params[key] = value
479+
parameters = params
480+
481+
run_response = Response(connection)
482+
pull_all_response = Response(connection)
483+
result = StatementResult(connection, run_response, pull_all_response)
484+
result.statement = statement
485+
result.parameters = parameters
486+
487+
connection.append(RUN, (statement, parameters), response=run_response)
488+
connection.append(PULL_ALL, response=pull_all_response)
489+
connection.send()
490+
491+
return result
492+
493+
453494
class Session(object):
454495
""" Logical session carried out over an established TCP connection.
455496
Sessions should generally be constructed using the :meth:`.Driver.session`
456497
method.
457498
"""
458499

459-
def __init__(self, driver):
500+
def __init__(self, driver, connection):
460501
self.driver = driver
461-
self.connection = connect(driver.host, driver.port, driver.ssl_context, **driver.config)
502+
self.connection = connection
462503
self.transaction = None
463-
self.last_result = None
464504

465505
def __enter__(self):
466506
return self
@@ -473,8 +513,7 @@ def healthy(self):
473513
""" Return ``True`` if this session is healthy, ``False`` if
474514
unhealthy and ``None`` if closed.
475515
"""
476-
connection = self.connection
477-
return None if connection.closed else not connection.defunct
516+
return self.connection.healthy
478517

479518
def run(self, statement, parameters=None):
480519
""" Run a parameterised Cypher statement.
@@ -487,41 +526,13 @@ def run(self, statement, parameters=None):
487526
if self.transaction:
488527
raise ProtocolError("Statements cannot be run directly on a session with an open transaction;"
489528
" either run from within the transaction or use a different session.")
490-
return self._run(statement, parameters)
491-
492-
def _run(self, statement, parameters=None):
493-
# Ensure the statement is a Unicode value
494-
if isinstance(statement, bytes):
495-
statement = statement.decode("UTF-8")
496-
497-
params = {}
498-
for key, value in (parameters or {}).items():
499-
if isinstance(key, bytes):
500-
key = key.decode("UTF-8")
501-
if isinstance(value, bytes):
502-
params[key] = value.decode("UTF-8")
503-
else:
504-
params[key] = value
505-
parameters = params
506-
507-
run_response = Response(self.connection)
508-
pull_all_response = Response(self.connection)
509-
result = StatementResult(self.connection, run_response, pull_all_response)
510-
result.statement = statement
511-
result.parameters = parameters
512-
513-
self.connection.append(RUN, (statement, parameters), response=run_response)
514-
self.connection.append(PULL_ALL, response=pull_all_response)
515-
self.connection.send()
516-
517-
self.last_result = result
518-
return result
529+
return run(self.connection, statement, parameters)
519530

520531
def close(self):
521532
""" Recycle this session through the driver it came from.
522533
"""
523-
if self.last_result:
524-
self.last_result.buffer()
534+
if self.connection and not self.connection.closed:
535+
self.connection.fetch_all()
525536
if self.transaction:
526537
self.transaction.close()
527538
self.driver.recycle(self)
@@ -534,7 +545,11 @@ def begin_transaction(self):
534545
if self.transaction:
535546
raise ProtocolError("You cannot begin a transaction on a session with an open transaction;"
536547
" either run from within the transaction or use a different session.")
537-
self.transaction = Transaction(self)
548+
549+
def clear_transaction():
550+
self.transaction = None
551+
552+
self.transaction = Transaction(self.connection, on_close=clear_transaction)
538553
return self.transaction
539554

540555

@@ -559,9 +574,10 @@ class Transaction(object):
559574
#: with commit or rollback.
560575
closed = False
561576

562-
def __init__(self, session):
563-
self.session = session
564-
self.session._run("BEGIN")
577+
def __init__(self, connection, on_close):
578+
self.connection = connection
579+
self.on_close = on_close
580+
run(self.connection, "BEGIN")
565581

566582
def __enter__(self):
567583
return self
@@ -574,12 +590,12 @@ def __exit__(self, exc_type, exc_value, traceback):
574590
def run(self, statement, parameters=None):
575591
""" Run a Cypher statement within the context of this transaction.
576592
577-
:param statement:
578-
:param parameters:
579-
:return:
593+
:param statement: Cypher statement
594+
:param parameters: dictionary of parameters
595+
:return: result object
580596
"""
581597
assert not self.closed
582-
return self.session._run(statement, parameters)
598+
return run(self.connection, statement, parameters)
583599

584600
def commit(self):
585601
""" Mark this transaction as successful and close in order to
@@ -600,11 +616,11 @@ def close(self):
600616
"""
601617
assert not self.closed
602618
if self.success:
603-
self.session._run("COMMIT")
619+
run(self.connection, "COMMIT")
604620
else:
605-
self.session._run("ROLLBACK")
621+
run(self.connection, "ROLLBACK")
606622
self.closed = True
607-
self.session.transaction = None
623+
self.on_close()
608624

609625

610626
class Record(object):

0 commit comments

Comments
 (0)