11"""
22DomainTrie class, as a data structure example.
33"""
4+ from ml .utils .logger import get_logger
5+
6+ LOGGER = get_logger (__name__ )
47
58
69class DomainTrie (object ):
@@ -10,7 +13,7 @@ class DomainTrie(object):
1013 """
1114
1215 def __init__ (self , domains ):
13- self ._trie = DomainTrie ._make_tree (domains )
16+ self ._trie = DomainTrie ._build_tree (domains )
1417
1518 def __contains__ (self , domain ):
1619 """
@@ -19,21 +22,19 @@ def __contains__(self, domain):
1922 @param domain: A dot separated domain.
2023 @return boolean: True if domain is in the trie; otherwise, False.
2124 """
25+ result = False
2226 parts = [part for part in domain .split ('.' )[::- 1 ] if part != '' ]
23- res = False
24- ref = self ._trie
27+ node = self ._trie # start from the trie root
2528 for part in parts :
26- if part in ref :
27- ref = ref [part ]
28- if '' in ref :
29- res = True
30- break
31- else :
29+ if part not in node :
3230 break
33- return res
31+ node = node [part ]
32+ if '.' in node :
33+ return True
34+ return result
3435
3536 def __iter__ (self ):
36- for result in DomainTrie ._walk (self ._trie , []):
37+ for result in DomainTrie ._list (self ._trie , []):
3738 yield result
3839
3940 def __str__ (self ):
@@ -43,37 +44,42 @@ def __str__(self):
4344 return str (result )
4445
4546 @staticmethod
46- def _add_branch (domain , tree , idx = 0 ):
47- # BASE CASE: out of character to add to this branch
48- if idx == len (domain ):
49- tree ['' ] = None
50- else :
51- if '' not in tree :
52- domain_piece = domain [idx ]
53- if domain_piece in tree :
54- branch = tree [domain_piece ]
55- else :
56- branch = {}
57- tree [domain_piece ] = branch
58- DomainTrie ._add_branch (domain , branch , idx + 1 )
47+ def _add_node (domain_pieces , tree , idx = 0 ):
48+ """
49+ add specific level of domain part to the tree.
50+ """
51+ if idx >= len (domain_pieces ):
52+ tree ['.' ] = None
53+ return
54+ # stop at "." because it is only building topper levels
55+ if "." not in tree :
56+ part = domain_pieces [idx ]
57+ node = tree .get (part , {})
58+ if node == {}:
59+ tree [part ] = node # adding a new node
60+ DomainTrie ._add_node (domain_pieces , node , idx + 1 )
5961
6062 @staticmethod
61- def _make_tree (domains ):
63+ def _build_tree (domains ):
6264 tree = {}
6365 for domain in domains :
64- domain_pieces = [piece for piece in domain .split ('.' ) if piece != '' ]
66+ d = domain .strip () if isinstance (domain , str ) else ''
67+ domain_pieces = [piece for piece in d .split ('.' ) if piece != '' ]
6568 if len (domain_pieces ) > 1 :
66- DomainTrie ._add_branch (domain_pieces [::- 1 ], tree )
69+ DomainTrie ._add_node (domain_pieces [::- 1 ], tree )
70+ else :
71+ LOGGER .warn ('Invalid domain: "%s"' , domain )
6772 return tree
6873
6974 @staticmethod
70- def _walk (tree , domain_pieces ):
71- result = []
72- if '' in tree :
73- result .append ('.' .join (domain_pieces [::- 1 ]))
74- else :
75- for piece in tree .keys ():
75+ def _list (tree , domain_pieces = []):
76+ results = []
77+ if '.' in tree and len (domain_pieces ) > 0 :
78+ results .append ('.' .join (domain_pieces [::- 1 ]))
79+ return results
80+ for piece in tree .keys ():
81+ if not piece == '.' :
7682 domain_pieces .append (piece )
77- result .extend (DomainTrie ._walk (tree [piece ], domain_pieces ))
83+ results .extend (DomainTrie ._list (tree [piece ], domain_pieces ))
7884 domain_pieces .pop ()
79- return result
85+ return results
0 commit comments