66depth and bucket size.
77"""
88import inspect
9+ import math
10+ from collections import deque
911
1012from pygorithm .geometry import (vector2 , polygon2 , rect2 )
1113
@@ -25,7 +27,7 @@ def __init__(self, aabb):
2527 :param aabb: axis-aligned bounding box
2628 :type aabb: :class:`pygorithm.geometry.rect2.Rect2`
2729 """
28- pass
30+ self . aabb = aabb
2931
3032 def __repr__ (self ):
3133 """
@@ -46,7 +48,7 @@ def __repr__(self):
4648 :returns: unambiguous representation of this quad tree entity
4749 :rtype: string
4850 """
49- pass
51+ return "quadtreeentity(aabb={})" . format ( repr ( self . aabb ))
5052
5153 def __str__ (self ):
5254 """
@@ -67,7 +69,7 @@ def __str__(self):
6769 :returns: human readable representation of this entity
6870 :rtype: string
6971 """
70- pass
72+ return "entity(at {})" . format ( str ( self . aabb ))
7173
7274class QuadTree (object ):
7375 """
@@ -129,7 +131,12 @@ def __init__(self, bucket_size, max_depth, location, depth = 0, entities = None)
129131 :param entities: the entities to initialize this quadtree with
130132 :type entities: list of :class:`.QuadTreeEntity` or None for empty list
131133 """
132- pass
134+ self .bucket_size = bucket_size
135+ self .max_depth = max_depth
136+ self .location = location
137+ self .depth = depth
138+ self .entities = entities if entities is not None else []
139+ self .children = None
133140
134141 def think (self , recursive = False ):
135142 """
@@ -145,7 +152,13 @@ def think(self, recursive = False):
145152 :param recursive: if `think(True)` should be called on :py:attr:`.children` (if there are any)
146153 :type recursive: bool
147154 """
148- pass
155+ if not self .children and self .depth < self .max_depth and len (self .entities ) > self .bucket_size :
156+ self .split ()
157+
158+ if recursive :
159+ if self .children :
160+ for child in self .children :
161+ child .think (True )
149162
150163 def split (self ):
151164 """
@@ -164,12 +177,43 @@ def split(self):
164177
165178 :raises ValueError: if :py:attr:`.children` is not empty
166179 """
167- pass
180+ if self .children :
181+ raise ValueError ("cannot split twice" )
182+
183+ _cls = type (self )
184+ def _cstr (r ):
185+ return _cls (self .bucket_size , self .max_depth , r , self .depth + 1 )
186+
187+ _halfwidth = self .location .width / 2
188+ _halfheight = self .location .height / 2
189+ _x = self .location .mincorner .x
190+ _y = self .location .mincorner .y
191+
192+ self .children = [
193+ _cstr (rect2 .Rect2 (_halfwidth , _halfheight , vector2 .Vector2 (_x , _y ))),
194+ _cstr (rect2 .Rect2 (_halfwidth , _halfheight , vector2 .Vector2 (_x + _halfwidth , _y ))),
195+ _cstr (rect2 .Rect2 (_halfwidth , _halfheight , vector2 .Vector2 (_x + _halfwidth , _y + _halfheight ))),
196+ _cstr (rect2 .Rect2 (_halfwidth , _halfheight , vector2 .Vector2 (_x , _y + _halfheight ))) ]
197+
198+ _newents = []
199+ for ent in self .entities :
200+ quad = self .get_quadrant (ent )
201+
202+ if quad < 0 :
203+ _newents .append (ent )
204+ else :
205+ self .children [quad ].entities .append (ent )
206+ self .entities = _newents
207+
208+
168209
169210 def get_quadrant (self , entity ):
170211 """
171212 Calculate the quadrant that the specified entity belongs to.
172213
214+ Touching a line is considered overlapping a line. Touching is
215+ determined using :py:meth:`math.isclose`
216+
173217 Quadrants are:
174218
175219 - -1: None (it overlaps 2 or more quadrants)
@@ -189,7 +233,48 @@ def get_quadrant(self, entity):
189233 :returns: quadrant
190234 :rtype: int
191235 """
192- pass
236+
237+ _aabb = entity .aabb
238+ _halfwidth = self .location .width / 2
239+ _halfheight = self .location .height / 2
240+ _x = self .location .mincorner .x
241+ _y = self .location .mincorner .y
242+
243+ if math .isclose (_aabb .mincorner .x , _x + _halfwidth ):
244+ return - 1
245+ if math .isclose (_aabb .mincorner .x + _aabb .width , _x + _halfwidth ):
246+ return - 1
247+ if math .isclose (_aabb .mincorner .y , _y + _halfheight ):
248+ return - 1
249+ if math .isclose (_aabb .mincorner .y + _aabb .height , _y + _halfheight ):
250+ return - 1
251+
252+ _leftside_isleft = _aabb .mincorner .x < _x + _halfwidth
253+ _rightside_isleft = _aabb .mincorner .x + _aabb .width < _x + _halfwidth
254+
255+ if _leftside_isleft != _rightside_isleft :
256+ return - 1
257+
258+ _topside_istop = _aabb .mincorner .y < _y + _halfheight
259+ _botside_istop = _aabb .mincorner .y + _aabb .height < _y + _halfheight
260+
261+ if _topside_istop != _botside_istop :
262+ return - 1
263+
264+ _left = _leftside_isleft
265+ _top = _topside_istop
266+
267+ if _left :
268+ if _top :
269+ return 0
270+ else :
271+ return 3
272+ else :
273+ if _top :
274+ return 1
275+ else :
276+ return 2
277+
193278
194279 def insert_and_think (self , entity ):
195280 """
@@ -204,7 +289,14 @@ def insert_and_think(self, entity):
204289 :param entity: the entity to insert
205290 :type entity: :class:`.QuadTreeEntity`
206291 """
207- pass
292+ if not self .children and len (self .entities ) == self .bucket_size and self .depth < self .max_depth :
293+ self .split ()
294+
295+ quad = self .get_quadrant (entity ) if self .children else - 1
296+ if quad < 0 :
297+ self .entities .append (entity )
298+ else :
299+ self .children [quad ].insert_and_think (entity )
208300
209301 def retrieve_collidables (self , entity , predicate = None ):
210302 """
@@ -227,19 +319,71 @@ def retrieve_collidables(self, entity, predicate = None):
227319 :returns: potential collidables (never `None)
228320 :rtype: list of :class:`.QuadTreeEntity`
229321 """
230- pass
322+ result = list (filter (predicate , self .entities ))
323+ quadrant = self .get_quadrant (entity ) if self .children else - 1
324+
325+ if quadrant >= 0 :
326+ result .extend (self .children [quadrant ].retrieve_collidables (entity , predicate ))
327+ elif self .children :
328+ for child in self .children :
329+ touching , overlapping , alwaysNone = rect2 .Rect2 .find_intersection (entity .aabb , child .location , find_mtv = False )
330+ if touching or overlapping :
331+ result .extend (child .retrieve_collidables (entity , predicate ))
332+
333+ return result
334+
335+ def _iter_helper (self , pred ):
336+ """
337+ Calls pred on each child and childs child, iteratively.
338+
339+ pred takes one positional argument (the child).
340+
341+ :param pred: function to call
342+ :type pred: `types.FunctionType`
343+ """
344+
345+ _stack = deque ()
346+ _stack .append (self )
231347
348+ while _stack :
349+ curr = _stack .pop ()
350+ if curr .children :
351+ for child in curr .children :
352+ _stack .append (child )
353+
354+ pred (curr )
355+
232356 def find_entities_per_depth (self ):
233357 """
234358 Calculate the number of nodes and entities at each depth level in this
235359 quad tree. Only returns for depth levels at or equal to this node.
236360
237361 This is implemented iteratively. See :py:meth:`.__str__` for usage example.
238362
239- :returns: dict of depth level to (number of nodes, number of entities)
240- :rtype: dict int: (int, int)
363+ :returns: dict of depth level to number of entities
364+ :rtype: dict int: int
365+ """
366+
367+ container = { 'result' : {} }
368+ def handler (curr , container = container ):
369+ container ['result' ][curr .depth ] = container ['result' ].get (curr .depth , 0 ) + len (curr .entities )
370+ self ._iter_helper (handler )
371+
372+ return container ['result' ]
373+
374+ def find_nodes_per_depth (self ):
375+ """
376+ Calculate the number of nodes at each depth level.
377+
378+ This is implemented iteratively. See :py:meth:`.__str__` for usage example.
379+
380+ :returns: dict of depth level to number of nodes
381+ :rtype: dict int: int
241382 """
242- pass
383+
384+ nodes_per_depth = {}
385+ self ._iter_helper (lambda curr , d = nodes_per_depth : d .update ({ (curr .depth , d .get (curr .depth , 0 ) + 1 ) }))
386+ return nodes_per_depth
243387
244388 def sum_entities (self , entities_per_depth = None ):
245389 """
@@ -254,7 +398,15 @@ def sum_entities(self, entities_per_depth=None):
254398 :returns: number of entities in this and child nodes
255399 :rtype: int
256400 """
257- pass
401+ if entities_per_depth is not None :
402+ return sum (entities_per_depth .values ())
403+
404+ container = { 'result' : 0 }
405+ def handler (curr , container = container ):
406+ container ['result' ] += len (curr .entities )
407+ self ._iter_helper (handler )
408+
409+ return container ['result' ]
258410
259411 def calculate_avg_ents_per_leaf (self ):
260412 """
@@ -270,7 +422,13 @@ def calculate_avg_ents_per_leaf(self):
270422 :returns: average number of entities at each leaf node
271423 :rtype: :class:`numbers.Number`
272424 """
273- pass
425+ container = { 'leafs' : 0 , 'total' : 0 }
426+ def handler (curr , container = container ):
427+ if not curr .children :
428+ container ['leafs' ] += 1
429+ container ['total' ] += len (curr .entities )
430+ self ._iter_helper (handler )
431+ return container ['total' ] / container ['leafs' ]
274432
275433 def calculate_weight_misplaced_ents (self , sum_entities = None ):
276434 """
@@ -293,11 +451,40 @@ def calculate_weight_misplaced_ents(self, sum_entities=None):
293451 :returns: weight of misplaced entities
294452 :rtype: :class:`numbers.Number`
295453 """
296- pass
297454
455+ # this iteration requires more context than _iter_helper provides.
456+ # we must keep track of parents as well in order to correctly update
457+ # weights
458+
459+ nonleaf_to_max_child_depth_dict = {}
460+
461+ # stack will be (quadtree, list (of parents) or None)
462+ _stack = deque ()
463+ _stack .append ((self , None ))
464+ while _stack :
465+ curr , parents = _stack .pop ()
466+ if parents :
467+ for p in parents :
468+ nonleaf_to_max_child_depth_dict [p ] = max (nonleaf_to_max_child_depth_dict .get (p , 0 ), curr .depth )
469+
470+ if curr .children :
471+ new_parents = list (parents ) if parents else []
472+ new_parents .append (curr )
473+ for child in curr .children :
474+ _stack .append ((child , new_parents ))
475+
476+ _weight = 0
477+ for nonleaf , maxchilddepth in nonleaf_to_max_child_depth_dict .items ():
478+ _weight += len (nonleaf .entities ) * 4 * (maxchilddepth - nonleaf .depth )
479+
480+ _sum = self .sum_entities () if sum_entities is None else sum_entities
481+ return _weight / _sum
482+
298483 def __repr__ (self ):
299484 """
300- Create an unambiguous, recursive representation of this quad tree.
485+ Create an unambiguous representation of this quad tree.
486+
487+ This is implemented iteratively.
301488
302489 Example:
303490
@@ -308,19 +495,18 @@ def __repr__(self):
308495
309496 # create a tree with a up to 2 entities in a bucket that
310497 # can have a depth of up to 5.
311- _tree = quadtree.QuadTree(2 , 5, rect2.Rect2(100, 100))
498+ _tree = quadtree.QuadTree(1 , 5, rect2.Rect2(100, 100))
312499
313500 # add a few entities to the tree
314501 _tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(5, 5))))
315502 _tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(95, 5))))
316503
317- # prints quadtree(bucket_size=2, max_depth=5, location=rect2(width=100, height=100, mincorner=vector2(x=0, y=0)), depth=0, entities=[], children=[ quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=0, y=0)), depth=1, entities=[ quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=5, y=5))) ], children=[]), quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=50, y=0)), depth=1, entities=[ quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=95, y=5))) ], children=[]), quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=50, y=50)), depth=1, entities=[], children=[]), quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=0, y=50)), depth=1, entities=[], children=[]) ])
318- print(repr(_tree))
504+ # prints quadtree(bucket_size=1, max_depth=5, location=rect2(width=100, height=100, mincorner=vector2(x=0, y=0)), depth=0, entities=[], children=[quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=0, y=0)), depth=1, entities=[quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=5, y=5)))], children=None), quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=50.0, y=0)), depth=1, entities=[quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=95, y=5)))], children=None), quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=50.0, y=50.0)), depth=1, entities=[], children=None), quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=0, y=50.0)), depth=1, entities=[], children=None)])
319505
320506 :returns: unambiguous, recursive representation of this quad tree
321507 :rtype: string
322508 """
323- pass
509+ return "quadtree(bucket_size={}, max_depth={}, location={}, depth={}, entities={}, children={})" . format ( self . bucket_size , self . max_depth , repr ( self . location ), self . depth , self . entities , self . children )
324510
325511 def __str__ (self ):
326512 """
@@ -347,12 +533,23 @@ def __str__(self):
347533 _tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(5, 5))))
348534 _tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(95, 5))))
349535
350- # prints quadtree(at rect(100x100 at <0, 0>) with 0 entities here (2 in total); (nodes, entities) per depth: [ 0: (1, 0), 1: (4, 2) ] (max depth: 5), avg ent/leaf: 0.5 (target 2), misplaced weight = 0 (0 best, >1 bad))
536+ # prints quadtree(at rect(100x100 at <0, 0>) with 0 entities here (2 in total); (nodes, entities) per depth: [ 0: (1, 0), 1: (4, 2) ] (allowed max depth: 5, actual: 1), avg ent/leaf: 0.5 (target 1), misplaced weight 0.0 (0 best, >1 bad)
537+ print(_tree)
351538
352539 :returns: human-readable representation of this quad tree
353540 :rtype: string
354541 """
355- pass
542+
543+ nodes_per_depth = self .find_nodes_per_depth ()
544+ _ents_per_depth = self .find_entities_per_depth ()
545+
546+ _nodes_ents_per_depth_str = "[ {} ]" .format (', ' .join ("{}: ({}, {})" .format (dep , nodes_per_depth [dep ], _ents_per_depth [dep ]) for dep in nodes_per_depth .keys ()))
547+
548+ _sum = self .sum_entities (entities_per_depth = _ents_per_depth )
549+ _max_depth = max (_ents_per_depth .keys ())
550+ _avg_ent_leaf = self .calculate_avg_ents_per_leaf ()
551+ _mispl_weight = self .calculate_weight_misplaced_ents (sum_entities = _sum )
552+ return "quadtree(at {} with {} entities here ({} in total); (nodes, entities) per depth: {} (allowed max depth: {}, actual: {}), avg ent/leaf: {} (target {}), misplaced weight {} (0 best, >1 bad)" .format (self .location , len (self .entities ), _sum , _nodes_ents_per_depth_str , self .max_depth , _max_depth , _avg_ent_leaf , self .bucket_size , _mispl_weight )
356553
357554 @staticmethod
358555 def get_code ():
0 commit comments