Skip to content

Commit 3f002ca

Browse files
committed
added Trie class, with tests
1 parent 8e21b67 commit 3f002ca

File tree

7 files changed

+6442
-35
lines changed

7 files changed

+6442
-35
lines changed

ml/misc/trie.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""
2+
ml/misc/trie.py
3+
"""
4+
5+
6+
class Trie:
7+
def __init__(self, words: list, name='trie'):
8+
self._root = {}
9+
_, self.count = Trie._build_tree(words, self._root)
10+
self.name = name
11+
pass
12+
13+
def __contains__(self, word):
14+
found, node = self.has_match(word)
15+
return found and None in node
16+
17+
def __iter__(self):
18+
for result in Trie._iterator(self._root):
19+
yield result
20+
21+
def __str__(self):
22+
words = []
23+
for word in self:
24+
words.append(word)
25+
return '\n'.join(words)
26+
27+
@staticmethod
28+
def _add_node(tree, word, index=0):
29+
if index >= len(word):
30+
# adding a special key (None) to indicate the end of a word
31+
tree[None] = '.' # or tree[None] = word
32+
return
33+
char = word[index]
34+
node = tree.get(char, {})
35+
if node == {}:
36+
tree[char] = node
37+
Trie._add_node(node, word, index+1)
38+
pass
39+
40+
@staticmethod
41+
def _build_tree(words: list, tree: dict={}):
42+
count = 0
43+
for v in words:
44+
value = v.strip() # trimmed string value
45+
if len(value) > 0:
46+
# print('adding', value, 'to', tree)
47+
Trie._add_node(tree, value)
48+
count += 1
49+
return tree, count
50+
51+
@staticmethod
52+
def _iterator(tree, chars=[], prefix=''):
53+
"""
54+
List all words from specific tree node.
55+
"""
56+
for key in tree:
57+
if key is not None:
58+
chars.append(key)
59+
for x in Trie._iterator(tree[key], chars, prefix):
60+
yield x
61+
chars.pop() # remove what has been processed
62+
else:
63+
suffix = ''.join(chars) if len(chars) > 0 else ''
64+
complete_word = prefix + suffix
65+
if complete_word:
66+
yield complete_word
67+
68+
@staticmethod
69+
def _list(tree, chars=[], prefix=''):
70+
return list(Trie._iterator(tree, chars, prefix))
71+
72+
def extend(self, words: list):
73+
_, count = Trie._build_tree(words, self._root)
74+
self.count += count
75+
pass
76+
77+
def get_matches(self, prefix: str):
78+
results = []
79+
found, node = self.has_match(prefix)
80+
if found and node is not None:
81+
results.extend(Trie._iterator(node, [], prefix))
82+
return results
83+
84+
def has(self, word):
85+
return self.__contains__(word)
86+
87+
def has_match(self, prefix: str) -> (bool, dict):
88+
"""
89+
Check if there is prefix in the trie list.
90+
"""
91+
if not prefix:
92+
return False, None
93+
node = self._root
94+
for ch in prefix:
95+
tree = node.get(ch)
96+
if tree is None: # cannot find sub-tree/node for the char
97+
return False, None
98+
node = tree
99+
return True, node
100+
101+
def has_prefix(self, prefix):
102+
found, node = self.has_match(prefix)
103+
return found and node is not None
104+
105+
def list(self):
106+
return Trie._list(self._root)

ml/utils/domain_trie.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""
22
DomainTrie class, as a data structure example.
33
"""
4+
from ml.utils.logger import get_logger
5+
6+
LOGGER = get_logger(__name__)
47

58

69
class DomainTrie(object):
@@ -10,7 +13,7 @@ class DomainTrie(object):
1013
"""
1114

1215
def __init__(self, domains):
13-
self._trie = DomainTrie._make_tree(domains)
16+
self._trie = DomainTrie._build_tree(domains)
1417

1518
def __contains__(self, domain):
1619
"""
@@ -19,21 +22,19 @@ def __contains__(self, domain):
1922
@param domain: A dot separated domain.
2023
@return boolean: True if domain is in the trie; otherwise, False.
2124
"""
25+
result = False
2226
parts = [part for part in domain.split('.')[::-1] if part != '']
23-
res = False
24-
ref = self._trie
27+
node = self._trie # start from the trie root
2528
for part in parts:
26-
if part in ref:
27-
ref = ref[part]
28-
if '' in ref:
29-
res = True
30-
break
31-
else:
29+
if part not in node:
3230
break
33-
return res
31+
node = node[part]
32+
if '.' in node:
33+
return True
34+
return result
3435

3536
def __iter__(self):
36-
for result in DomainTrie._walk(self._trie, []):
37+
for result in DomainTrie._list(self._trie, []):
3738
yield result
3839

3940
def __str__(self):
@@ -43,37 +44,42 @@ def __str__(self):
4344
return str(result)
4445

4546
@staticmethod
46-
def _add_branch(domain, tree, idx=0):
47-
# BASE CASE: out of character to add to this branch
48-
if idx == len(domain):
49-
tree[''] = None
50-
else:
51-
if '' not in tree:
52-
domain_piece = domain[idx]
53-
if domain_piece in tree:
54-
branch = tree[domain_piece]
55-
else:
56-
branch = {}
57-
tree[domain_piece] = branch
58-
DomainTrie._add_branch(domain, branch, idx + 1)
47+
def _add_node(domain_pieces, tree, idx=0):
48+
"""
49+
add specific level of domain part to the tree.
50+
"""
51+
if idx >= len(domain_pieces):
52+
tree['.'] = None
53+
return
54+
# stop at "." because it is only building topper levels
55+
if "." not in tree:
56+
part = domain_pieces[idx]
57+
node = tree.get(part, {})
58+
if node == {}:
59+
tree[part] = node # adding a new node
60+
DomainTrie._add_node(domain_pieces, node, idx + 1)
5961

6062
@staticmethod
61-
def _make_tree(domains):
63+
def _build_tree(domains):
6264
tree = {}
6365
for domain in domains:
64-
domain_pieces = [piece for piece in domain.split('.') if piece != '']
66+
d = domain.strip() if isinstance(domain, str) else ''
67+
domain_pieces = [piece for piece in d.split('.') if piece != '']
6568
if len(domain_pieces) > 1:
66-
DomainTrie._add_branch(domain_pieces[::-1], tree)
69+
DomainTrie._add_node(domain_pieces[::-1], tree)
70+
else:
71+
LOGGER.warn('Invalid domain: "%s"', domain)
6772
return tree
6873

6974
@staticmethod
70-
def _walk(tree, domain_pieces):
71-
result = []
72-
if '' in tree:
73-
result.append('.'.join(domain_pieces[::-1]))
74-
else:
75-
for piece in tree.keys():
75+
def _list(tree, domain_pieces=[]):
76+
results = []
77+
if '.' in tree and len(domain_pieces) > 0:
78+
results.append('.'.join(domain_pieces[::-1]))
79+
return results
80+
for piece in tree.keys():
81+
if not piece == '.':
7682
domain_pieces.append(piece)
77-
result.extend(DomainTrie._walk(tree[piece], domain_pieces))
83+
results.extend(DomainTrie._list(tree[piece], domain_pieces))
7884
domain_pieces.pop()
79-
return result
85+
return results

0 commit comments

Comments
 (0)