From 8552899faaea9e84b776fca854849e0bda9a00ea Mon Sep 17 00:00:00 2001 From: Lauren Shin Date: Mon, 28 May 2018 14:12:03 -0700 Subject: [PATCH 1/7] added linear regression --- pom.xml | 29 +++++- src/main/java/regression/LR.java | 122 ++++++++++++++++++++++++++ src/main/java/regression/LRModel.java | 122 ++++++++++++++++++++++++++ src/test/java/regression/LRTest.java | 113 ++++++++++++++++++++++++ 4 files changed, 384 insertions(+), 2 deletions(-) create mode 100644 src/main/java/regression/LR.java create mode 100644 src/main/java/regression/LRModel.java create mode 100644 src/test/java/regression/LRTest.java diff --git a/pom.xml b/pom.xml index 5f84031..ad79085 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ ${encoding} ${java.version} ${java.version} - 3.2.2 + 3.3.4 @@ -54,7 +54,7 @@ org.apache.commons commons-math3 - 3.4.1 + 3.6.1 org.apache.commons @@ -103,6 +103,21 @@ test + + org.neo4j.test + neo4j-harness + ${neo4j.version} + test + + + + + org.neo4j.driver + neo4j-java-driver + 1.5.0 + test + + @@ -119,6 +134,16 @@ + + + maven-compiler-plugin + 3.1 + + + 1.8 + 1.8 + + diff --git a/src/main/java/regression/LR.java b/src/main/java/regression/LR.java new file mode 100644 index 0000000..829519b --- /dev/null +++ b/src/main/java/regression/LR.java @@ -0,0 +1,122 @@ +package regression; + +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.neo4j.graphdb.Entity; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.ResourceIterator; +import org.neo4j.logging.Log; +import org.neo4j.procedure.*; +import org.neo4j.procedure.Mode; + +import java.io.*; +import java.util.*; +import java.util.stream.Stream; + +public class LR { + @Context + public GraphDatabaseService db; + + @Context + public Log log; + + @Procedure(value = "regression.linear.create", mode = Mode.READ) + public Stream create(@Name("model") String model) { + return Stream.of((new LRModel(model)).asResult()); + } + + @Procedure(value = "regression.linear.addData", mode = Mode.READ) + public Stream addData(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { + LRModel lrModel = LRModel.from(model); + lrModel.add(given, expected); + return Stream.of(lrModel.asResult()); + } + + @Procedure(value = "regression.linear.removeData", mode = Mode.READ) + public Stream removeData(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { + LRModel lrModel = LRModel.from(model); + lrModel.removeData(given, expected); + return Stream.of(lrModel.asResult()); + } + + @Procedure(value = "regression.linear.removeModel", mode = Mode.READ) + public Stream removeModel(@Name("model") String model) { + return Stream.of(LRModel.removeModel(model)); + } + + @Procedure(value = "regression.linear.predict", mode = Mode.READ) + public Stream predict(@Name("mode") String model, @Name("given") double given) { + LRModel lrModel = LRModel.from(model); + return Stream.of(lrModel.predict(given)); + } + + @Procedure(value = "regression.linear.storeModel", mode = Mode.WRITE) + public Stream storeModel(@Name("model") String model) { + LRModel lrModel = LRModel.from(model); + lrModel.store(db); + return Stream.of(lrModel.asResult()); + } + + @Procedure(value = "regression.linear.createFromStorage", mode = Mode.READ) + public Stream createFromStorage(@Name("model") String model) { + Map parameters = new HashMap<>(); + parameters.put("name", model); + Entity modelNode; + SimpleRegression R; + try { + ResourceIterator n = db.execute("MATCH (n:LRModel {name:$name}) RETURN " + + "n", parameters).columnAs("n"); + modelNode = n.next(); + byte[] m = (byte[]) modelNode.getProperty("serializedModel"); + R = (SimpleRegression) convertFromBytes(m); + } catch (Exception e) { + throw new RuntimeException("no existing model for specified independent and dependent variables and model ID"); + } + return Stream.of(new LRModel(model, R, (String) modelNode.getProperty("state")).asResult()); + } + + public static class ModelResult { + public final String model; + public final String state; + public final double N; + public final Map info = new HashMap<>(); + + public ModelResult(String model, LRModel.State state, double N) { + this.model = model; + this.state = state.name(); + this.N = N; + } + + ModelResult withInfo(Object...infos) { + for (int i = 0; i < infos.length; i+=2) { + info.put(infos[i].toString(),infos[i+1]); + } + return this; + } + } + + public static class PredictResult { + public final double prediction; + public PredictResult(double p) { + this.prediction = p; + } + } + + //Serializes the object into a byte array for storage + public static byte[] convertToBytes(Object object) throws IOException { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = new ObjectOutputStream(bos)) { + out.writeObject(object); + return bos.toByteArray(); + } + } + + //de serializes the byte array and returns the stored object + private static Object convertFromBytes(byte[] bytes) throws IOException, ClassNotFoundException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + ObjectInput in = new ObjectInputStream(bis)) { + return in.readObject(); + } + } + + +} \ No newline at end of file diff --git a/src/main/java/regression/LRModel.java b/src/main/java/regression/LRModel.java new file mode 100644 index 0000000..1edd3a1 --- /dev/null +++ b/src/main/java/regression/LRModel.java @@ -0,0 +1,122 @@ +package regression; +import java.util.concurrent.ConcurrentHashMap; +import java.util.*; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.*; +import org.neo4j.procedure.UserAggregationUpdate; +import java.io.*; + +public class LRModel { + + private static ConcurrentHashMap models = new ConcurrentHashMap<>(); + private final String name; + private State state; + SimpleRegression R; + + + public LRModel(String model) { + if (models.containsKey(model)) + throw new IllegalArgumentException("Model " + model + " already exists, please remove it first"); + this.name = model; + this.state = State.created; + this.R = new SimpleRegression(); + models.put(name, this); + } + + public LRModel(String model, SimpleRegression R, String state) { + if (models.containsKey(model)) + throw new IllegalArgumentException("Model " + model + " already exists, please remove it first"); + this.name = model; + switch(state) { + case "created": this.state = State.created; + case "ready": this.state = State.ready; + case "removed": this.state = State.removed; + case "unknown": this.state = State.unknown; + } + this.R = R; + models.put(name, this); + } + + public static LRModel from(String name) { + LRModel model = models.get(name); + if (model != null) return model; + throw new IllegalArgumentException("No valid LR-Model " + name); + } + + public void add(double given, double expected) { + R.addData(given, expected); + if (R.getN() > 1) { + this.state = State.ready; + } + } + + public LR.PredictResult predict(double given) { + if (this.state == State.ready) + return new LR.PredictResult(R.predict(given)); + throw new IllegalArgumentException("Not enough data in model to predict yet"); + } + + + public void removeData(double given, double expected) { + R.removeData(given, expected); + if (R.getN() < 2) { + this.state = State.created; + } + } + + public static LR.ModelResult removeModel(String model) { + LRModel existing = models.remove(model); + return new LR.ModelResult(model, existing == null ? State.unknown : State.removed, 0); + } + + public void store(GraphDatabaseService db) { + try{ byte[] serializedR = LR.convertToBytes(R); + Map parameters = new HashMap<>(); + parameters.put("name", name); + ResourceIterator n = db.execute("MERGE (n:LRModel {name:$name}) RETURN n", parameters).columnAs("n"); + Entity modelNode = n.next(); + modelNode.setProperty("state", state.name()); + modelNode.setProperty("serializedModel", serializedR);} + catch (IOException e) { throw new RuntimeException(name + " cannot be serialized."); } + } + + /*protected void initTypes(Map types, String output) { + if (!types.containsKey(output)) throw new IllegalArgumentException("Outputs not defined: " + output); + int i = 0; + for (Map.Entry entry : types.entrySet()) { + String key = entry.getKey(); + this.types.put(key, DataType.from(entry.getValue())); + if (!key.equals(output)) this.offsets.put(key, i++); + } + this.offsets.put(output, i); + }*/ + + + /*enum DataType { + _class, _float, _order; + + public static DataType from(String type) { + switch (type.toUpperCase()) { + case "CLASS": + return DataType._class; + case "FLOAT": + return DataType._float; + case "ORDER": + return DataType._order; + default: + throw new IllegalArgumentException("Unknown type: " + type); + } + } + }*/ + + public enum State {created, ready, removed, unknown} + + public LR.ModelResult asResult() { + LR.ModelResult result = new LR.ModelResult(this.name, this.state, R.getN()); + return result; + } + + + +} diff --git a/src/test/java/regression/LRTest.java b/src/test/java/regression/LRTest.java new file mode 100644 index 0000000..61dae8a --- /dev/null +++ b/src/test/java/regression/LRTest.java @@ -0,0 +1,113 @@ +package regression; + +import java.util.HashMap; +import org.junit.Rule; +import org.junit.Test; +import org.neo4j.driver.v1.*; +import org.neo4j.harness.junit.Neo4jRule; + +import static org.junit.Assert.*; +import static org.hamcrest.CoreMatchers.equalTo; + +import org.apache.commons.math3.stat.regression.SimpleRegression; + +public class LRTest { + + // Start a Neo4j instance + @Rule + public Neo4jRule neo4j = new Neo4jRule() + .withFunction(LR.class) + .withProcedure(LR.class); + + + private static String createKnownRelationships = "CREATE (:Node {id:1}) - [:WORKS_FOR {time:1.0, progress:1.345}] -> " + + "(:Node {id:2}) - [:WORKS_FOR {time:2.0, progress:2.596}] -> " + + "(:Node {id:3}) - [:WORKS_FOR {time:3.0, progress:3.259}] -> (:Node {id:4})"; + + private static String createUnknownRelationships = "CREATE (:Node {id:5}) -[:WORKS_FOR {time:4.0}] -> " + + "(:Node {id:6}) - [:WORKS_FOR {time:5.0}] -> (:Node {id:7})"; + + private static String gatherPredictedValues = "MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND " + + "exists(r.predictedProgress) RETURN r.time as time, r.predictedProgress as predictedProgress"; + + @Test + public void shouldPerformRegression() throws Throwable { + try (Driver driver = GraphDatabase.driver(neo4j.boltURI(), Config.build().withoutEncryption().toConfig()); + Session session = driver.session()) { + + session.run(createKnownRelationships); + session.run(createUnknownRelationships); + //initialize the model + session.run("CALL regression.linear.create('work and progress')"); + //add known data + session.run("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND exists(r.progress) CALL " + + "regression.linear.addData('work and progress', r.time, r.progress) YIELD state RETURN r.time, r.progress, state"); + //store predictions + session.run("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND NOT exists(r.progress) CALL " + + "regression.linear.predict('work and progress', r.time) YIELD prediction SET r.predictedProgress = " + + "prediction"); + + SimpleRegression R = new SimpleRegression(); + R.addData(1.0, 1.345); + R.addData(2.0, 2.596); + R.addData(3.0, 3.259); + + HashMap expected = new HashMap<>(); + expected.put(4.0, R.predict(4.0)); + expected.put(5.0, R.predict(5.0)); + + StatementResult result = session.run(gatherPredictedValues); + + while (result.hasNext()) { + Record actual = result.next(); + + double time = actual.get("time").asDouble(); + double expectedPrediction = expected.get(time); + double actualPrediction = actual.get("predictedProgress").asDouble(); + + assertThat(actualPrediction, equalTo(expectedPrediction)); + } + + session.run("CALL regression.linear.storeModel('work and progress')"); + session.run("CALL regression.linear.removeModel('work and progress')"); + session.run("CALL regression.linear.createFromStorage('work and progress')"); + + + //remove data from relationship between nodes 1 and 2 + session.run("MATCH (:Node {id:1})-[r:WORKS_FOR]->(:Node {id:2}) CALL regression.linear.removeData('work " + + "and progress', r.time, r.progress) YIELD state, N RETURN r.time as time, r.progress as progress, state, N"); + + //create a new relationship between nodes 7 and 8 + session.run("MATCH (n7:Node {id:7}) MERGE (n7)-[:WORKS_FOR {time:6.0, progress:5.870}]->(:Node {id:8})"); + + //add data from new relationship to model + session.run("MATCH (:Node {id:7})-[r:WORKS_FOR]->(:Node {id:8}) CALL regression.linear.addData('work " + + "and progress', r.time, r.progress) YIELD state, N RETURN r.time, r.progress, state, N"); + + //map new model on all relationships with unknown progress + session.run("MATCH (:Node)-[r:WORKS_FOR]->(:Node) WHERE exists(r.time) AND NOT exists(r.progress) " + + "CALL regression.linear.predict('work and progress', " + + "r.time) YIELD prediction SET r.predictedProgress = prediction"); + + //replicate the creation and updates of the model + R.removeData(1.0, 1.345); + R.addData(6.0, 5.870); + + expected.put(4.0, R.predict(4.0)); + expected.put(5.0, R.predict(5.0)); + + //make sure predicted values are correct + result = session.run(gatherPredictedValues); + while (result.hasNext()) { + Record actual = result.next(); + + double time = actual.get("time").asDouble(); + double expectedPrediction = expected.get(time); + double actualPrediction = actual.get("predictedProgress").asDouble(); + + assertThat( actualPrediction, equalTo( expectedPrediction ) ); + } + + } + } +} From 02546d0232e87f29c5a3c0d3c82e3f6d6163a982 Mon Sep 17 00:00:00 2001 From: Lauren Shin Date: Tue, 29 May 2018 16:35:55 -0700 Subject: [PATCH 2/7] first batch of hunger edits --- pom.xml | 5 ++- src/main/java/regression/LR.java | 65 ++++++++++++--------------- src/main/java/regression/LRModel.java | 26 +++++++---- src/test/java/regression/LRTest.java | 31 +++++++------ 4 files changed, 67 insertions(+), 60 deletions(-) diff --git a/pom.xml b/pom.xml index ad79085..c3fbfd6 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ ${encoding} ${java.version} ${java.version} - 3.3.4 + 3.2.2 @@ -54,8 +54,9 @@ org.apache.commons commons-math3 - 3.6.1 + 3.4.1 + org.apache.commons commons-lang3 diff --git a/src/main/java/regression/LR.java b/src/main/java/regression/LR.java index 829519b..46662b4 100644 --- a/src/main/java/regression/LR.java +++ b/src/main/java/regression/LR.java @@ -1,12 +1,15 @@ package regression; import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.bytedeco.javacv.FrameFilter; import org.neo4j.graphdb.Entity; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.ResourceIterator; import org.neo4j.logging.Log; import org.neo4j.procedure.*; import org.neo4j.procedure.Mode; +import org.neo4j.unsafe.impl.batchimport.cache.ByteArray; +import sun.java2d.pipe.SpanShapeRenderer; import java.io.*; import java.util.*; @@ -24,54 +27,49 @@ public Stream create(@Name("model") String model) { return Stream.of((new LRModel(model)).asResult()); } - @Procedure(value = "regression.linear.addData", mode = Mode.READ) - public Stream addData(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { + @Procedure(value = "regression.linear.info", mode = Mode.READ) + public Stream info(@Name("model") String model) { LRModel lrModel = LRModel.from(model); - lrModel.add(given, expected); return Stream.of(lrModel.asResult()); } - @Procedure(value = "regression.linear.removeData", mode = Mode.READ) - public Stream removeData(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { + @Procedure(value = "regression.linear.add", mode = Mode.READ) + public void add(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { + LRModel lrModel = LRModel.from(model); + lrModel.add(given, expected); + } + + @Procedure(value = "regression.linear.remove", mode = Mode.READ) + public void remove(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { LRModel lrModel = LRModel.from(model); lrModel.removeData(given, expected); - return Stream.of(lrModel.asResult()); } - @Procedure(value = "regression.linear.removeModel", mode = Mode.READ) - public Stream removeModel(@Name("model") String model) { + @Procedure(value = "regression.linear.delete", mode = Mode.READ) + public Stream delete(@Name("model") String model) { return Stream.of(LRModel.removeModel(model)); } - @Procedure(value = "regression.linear.predict", mode = Mode.READ) - public Stream predict(@Name("mode") String model, @Name("given") double given) { + @UserFunction(value = "regression.linear.predict") + public double predict(@Name("mode") String model, @Name("given") double given) { LRModel lrModel = LRModel.from(model); - return Stream.of(lrModel.predict(given)); + return lrModel.predict(given); } - @Procedure(value = "regression.linear.storeModel", mode = Mode.WRITE) - public Stream storeModel(@Name("model") String model) { + @UserFunction(value = "regression.linear.serialize") + public Object serialize(@Name("model") String model) { LRModel lrModel = LRModel.from(model); - lrModel.store(db); - return Stream.of(lrModel.asResult()); + return lrModel.serialize(); } - @Procedure(value = "regression.linear.createFromStorage", mode = Mode.READ) - public Stream createFromStorage(@Name("model") String model) { - Map parameters = new HashMap<>(); - parameters.put("name", model); - Entity modelNode; + @Procedure(value = "regression.linear.load", mode = Mode.READ) + public Stream load(@Name("model") String model, @Name("data") Object data) { SimpleRegression R; - try { - ResourceIterator n = db.execute("MATCH (n:LRModel {name:$name}) RETURN " + - "n", parameters).columnAs("n"); - modelNode = n.next(); - byte[] m = (byte[]) modelNode.getProperty("serializedModel"); - R = (SimpleRegression) convertFromBytes(m); - } catch (Exception e) { - throw new RuntimeException("no existing model for specified independent and dependent variables and model ID"); + try { R = (SimpleRegression) convertFromBytes((byte[]) data); } + catch (Exception e) { + throw new RuntimeException("invalid data"); } - return Stream.of(new LRModel(model, R, (String) modelNode.getProperty("state")).asResult()); + return Stream.of((new LRModel(model, R)).asResult()); } public static class ModelResult { @@ -94,13 +92,6 @@ ModelResult withInfo(Object...infos) { } } - public static class PredictResult { - public final double prediction; - public PredictResult(double p) { - this.prediction = p; - } - } - //Serializes the object into a byte array for storage public static byte[] convertToBytes(Object object) throws IOException { try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); @@ -119,4 +110,6 @@ private static Object convertFromBytes(byte[] bytes) throws IOException, ClassNo } + + } \ No newline at end of file diff --git a/src/main/java/regression/LRModel.java b/src/main/java/regression/LRModel.java index 1edd3a1..a65d0c6 100644 --- a/src/main/java/regression/LRModel.java +++ b/src/main/java/regression/LRModel.java @@ -24,17 +24,20 @@ public LRModel(String model) { models.put(name, this); } - public LRModel(String model, SimpleRegression R, String state) { + public LRModel(String model, SimpleRegression R) { if (models.containsKey(model)) throw new IllegalArgumentException("Model " + model + " already exists, please remove it first"); this.name = model; - switch(state) { - case "created": this.state = State.created; - case "ready": this.state = State.ready; - case "removed": this.state = State.removed; - case "unknown": this.state = State.unknown; + if (R == null) { + this.R = new SimpleRegression(); + this.state = State.created; + } else { + this.R = R; + if (R.getN() < 2) + this.state = State.created; + else + this.state = State.ready; } - this.R = R; models.put(name, this); } @@ -51,9 +54,9 @@ public void add(double given, double expected) { } } - public LR.PredictResult predict(double given) { + public double predict(double given) { if (this.state == State.ready) - return new LR.PredictResult(R.predict(given)); + return R.predict(given); throw new IllegalArgumentException("Not enough data in model to predict yet"); } @@ -65,6 +68,11 @@ public void removeData(double given, double expected) { } } + public byte[] serialize() { + try { return LR.convertToBytes(R); } + catch (IOException e) { throw new RuntimeException(name + " cannot be serialized."); } + } + public static LR.ModelResult removeModel(String model) { LRModel existing = models.remove(model); return new LR.ModelResult(model, existing == null ? State.unknown : State.removed, 0); diff --git a/src/test/java/regression/LRTest.java b/src/test/java/regression/LRTest.java index 61dae8a..2b17528 100644 --- a/src/test/java/regression/LRTest.java +++ b/src/test/java/regression/LRTest.java @@ -1,6 +1,7 @@ package regression; import java.util.HashMap; +import java.util.Map; import org.junit.Rule; import org.junit.Test; import org.neo4j.driver.v1.*; @@ -41,11 +42,10 @@ public void shouldPerformRegression() throws Throwable { session.run("CALL regression.linear.create('work and progress')"); //add known data session.run("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND exists(r.progress) CALL " + - "regression.linear.addData('work and progress', r.time, r.progress) YIELD state RETURN r.time, r.progress, state"); + "regression.linear.add('work and progress', r.time, r.progress) RETURN r"); //store predictions - session.run("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND NOT exists(r.progress) CALL " + - "regression.linear.predict('work and progress', r.time) YIELD prediction SET r.predictedProgress = " + - "prediction"); + session.run("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND NOT exists(r.progress) SET " + + "r.predictedProgress = regression.linear.predict('work and progress', r.time)"); SimpleRegression R = new SimpleRegression(); R.addData(1.0, 1.345); @@ -68,26 +68,28 @@ public void shouldPerformRegression() throws Throwable { assertThat(actualPrediction, equalTo(expectedPrediction)); } - session.run("CALL regression.linear.storeModel('work and progress')"); - session.run("CALL regression.linear.removeModel('work and progress')"); - session.run("CALL regression.linear.createFromStorage('work and progress')"); + Record r = session.run("RETURN regression.linear.serialize('work and progress') as data").next(); + byte[] data = r.get("data").asByteArray(); + session.run("CALL regression.linear.delete('work and progress')"); + Map params = new HashMap<>(); + params.put("data", data); + session.run("CALL regression.linear.load('work and progress', $data)", params); //remove data from relationship between nodes 1 and 2 - session.run("MATCH (:Node {id:1})-[r:WORKS_FOR]->(:Node {id:2}) CALL regression.linear.removeData('work " + - "and progress', r.time, r.progress) YIELD state, N RETURN r.time as time, r.progress as progress, state, N"); + session.run("MATCH (:Node {id:1})-[r:WORKS_FOR]->(:Node {id:2}) CALL regression.linear.remove('work " + + "and progress', r.time, r.progress) return r"); //create a new relationship between nodes 7 and 8 session.run("MATCH (n7:Node {id:7}) MERGE (n7)-[:WORKS_FOR {time:6.0, progress:5.870}]->(:Node {id:8})"); //add data from new relationship to model - session.run("MATCH (:Node {id:7})-[r:WORKS_FOR]->(:Node {id:8}) CALL regression.linear.addData('work " + - "and progress', r.time, r.progress) YIELD state, N RETURN r.time, r.progress, state, N"); + session.run("MATCH (:Node {id:7})-[r:WORKS_FOR]->(:Node {id:8}) CALL regression.linear.add('work " + + "and progress', r.time, r.progress) RETURN r"); //map new model on all relationships with unknown progress session.run("MATCH (:Node)-[r:WORKS_FOR]->(:Node) WHERE exists(r.time) AND NOT exists(r.progress) " + - "CALL regression.linear.predict('work and progress', " + - "r.time) YIELD prediction SET r.predictedProgress = prediction"); + "SET r.predictedProgress = regression.linear.predict('work and progress', r.time)"); //replicate the creation and updates of the model R.removeData(1.0, 1.345); @@ -108,6 +110,9 @@ public void shouldPerformRegression() throws Throwable { assertThat( actualPrediction, equalTo( expectedPrediction ) ); } + session.run("CALL regression.linear.delete('work and progress')"); + + } } } From 717dc0e30ac70c97825fc7fdfe1d392c1451bf83 Mon Sep 17 00:00:00 2001 From: Lauren Shin Date: Wed, 30 May 2018 14:42:21 -0700 Subject: [PATCH 3/7] embedded testing --- pom.xml | 15 -- src/main/java/regression/LR.java | 13 +- src/main/java/regression/LRModel.java | 2 +- src/test/java/regression/LRTest.java | 213 ++++++++++++++++---------- 4 files changed, 147 insertions(+), 96 deletions(-) diff --git a/pom.xml b/pom.xml index c3fbfd6..996f60f 100644 --- a/pom.xml +++ b/pom.xml @@ -104,21 +104,6 @@ test - - org.neo4j.test - neo4j-harness - ${neo4j.version} - test - - - - - org.neo4j.driver - neo4j-java-driver - 1.5.0 - test - - diff --git a/src/main/java/regression/LR.java b/src/main/java/regression/LR.java index 46662b4..25f14d4 100644 --- a/src/main/java/regression/LR.java +++ b/src/main/java/regression/LR.java @@ -39,6 +39,17 @@ public void add(@Name("model") String model, @Name("given") double given, @Name( lrModel.add(given, expected); } + @Procedure(value = "regression.linear.addM", mode = Mode.READ) + public void addM(@Name("model") String model, @Name("given") Object given, @Name("expected") Object expected) { + double[] g = (double[]) given; + double[] e = (double[]) expected; + LRModel lrModel = LRModel.from(model); + if (g.length != e.length) throw new IllegalArgumentException("Lengths of the two data arrays are unequal."); + for (int i = 0; i < g.length; i++) { + lrModel.add(g[i], e[i]); + } + } + @Procedure(value = "regression.linear.remove", mode = Mode.READ) public void remove(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { LRModel lrModel = LRModel.from(model); @@ -102,7 +113,7 @@ public static byte[] convertToBytes(Object object) throws IOException { } //de serializes the byte array and returns the stored object - private static Object convertFromBytes(byte[] bytes) throws IOException, ClassNotFoundException { + public static Object convertFromBytes(byte[] bytes) throws IOException, ClassNotFoundException { try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInput in = new ObjectInputStream(bis)) { return in.readObject(); diff --git a/src/main/java/regression/LRModel.java b/src/main/java/regression/LRModel.java index a65d0c6..17b6ff2 100644 --- a/src/main/java/regression/LRModel.java +++ b/src/main/java/regression/LRModel.java @@ -121,7 +121,7 @@ public static DataType from(String type) { public enum State {created, ready, removed, unknown} public LR.ModelResult asResult() { - LR.ModelResult result = new LR.ModelResult(this.name, this.state, R.getN()); + LR.ModelResult result = new LR.ModelResult(this.name, this.state, this.R.getN()); return result; } diff --git a/src/test/java/regression/LRTest.java b/src/test/java/regression/LRTest.java index 2b17528..8bcf177 100644 --- a/src/test/java/regression/LRTest.java +++ b/src/test/java/regression/LRTest.java @@ -2,117 +2,172 @@ import java.util.HashMap; import java.util.Map; -import org.junit.Rule; + +import org.junit.After; +import org.junit.Before; import org.junit.Test; -import org.neo4j.driver.v1.*; -import org.neo4j.harness.junit.Neo4jRule; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.*; import static org.junit.Assert.*; import static org.hamcrest.CoreMatchers.equalTo; import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.neo4j.kernel.impl.proc.Procedures; +import org.neo4j.kernel.internal.GraphDatabaseAPI; +import org.neo4j.test.TestGraphDatabaseFactory; public class LRTest { - // Start a Neo4j instance - @Rule - public Neo4jRule neo4j = new Neo4jRule() - .withFunction(LR.class) - .withProcedure(LR.class); - - - private static String createKnownRelationships = "CREATE (:Node {id:1}) - [:WORKS_FOR {time:1.0, progress:1.345}] -> " + - "(:Node {id:2}) - [:WORKS_FOR {time:2.0, progress:2.596}] -> " + - "(:Node {id:3}) - [:WORKS_FOR {time:3.0, progress:3.259}] -> (:Node {id:4})"; + private GraphDatabaseService db; - private static String createUnknownRelationships = "CREATE (:Node {id:5}) -[:WORKS_FOR {time:4.0}] -> " + - "(:Node {id:6}) - [:WORKS_FOR {time:5.0}] -> (:Node {id:7})"; + @Before + public void setUp() throws Exception { + db = new TestGraphDatabaseFactory().newImpermanentDatabase(); + Procedures procedures = ((GraphDatabaseAPI) db).getDependencyResolver().resolveDependency(Procedures.class); + procedures.registerProcedure(LR.class); + procedures.registerFunction(LR.class); + } - private static String gatherPredictedValues = "MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND " + - "exists(r.predictedProgress) RETURN r.time as time, r.predictedProgress as predictedProgress"; + @After + public void tearDown() throws Exception { + db.shutdown(); + } @Test - public void shouldPerformRegression() throws Throwable { - try (Driver driver = GraphDatabase.driver(neo4j.boltURI(), Config.build().withoutEncryption().toConfig()); - Session session = driver.session()) { - - session.run(createKnownRelationships); - session.run(createUnknownRelationships); - //initialize the model - session.run("CALL regression.linear.create('work and progress')"); - //add known data - session.run("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND exists(r.progress) CALL " + - "regression.linear.add('work and progress', r.time, r.progress) RETURN r"); - //store predictions - session.run("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND NOT exists(r.progress) SET " + - "r.predictedProgress = regression.linear.predict('work and progress', r.time)"); + public void regression() throws Exception { + //create known relationships for times 1, 2, 3 + db.execute("CREATE (:Node {id:1}) - [:WORKS_FOR {time:1.0, progress:1.345}] -> " + + "(:Node {id:2}) - [:WORKS_FOR {time:2.0, progress:2.596}] -> " + + "(:Node {id:3}) - [:WORKS_FOR {time:3.0, progress:3.259}] -> (:Node {id:4})"); + + //create unknown relationships for times 4, 5 + db.execute("CREATE (:Node {id:5}) -[:WORKS_FOR {time:4.0}] -> " + + "(:Node {id:6}) - [:WORKS_FOR {time:5.0}] -> (:Node {id:7})"); + + //initialize the model + db.execute("CALL regression.linear.create('work and progress')"); + + //add known data + Result r = db.execute("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND exists(r.progress) CALL " + + "regression.linear.add('work and progress', r.time, r.progress) RETURN r"); + //these rows are computed lazily so we must iterate through all rows to ensure all data points are added to the model + while(r.hasNext()) r.next(); + + //check that the correct info is stored in the model (should contain 3 data points) + Map info = db.execute("CALL regression.linear.info('work and progress') YIELD model, state, N " + + "RETURN model, state, N").next(); + assertTrue(info.get("model").equals("work and progress")); + assertTrue(info.get("state").equals("ready")); + assertThat(info.get("N"), equalTo(3.0)); + + //store predictions + db.execute("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND NOT exists(r.progress) SET " + + "r.predictedProgress = regression.linear.predict('work and progress', r.time)"); + + //check that predictions are correct + + SimpleRegression R = new SimpleRegression(); + R.addData(1.0, 1.345); + R.addData(2.0, 2.596); + R.addData(3.0, 3.259); + HashMap expected = new HashMap<>(); + expected.put(4.0, R.predict(4.0)); + expected.put(5.0, R.predict(5.0)); + + String gatherPredictedValues = "MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND " + + "exists(r.predictedProgress) RETURN r.time as time, r.predictedProgress as predictedProgress"; + + Result result = db.execute(gatherPredictedValues); + + while (result.hasNext()) { + Map actual = result.next(); + + double time = (double) actual.get("time"); + double expectedPrediction = expected.get(time); + double actualPrediction = (double) actual.get("predictedProgress"); + + assertThat(actualPrediction, equalTo(expectedPrediction)); + } - SimpleRegression R = new SimpleRegression(); - R.addData(1.0, 1.345); - R.addData(2.0, 2.596); - R.addData(3.0, 3.259); + //serialize the model + Map serial = db.execute("RETURN regression.linear.serialize('work and progress') as data").next(); + Object data = serial.get("data"); - HashMap expected = new HashMap<>(); - expected.put(4.0, R.predict(4.0)); - expected.put(5.0, R.predict(5.0)); + //check that the model returns same predictions as the model stored in the procedure + SimpleRegression storedR = (SimpleRegression) LR.convertFromBytes((byte[]) data); + assertThat(storedR.predict(4.0), equalTo(expected.get(4.0))); + assertThat(storedR.predict(5.0), equalTo(expected.get(5.0))); - StatementResult result = session.run(gatherPredictedValues); + //delete model then re-create using serialization + db.execute("CALL regression.linear.delete('work and progress')"); + Map params = new HashMap<>(); + params.put("data", data); + db.execute("CALL regression.linear.load('work and progress', $data)", params); - while (result.hasNext()) { - Record actual = result.next(); + //remove data from relationship between nodes 1 and 2 + r = db.execute("MATCH (:Node {id:1})-[r:WORKS_FOR]->(:Node {id:2}) CALL regression.linear.remove('work " + + "and progress', r.time, r.progress) return r"); + while (r.hasNext()) r.next(); - double time = actual.get("time").asDouble(); - double expectedPrediction = expected.get(time); - double actualPrediction = actual.get("predictedProgress").asDouble(); + //create a new relationship between nodes 7 and 8 + db.execute("MATCH (n7:Node {id:7}) MERGE (n7)-[:WORKS_FOR {time:6.0, progress:5.870}]->(:Node {id:8})"); - assertThat(actualPrediction, equalTo(expectedPrediction)); - } + //add data from new relationship to model + r = db.execute("MATCH (:Node {id:7})-[r:WORKS_FOR]->(:Node {id:8}) CALL regression.linear.add('work " + + "and progress', r.time, r.progress) RETURN r.time"); + //again must iterate through rows + while (r.hasNext()) r.next(); - Record r = session.run("RETURN regression.linear.serialize('work and progress') as data").next(); - byte[] data = r.get("data").asByteArray(); - session.run("CALL regression.linear.delete('work and progress')"); - Map params = new HashMap<>(); - params.put("data", data); - session.run("CALL regression.linear.load('work and progress', $data)", params); + //map new model on all relationships with unknown progress + db.execute("MATCH (:Node)-[r:WORKS_FOR]->(:Node) WHERE exists(r.time) AND NOT exists(r.progress) " + + "SET r.predictedProgress = regression.linear.predict('work and progress', r.time)"); + //replicate the creation and updates of the model + R.removeData(1.0, 1.345); + R.addData(6.0, 5.870); - //remove data from relationship between nodes 1 and 2 - session.run("MATCH (:Node {id:1})-[r:WORKS_FOR]->(:Node {id:2}) CALL regression.linear.remove('work " + - "and progress', r.time, r.progress) return r"); + expected.put(4.0, R.predict(4.0)); + expected.put(5.0, R.predict(5.0)); - //create a new relationship between nodes 7 and 8 - session.run("MATCH (n7:Node {id:7}) MERGE (n7)-[:WORKS_FOR {time:6.0, progress:5.870}]->(:Node {id:8})"); + //make sure predicted values are correct + result = db.execute(gatherPredictedValues); + while (result.hasNext()) { + Map actual = result.next(); - //add data from new relationship to model - session.run("MATCH (:Node {id:7})-[r:WORKS_FOR]->(:Node {id:8}) CALL regression.linear.add('work " + - "and progress', r.time, r.progress) RETURN r"); + double time = (double) actual.get("time"); + double expectedPrediction = expected.get(time); + double actualPrediction = (double) actual.get("predictedProgress"); - //map new model on all relationships with unknown progress - session.run("MATCH (:Node)-[r:WORKS_FOR]->(:Node) WHERE exists(r.time) AND NOT exists(r.progress) " + - "SET r.predictedProgress = regression.linear.predict('work and progress', r.time)"); + assertThat( actualPrediction, equalTo( expectedPrediction ) ); + } - //replicate the creation and updates of the model - R.removeData(1.0, 1.345); - R.addData(6.0, 5.870); + double[] points = {7.0, 8.0}; + double[] observed = {6.900, 9.234}; + params.put("points", points); + params.put("observed", observed); - expected.put(4.0, R.predict(4.0)); - expected.put(5.0, R.predict(5.0)); + db.execute("CALL regression.linear.addM('work and progress', $points, $observed)", params); + db.execute("MATCH (:Node)-[r:WORKS_FOR]->(:Node) WHERE exists(r.time) AND NOT exists(r.progress) " + + "SET r.predictedProgress = regression.linear.predict('work and progress', r.time)"); + R.addData(7.0, 6.900); + R.addData(8.0, 9.234); + expected.put(4.0, R.predict(4.0)); + expected.put(5.0, R.predict(5.0)); + result = db.execute(gatherPredictedValues); - //make sure predicted values are correct - result = session.run(gatherPredictedValues); - while (result.hasNext()) { - Record actual = result.next(); + while (result.hasNext()) { + Map actual = result.next(); - double time = actual.get("time").asDouble(); - double expectedPrediction = expected.get(time); - double actualPrediction = actual.get("predictedProgress").asDouble(); + double time = (double) actual.get("time"); + double expectedPrediction = expected.get(time); + double actualPrediction = (double) actual.get("predictedProgress"); - assertThat( actualPrediction, equalTo( expectedPrediction ) ); - } + assertThat( actualPrediction, equalTo( expectedPrediction ) ); + } - session.run("CALL regression.linear.delete('work and progress')"); + db.execute("CALL regression.linear.delete('work and progress')").close(); - } } } From d56491d8b294cfd25256c8c7002f9fda22584a9e Mon Sep 17 00:00:00 2001 From: Lauren Shin Date: Wed, 30 May 2018 14:56:11 -0700 Subject: [PATCH 4/7] cleaned up test --- src/test/java/regression/LRTest.java | 68 ++++++++++++---------------- 1 file changed, 30 insertions(+), 38 deletions(-) diff --git a/src/test/java/regression/LRTest.java b/src/test/java/regression/LRTest.java index 8bcf177..e57fa62 100644 --- a/src/test/java/regression/LRTest.java +++ b/src/test/java/regression/LRTest.java @@ -1,6 +1,7 @@ package regression; import java.util.HashMap; +import java.util.Iterator; import java.util.Map; import org.junit.After; @@ -34,6 +35,22 @@ public void tearDown() throws Exception { db.shutdown(); } + private void exhaust(Iterator r) { + while(r.hasNext()) r.next(); + } + + private void check(Result result, Map expected) { + while (result.hasNext()) { + Map actual = result.next(); + + double time = (double) actual.get("time"); + double expectedPrediction = expected.get(time); + double actualPrediction = (double) actual.get("predictedProgress"); + + assertThat( actualPrediction, equalTo( expectedPrediction ) ); + } + } + @Test public void regression() throws Exception { //create known relationships for times 1, 2, 3 @@ -52,7 +69,7 @@ public void regression() throws Exception { Result r = db.execute("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND exists(r.progress) CALL " + "regression.linear.add('work and progress', r.time, r.progress) RETURN r"); //these rows are computed lazily so we must iterate through all rows to ensure all data points are added to the model - while(r.hasNext()) r.next(); + exhaust(r); //check that the correct info is stored in the model (should contain 3 data points) Map info = db.execute("CALL regression.linear.info('work and progress') YIELD model, state, N " + @@ -62,8 +79,9 @@ public void regression() throws Exception { assertThat(info.get("N"), equalTo(3.0)); //store predictions - db.execute("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND NOT exists(r.progress) SET " + - "r.predictedProgress = regression.linear.predict('work and progress', r.time)"); + String storePredictions = "MATCH (:Node)-[r:WORKS_FOR]->(:Node) WHERE exists(r.time) AND NOT exists(r.progress) " + + "SET r.predictedProgress = regression.linear.predict('work and progress', r.time)"; + db.execute(storePredictions); //check that predictions are correct @@ -80,21 +98,13 @@ public void regression() throws Exception { Result result = db.execute(gatherPredictedValues); - while (result.hasNext()) { - Map actual = result.next(); - - double time = (double) actual.get("time"); - double expectedPrediction = expected.get(time); - double actualPrediction = (double) actual.get("predictedProgress"); - - assertThat(actualPrediction, equalTo(expectedPrediction)); - } + check(result, expected); //serialize the model Map serial = db.execute("RETURN regression.linear.serialize('work and progress') as data").next(); Object data = serial.get("data"); - //check that the model returns same predictions as the model stored in the procedure + //check that the byte[] model returns same predictions as the model stored in the procedure SimpleRegression storedR = (SimpleRegression) LR.convertFromBytes((byte[]) data); assertThat(storedR.predict(4.0), equalTo(expected.get(4.0))); assertThat(storedR.predict(5.0), equalTo(expected.get(5.0))); @@ -108,7 +118,7 @@ public void regression() throws Exception { //remove data from relationship between nodes 1 and 2 r = db.execute("MATCH (:Node {id:1})-[r:WORKS_FOR]->(:Node {id:2}) CALL regression.linear.remove('work " + "and progress', r.time, r.progress) return r"); - while (r.hasNext()) r.next(); + exhaust(r); //create a new relationship between nodes 7 and 8 db.execute("MATCH (n7:Node {id:7}) MERGE (n7)-[:WORKS_FOR {time:6.0, progress:5.870}]->(:Node {id:8})"); @@ -117,54 +127,36 @@ public void regression() throws Exception { r = db.execute("MATCH (:Node {id:7})-[r:WORKS_FOR]->(:Node {id:8}) CALL regression.linear.add('work " + "and progress', r.time, r.progress) RETURN r.time"); //again must iterate through rows - while (r.hasNext()) r.next(); + exhaust(r); //map new model on all relationships with unknown progress - db.execute("MATCH (:Node)-[r:WORKS_FOR]->(:Node) WHERE exists(r.time) AND NOT exists(r.progress) " + - "SET r.predictedProgress = regression.linear.predict('work and progress', r.time)"); + db.execute(storePredictions); //replicate the creation and updates of the model R.removeData(1.0, 1.345); R.addData(6.0, 5.870); - expected.put(4.0, R.predict(4.0)); expected.put(5.0, R.predict(5.0)); //make sure predicted values are correct result = db.execute(gatherPredictedValues); - while (result.hasNext()) { - Map actual = result.next(); - - double time = (double) actual.get("time"); - double expectedPrediction = expected.get(time); - double actualPrediction = (double) actual.get("predictedProgress"); - - assertThat( actualPrediction, equalTo( expectedPrediction ) ); - } + check(result, expected); + //test addM procedure for adding multiple data points double[] points = {7.0, 8.0}; double[] observed = {6.900, 9.234}; params.put("points", points); params.put("observed", observed); db.execute("CALL regression.linear.addM('work and progress', $points, $observed)", params); - db.execute("MATCH (:Node)-[r:WORKS_FOR]->(:Node) WHERE exists(r.time) AND NOT exists(r.progress) " + - "SET r.predictedProgress = regression.linear.predict('work and progress', r.time)"); + db.execute(storePredictions); R.addData(7.0, 6.900); R.addData(8.0, 9.234); expected.put(4.0, R.predict(4.0)); expected.put(5.0, R.predict(5.0)); result = db.execute(gatherPredictedValues); - while (result.hasNext()) { - Map actual = result.next(); - - double time = (double) actual.get("time"); - double expectedPrediction = expected.get(time); - double actualPrediction = (double) actual.get("predictedProgress"); - - assertThat( actualPrediction, equalTo( expectedPrediction ) ); - } + check(result, expected); db.execute("CALL regression.linear.delete('work and progress')").close(); From fb12338c3ccf86d1d346b50b13365f2ba744f920 Mon Sep 17 00:00:00 2001 From: Lauren Shin Date: Wed, 30 May 2018 15:00:32 -0700 Subject: [PATCH 5/7] removed unnecessary plugin --- pom.xml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/pom.xml b/pom.xml index 996f60f..5918d2f 100644 --- a/pom.xml +++ b/pom.xml @@ -120,16 +120,6 @@ - - - maven-compiler-plugin - 3.1 - - - 1.8 - 1.8 - - From 718cd7c94ae074535ed672e3056d5e0fc50e9b6d Mon Sep 17 00:00:00 2001 From: Lauren Shin Date: Mon, 4 Jun 2018 15:13:09 -0700 Subject: [PATCH 6/7] added doc and fixed addM --- asciidoc | 151 +++++++++++++++++++++++++++ src/main/java/regression/LR.java | 45 ++++++-- src/test/java/regression/LRTest.java | 6 +- 3 files changed, 193 insertions(+), 9 deletions(-) create mode 100644 asciidoc diff --git a/asciidoc b/asciidoc new file mode 100644 index 0000000..8da411c --- /dev/null +++ b/asciidoc @@ -0,0 +1,151 @@ += Simple Linear Regression + +// tag::introduction[] +Regression is a statistical tool for investigating the relationships between variables. Simple linear regression is the simplest form of regression; it creates a linear model for the relationship between the dependent variable and a single independent variable. Visually, simple linear regression "draws" a trend line on the scatter plot of two variables that best approximates their linear relationship. The model can be expressed with the two parameters of its line: slope and intercept. This is one of the most popular tools in statistics, and it is frequently used as a predictor for machine learning. +// end::introduction[] + +== Explanation and history + +// tag::explanation[] +At the core of linear regression is the method of least squares. In this method, the linear trend line is chosen which minimizes the sum of every data point's squared residual (deviation from the model). The method of least squares was independently discovered by Carl Friedrich-Gauss and Adrien-Marie Legendre in the early 19th century. The linear regression methods used today can be primarily attributed to the work of R.A. Fisher in the 1920s. +// end::explanation[] + +== Use-cases + +// tag::use-case[] +In simple linear regression, both the independent and dependent variables must be numeric. The dependent variable (`y`) can then be expressed in terms of the independent variable (`x`) using the two line parameters slope (`m`) and intercept (`b`) with the equation `y = m * x + b`. For these approximations to be meaningful, the dependent variable should take continuous values. The relationship between any two variables satisfying these conditions can be analyzed with simple linear regression. However, the model will only be successful for linearly related data. Some common examples include: + +* Predicting housing prices with square footage, number of bedrooms, number of bathrooms, etc. +* Analyzing sales of a product using pricing or performance information +* Calculating causal relationships between parameters in biological systems +// end::use-case[] + +== Constraints + +// tag::constraints[] +Because simple linear regression is so straightforward, it can be used with any numeric data pair. The real question is how well the best model fits the data. There are several measurements which attempt to quantify the success of the model. For example, the coefficient of determination (`r^2^`) is the proportion of the variance in the dependent variable that is predictable from the independent variable. A coefficient `r^2^ = 1` indicates that the variance in the dependent variable is entirely predictable from the independent variable (and thus the model is perfect). +// end::use-case[] + +== Example + +Let's look at a straightforward example--predicting Airbnb listing prices using the listing's number of bedrooms. Run `:play http://guides.neo4j.com/listings` and follow the import statements to load Will Lyon's Airbnb graph. + +.First initialize the model +[source,cypher] +---- +CALL regression.linear.create('airbnb prices') +---- + +.Then add data point by point +[source,cypher] +---- +MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) +WHERE exists(list.bedrooms) + AND exists(list.price) + AND NOT exists(list.added) OR list.added = false +CALL regression.linear.add('airbnb prices', list.bedrooms, list.price) +SET list.added = true +RETURN list.listing_id +---- + +.OR add multiple data points at once +[source,cypher] +---- +MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) +WHERE exists(list.bedrooms) + AND exists(list.price) + AND NOT exists(list.added) OR list.added = false +SET list.added = true +WITH collect(list.bedrooms) AS bedrooms, collect(list.price) AS prices +CALL regression.linear.addM('airbnb prices', bedrooms, prices) +RETURN bedrooms, prices +---- + +.Next predict price for a four-bedroom listing +[source,cypher] +---- +RETURN regression.linear.predict('airbnb prices', 4) +---- + +.Or make and store many predictions +[source,cypher] +---- +MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) +WHERE exists(list.bedrooms) AND NOT exists(list.price) +SET list.predicted_price = regression.linear.predict(list.bedrooms) +---- + +.You can remove data +[source,cypher] +---- +MATCH (list:Listing {listing_id:2467149})-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) +CALL regression.linear.remove('airbnb prices', list.bedrooms, list.price) +SET list.added = false +---- + +.Add some data from a nearby neighborhood +[source,cypher] +---- +MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78753'}) +WHERE exists(list.bedrooms) + AND exists(list.price) + AND NOT exists(list.added) OR list.added = false +CALL regression.linear.add('airbnb prices', list.bedrooms, list.price) RETURN list +---- + +.Check out the number of data points in your model +[source,cypher] +---- +CALL regression.linear.info('airbnb prices') +YIELD model, state, N +RETURN model, state, N +---- + +.And the statistics +[source,cypher] +---- +CALL regression.linear.stats('airbnb prices') +YIELD intercept, slope, rSquare, significance +RETURN intercept, slope, rSquare, significance +---- + +.Make sure that before shutting down the database, you store the model in the graph or externally +[source,cypher] +---- +MERGE (m:ModelNode {model: 'airbnb prices'}) +SET m.data = regression.linear.serialize('airbnb prices') +RETURN m +---- + +.Delete the model +[source,cypher] +---- +CALL regression.linear.delete('airbnb prices') +YIELD model, state, N +RETURN model, state, N +---- + +.And then when you restart the database, load the model from the graph back into the procedure +[source,cypher] +---- +MATCH (m:ModelNode {model: 'airbnb prices'}) +CALL regression.linear.load('airbnb prices', m.data) +---- + +Now the model is ready for further data changes and predictions! + +== Syntax + +// tag::syntax[] + +If your queries return duplicate values (eg: both directions of the same relationship) then data from the same observation may be added to the model multiple times. This will make your model less accurate. It is recommended that you be careful with queries (eg: specify direction of relationship) or store somewhere in relevant nodes/relationships whether this data has been added to the model. This way you can be sure to select relevant data points which have not yet been added to the model. + +// end::syntax[] + +== References + +// tag::references[] +* https://priceonomics.com/the-discovery-of-statistical-regression/ +* https://en.wikipedia.org/wiki/Regression_analysis +* https://dzone.com/articles/decision-trees-vs-clustering-algorithms-vs-linear +// end::references[] diff --git a/src/main/java/regression/LR.java b/src/main/java/regression/LR.java index 25f14d4..8a1377e 100644 --- a/src/main/java/regression/LR.java +++ b/src/main/java/regression/LR.java @@ -1,7 +1,6 @@ package regression; import org.apache.commons.math3.stat.regression.SimpleRegression; -import org.bytedeco.javacv.FrameFilter; import org.neo4j.graphdb.Entity; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.ResourceIterator; @@ -23,57 +22,75 @@ public class LR { public Log log; @Procedure(value = "regression.linear.create", mode = Mode.READ) + @Description("Create a simple linear regression named 'model'. Returns a stream containing its name (model), state (state), and " + + "number of data points (N).") public Stream create(@Name("model") String model) { return Stream.of((new LRModel(model)).asResult()); } @Procedure(value = "regression.linear.info", mode = Mode.READ) + @Description("Returns a stream containing the model's name (model), state (state), and number of data points (N).") public Stream info(@Name("model") String model) { LRModel lrModel = LRModel.from(model); return Stream.of(lrModel.asResult()); } + @Procedure(value = "regression.linear.stats", mode = Mode.READ) + @Description("Returns a stream containing the model's intercept (intercept), slope (slope), coefficient of determination " + + "(rSquare), and significance of the slope (significance).") + public Stream stat(@Name("model") String model) { + LRModel lrModel = LRModel.from(model); + return Stream.of(lrModel.stats()); + } + @Procedure(value = "regression.linear.add", mode = Mode.READ) + @Description("Void procedure which adds a single data point to 'model'.") public void add(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { LRModel lrModel = LRModel.from(model); lrModel.add(given, expected); } @Procedure(value = "regression.linear.addM", mode = Mode.READ) - public void addM(@Name("model") String model, @Name("given") Object given, @Name("expected") Object expected) { - double[] g = (double[]) given; - double[] e = (double[]) expected; + @Description("Void procedure which adds multiple data points (given[i], expected[i]) to 'model'.") + public void addM(@Name("model") String model, @Name("given") List given, @Name("expected") List expected) { LRModel lrModel = LRModel.from(model); - if (g.length != e.length) throw new IllegalArgumentException("Lengths of the two data arrays are unequal."); - for (int i = 0; i < g.length; i++) { - lrModel.add(g[i], e[i]); + if (given.size() != expected.size()) throw new IllegalArgumentException("Lengths of the two data lists are unequal."); + for (int i = 0; i < given.size(); i++) { + lrModel.add(given.get(i), expected.get(i)); } } @Procedure(value = "regression.linear.remove", mode = Mode.READ) + @Description("Void procedure which removes a single data point from 'model'.") public void remove(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { LRModel lrModel = LRModel.from(model); lrModel.removeData(given, expected); } @Procedure(value = "regression.linear.delete", mode = Mode.READ) + @Description("Deletes 'model' from storage. Returns a stream containing the model's name (model), state (state), and " + + "number of data points (N).") public Stream delete(@Name("model") String model) { return Stream.of(LRModel.removeModel(model)); } @UserFunction(value = "regression.linear.predict") + @Description("Function which returns a single double which is 'model' evaluated at the point 'given'.") public double predict(@Name("mode") String model, @Name("given") double given) { LRModel lrModel = LRModel.from(model); return lrModel.predict(given); } @UserFunction(value = "regression.linear.serialize") + @Description("Function which serializes the model's Java object and returns the byte[] serialization.") public Object serialize(@Name("model") String model) { LRModel lrModel = LRModel.from(model); return lrModel.serialize(); } @Procedure(value = "regression.linear.load", mode = Mode.READ) + @Description("Loads the model stored in data into the procedure's memory under the name 'model'. 'data' must be a byte array. " + + "Returns a stream containing the model's name (model), state (state), and number of data points (N).") public Stream load(@Name("model") String model, @Name("data") Object data) { SimpleRegression R; try { R = (SimpleRegression) convertFromBytes((byte[]) data); } @@ -103,6 +120,20 @@ ModelResult withInfo(Object...infos) { } } + public static class StatResult { + public final double intercept; + public final double slope; + public final double rSquare; + public final double significance; + + public StatResult(double intercept, double slope, double rSquare, double significance) { + this.intercept = intercept; + this.slope = slope; + this.rSquare = rSquare; + this.significance = significance; + } + } + //Serializes the object into a byte array for storage public static byte[] convertToBytes(Object object) throws IOException { try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); diff --git a/src/test/java/regression/LRTest.java b/src/test/java/regression/LRTest.java index e57fa62..fbe06f0 100644 --- a/src/test/java/regression/LRTest.java +++ b/src/test/java/regression/LRTest.java @@ -3,6 +3,8 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; +import java.util.List; +import java.util.Arrays; import org.junit.After; import org.junit.Before; @@ -143,8 +145,8 @@ public void regression() throws Exception { check(result, expected); //test addM procedure for adding multiple data points - double[] points = {7.0, 8.0}; - double[] observed = {6.900, 9.234}; + List points = Arrays.asList(7.0, 8.0); + List observed = Arrays.asList(6.900, 9.234); params.put("points", points); params.put("observed", observed); From 7fb116c25faae1f0c284ab67001e4e5321583032 Mon Sep 17 00:00:00 2001 From: Lauren Shin Date: Mon, 4 Jun 2018 15:15:58 -0700 Subject: [PATCH 7/7] fixed doc --- asciidoc => asciidoc/simple-linear-regression.adoc | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename asciidoc => asciidoc/simple-linear-regression.adoc (100%) diff --git a/asciidoc b/asciidoc/simple-linear-regression.adoc similarity index 100% rename from asciidoc rename to asciidoc/simple-linear-regression.adoc