diff --git a/asciidoc/simple-linear-regression.adoc b/asciidoc/simple-linear-regression.adoc new file mode 100644 index 0000000..8da411c --- /dev/null +++ b/asciidoc/simple-linear-regression.adoc @@ -0,0 +1,151 @@ += Simple Linear Regression + +// tag::introduction[] +Regression is a statistical tool for investigating the relationships between variables. Simple linear regression is the simplest form of regression; it creates a linear model for the relationship between the dependent variable and a single independent variable. Visually, simple linear regression "draws" a trend line on the scatter plot of two variables that best approximates their linear relationship. The model can be expressed with the two parameters of its line: slope and intercept. This is one of the most popular tools in statistics, and it is frequently used as a predictor for machine learning. +// end::introduction[] + +== Explanation and history + +// tag::explanation[] +At the core of linear regression is the method of least squares. In this method, the linear trend line is chosen which minimizes the sum of every data point's squared residual (deviation from the model). The method of least squares was independently discovered by Carl Friedrich-Gauss and Adrien-Marie Legendre in the early 19th century. The linear regression methods used today can be primarily attributed to the work of R.A. Fisher in the 1920s. +// end::explanation[] + +== Use-cases + +// tag::use-case[] +In simple linear regression, both the independent and dependent variables must be numeric. The dependent variable (`y`) can then be expressed in terms of the independent variable (`x`) using the two line parameters slope (`m`) and intercept (`b`) with the equation `y = m * x + b`. For these approximations to be meaningful, the dependent variable should take continuous values. The relationship between any two variables satisfying these conditions can be analyzed with simple linear regression. However, the model will only be successful for linearly related data. Some common examples include: + +* Predicting housing prices with square footage, number of bedrooms, number of bathrooms, etc. +* Analyzing sales of a product using pricing or performance information +* Calculating causal relationships between parameters in biological systems +// end::use-case[] + +== Constraints + +// tag::constraints[] +Because simple linear regression is so straightforward, it can be used with any numeric data pair. The real question is how well the best model fits the data. There are several measurements which attempt to quantify the success of the model. For example, the coefficient of determination (`r^2^`) is the proportion of the variance in the dependent variable that is predictable from the independent variable. A coefficient `r^2^ = 1` indicates that the variance in the dependent variable is entirely predictable from the independent variable (and thus the model is perfect). +// end::use-case[] + +== Example + +Let's look at a straightforward example--predicting Airbnb listing prices using the listing's number of bedrooms. Run `:play http://guides.neo4j.com/listings` and follow the import statements to load Will Lyon's Airbnb graph. + +.First initialize the model +[source,cypher] +---- +CALL regression.linear.create('airbnb prices') +---- + +.Then add data point by point +[source,cypher] +---- +MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) +WHERE exists(list.bedrooms) + AND exists(list.price) + AND NOT exists(list.added) OR list.added = false +CALL regression.linear.add('airbnb prices', list.bedrooms, list.price) +SET list.added = true +RETURN list.listing_id +---- + +.OR add multiple data points at once +[source,cypher] +---- +MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) +WHERE exists(list.bedrooms) + AND exists(list.price) + AND NOT exists(list.added) OR list.added = false +SET list.added = true +WITH collect(list.bedrooms) AS bedrooms, collect(list.price) AS prices +CALL regression.linear.addM('airbnb prices', bedrooms, prices) +RETURN bedrooms, prices +---- + +.Next predict price for a four-bedroom listing +[source,cypher] +---- +RETURN regression.linear.predict('airbnb prices', 4) +---- + +.Or make and store many predictions +[source,cypher] +---- +MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) +WHERE exists(list.bedrooms) AND NOT exists(list.price) +SET list.predicted_price = regression.linear.predict(list.bedrooms) +---- + +.You can remove data +[source,cypher] +---- +MATCH (list:Listing {listing_id:2467149})-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) +CALL regression.linear.remove('airbnb prices', list.bedrooms, list.price) +SET list.added = false +---- + +.Add some data from a nearby neighborhood +[source,cypher] +---- +MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78753'}) +WHERE exists(list.bedrooms) + AND exists(list.price) + AND NOT exists(list.added) OR list.added = false +CALL regression.linear.add('airbnb prices', list.bedrooms, list.price) RETURN list +---- + +.Check out the number of data points in your model +[source,cypher] +---- +CALL regression.linear.info('airbnb prices') +YIELD model, state, N +RETURN model, state, N +---- + +.And the statistics +[source,cypher] +---- +CALL regression.linear.stats('airbnb prices') +YIELD intercept, slope, rSquare, significance +RETURN intercept, slope, rSquare, significance +---- + +.Make sure that before shutting down the database, you store the model in the graph or externally +[source,cypher] +---- +MERGE (m:ModelNode {model: 'airbnb prices'}) +SET m.data = regression.linear.serialize('airbnb prices') +RETURN m +---- + +.Delete the model +[source,cypher] +---- +CALL regression.linear.delete('airbnb prices') +YIELD model, state, N +RETURN model, state, N +---- + +.And then when you restart the database, load the model from the graph back into the procedure +[source,cypher] +---- +MATCH (m:ModelNode {model: 'airbnb prices'}) +CALL regression.linear.load('airbnb prices', m.data) +---- + +Now the model is ready for further data changes and predictions! + +== Syntax + +// tag::syntax[] + +If your queries return duplicate values (eg: both directions of the same relationship) then data from the same observation may be added to the model multiple times. This will make your model less accurate. It is recommended that you be careful with queries (eg: specify direction of relationship) or store somewhere in relevant nodes/relationships whether this data has been added to the model. This way you can be sure to select relevant data points which have not yet been added to the model. + +// end::syntax[] + +== References + +// tag::references[] +* https://priceonomics.com/the-discovery-of-statistical-regression/ +* https://en.wikipedia.org/wiki/Regression_analysis +* https://dzone.com/articles/decision-trees-vs-clustering-algorithms-vs-linear +// end::references[] diff --git a/pom.xml b/pom.xml index 5f84031..5918d2f 100644 --- a/pom.xml +++ b/pom.xml @@ -56,6 +56,7 @@ commons-math3 3.4.1 + org.apache.commons commons-lang3 diff --git a/src/main/java/regression/LR.java b/src/main/java/regression/LR.java new file mode 100644 index 0000000..8a1377e --- /dev/null +++ b/src/main/java/regression/LR.java @@ -0,0 +1,157 @@ +package regression; + +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.neo4j.graphdb.Entity; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.ResourceIterator; +import org.neo4j.logging.Log; +import org.neo4j.procedure.*; +import org.neo4j.procedure.Mode; +import org.neo4j.unsafe.impl.batchimport.cache.ByteArray; +import sun.java2d.pipe.SpanShapeRenderer; + +import java.io.*; +import java.util.*; +import java.util.stream.Stream; + +public class LR { + @Context + public GraphDatabaseService db; + + @Context + public Log log; + + @Procedure(value = "regression.linear.create", mode = Mode.READ) + @Description("Create a simple linear regression named 'model'. Returns a stream containing its name (model), state (state), and " + + "number of data points (N).") + public Stream create(@Name("model") String model) { + return Stream.of((new LRModel(model)).asResult()); + } + + @Procedure(value = "regression.linear.info", mode = Mode.READ) + @Description("Returns a stream containing the model's name (model), state (state), and number of data points (N).") + public Stream info(@Name("model") String model) { + LRModel lrModel = LRModel.from(model); + return Stream.of(lrModel.asResult()); + } + + @Procedure(value = "regression.linear.stats", mode = Mode.READ) + @Description("Returns a stream containing the model's intercept (intercept), slope (slope), coefficient of determination " + + "(rSquare), and significance of the slope (significance).") + public Stream stat(@Name("model") String model) { + LRModel lrModel = LRModel.from(model); + return Stream.of(lrModel.stats()); + } + + @Procedure(value = "regression.linear.add", mode = Mode.READ) + @Description("Void procedure which adds a single data point to 'model'.") + public void add(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { + LRModel lrModel = LRModel.from(model); + lrModel.add(given, expected); + } + + @Procedure(value = "regression.linear.addM", mode = Mode.READ) + @Description("Void procedure which adds multiple data points (given[i], expected[i]) to 'model'.") + public void addM(@Name("model") String model, @Name("given") List given, @Name("expected") List expected) { + LRModel lrModel = LRModel.from(model); + if (given.size() != expected.size()) throw new IllegalArgumentException("Lengths of the two data lists are unequal."); + for (int i = 0; i < given.size(); i++) { + lrModel.add(given.get(i), expected.get(i)); + } + } + + @Procedure(value = "regression.linear.remove", mode = Mode.READ) + @Description("Void procedure which removes a single data point from 'model'.") + public void remove(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { + LRModel lrModel = LRModel.from(model); + lrModel.removeData(given, expected); + } + + @Procedure(value = "regression.linear.delete", mode = Mode.READ) + @Description("Deletes 'model' from storage. Returns a stream containing the model's name (model), state (state), and " + + "number of data points (N).") + public Stream delete(@Name("model") String model) { + return Stream.of(LRModel.removeModel(model)); + } + + @UserFunction(value = "regression.linear.predict") + @Description("Function which returns a single double which is 'model' evaluated at the point 'given'.") + public double predict(@Name("mode") String model, @Name("given") double given) { + LRModel lrModel = LRModel.from(model); + return lrModel.predict(given); + } + + @UserFunction(value = "regression.linear.serialize") + @Description("Function which serializes the model's Java object and returns the byte[] serialization.") + public Object serialize(@Name("model") String model) { + LRModel lrModel = LRModel.from(model); + return lrModel.serialize(); + } + + @Procedure(value = "regression.linear.load", mode = Mode.READ) + @Description("Loads the model stored in data into the procedure's memory under the name 'model'. 'data' must be a byte array. " + + "Returns a stream containing the model's name (model), state (state), and number of data points (N).") + public Stream load(@Name("model") String model, @Name("data") Object data) { + SimpleRegression R; + try { R = (SimpleRegression) convertFromBytes((byte[]) data); } + catch (Exception e) { + throw new RuntimeException("invalid data"); + } + return Stream.of((new LRModel(model, R)).asResult()); + } + + public static class ModelResult { + public final String model; + public final String state; + public final double N; + public final Map info = new HashMap<>(); + + public ModelResult(String model, LRModel.State state, double N) { + this.model = model; + this.state = state.name(); + this.N = N; + } + + ModelResult withInfo(Object...infos) { + for (int i = 0; i < infos.length; i+=2) { + info.put(infos[i].toString(),infos[i+1]); + } + return this; + } + } + + public static class StatResult { + public final double intercept; + public final double slope; + public final double rSquare; + public final double significance; + + public StatResult(double intercept, double slope, double rSquare, double significance) { + this.intercept = intercept; + this.slope = slope; + this.rSquare = rSquare; + this.significance = significance; + } + } + + //Serializes the object into a byte array for storage + public static byte[] convertToBytes(Object object) throws IOException { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = new ObjectOutputStream(bos)) { + out.writeObject(object); + return bos.toByteArray(); + } + } + + //de serializes the byte array and returns the stored object + public static Object convertFromBytes(byte[] bytes) throws IOException, ClassNotFoundException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + ObjectInput in = new ObjectInputStream(bis)) { + return in.readObject(); + } + } + + + + +} \ No newline at end of file diff --git a/src/main/java/regression/LRModel.java b/src/main/java/regression/LRModel.java new file mode 100644 index 0000000..17b6ff2 --- /dev/null +++ b/src/main/java/regression/LRModel.java @@ -0,0 +1,130 @@ +package regression; +import java.util.concurrent.ConcurrentHashMap; +import java.util.*; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.*; +import org.neo4j.procedure.UserAggregationUpdate; +import java.io.*; + +public class LRModel { + + private static ConcurrentHashMap models = new ConcurrentHashMap<>(); + private final String name; + private State state; + SimpleRegression R; + + + public LRModel(String model) { + if (models.containsKey(model)) + throw new IllegalArgumentException("Model " + model + " already exists, please remove it first"); + this.name = model; + this.state = State.created; + this.R = new SimpleRegression(); + models.put(name, this); + } + + public LRModel(String model, SimpleRegression R) { + if (models.containsKey(model)) + throw new IllegalArgumentException("Model " + model + " already exists, please remove it first"); + this.name = model; + if (R == null) { + this.R = new SimpleRegression(); + this.state = State.created; + } else { + this.R = R; + if (R.getN() < 2) + this.state = State.created; + else + this.state = State.ready; + } + models.put(name, this); + } + + public static LRModel from(String name) { + LRModel model = models.get(name); + if (model != null) return model; + throw new IllegalArgumentException("No valid LR-Model " + name); + } + + public void add(double given, double expected) { + R.addData(given, expected); + if (R.getN() > 1) { + this.state = State.ready; + } + } + + public double predict(double given) { + if (this.state == State.ready) + return R.predict(given); + throw new IllegalArgumentException("Not enough data in model to predict yet"); + } + + + public void removeData(double given, double expected) { + R.removeData(given, expected); + if (R.getN() < 2) { + this.state = State.created; + } + } + + public byte[] serialize() { + try { return LR.convertToBytes(R); } + catch (IOException e) { throw new RuntimeException(name + " cannot be serialized."); } + } + + public static LR.ModelResult removeModel(String model) { + LRModel existing = models.remove(model); + return new LR.ModelResult(model, existing == null ? State.unknown : State.removed, 0); + } + + public void store(GraphDatabaseService db) { + try{ byte[] serializedR = LR.convertToBytes(R); + Map parameters = new HashMap<>(); + parameters.put("name", name); + ResourceIterator n = db.execute("MERGE (n:LRModel {name:$name}) RETURN n", parameters).columnAs("n"); + Entity modelNode = n.next(); + modelNode.setProperty("state", state.name()); + modelNode.setProperty("serializedModel", serializedR);} + catch (IOException e) { throw new RuntimeException(name + " cannot be serialized."); } + } + + /*protected void initTypes(Map types, String output) { + if (!types.containsKey(output)) throw new IllegalArgumentException("Outputs not defined: " + output); + int i = 0; + for (Map.Entry entry : types.entrySet()) { + String key = entry.getKey(); + this.types.put(key, DataType.from(entry.getValue())); + if (!key.equals(output)) this.offsets.put(key, i++); + } + this.offsets.put(output, i); + }*/ + + + /*enum DataType { + _class, _float, _order; + + public static DataType from(String type) { + switch (type.toUpperCase()) { + case "CLASS": + return DataType._class; + case "FLOAT": + return DataType._float; + case "ORDER": + return DataType._order; + default: + throw new IllegalArgumentException("Unknown type: " + type); + } + } + }*/ + + public enum State {created, ready, removed, unknown} + + public LR.ModelResult asResult() { + LR.ModelResult result = new LR.ModelResult(this.name, this.state, this.R.getN()); + return result; + } + + + +} diff --git a/src/test/java/regression/LRTest.java b/src/test/java/regression/LRTest.java new file mode 100644 index 0000000..fbe06f0 --- /dev/null +++ b/src/test/java/regression/LRTest.java @@ -0,0 +1,167 @@ +package regression; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.List; +import java.util.Arrays; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.*; + +import static org.junit.Assert.*; +import static org.hamcrest.CoreMatchers.equalTo; + +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.neo4j.kernel.impl.proc.Procedures; +import org.neo4j.kernel.internal.GraphDatabaseAPI; +import org.neo4j.test.TestGraphDatabaseFactory; + +public class LRTest { + + private GraphDatabaseService db; + + @Before + public void setUp() throws Exception { + db = new TestGraphDatabaseFactory().newImpermanentDatabase(); + Procedures procedures = ((GraphDatabaseAPI) db).getDependencyResolver().resolveDependency(Procedures.class); + procedures.registerProcedure(LR.class); + procedures.registerFunction(LR.class); + } + + @After + public void tearDown() throws Exception { + db.shutdown(); + } + + private void exhaust(Iterator r) { + while(r.hasNext()) r.next(); + } + + private void check(Result result, Map expected) { + while (result.hasNext()) { + Map actual = result.next(); + + double time = (double) actual.get("time"); + double expectedPrediction = expected.get(time); + double actualPrediction = (double) actual.get("predictedProgress"); + + assertThat( actualPrediction, equalTo( expectedPrediction ) ); + } + } + + @Test + public void regression() throws Exception { + //create known relationships for times 1, 2, 3 + db.execute("CREATE (:Node {id:1}) - [:WORKS_FOR {time:1.0, progress:1.345}] -> " + + "(:Node {id:2}) - [:WORKS_FOR {time:2.0, progress:2.596}] -> " + + "(:Node {id:3}) - [:WORKS_FOR {time:3.0, progress:3.259}] -> (:Node {id:4})"); + + //create unknown relationships for times 4, 5 + db.execute("CREATE (:Node {id:5}) -[:WORKS_FOR {time:4.0}] -> " + + "(:Node {id:6}) - [:WORKS_FOR {time:5.0}] -> (:Node {id:7})"); + + //initialize the model + db.execute("CALL regression.linear.create('work and progress')"); + + //add known data + Result r = db.execute("MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND exists(r.progress) CALL " + + "regression.linear.add('work and progress', r.time, r.progress) RETURN r"); + //these rows are computed lazily so we must iterate through all rows to ensure all data points are added to the model + exhaust(r); + + //check that the correct info is stored in the model (should contain 3 data points) + Map info = db.execute("CALL regression.linear.info('work and progress') YIELD model, state, N " + + "RETURN model, state, N").next(); + assertTrue(info.get("model").equals("work and progress")); + assertTrue(info.get("state").equals("ready")); + assertThat(info.get("N"), equalTo(3.0)); + + //store predictions + String storePredictions = "MATCH (:Node)-[r:WORKS_FOR]->(:Node) WHERE exists(r.time) AND NOT exists(r.progress) " + + "SET r.predictedProgress = regression.linear.predict('work and progress', r.time)"; + db.execute(storePredictions); + + //check that predictions are correct + + SimpleRegression R = new SimpleRegression(); + R.addData(1.0, 1.345); + R.addData(2.0, 2.596); + R.addData(3.0, 3.259); + HashMap expected = new HashMap<>(); + expected.put(4.0, R.predict(4.0)); + expected.put(5.0, R.predict(5.0)); + + String gatherPredictedValues = "MATCH () - [r:WORKS_FOR] -> () WHERE exists(r.time) AND " + + "exists(r.predictedProgress) RETURN r.time as time, r.predictedProgress as predictedProgress"; + + Result result = db.execute(gatherPredictedValues); + + check(result, expected); + + //serialize the model + Map serial = db.execute("RETURN regression.linear.serialize('work and progress') as data").next(); + Object data = serial.get("data"); + + //check that the byte[] model returns same predictions as the model stored in the procedure + SimpleRegression storedR = (SimpleRegression) LR.convertFromBytes((byte[]) data); + assertThat(storedR.predict(4.0), equalTo(expected.get(4.0))); + assertThat(storedR.predict(5.0), equalTo(expected.get(5.0))); + + //delete model then re-create using serialization + db.execute("CALL regression.linear.delete('work and progress')"); + Map params = new HashMap<>(); + params.put("data", data); + db.execute("CALL regression.linear.load('work and progress', $data)", params); + + //remove data from relationship between nodes 1 and 2 + r = db.execute("MATCH (:Node {id:1})-[r:WORKS_FOR]->(:Node {id:2}) CALL regression.linear.remove('work " + + "and progress', r.time, r.progress) return r"); + exhaust(r); + + //create a new relationship between nodes 7 and 8 + db.execute("MATCH (n7:Node {id:7}) MERGE (n7)-[:WORKS_FOR {time:6.0, progress:5.870}]->(:Node {id:8})"); + + //add data from new relationship to model + r = db.execute("MATCH (:Node {id:7})-[r:WORKS_FOR]->(:Node {id:8}) CALL regression.linear.add('work " + + "and progress', r.time, r.progress) RETURN r.time"); + //again must iterate through rows + exhaust(r); + + //map new model on all relationships with unknown progress + db.execute(storePredictions); + + //replicate the creation and updates of the model + R.removeData(1.0, 1.345); + R.addData(6.0, 5.870); + expected.put(4.0, R.predict(4.0)); + expected.put(5.0, R.predict(5.0)); + + //make sure predicted values are correct + result = db.execute(gatherPredictedValues); + check(result, expected); + + //test addM procedure for adding multiple data points + List points = Arrays.asList(7.0, 8.0); + List observed = Arrays.asList(6.900, 9.234); + params.put("points", points); + params.put("observed", observed); + + db.execute("CALL regression.linear.addM('work and progress', $points, $observed)", params); + db.execute(storePredictions); + R.addData(7.0, 6.900); + R.addData(8.0, 9.234); + expected.put(4.0, R.predict(4.0)); + expected.put(5.0, R.predict(5.0)); + result = db.execute(gatherPredictedValues); + + check(result, expected); + + db.execute("CALL regression.linear.delete('work and progress')").close(); + + + } +}