Skip to content

Commit 7bc1199

Browse files
Add forest generation
1 parent b4b5122 commit 7bc1199

File tree

2 files changed

+124
-48
lines changed

2 files changed

+124
-48
lines changed

cyaron/graph.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .utils import *
2+
from .vector import Vector
23
import random
34
from typing import TypeVar, Callable
45

5-
66
__all__ = ["Edge", "Graph"]
77

88

@@ -46,7 +46,7 @@ def __init__(self, point_count, directed=False):
4646
"""
4747
self.directed = directed
4848
self.edges = [[] for i in range(point_count + 1)]
49-
49+
5050
def edge_count(self):
5151
"""edge_count(self) -> int
5252
Return the count of the edges in the graph.
@@ -292,9 +292,11 @@ def graph(point_count, edge_count, **kwargs):
292292
self_loop = kwargs.get("self_loop", True)
293293
repeated_edges = kwargs.get("repeated_edges", True)
294294
if not repeated_edges:
295-
max_edge = Graph._calc_max_edge(point_count, directed, self_loop)
295+
max_edge = Graph._calc_max_edge(point_count, directed, self_loop)
296296
if edge_count > max_edge:
297-
raise Exception("the number of edges of this kind of graph which has %d vertexes must be less than or equal to %d." % (point_count, max_edge))
297+
raise Exception(
298+
"the number of edges of this kind of graph which has %d vertexes must be less than or equal to %d."
299+
% (point_count, max_edge))
298300

299301
weight_limit = kwargs.get("weight_limit", (1, 1))
300302
if not list_like(weight_limit):
@@ -349,9 +351,11 @@ def DAG(point_count, edge_count, **kwargs):
349351
repeated_edges = kwargs.get("repeated_edges", True)
350352
loop = kwargs.get("loop", False)
351353
if not repeated_edges:
352-
max_edge = Graph._calc_max_edge(point_count, not loop, self_loop)
354+
max_edge = Graph._calc_max_edge(point_count, not loop, self_loop)
353355
if edge_count > max_edge:
354-
raise Exception("the number of edges of this kind of graph which has %d vertexes must be less than or equal to %d." % (point_count, max_edge))
356+
raise Exception(
357+
"the number of edges of this kind of graph which has %d vertexes must be less than or equal to %d."
358+
% (point_count, max_edge))
355359

356360
weight_limit = kwargs.get("weight_limit", (1, 1))
357361
if not list_like(weight_limit):
@@ -418,9 +422,11 @@ def UDAG(point_count, edge_count, **kwargs):
418422
self_loop = kwargs.get("self_loop", True)
419423
repeated_edges = kwargs.get("repeated_edges", True)
420424
if not repeated_edges:
421-
max_edge = Graph._calc_max_edge(point_count, False, self_loop)
425+
max_edge = Graph._calc_max_edge(point_count, False, self_loop)
422426
if edge_count > max_edge:
423-
raise Exception("the number of edges of this kind of graph which has %d vertexes must be less than or equal to %d." % (point_count, max_edge))
427+
raise Exception(
428+
"the number of edges of this kind of graph which has %d vertexes must be less than or equal to %d."
429+
% (point_count, max_edge))
424430

425431
weight_limit = kwargs.get("weight_limit", (1, 1))
426432
if not list_like(weight_limit):
@@ -456,7 +462,7 @@ def UDAG(point_count, edge_count, **kwargs):
456462
i += 1
457463

458464
return graph
459-
465+
460466
@staticmethod
461467
def connected(point_count, edge_count, directed=False, **kwargs):
462468
"""connected(point_count, edge_count, **kwargs) -> Graph
@@ -519,7 +525,7 @@ def hack_spfa(point_count, **kwargs):
519525
graph.add_edge(u, v, weight=weight_gen())
520526

521527
return graph
522-
528+
523529
@staticmethod
524530
def _calc_max_edge(point_count, directed, self_loop):
525531
max_edge = point_count * (point_count - 1)
@@ -529,6 +535,22 @@ def _calc_max_edge(point_count, directed, self_loop):
529535
max_edge += point_count
530536
return max_edge
531537

538+
@staticmethod
539+
def forest(point_count, tree_count, **kwargs):
540+
if tree_count <= 0 or tree_count > point_count:
541+
raise ValueError("tree_count must be between 1 and point_count")
542+
tree = list(Graph.tree(point_count, **kwargs).iterate_edges())
543+
need_delete = set(
544+
i[0] for i in (Vector.random_unique_vector(tree_count - 1, [(
545+
0, point_count - 2)]) if tree_count > 1 else []))
546+
result = Graph(point_count, 0)
547+
for i in range(point_count - 1):
548+
if i not in need_delete:
549+
result.add_edge(tree[i].start,
550+
tree[i].end,
551+
weight=tree[i].weight)
552+
return result
553+
532554

533555
class GraphMatrix:
534556
"""
@@ -557,7 +579,8 @@ def __init__(self,
557579
self.matrix[edge.start][edge.end], edge)
558580

559581
def __str__(self):
560-
return '\n'.join([' '.join(map(str, row[1:])) for row in self.matrix[1:]])
582+
return '\n'.join(
583+
[' '.join(map(str, row[1:])) for row in self.matrix[1:]])
561584

562585
def __call__(self, u: int, v: int):
563586
return self.matrix[u][v]

cyaron/tests/graph_test.py

Lines changed: 90 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import unittest
22
from cyaron import Graph
3+
from random import randint
34

45

56
class UnionFindSet:
7+
68
def __init__(self, size):
79
self.father = [0] + [i + 1 for i in range(size)]
810

@@ -23,15 +25,17 @@ def test_same(self, l, r):
2325

2426

2527
def tarjan(graph, n):
28+
2629
def new_array(len, val=0):
27-
return [val for _ in range(len+1)]
30+
return [val for _ in range(len + 1)]
2831

2932
instack = new_array(n, False)
3033
low = new_array(n)
3134
dfn = new_array(n, 0)
3235
stap = new_array(n)
3336
belong = new_array(n)
34-
var = [0, 0, 0] # cnt, bc, stop
37+
var = [0, 0, 0] # cnt, bc, stop
38+
3539
# cnt = bc = stop = 0
3640

3741
def dfs(cur):
@@ -49,7 +53,7 @@ def dfs(cur):
4953
low[cur] = min(low[cur], dfn[v.end])
5054

5155
if dfn[cur] == low[cur]:
52-
v = cur + 1 # set v != cur
56+
v = cur + 1 # set v != cur
5357
var[1] += 1
5458
while v != cur:
5559
var[2] -= 1
@@ -58,8 +62,8 @@ def dfs(cur):
5862
belong[v] = var[1]
5963

6064
for i in range(n):
61-
if dfn[i+1] == 0:
62-
dfs(i+1)
65+
if dfn[i + 1] == 0:
66+
dfs(i + 1)
6367

6468
return belong
6569

@@ -69,28 +73,38 @@ class TestGraph(unittest.TestCase):
6973
def test_self_loop(self):
7074
graph_size = 20
7175
for _ in range(20):
72-
graph = Graph.graph(graph_size, int(graph_size*2), self_loop=True)
73-
has_self_loop = max([e.start == e.end for e in graph.iterate_edges()])
76+
graph = Graph.graph(graph_size,
77+
int(graph_size * 2),
78+
self_loop=True)
79+
has_self_loop = max(
80+
[e.start == e.end for e in graph.iterate_edges()])
7481
if has_self_loop:
7582
break
7683
self.assertTrue(has_self_loop)
7784

7885
for _ in range(10):
79-
graph = Graph.graph(graph_size, int(graph_size*2), self_loop=False)
80-
self.assertFalse(max([e.start == e.end for e in graph.iterate_edges()]))
86+
graph = Graph.graph(graph_size,
87+
int(graph_size * 2),
88+
self_loop=False)
89+
self.assertFalse(
90+
max([e.start == e.end for e in graph.iterate_edges()]))
8191

8292
def test_repeated_edges(self):
8393
graph_size = 20
8494
for _ in range(20):
85-
graph = Graph.graph(graph_size, int(graph_size*2), repeated_edges=True)
95+
graph = Graph.graph(graph_size,
96+
int(graph_size * 2),
97+
repeated_edges=True)
8698
edges = [(e.start, e.end) for e in graph.iterate_edges()]
8799
has_repeated_edges = len(edges) > len(set(edges))
88100
if has_repeated_edges:
89101
break
90102
self.assertTrue(has_repeated_edges)
91103

92104
for _ in range(10):
93-
graph = Graph.graph(graph_size, int(graph_size*2), repeated_edges=False)
105+
graph = Graph.graph(graph_size,
106+
int(graph_size * 2),
107+
repeated_edges=False)
94108
edges = [(e.start, e.end) for e in graph.iterate_edges()]
95109
self.assertEqual(len(edges), len(set(edges)))
96110

@@ -101,60 +115,78 @@ def test_tree_connected(self):
101115
tree = Graph.tree(graph_size)
102116
for edge in tree.iterate_edges():
103117
ufs.merge(edge.start, edge.end)
104-
for i in range(graph_size-1):
105-
self.assertTrue(ufs.test_same(i+1, i+2))
106-
118+
for i in range(graph_size - 1):
119+
self.assertTrue(ufs.test_same(i + 1, i + 2))
107120

108121
def test_DAG(self):
109122
graph_size = 20
110-
for _ in range(10): # test 10 times
123+
for _ in range(10): # test 10 times
111124
ufs = UnionFindSet(graph_size)
112-
graph = Graph.DAG(graph_size, int(graph_size*1.6), repeated_edges=False, self_loop=False, loop=True)
125+
graph = Graph.DAG(graph_size,
126+
int(graph_size * 1.6),
127+
repeated_edges=False,
128+
self_loop=False,
129+
loop=True)
113130

114-
self.assertEqual(len(list(graph.iterate_edges())), int(graph_size*1.6))
131+
self.assertEqual(len(list(graph.iterate_edges())),
132+
int(graph_size * 1.6))
115133

116134
for edge in graph.iterate_edges():
117135
ufs.merge(edge.start, edge.end)
118-
for i in range(graph_size-1):
119-
self.assertTrue(ufs.test_same(i+1, i+2))
136+
for i in range(graph_size - 1):
137+
self.assertTrue(ufs.test_same(i + 1, i + 2))
120138

121139
def test_DAG_without_loop(self):
122140
graph_size = 20
123-
for _ in range(10): # test 10 times
141+
for _ in range(10): # test 10 times
124142
ufs = UnionFindSet(graph_size)
125-
graph = Graph.DAG(graph_size, int(graph_size*1.6), repeated_edges=False, self_loop=False, loop=False)
143+
graph = Graph.DAG(graph_size,
144+
int(graph_size * 1.6),
145+
repeated_edges=False,
146+
self_loop=False,
147+
loop=False)
126148

127-
self.assertEqual(len(list(graph.iterate_edges())), int(graph_size*1.6))
149+
self.assertEqual(len(list(graph.iterate_edges())),
150+
int(graph_size * 1.6))
128151

129152
for edge in graph.iterate_edges():
130153
ufs.merge(edge.start, edge.end)
131-
for i in range(graph_size-1):
132-
self.assertTrue(ufs.test_same(i+1, i+2))
154+
for i in range(graph_size - 1):
155+
self.assertTrue(ufs.test_same(i + 1, i + 2))
133156

134157
belong = tarjan(graph, graph_size)
135158
self.assertEqual(max(belong), graph_size)
136159

137160
def test_undirected_graph(self):
138161
graph_size = 20
139-
for _ in range(10): # test 10 times
162+
for _ in range(10): # test 10 times
140163
ufs = UnionFindSet(graph_size)
141-
graph = Graph.UDAG(graph_size, int(graph_size*1.6), repeated_edges=False, self_loop=False)
164+
graph = Graph.UDAG(graph_size,
165+
int(graph_size * 1.6),
166+
repeated_edges=False,
167+
self_loop=False)
142168

143-
self.assertEqual(len(list(graph.iterate_edges())), int(graph_size*1.6))
169+
self.assertEqual(len(list(graph.iterate_edges())),
170+
int(graph_size * 1.6))
144171

145172
for edge in graph.iterate_edges():
146173
ufs.merge(edge.start, edge.end)
147-
for i in range(graph_size-1):
148-
self.assertTrue(ufs.test_same(i+1, i+2))
174+
for i in range(graph_size - 1):
175+
self.assertTrue(ufs.test_same(i + 1, i + 2))
149176

150177
def test_DAG_boundary(self):
151-
with self.assertRaises(Exception, msg="the number of edges of connected graph must more than the number of nodes - 1"):
178+
with self.assertRaises(
179+
Exception,
180+
msg=
181+
"the number of edges of connected graph must more than the number of nodes - 1"
182+
):
152183
Graph.DAG(8, 6)
153184
Graph.DAG(8, 7)
154185

155186
def test_GraphMatrix(self):
156187
g = Graph(3, True)
157-
edge_set = [(2, 3, 3), (3, 3, 1), (2, 3, 7), (2, 3, 4), (3, 2, 1), (1, 3, 3)]
188+
edge_set = [(2, 3, 3), (3, 3, 1), (2, 3, 7), (2, 3, 4), (3, 2, 1),
189+
(1, 3, 3)]
158190
for u, v, w in edge_set:
159191
g.add_edge(u, v, weight=w)
160192
self.assertEqual(str(g.to_matrix()), "-1 -1 3\n-1 -1 4\n-1 1 1")
@@ -166,9 +198,30 @@ def test_GraphMatrix(self):
166198
merge2 = lambda val, edge: max(edge.weight, val)
167199
merge3 = lambda val, edge: min(edge.weight, val)
168200
merge4 = lambda val, edge: gcd(val, edge.weight)
169-
merge5 = lambda val, edge: lcm(val, edge.weight) if val else edge.weight
170-
self.assertEqual(str(g.to_matrix(merge=merge1)), "-1 -1 3\n-1 -1 3\n-1 1 1")
171-
self.assertEqual(str(g.to_matrix(merge=merge2)), "-1 -1 3\n-1 -1 7\n-1 1 1")
172-
self.assertEqual(str(g.to_matrix(default=9, merge=merge3)), "9 9 3\n9 9 3\n9 1 1")
173-
self.assertEqual(str(g.to_matrix(default=0, merge=merge4)), "0 0 3\n0 0 1\n0 1 1")
174-
self.assertEqual(str(g.to_matrix(default=0, merge=merge5)), "0 0 3\n0 0 84\n0 1 1")
201+
merge5 = lambda val, edge: lcm(val, edge.weight
202+
) if val else edge.weight
203+
self.assertEqual(str(g.to_matrix(merge=merge1)),
204+
"-1 -1 3\n-1 -1 3\n-1 1 1")
205+
self.assertEqual(str(g.to_matrix(merge=merge2)),
206+
"-1 -1 3\n-1 -1 7\n-1 1 1")
207+
self.assertEqual(str(g.to_matrix(default=9, merge=merge3)),
208+
"9 9 3\n9 9 3\n9 1 1")
209+
self.assertEqual(str(g.to_matrix(default=0, merge=merge4)),
210+
"0 0 3\n0 0 1\n0 1 1")
211+
self.assertEqual(str(g.to_matrix(default=0, merge=merge5)),
212+
"0 0 3\n0 0 84\n0 1 1")
213+
214+
def test_forest(self):
215+
for i in range(10):
216+
size = randint(1, 100)
217+
part_count = randint(1, size)
218+
forest = Graph.forest(size, part_count)
219+
dsu = UnionFindSet(size)
220+
for edge in forest.iterate_edges():
221+
self.assertFalse(dsu.test_same(edge.start, edge.end))
222+
dsu.merge(edge.start, edge.end)
223+
count = 0
224+
for i in range(1, size + 1):
225+
if dsu.get_father(i) == i:
226+
count += 1
227+
self.assertEqual(count, part_count)

0 commit comments

Comments
 (0)