@@ -32,7 +32,7 @@ class which can be used to obtain `Driver` instances that are used for
3232
3333from .compat import integer , string , urlparse
3434from .bolt import connect , Response , RUN , PULL_ALL
35- from .constants import ENCRYPTED_DEFAULT , TRUST_DEFAULT , TRUST_SIGNED_CERTIFICATES
35+ from .constants import DEFAULT_PORT , ENCRYPTED_DEFAULT , TRUST_DEFAULT , TRUST_SIGNED_CERTIFICATES
3636from .exceptions import CypherError , ProtocolError , ResultError
3737from .ssl_compat import SSL_AVAILABLE , SSLContext , PROTOCOL_SSLv23 , OP_NO_SSLv2 , CERT_REQUIRED
3838from .types import hydrated
@@ -100,15 +100,21 @@ class Driver(object):
100100 """ Accessor for a specific graph database resource.
101101 """
102102
103- def __init__ (self , url , ** config ):
104- self .url = url
105- parsed = urlparse (self .url )
106- if parsed .scheme == "bolt" :
107- self .host = parsed .hostname
108- self .port = parsed .port
103+ def __init__ (self , address , ** config ):
104+ if "://" in address :
105+ parsed = urlparse (address )
106+ if parsed .scheme == "bolt" :
107+ host = parsed .hostname
108+ port = parsed .port or DEFAULT_PORT
109+ else :
110+ raise ProtocolError ("Only the 'bolt' URI scheme is supported [%s]" % address )
111+ elif ":" in address :
112+ host , port = address .split (":" )
113+ port = int (port )
109114 else :
110- raise ProtocolError ("Unsupported URI scheme: '%s' in url: '%s'. Currently only supported 'bolt'." %
111- (parsed .scheme , url ))
115+ host = address
116+ port = DEFAULT_PORT
117+ self .address = (host , port )
112118 self .config = config
113119 self .max_pool_size = config .get ("max_pool_size" , DEFAULT_MAX_POOL_SIZE )
114120 self .session_pool = deque ()
@@ -137,20 +143,20 @@ def session(self):
137143 >>> from neo4j.v1 import GraphDatabase
138144 >>> driver = GraphDatabase.driver("bolt://localhost")
139145 >>> session = driver.session()
140-
141146 """
142147 session = None
143- done = False
144- while not done :
148+ connected = False
149+ while not connected :
145150 try :
146151 session = self .session_pool .pop ()
147152 except IndexError :
148- session = Session (self )
149- done = True
153+ connection = connect (self .address , self .ssl_context , ** self .config )
154+ session = Session (self , connection )
155+ connected = True
150156 else :
151157 if session .healthy :
152158 session .connection .reset ()
153- done = session .healthy
159+ connected = session .healthy
154160 return session
155161
156162 def recycle (self , session ):
@@ -450,17 +456,51 @@ def make_plan(plan_dict):
450456 return Plan (operator_type , identifiers , arguments , children )
451457
452458
459+ def run (connection , statement , parameters = None ):
460+ """ Run a Cypher statement on a given connection.
461+
462+ :param connection: connection to carry the request and response
463+ :param statement: Cypher statement
464+ :param parameters: optional dictionary of parameters
465+ :return: statement result
466+ """
467+ # Ensure the statement is a Unicode value
468+ if isinstance (statement , bytes ):
469+ statement = statement .decode ("UTF-8" )
470+
471+ params = {}
472+ for key , value in (parameters or {}).items ():
473+ if isinstance (key , bytes ):
474+ key = key .decode ("UTF-8" )
475+ if isinstance (value , bytes ):
476+ params [key ] = value .decode ("UTF-8" )
477+ else :
478+ params [key ] = value
479+ parameters = params
480+
481+ run_response = Response (connection )
482+ pull_all_response = Response (connection )
483+ result = StatementResult (connection , run_response , pull_all_response )
484+ result .statement = statement
485+ result .parameters = parameters
486+
487+ connection .append (RUN , (statement , parameters ), response = run_response )
488+ connection .append (PULL_ALL , response = pull_all_response )
489+ connection .send ()
490+
491+ return result
492+
493+
453494class Session (object ):
454495 """ Logical session carried out over an established TCP connection.
455496 Sessions should generally be constructed using the :meth:`.Driver.session`
456497 method.
457498 """
458499
459- def __init__ (self , driver ):
500+ def __init__ (self , driver , connection ):
460501 self .driver = driver
461- self .connection = connect ( driver . host , driver . port , driver . ssl_context , ** driver . config )
502+ self .connection = connection
462503 self .transaction = None
463- self .last_result = None
464504
465505 def __enter__ (self ):
466506 return self
@@ -473,8 +513,7 @@ def healthy(self):
473513 """ Return ``True`` if this session is healthy, ``False`` if
474514 unhealthy and ``None`` if closed.
475515 """
476- connection = self .connection
477- return None if connection .closed else not connection .defunct
516+ return self .connection .healthy
478517
479518 def run (self , statement , parameters = None ):
480519 """ Run a parameterised Cypher statement.
@@ -487,41 +526,13 @@ def run(self, statement, parameters=None):
487526 if self .transaction :
488527 raise ProtocolError ("Statements cannot be run directly on a session with an open transaction;"
489528 " either run from within the transaction or use a different session." )
490- return self ._run (statement , parameters )
491-
492- def _run (self , statement , parameters = None ):
493- # Ensure the statement is a Unicode value
494- if isinstance (statement , bytes ):
495- statement = statement .decode ("UTF-8" )
496-
497- params = {}
498- for key , value in (parameters or {}).items ():
499- if isinstance (key , bytes ):
500- key = key .decode ("UTF-8" )
501- if isinstance (value , bytes ):
502- params [key ] = value .decode ("UTF-8" )
503- else :
504- params [key ] = value
505- parameters = params
506-
507- run_response = Response (self .connection )
508- pull_all_response = Response (self .connection )
509- result = StatementResult (self .connection , run_response , pull_all_response )
510- result .statement = statement
511- result .parameters = parameters
512-
513- self .connection .append (RUN , (statement , parameters ), response = run_response )
514- self .connection .append (PULL_ALL , response = pull_all_response )
515- self .connection .send ()
516-
517- self .last_result = result
518- return result
529+ return run (self .connection , statement , parameters )
519530
520531 def close (self ):
521532 """ Recycle this session through the driver it came from.
522533 """
523- if self .last_result :
524- self .last_result . buffer ()
534+ if self .connection and not self . connection . closed :
535+ self .connection . fetch_all ()
525536 if self .transaction :
526537 self .transaction .close ()
527538 self .driver .recycle (self )
@@ -534,7 +545,11 @@ def begin_transaction(self):
534545 if self .transaction :
535546 raise ProtocolError ("You cannot begin a transaction on a session with an open transaction;"
536547 " either run from within the transaction or use a different session." )
537- self .transaction = Transaction (self )
548+
549+ def clear_transaction ():
550+ self .transaction = None
551+
552+ self .transaction = Transaction (self .connection , on_close = clear_transaction )
538553 return self .transaction
539554
540555
@@ -559,9 +574,10 @@ class Transaction(object):
559574 #: with commit or rollback.
560575 closed = False
561576
562- def __init__ (self , session ):
563- self .session = session
564- self .session ._run ("BEGIN" )
577+ def __init__ (self , connection , on_close ):
578+ self .connection = connection
579+ self .on_close = on_close
580+ run (self .connection , "BEGIN" )
565581
566582 def __enter__ (self ):
567583 return self
@@ -574,12 +590,12 @@ def __exit__(self, exc_type, exc_value, traceback):
574590 def run (self , statement , parameters = None ):
575591 """ Run a Cypher statement within the context of this transaction.
576592
577- :param statement:
578- :param parameters:
579- :return:
593+ :param statement: Cypher statement
594+ :param parameters: dictionary of parameters
595+ :return: result object
580596 """
581597 assert not self .closed
582- return self .session . _run ( statement , parameters )
598+ return run ( self .connection , statement , parameters )
583599
584600 def commit (self ):
585601 """ Mark this transaction as successful and close in order to
@@ -600,11 +616,11 @@ def close(self):
600616 """
601617 assert not self .closed
602618 if self .success :
603- self .session . _run ( "COMMIT" )
619+ run ( self .connection , "COMMIT" )
604620 else :
605- self .session . _run ( "ROLLBACK" )
621+ run ( self .connection , "ROLLBACK" )
606622 self .closed = True
607- self .session . transaction = None
623+ self .on_close ()
608624
609625
610626class Record (object ):
0 commit comments