Skip to content

Commit 2be0c28

Browse files
diaohancaidiaohancaiimbajin
authored
feat(algorithm): support biased second order random walk (#280)
- implement #279 - follow-up #274 (V1 version) The current random walk algorithm requires 2 additional features. - Biased random walk. - Second order random walk. --------- Co-authored-by: diaohancai <diaohancai@cvte.com> Co-authored-by: imbajin <jin@apache.org>
1 parent 9d4d276 commit 2be0c28

File tree

5 files changed

+317
-72
lines changed

5 files changed

+317
-72
lines changed

computer-algorithm/src/main/java/org/apache/hugegraph/computer/algorithm/sampling/RandomWalk.java

Lines changed: 224 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,43 +17,90 @@
1717

1818
package org.apache.hugegraph.computer.algorithm.sampling;
1919

20+
import java.util.ArrayList;
21+
import java.util.Iterator;
22+
import java.util.List;
23+
import java.util.Random;
24+
2025
import org.apache.hugegraph.computer.core.common.exception.ComputerException;
2126
import org.apache.hugegraph.computer.core.config.Config;
2227
import org.apache.hugegraph.computer.core.graph.edge.Edge;
2328
import org.apache.hugegraph.computer.core.graph.edge.Edges;
2429
import org.apache.hugegraph.computer.core.graph.id.Id;
30+
import org.apache.hugegraph.computer.core.graph.value.DoubleValue;
2531
import org.apache.hugegraph.computer.core.graph.value.IdList;
2632
import org.apache.hugegraph.computer.core.graph.value.IdListList;
33+
import org.apache.hugegraph.computer.core.graph.value.Value;
2734
import org.apache.hugegraph.computer.core.graph.vertex.Vertex;
2835
import org.apache.hugegraph.computer.core.worker.Computation;
2936
import org.apache.hugegraph.computer.core.worker.ComputationContext;
3037
import org.apache.hugegraph.util.Log;
3138
import org.slf4j.Logger;
3239

33-
import java.util.Iterator;
34-
import java.util.Random;
35-
3640
public class RandomWalk implements Computation<RandomWalkMessage> {
3741

3842
private static final Logger LOG = Log.logger(RandomWalk.class);
3943

40-
public static final String OPTION_WALK_PER_NODE = "randomwalk.walk_per_node";
41-
public static final String OPTION_WALK_LENGTH = "randomwalk.walk_length";
44+
public static final String OPTION_WALK_PER_NODE = "random_walk.walk_per_node";
45+
public static final String OPTION_WALK_LENGTH = "random_walk.walk_length";
46+
47+
public static final String OPTION_WEIGHT_PROPERTY = "random_walk.weight_property";
48+
public static final String OPTION_DEFAULT_WEIGHT = "random_walk.default_weight";
49+
public static final String OPTION_MIN_WEIGHT_THRESHOLD = "random_walk.min_weight_threshold";
50+
public static final String OPTION_MAX_WEIGHT_THRESHOLD = "random_walk.max_weight_threshold";
51+
52+
public static final String OPTION_RETURN_FACTOR = "random_walk.return_factor";
53+
public static final String OPTION_INOUT_FACTOR = "random_walk.inout_factor";
4254

4355
/**
44-
* number of times per vertex(source vertex) walks
56+
* Random
57+
*/
58+
private Random random;
59+
60+
/**
61+
* Number of times per vertex(source vertex) walks
4562
*/
4663
private Integer walkPerNode;
4764

4865
/**
49-
* walk length
66+
* Walk length
5067
*/
5168
private Integer walkLength;
5269

5370
/**
54-
* random
71+
* Weight property, related to the walking probability
5572
*/
56-
private Random random;
73+
private String weightProperty;
74+
75+
/**
76+
* Biased walk
77+
* Default 1
78+
*/
79+
private Double defaultWeight;
80+
81+
/**
82+
* Weight less than this threshold will be truncated.
83+
* Default 0
84+
*/
85+
private Double minWeightThreshold;
86+
87+
/**
88+
* Weight greater than this threshold will be truncated.
89+
* Default Integer.MAX_VALUE
90+
*/
91+
private Double maxWeightThreshold;
92+
93+
/**
94+
* Controls the probability of re-walk to a previously walked vertex.
95+
* Default 1
96+
*/
97+
private Double returnFactor;
98+
99+
/**
100+
* Controls whether to walk inward or outward.
101+
* Default 1
102+
*/
103+
private Double inOutFactor;
57104

58105
@Override
59106
public String category() {
@@ -67,23 +114,63 @@ public String name() {
67114

68115
@Override
69116
public void init(Config config) {
117+
this.random = new Random();
118+
70119
this.walkPerNode = config.getInt(OPTION_WALK_PER_NODE, 3);
71120
if (this.walkPerNode <= 0) {
72121
throw new ComputerException("The param %s must be greater than 0, " +
73-
"actual got '%s'",
74-
OPTION_WALK_PER_NODE, this.walkPerNode);
122+
"actual got '%s'",
123+
OPTION_WALK_PER_NODE, this.walkPerNode);
75124
}
76-
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_PER_NODE, walkPerNode);
77125

78126
this.walkLength = config.getInt(OPTION_WALK_LENGTH, 3);
79127
if (this.walkLength <= 0) {
80128
throw new ComputerException("The param %s must be greater than 0, " +
81-
"actual got '%s'",
82-
OPTION_WALK_LENGTH, this.walkLength);
129+
"actual got '%s'",
130+
OPTION_WALK_LENGTH, this.walkLength);
83131
}
84-
LOG.info("[RandomWalk] algorithm param, {}: {}", OPTION_WALK_LENGTH, walkLength);
85132

86-
this.random = new Random();
133+
this.weightProperty = config.getString(OPTION_WEIGHT_PROPERTY, "");
134+
135+
this.defaultWeight = config.getDouble(OPTION_DEFAULT_WEIGHT, 1);
136+
if (this.defaultWeight <= 0) {
137+
throw new ComputerException("The param %s must be greater than 0, " +
138+
"actual got '%s'",
139+
OPTION_DEFAULT_WEIGHT, this.defaultWeight);
140+
}
141+
142+
this.minWeightThreshold = config.getDouble(OPTION_MIN_WEIGHT_THRESHOLD, 0.0);
143+
if (this.minWeightThreshold < 0) {
144+
throw new ComputerException("The param %s must be greater than or equal 0, " +
145+
"actual got '%s'",
146+
OPTION_MIN_WEIGHT_THRESHOLD, this.minWeightThreshold);
147+
}
148+
149+
this.maxWeightThreshold = config.getDouble(OPTION_MAX_WEIGHT_THRESHOLD, Double.MAX_VALUE);
150+
if (this.maxWeightThreshold < 0) {
151+
throw new ComputerException("The param %s must be greater than or equal 0, " +
152+
"actual got '%s'",
153+
OPTION_MAX_WEIGHT_THRESHOLD, this.maxWeightThreshold);
154+
}
155+
156+
if (this.minWeightThreshold > this.maxWeightThreshold) {
157+
throw new ComputerException("%s must be greater than or equal %s, ",
158+
OPTION_MAX_WEIGHT_THRESHOLD, OPTION_MIN_WEIGHT_THRESHOLD);
159+
}
160+
161+
this.returnFactor = config.getDouble(OPTION_RETURN_FACTOR, 1);
162+
if (this.returnFactor <= 0) {
163+
throw new ComputerException("The param %s must be greater than 0, " +
164+
"actual got '%s'",
165+
OPTION_RETURN_FACTOR, this.returnFactor);
166+
}
167+
168+
this.inOutFactor = config.getDouble(OPTION_INOUT_FACTOR, 1);
169+
if (this.inOutFactor <= 0) {
170+
throw new ComputerException("The param %s must be greater than 0, " +
171+
"actual got '%s'",
172+
OPTION_INOUT_FACTOR, this.inOutFactor);
173+
}
87174
}
88175

89176
@Override
@@ -95,14 +182,16 @@ public void compute0(ComputationContext context, Vertex vertex) {
95182

96183
if (vertex.numEdges() <= 0) {
97184
// isolated vertex
98-
this.savePath(vertex, message.path()); // save result
185+
this.savePath(vertex, message.path());
99186
vertex.inactivate();
100187
return;
101188
}
102189

190+
vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId()));
191+
103192
for (int i = 0; i < walkPerNode; ++i) {
104193
// random select one edge and walk
105-
Edge selectedEdge = this.randomSelectEdge(vertex.edges());
194+
Edge selectedEdge = this.randomSelectEdge(null, null, vertex.edges());
106195
context.sendMessage(selectedEdge.targetId(), message);
107196
}
108197
}
@@ -112,9 +201,11 @@ public void compute(ComputationContext context, Vertex vertex,
112201
Iterator<RandomWalkMessage> messages) {
113202
while (messages.hasNext()) {
114203
RandomWalkMessage message = messages.next();
204+
// the last id of path is the previous id
205+
Id preVertexId = message.path().getLast();
115206

116207
if (message.isFinish()) {
117-
this.savePath(vertex, message.path()); // save result
208+
this.savePath(vertex, message.path());
118209

119210
vertex.inactivate();
120211
continue;
@@ -123,7 +214,7 @@ public void compute(ComputationContext context, Vertex vertex,
123214
message.addToPath(vertex);
124215

125216
if (vertex.numEdges() <= 0) {
126-
// there is nowhere to walkfinish eariler
217+
// there is nowhere to walk, finish eariler
127218
message.finish();
128219
context.sendMessage(this.getSourceId(message.path()), message);
129220

@@ -137,7 +228,7 @@ public void compute(ComputationContext context, Vertex vertex,
137228

138229
if (vertex.id().equals(sourceId)) {
139230
// current vertex is the source vertex,no need to send message once more
140-
this.savePath(vertex, message.path()); // save result
231+
this.savePath(vertex, message.path());
141232
} else {
142233
context.sendMessage(sourceId, message);
143234
}
@@ -146,29 +237,133 @@ public void compute(ComputationContext context, Vertex vertex,
146237
continue;
147238
}
148239

240+
vertex.edges().forEach(edge -> message.addToPreVertexAdjacence(edge.targetId()));
241+
149242
// random select one edge and walk
150-
Edge selectedEdge = this.randomSelectEdge(vertex.edges());
243+
Edge selectedEdge = this.randomSelectEdge(preVertexId, message.preVertexAdjacence(),
244+
vertex.edges());
151245
context.sendMessage(selectedEdge.targetId(), message);
152246
}
153247
}
154248

155249
/**
156250
* random select one edge
157251
*/
158-
private Edge randomSelectEdge(Edges edges) {
159-
Edge selectedEdge = null;
160-
int randomNum = random.nextInt(edges.size());
252+
private Edge randomSelectEdge(Id preVertexId, IdList preVertexAdjacenceIdList, Edges edges) {
253+
// TODO: use primitive array instead, like DoubleArray,
254+
// in order to reduce memory fragmentation generated during calculations
255+
List<Double> weightList = new ArrayList<>();
161256

162-
int i = 0;
163257
Iterator<Edge> iterator = edges.iterator();
164258
while (iterator.hasNext()) {
165-
selectedEdge = iterator.next();
166-
if (i == randomNum) {
259+
Edge edge = iterator.next();
260+
// calculate edge weight
261+
double weight = this.getEdgeWeight(edge);
262+
double finalWeight = this.calculateEdgeWeight(preVertexId, preVertexAdjacenceIdList,
263+
edge.targetId(), weight);
264+
// TODO: improve to avoid OOM
265+
weightList.add(finalWeight);
266+
}
267+
268+
int selectedIndex = this.randomSelectIndex(weightList);
269+
Edge selectedEdge = this.selectEdge(edges.iterator(), selectedIndex);
270+
return selectedEdge;
271+
}
272+
273+
/**
274+
* get the weight of an edge by its weight property
275+
*/
276+
private double getEdgeWeight(Edge edge) {
277+
double weight = this.defaultWeight;
278+
279+
Value property = edge.property(this.weightProperty);
280+
if (property != null) {
281+
if (!property.isNumber()) {
282+
throw new ComputerException("The value of %s must be a numeric value, " +
283+
"actual got '%s'",
284+
this.weightProperty, property.string());
285+
}
286+
287+
weight = ((DoubleValue) property).doubleValue();
288+
}
289+
290+
// weight threshold truncation
291+
if (weight < this.minWeightThreshold) {
292+
weight = this.minWeightThreshold;
293+
}
294+
if (weight > this.maxWeightThreshold) {
295+
weight = this.maxWeightThreshold;
296+
}
297+
return weight;
298+
}
299+
300+
/**
301+
* calculate edge weight
302+
*/
303+
private double calculateEdgeWeight(Id preVertexId, IdList preVertexAdjacenceIdList,
304+
Id nextVertexId, double weight) {
305+
/*
306+
* 3 types of vertices.
307+
* 1. current vertex, called v
308+
* 2. previous vertex, called t
309+
* 3. current vertex outer vertex, called x(x1, x2.. xn)
310+
*
311+
* Definition of weight correction coefficient α:
312+
* if distance(t, x) = 0, then α = 1.0 / returnFactor
313+
* if distance(t, x) = 1, then α = 1.0
314+
* if distance(t, x) = 2, then α = 1.0 / inOutFactor
315+
*
316+
* Final edge weight π(v, x) = α * edgeWeight
317+
*/
318+
double finalWeight = 0.0;
319+
if (preVertexId != null && preVertexId.equals(nextVertexId)) {
320+
// distance(t, x) = 0
321+
finalWeight = 1.0 / this.returnFactor * weight;
322+
} else if (preVertexAdjacenceIdList != null &&
323+
preVertexAdjacenceIdList.contains(nextVertexId)) {
324+
// distance(t, x) = 1
325+
finalWeight = 1.0 * weight;
326+
} else {
327+
// distance(t, x) = 2
328+
finalWeight = 1.0 / this.inOutFactor * weight;
329+
}
330+
return finalWeight;
331+
}
332+
333+
/**
334+
* random select index
335+
*/
336+
private int randomSelectIndex(List<Double> weightList) {
337+
int selectedIndex = 0;
338+
double totalWeight = weightList.stream().mapToDouble(Double::doubleValue).sum();
339+
double randomNum = random.nextDouble() * totalWeight; // [0, totalWeight)
340+
341+
// determine which interval the random number falls into
342+
double cumulativeWeight = 0;
343+
for (int i = 0; i < weightList.size(); ++i) {
344+
cumulativeWeight += weightList.get(i);
345+
if (randomNum < cumulativeWeight) {
346+
selectedIndex = i;
167347
break;
168348
}
169-
i++;
170349
}
350+
return selectedIndex;
351+
}
352+
353+
/**
354+
* select edge from iterator by index
355+
*/
356+
private Edge selectEdge(Iterator<Edge> iterator, int selectedIndex) {
357+
Edge selectedEdge = null;
171358

359+
int index = 0;
360+
while (iterator.hasNext()) {
361+
selectedEdge = iterator.next();
362+
if (index == selectedIndex) {
363+
break;
364+
}
365+
index++;
366+
}
172367
return selectedEdge;
173368
}
174369

0 commit comments

Comments
 (0)