1717# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1818# See the License for the specific language governing permissions and
1919# limitations under the License.
20-
21-
20+ from abc import abstractmethod
21+ from sys import maxsize
2222from threading import Lock
2323from time import clock
2424
2525from neo4j .addressing import SocketAddress , resolve
2626from neo4j .bolt import ConnectionPool , ServiceUnavailable , ProtocolError , DEFAULT_PORT , connect
2727from neo4j .compat .collections import MutableSet , OrderedDict
2828from neo4j .exceptions import CypherError
29+ from neo4j .util import ServerVersion
2930from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS , fix_statement , fix_parameters
3031from neo4j .v1 .exceptions import SessionExpired
3132from neo4j .v1 .security import SecurityPlan
3233from neo4j .v1 .session import BoltSession
33- from neo4j .util import ServerVersion
3434
3535
36- class RoundRobinSet (MutableSet ):
36+ LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0
37+ LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1
38+ LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED
39+
40+
41+ class OrderedSet (MutableSet ):
3742
3843 def __init__ (self , elements = ()):
3944 self ._elements = OrderedDict .fromkeys (elements )
@@ -45,22 +50,15 @@ def __repr__(self):
4550 def __contains__ (self , element ):
4651 return element in self ._elements
4752
48- def __next__ (self ):
49- current = None
50- if self ._elements :
51- if self ._current is None :
52- self ._current = 0
53- else :
54- self ._current = (self ._current + 1 ) % len (self ._elements )
55- current = list (self ._elements .keys ())[self ._current ]
56- return current
57-
5853 def __iter__ (self ):
5954 return iter (self ._elements )
6055
6156 def __len__ (self ):
6257 return len (self ._elements )
6358
59+ def __getitem__ (self , index ):
60+ return list (self ._elements .keys ())[index ]
61+
6462 def add (self , element ):
6563 self ._elements [element ] = None
6664
@@ -73,9 +71,6 @@ def discard(self, element):
7371 except KeyError :
7472 pass
7573
76- def next (self ):
77- return self .__next__ ()
78-
7974 def remove (self , element ):
8075 try :
8176 del self ._elements [element ]
@@ -126,9 +121,9 @@ def parse_routing_info(cls, records):
126121 return cls (routers , readers , writers , ttl )
127122
128123 def __init__ (self , routers = (), readers = (), writers = (), ttl = 0 ):
129- self .routers = RoundRobinSet (routers )
130- self .readers = RoundRobinSet (readers )
131- self .writers = RoundRobinSet (writers )
124+ self .routers = OrderedSet (routers )
125+ self .readers = OrderedSet (readers )
126+ self .writers = OrderedSet (writers )
132127 self .last_updated_time = self .timer ()
133128 self .ttl = ttl
134129
@@ -168,17 +163,102 @@ def __run__(self, ignored, routing_context):
168163 return self ._run (fix_statement (statement ), fix_parameters (parameters ))
169164
170165
166+ class LoadBalancingStrategy (object ):
167+
168+ @classmethod
169+ def build (cls , connection_pool , ** config ):
170+ load_balancing_strategy = config .get ("load_balancing_strategy" , LOAD_BALANCING_STRATEGY_DEFAULT )
171+ if load_balancing_strategy == LOAD_BALANCING_STRATEGY_LEAST_CONNECTED :
172+ return LeastConnectedLoadBalancingStrategy (connection_pool )
173+ elif load_balancing_strategy == LOAD_BALANCING_STRATEGY_ROUND_ROBIN :
174+ return RoundRobinLoadBalancingStrategy ()
175+ else :
176+ raise ValueError ("Unknown load balancing strategy '%s'" % load_balancing_strategy )
177+
178+ @abstractmethod
179+ def select_reader (self , known_readers ):
180+ raise NotImplementedError ()
181+
182+ @abstractmethod
183+ def select_writer (self , known_writers ):
184+ raise NotImplementedError ()
185+
186+
187+ class RoundRobinLoadBalancingStrategy (LoadBalancingStrategy ):
188+
189+ _readers_offset = 0
190+ _writers_offset = 0
191+
192+ def select_reader (self , known_readers ):
193+ address = self ._select (self ._readers_offset , known_readers )
194+ self ._readers_offset += 1
195+ return address
196+
197+ def select_writer (self , known_writers ):
198+ address = self ._select (self ._writers_offset , known_writers )
199+ self ._writers_offset += 1
200+ return address
201+
202+ @classmethod
203+ def _select (cls , offset , addresses ):
204+ if not addresses :
205+ return None
206+ return addresses [offset % len (addresses )]
207+
208+
209+ class LeastConnectedLoadBalancingStrategy (LoadBalancingStrategy ):
210+
211+ def __init__ (self , connection_pool ):
212+ self ._readers_offset = 0
213+ self ._writers_offset = 0
214+ self ._connection_pool = connection_pool
215+
216+ def select_reader (self , known_readers ):
217+ address = self ._select (self ._readers_offset , known_readers )
218+ self ._readers_offset += 1
219+ return address
220+
221+ def select_writer (self , known_writers ):
222+ address = self ._select (self ._writers_offset , known_writers )
223+ self ._writers_offset += 1
224+ return address
225+
226+ def _select (self , offset , addresses ):
227+ if not addresses :
228+ return None
229+ num_addresses = len (addresses )
230+ start_index = offset % num_addresses
231+ index = start_index
232+
233+ least_connected_address = None
234+ least_in_use_connections = maxsize
235+
236+ while True :
237+ address = addresses [index ]
238+ index = (index + 1 ) % num_addresses
239+
240+ in_use_connections = self ._connection_pool .in_use_connection_count (address )
241+
242+ if in_use_connections < least_in_use_connections :
243+ least_connected_address = address
244+ least_in_use_connections = in_use_connections
245+
246+ if index == start_index :
247+ return least_connected_address
248+
249+
171250class RoutingConnectionPool (ConnectionPool ):
172251 """ Connection pool with routing table.
173252 """
174253
175- def __init__ (self , connector , initial_address , routing_context , * routers ):
254+ def __init__ (self , connector , initial_address , routing_context , * routers , ** config ):
176255 super (RoutingConnectionPool , self ).__init__ (connector )
177256 self .initial_address = initial_address
178257 self .routing_context = routing_context
179258 self .routing_table = RoutingTable (routers )
180259 self .missing_writer = False
181260 self .refresh_lock = Lock ()
261+ self .load_balancing_strategy = LoadBalancingStrategy .build (self , ** config )
182262
183263 def fetch_routing_info (self , address ):
184264 """ Fetch raw routing info from a given router address.
@@ -304,14 +384,16 @@ def acquire(self, access_mode=None):
304384 access_mode = WRITE_ACCESS
305385 if access_mode == READ_ACCESS :
306386 server_list = self .routing_table .readers
387+ server_selector = self .load_balancing_strategy .select_reader
307388 elif access_mode == WRITE_ACCESS :
308389 server_list = self .routing_table .writers
390+ server_selector = self .load_balancing_strategy .select_writer
309391 else :
310392 raise ValueError ("Unsupported access mode {}" .format (access_mode ))
311393
312394 self .ensure_routing_table_is_fresh (access_mode )
313395 while True :
314- address = next (server_list )
396+ address = server_selector (server_list )
315397 if address is None :
316398 break
317399 try :
@@ -354,7 +436,7 @@ def __init__(self, uri, **config):
354436 def connector (a ):
355437 return connect (a , security_plan .ssl_context , ** config )
356438
357- pool = RoutingConnectionPool (connector , initial_address , routing_context , * resolve (initial_address ))
439+ pool = RoutingConnectionPool (connector , initial_address , routing_context , * resolve (initial_address ), ** config )
358440 try :
359441 pool .update_routing_table ()
360442 except :
0 commit comments