1919# limitations under the License.
2020
2121
22+ from sys import maxsize
2223from threading import Lock
2324from time import clock
2425
2526from neo4j .addressing import SocketAddress , resolve
2627from neo4j .bolt import ConnectionPool , ServiceUnavailable , ProtocolError , DEFAULT_PORT , connect
2728from neo4j .compat .collections import MutableSet , OrderedDict
2829from neo4j .exceptions import CypherError
30+ from neo4j .util import ServerVersion
2931from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS , fix_statement , fix_parameters
3032from neo4j .v1 .exceptions import SessionExpired
3133from neo4j .v1 .security import SecurityPlan
3234from neo4j .v1 .session import BoltSession
33- from neo4j .util import ServerVersion
35+
36+
37+ LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0
38+ LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1
39+ LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED
3440
3541
3642class RoundRobinSet (MutableSet ):
@@ -52,7 +58,7 @@ def __next__(self):
5258 self ._current = 0
5359 else :
5460 self ._current = (self ._current + 1 ) % len (self ._elements )
55- current = list ( self ._elements . keys ())[ self ._current ]
61+ current = self .get ( self ._current )
5662 return current
5763
5864 def __iter__ (self ):
@@ -90,6 +96,9 @@ def replace(self, elements=()):
9096 e .clear ()
9197 e .update (OrderedDict .fromkeys (elements ))
9298
99+ def get (self , index ):
100+ return list (self ._elements .keys ())[index ]
101+
93102
94103class RoutingTable (object ):
95104
@@ -168,17 +177,109 @@ def __run__(self, ignored, routing_context):
168177 return self ._run (fix_statement (statement ), fix_parameters (parameters ))
169178
170179
180+ class LoadBalancingStrategy (object ):
181+
182+ @classmethod
183+ def build (cls , connection_pool , ** config ):
184+ load_balancing_strategy = config .get ("load_balancing_strategy" , LOAD_BALANCING_STRATEGY_DEFAULT )
185+ if load_balancing_strategy == LOAD_BALANCING_STRATEGY_LEAST_CONNECTED :
186+ return LeastConnectedLoadBalancingStrategy (connection_pool )
187+ elif load_balancing_strategy == LOAD_BALANCING_STRATEGY_ROUND_ROBIN :
188+ return RoundRobinLoadBalancingStrategy ()
189+ else :
190+ raise ValueError ("Unknown load balancing strategy '%s'" % load_balancing_strategy )
191+ pass
192+
193+ def select_reader (self , known_readers ):
194+ raise NotImplementedError ()
195+
196+ def select_writer (self , known_writers ):
197+ raise NotImplementedError ()
198+
199+
200+ class RoundRobinLoadBalancingStrategy (LoadBalancingStrategy ):
201+
202+ _readers_offset = 0
203+ _writers_offset = 0
204+
205+ def select_reader (self , known_readers ):
206+ address = self .select (self ._readers_offset , known_readers )
207+ self ._readers_offset += 1
208+ return address
209+
210+ def select_writer (self , known_writers ):
211+ address = self .select (self ._writers_offset , known_writers )
212+ self ._writers_offset += 1
213+ return address
214+
215+ def select (self , offset , addresses ):
216+ length = len (addresses )
217+ if length == 0 :
218+ return None
219+ else :
220+ index = offset % length
221+ return addresses .get (index )
222+
223+
224+ class LeastConnectedLoadBalancingStrategy (LoadBalancingStrategy ):
225+
226+ def __init__ (self , connection_pool ):
227+ self ._readers_offset = 0
228+ self ._writers_offset = 0
229+ self ._connection_pool = connection_pool
230+
231+ def select_reader (self , known_readers ):
232+ address = self .select (self ._readers_offset , known_readers )
233+ self ._readers_offset += 1
234+ return address
235+
236+ def select_writer (self , known_writers ):
237+ address = self .select (self ._writers_offset , known_writers )
238+ self ._writers_offset += 1
239+ return address
240+
241+ def select (self , offset , addresses ):
242+ length = len (addresses )
243+ if length == 0 :
244+ return None
245+ else :
246+ start_index = offset % length
247+ index = start_index
248+
249+ least_connected_address = None
250+ least_in_use_connections = maxsize
251+
252+ while True :
253+ address = addresses .get (index )
254+ in_use_connections = self ._connection_pool .in_use_connection_count (address )
255+
256+ if in_use_connections < least_in_use_connections :
257+ least_connected_address = address
258+ least_in_use_connections = in_use_connections
259+
260+ if index == length - 1 :
261+ index = 0
262+ else :
263+ index += 1
264+
265+ if index == start_index :
266+ break
267+
268+ return least_connected_address
269+
270+
171271class RoutingConnectionPool (ConnectionPool ):
172272 """ Connection pool with routing table.
173273 """
174274
175- def __init__ (self , connector , initial_address , routing_context , * routers ):
275+ def __init__ (self , connector , initial_address , routing_context , * routers , ** config ):
176276 super (RoutingConnectionPool , self ).__init__ (connector )
177277 self .initial_address = initial_address
178278 self .routing_context = routing_context
179279 self .routing_table = RoutingTable (routers )
180280 self .missing_writer = False
181281 self .refresh_lock = Lock ()
282+ self .load_balancing_strategy = LoadBalancingStrategy .build (self , ** config )
182283
183284 def fetch_routing_info (self , address ):
184285 """ Fetch raw routing info from a given router address.
@@ -304,14 +405,16 @@ def acquire(self, access_mode=None):
304405 access_mode = WRITE_ACCESS
305406 if access_mode == READ_ACCESS :
306407 server_list = self .routing_table .readers
408+ server_selector = self .load_balancing_strategy .select_reader
307409 elif access_mode == WRITE_ACCESS :
308410 server_list = self .routing_table .writers
411+ server_selector = self .load_balancing_strategy .select_writer
309412 else :
310413 raise ValueError ("Unsupported access mode {}" .format (access_mode ))
311414
312415 self .ensure_routing_table_is_fresh (access_mode )
313416 while True :
314- address = next (server_list )
417+ address = server_selector (server_list )
315418 if address is None :
316419 break
317420 try :
@@ -354,7 +457,7 @@ def __init__(self, uri, **config):
354457 def connector (a ):
355458 return connect (a , security_plan .ssl_context , ** config )
356459
357- pool = RoutingConnectionPool (connector , initial_address , routing_context , * resolve (initial_address ))
460+ pool = RoutingConnectionPool (connector , initial_address , routing_context , * resolve (initial_address ), ** config )
358461 try :
359462 pool .update_routing_table ()
360463 except :
0 commit comments