1717
1818package org .apache .hugegraph .computer .algorithm .sampling ;
1919
20+ import java .util .ArrayList ;
21+ import java .util .Iterator ;
22+ import java .util .List ;
23+ import java .util .Random ;
24+
2025import org .apache .hugegraph .computer .core .common .exception .ComputerException ;
2126import org .apache .hugegraph .computer .core .config .Config ;
2227import org .apache .hugegraph .computer .core .graph .edge .Edge ;
2328import org .apache .hugegraph .computer .core .graph .edge .Edges ;
2429import org .apache .hugegraph .computer .core .graph .id .Id ;
30+ import org .apache .hugegraph .computer .core .graph .value .DoubleValue ;
2531import org .apache .hugegraph .computer .core .graph .value .IdList ;
2632import org .apache .hugegraph .computer .core .graph .value .IdListList ;
33+ import org .apache .hugegraph .computer .core .graph .value .Value ;
2734import org .apache .hugegraph .computer .core .graph .vertex .Vertex ;
2835import org .apache .hugegraph .computer .core .worker .Computation ;
2936import org .apache .hugegraph .computer .core .worker .ComputationContext ;
3037import org .apache .hugegraph .util .Log ;
3138import org .slf4j .Logger ;
3239
33- import java .util .Iterator ;
34- import java .util .Random ;
35-
3640public class RandomWalk implements Computation <RandomWalkMessage > {
3741
3842 private static final Logger LOG = Log .logger (RandomWalk .class );
3943
40- public static final String OPTION_WALK_PER_NODE = "randomwalk.walk_per_node" ;
41- public static final String OPTION_WALK_LENGTH = "randomwalk.walk_length" ;
44+ public static final String OPTION_WALK_PER_NODE = "random_walk.walk_per_node" ;
45+ public static final String OPTION_WALK_LENGTH = "random_walk.walk_length" ;
46+
47+ public static final String OPTION_WEIGHT_PROPERTY = "random_walk.weight_property" ;
48+ public static final String OPTION_DEFAULT_WEIGHT = "random_walk.default_weight" ;
49+ public static final String OPTION_MIN_WEIGHT_THRESHOLD = "random_walk.min_weight_threshold" ;
50+ public static final String OPTION_MAX_WEIGHT_THRESHOLD = "random_walk.max_weight_threshold" ;
51+
52+ public static final String OPTION_RETURN_FACTOR = "random_walk.return_factor" ;
53+ public static final String OPTION_INOUT_FACTOR = "random_walk.inout_factor" ;
4254
4355 /**
44- * number of times per vertex(source vertex) walks
56+ * Random
57+ */
58+ private Random random ;
59+
60+ /**
61+ * Number of times per vertex(source vertex) walks
4562 */
4663 private Integer walkPerNode ;
4764
4865 /**
49- * walk length
66+ * Walk length
5067 */
5168 private Integer walkLength ;
5269
5370 /**
54- * random
71+ * Weight property, related to the walking probability
5572 */
56- private Random random ;
73+ private String weightProperty ;
74+
75+ /**
76+ * Biased walk
77+ * Default 1
78+ */
79+ private Double defaultWeight ;
80+
81+ /**
82+ * Weight less than this threshold will be truncated.
83+ * Default 0
84+ */
85+ private Double minWeightThreshold ;
86+
87+ /**
88+ * Weight greater than this threshold will be truncated.
89+ * Default Integer.MAX_VALUE
90+ */
91+ private Double maxWeightThreshold ;
92+
93+ /**
94+ * Controls the probability of re-walk to a previously walked vertex.
95+ * Default 1
96+ */
97+ private Double returnFactor ;
98+
99+ /**
100+ * Controls whether to walk inward or outward.
101+ * Default 1
102+ */
103+ private Double inOutFactor ;
57104
58105 @ Override
59106 public String category () {
@@ -67,23 +114,63 @@ public String name() {
67114
68115 @ Override
69116 public void init (Config config ) {
117+ this .random = new Random ();
118+
70119 this .walkPerNode = config .getInt (OPTION_WALK_PER_NODE , 3 );
71120 if (this .walkPerNode <= 0 ) {
72121 throw new ComputerException ("The param %s must be greater than 0, " +
73- "actual got '%s'" ,
74- OPTION_WALK_PER_NODE , this .walkPerNode );
122+ "actual got '%s'" ,
123+ OPTION_WALK_PER_NODE , this .walkPerNode );
75124 }
76- LOG .info ("[RandomWalk] algorithm param, {}: {}" , OPTION_WALK_PER_NODE , walkPerNode );
77125
78126 this .walkLength = config .getInt (OPTION_WALK_LENGTH , 3 );
79127 if (this .walkLength <= 0 ) {
80128 throw new ComputerException ("The param %s must be greater than 0, " +
81- "actual got '%s'" ,
82- OPTION_WALK_LENGTH , this .walkLength );
129+ "actual got '%s'" ,
130+ OPTION_WALK_LENGTH , this .walkLength );
83131 }
84- LOG .info ("[RandomWalk] algorithm param, {}: {}" , OPTION_WALK_LENGTH , walkLength );
85132
86- this .random = new Random ();
133+ this .weightProperty = config .getString (OPTION_WEIGHT_PROPERTY , "" );
134+
135+ this .defaultWeight = config .getDouble (OPTION_DEFAULT_WEIGHT , 1 );
136+ if (this .defaultWeight <= 0 ) {
137+ throw new ComputerException ("The param %s must be greater than 0, " +
138+ "actual got '%s'" ,
139+ OPTION_DEFAULT_WEIGHT , this .defaultWeight );
140+ }
141+
142+ this .minWeightThreshold = config .getDouble (OPTION_MIN_WEIGHT_THRESHOLD , 0.0 );
143+ if (this .minWeightThreshold < 0 ) {
144+ throw new ComputerException ("The param %s must be greater than or equal 0, " +
145+ "actual got '%s'" ,
146+ OPTION_MIN_WEIGHT_THRESHOLD , this .minWeightThreshold );
147+ }
148+
149+ this .maxWeightThreshold = config .getDouble (OPTION_MAX_WEIGHT_THRESHOLD , Double .MAX_VALUE );
150+ if (this .maxWeightThreshold < 0 ) {
151+ throw new ComputerException ("The param %s must be greater than or equal 0, " +
152+ "actual got '%s'" ,
153+ OPTION_MAX_WEIGHT_THRESHOLD , this .maxWeightThreshold );
154+ }
155+
156+ if (this .minWeightThreshold > this .maxWeightThreshold ) {
157+ throw new ComputerException ("%s must be greater than or equal %s, " ,
158+ OPTION_MAX_WEIGHT_THRESHOLD , OPTION_MIN_WEIGHT_THRESHOLD );
159+ }
160+
161+ this .returnFactor = config .getDouble (OPTION_RETURN_FACTOR , 1 );
162+ if (this .returnFactor <= 0 ) {
163+ throw new ComputerException ("The param %s must be greater than 0, " +
164+ "actual got '%s'" ,
165+ OPTION_RETURN_FACTOR , this .returnFactor );
166+ }
167+
168+ this .inOutFactor = config .getDouble (OPTION_INOUT_FACTOR , 1 );
169+ if (this .inOutFactor <= 0 ) {
170+ throw new ComputerException ("The param %s must be greater than 0, " +
171+ "actual got '%s'" ,
172+ OPTION_INOUT_FACTOR , this .inOutFactor );
173+ }
87174 }
88175
89176 @ Override
@@ -95,14 +182,16 @@ public void compute0(ComputationContext context, Vertex vertex) {
95182
96183 if (vertex .numEdges () <= 0 ) {
97184 // isolated vertex
98- this .savePath (vertex , message .path ()); // save result
185+ this .savePath (vertex , message .path ());
99186 vertex .inactivate ();
100187 return ;
101188 }
102189
190+ vertex .edges ().forEach (edge -> message .addToPreVertexAdjacence (edge .targetId ()));
191+
103192 for (int i = 0 ; i < walkPerNode ; ++i ) {
104193 // random select one edge and walk
105- Edge selectedEdge = this .randomSelectEdge (vertex .edges ());
194+ Edge selectedEdge = this .randomSelectEdge (null , null , vertex .edges ());
106195 context .sendMessage (selectedEdge .targetId (), message );
107196 }
108197 }
@@ -112,9 +201,11 @@ public void compute(ComputationContext context, Vertex vertex,
112201 Iterator <RandomWalkMessage > messages ) {
113202 while (messages .hasNext ()) {
114203 RandomWalkMessage message = messages .next ();
204+ // the last id of path is the previous id
205+ Id preVertexId = message .path ().getLast ();
115206
116207 if (message .isFinish ()) {
117- this .savePath (vertex , message .path ()); // save result
208+ this .savePath (vertex , message .path ());
118209
119210 vertex .inactivate ();
120211 continue ;
@@ -123,7 +214,7 @@ public void compute(ComputationContext context, Vertex vertex,
123214 message .addToPath (vertex );
124215
125216 if (vertex .numEdges () <= 0 ) {
126- // there is nowhere to walk, finish eariler
217+ // there is nowhere to walk, finish eariler
127218 message .finish ();
128219 context .sendMessage (this .getSourceId (message .path ()), message );
129220
@@ -137,7 +228,7 @@ public void compute(ComputationContext context, Vertex vertex,
137228
138229 if (vertex .id ().equals (sourceId )) {
139230 // current vertex is the source vertex,no need to send message once more
140- this .savePath (vertex , message .path ()); // save result
231+ this .savePath (vertex , message .path ());
141232 } else {
142233 context .sendMessage (sourceId , message );
143234 }
@@ -146,29 +237,133 @@ public void compute(ComputationContext context, Vertex vertex,
146237 continue ;
147238 }
148239
240+ vertex .edges ().forEach (edge -> message .addToPreVertexAdjacence (edge .targetId ()));
241+
149242 // random select one edge and walk
150- Edge selectedEdge = this .randomSelectEdge (vertex .edges ());
243+ Edge selectedEdge = this .randomSelectEdge (preVertexId , message .preVertexAdjacence (),
244+ vertex .edges ());
151245 context .sendMessage (selectedEdge .targetId (), message );
152246 }
153247 }
154248
155249 /**
156250 * random select one edge
157251 */
158- private Edge randomSelectEdge (Edges edges ) {
159- Edge selectedEdge = null ;
160- int randomNum = random .nextInt (edges .size ());
252+ private Edge randomSelectEdge (Id preVertexId , IdList preVertexAdjacenceIdList , Edges edges ) {
253+ // TODO: use primitive array instead, like DoubleArray,
254+ // in order to reduce memory fragmentation generated during calculations
255+ List <Double > weightList = new ArrayList <>();
161256
162- int i = 0 ;
163257 Iterator <Edge > iterator = edges .iterator ();
164258 while (iterator .hasNext ()) {
165- selectedEdge = iterator .next ();
166- if (i == randomNum ) {
259+ Edge edge = iterator .next ();
260+ // calculate edge weight
261+ double weight = this .getEdgeWeight (edge );
262+ double finalWeight = this .calculateEdgeWeight (preVertexId , preVertexAdjacenceIdList ,
263+ edge .targetId (), weight );
264+ // TODO: improve to avoid OOM
265+ weightList .add (finalWeight );
266+ }
267+
268+ int selectedIndex = this .randomSelectIndex (weightList );
269+ Edge selectedEdge = this .selectEdge (edges .iterator (), selectedIndex );
270+ return selectedEdge ;
271+ }
272+
273+ /**
274+ * get the weight of an edge by its weight property
275+ */
276+ private double getEdgeWeight (Edge edge ) {
277+ double weight = this .defaultWeight ;
278+
279+ Value property = edge .property (this .weightProperty );
280+ if (property != null ) {
281+ if (!property .isNumber ()) {
282+ throw new ComputerException ("The value of %s must be a numeric value, " +
283+ "actual got '%s'" ,
284+ this .weightProperty , property .string ());
285+ }
286+
287+ weight = ((DoubleValue ) property ).doubleValue ();
288+ }
289+
290+ // weight threshold truncation
291+ if (weight < this .minWeightThreshold ) {
292+ weight = this .minWeightThreshold ;
293+ }
294+ if (weight > this .maxWeightThreshold ) {
295+ weight = this .maxWeightThreshold ;
296+ }
297+ return weight ;
298+ }
299+
300+ /**
301+ * calculate edge weight
302+ */
303+ private double calculateEdgeWeight (Id preVertexId , IdList preVertexAdjacenceIdList ,
304+ Id nextVertexId , double weight ) {
305+ /*
306+ * 3 types of vertices.
307+ * 1. current vertex, called v
308+ * 2. previous vertex, called t
309+ * 3. current vertex outer vertex, called x(x1, x2.. xn)
310+ *
311+ * Definition of weight correction coefficient α:
312+ * if distance(t, x) = 0, then α = 1.0 / returnFactor
313+ * if distance(t, x) = 1, then α = 1.0
314+ * if distance(t, x) = 2, then α = 1.0 / inOutFactor
315+ *
316+ * Final edge weight π(v, x) = α * edgeWeight
317+ */
318+ double finalWeight = 0.0 ;
319+ if (preVertexId != null && preVertexId .equals (nextVertexId )) {
320+ // distance(t, x) = 0
321+ finalWeight = 1.0 / this .returnFactor * weight ;
322+ } else if (preVertexAdjacenceIdList != null &&
323+ preVertexAdjacenceIdList .contains (nextVertexId )) {
324+ // distance(t, x) = 1
325+ finalWeight = 1.0 * weight ;
326+ } else {
327+ // distance(t, x) = 2
328+ finalWeight = 1.0 / this .inOutFactor * weight ;
329+ }
330+ return finalWeight ;
331+ }
332+
333+ /**
334+ * random select index
335+ */
336+ private int randomSelectIndex (List <Double > weightList ) {
337+ int selectedIndex = 0 ;
338+ double totalWeight = weightList .stream ().mapToDouble (Double ::doubleValue ).sum ();
339+ double randomNum = random .nextDouble () * totalWeight ; // [0, totalWeight)
340+
341+ // determine which interval the random number falls into
342+ double cumulativeWeight = 0 ;
343+ for (int i = 0 ; i < weightList .size (); ++i ) {
344+ cumulativeWeight += weightList .get (i );
345+ if (randomNum < cumulativeWeight ) {
346+ selectedIndex = i ;
167347 break ;
168348 }
169- i ++;
170349 }
350+ return selectedIndex ;
351+ }
352+
353+ /**
354+ * select edge from iterator by index
355+ */
356+ private Edge selectEdge (Iterator <Edge > iterator , int selectedIndex ) {
357+ Edge selectedEdge = null ;
171358
359+ int index = 0 ;
360+ while (iterator .hasNext ()) {
361+ selectedEdge = iterator .next ();
362+ if (index == selectedIndex ) {
363+ break ;
364+ }
365+ index ++;
366+ }
172367 return selectedEdge ;
173368 }
174369
0 commit comments