-
Notifications
You must be signed in to change notification settings - Fork 13
added linear regression #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
L-Shin
wants to merge
7
commits into
neo4j-contrib:master
Choose a base branch
from
L-Shin:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| = Simple Linear Regression | ||
|
|
||
| // tag::introduction[] | ||
| Regression is a statistical tool for investigating the relationships between variables. Simple linear regression is the simplest form of regression; it creates a linear model for the relationship between the dependent variable and a single independent variable. Visually, simple linear regression "draws" a trend line on the scatter plot of two variables that best approximates their linear relationship. The model can be expressed with the two parameters of its line: slope and intercept. This is one of the most popular tools in statistics, and it is frequently used as a predictor for machine learning. | ||
| // end::introduction[] | ||
|
|
||
| == Explanation and history | ||
|
|
||
| // tag::explanation[] | ||
| At the core of linear regression is the method of least squares. In this method, the linear trend line is chosen which minimizes the sum of every data point's squared residual (deviation from the model). The method of least squares was independently discovered by Carl Friedrich-Gauss and Adrien-Marie Legendre in the early 19th century. The linear regression methods used today can be primarily attributed to the work of R.A. Fisher in the 1920s. | ||
| // end::explanation[] | ||
|
|
||
| == Use-cases | ||
|
|
||
| // tag::use-case[] | ||
| In simple linear regression, both the independent and dependent variables must be numeric. The dependent variable (`y`) can then be expressed in terms of the independent variable (`x`) using the two line parameters slope (`m`) and intercept (`b`) with the equation `y = m * x + b`. For these approximations to be meaningful, the dependent variable should take continuous values. The relationship between any two variables satisfying these conditions can be analyzed with simple linear regression. However, the model will only be successful for linearly related data. Some common examples include: | ||
|
|
||
| * Predicting housing prices with square footage, number of bedrooms, number of bathrooms, etc. | ||
| * Analyzing sales of a product using pricing or performance information | ||
| * Calculating causal relationships between parameters in biological systems | ||
| // end::use-case[] | ||
|
|
||
| == Constraints | ||
|
|
||
| // tag::constraints[] | ||
| Because simple linear regression is so straightforward, it can be used with any numeric data pair. The real question is how well the best model fits the data. There are several measurements which attempt to quantify the success of the model. For example, the coefficient of determination (`r^2^`) is the proportion of the variance in the dependent variable that is predictable from the independent variable. A coefficient `r^2^ = 1` indicates that the variance in the dependent variable is entirely predictable from the independent variable (and thus the model is perfect). | ||
| // end::use-case[] | ||
|
|
||
| == Example | ||
|
|
||
| Let's look at a straightforward example--predicting Airbnb listing prices using the listing's number of bedrooms. Run `:play http://guides.neo4j.com/listings` and follow the import statements to load Will Lyon's Airbnb graph. | ||
|
|
||
| .First initialize the model | ||
| [source,cypher] | ||
| ---- | ||
| CALL regression.linear.create('airbnb prices') | ||
| ---- | ||
|
|
||
| .Then add data point by point | ||
| [source,cypher] | ||
| ---- | ||
| MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) | ||
| WHERE exists(list.bedrooms) | ||
| AND exists(list.price) | ||
| AND NOT exists(list.added) OR list.added = false | ||
| CALL regression.linear.add('airbnb prices', list.bedrooms, list.price) | ||
| SET list.added = true | ||
| RETURN list.listing_id | ||
| ---- | ||
|
|
||
| .OR add multiple data points at once | ||
| [source,cypher] | ||
| ---- | ||
| MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) | ||
| WHERE exists(list.bedrooms) | ||
| AND exists(list.price) | ||
| AND NOT exists(list.added) OR list.added = false | ||
| SET list.added = true | ||
| WITH collect(list.bedrooms) AS bedrooms, collect(list.price) AS prices | ||
| CALL regression.linear.addM('airbnb prices', bedrooms, prices) | ||
| RETURN bedrooms, prices | ||
| ---- | ||
|
|
||
| .Next predict price for a four-bedroom listing | ||
| [source,cypher] | ||
| ---- | ||
| RETURN regression.linear.predict('airbnb prices', 4) | ||
| ---- | ||
|
|
||
| .Or make and store many predictions | ||
| [source,cypher] | ||
| ---- | ||
| MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) | ||
| WHERE exists(list.bedrooms) AND NOT exists(list.price) | ||
| SET list.predicted_price = regression.linear.predict(list.bedrooms) | ||
| ---- | ||
|
|
||
| .You can remove data | ||
| [source,cypher] | ||
| ---- | ||
| MATCH (list:Listing {listing_id:2467149})-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78752'}) | ||
| CALL regression.linear.remove('airbnb prices', list.bedrooms, list.price) | ||
| SET list.added = false | ||
| ---- | ||
|
|
||
| .Add some data from a nearby neighborhood | ||
| [source,cypher] | ||
| ---- | ||
| MATCH (list:Listing)-[:IN_NEIGHBORHOOD]->(:Neighborhood {neighborhood_id:'78753'}) | ||
| WHERE exists(list.bedrooms) | ||
| AND exists(list.price) | ||
| AND NOT exists(list.added) OR list.added = false | ||
| CALL regression.linear.add('airbnb prices', list.bedrooms, list.price) RETURN list | ||
| ---- | ||
|
|
||
| .Check out the number of data points in your model | ||
| [source,cypher] | ||
| ---- | ||
| CALL regression.linear.info('airbnb prices') | ||
| YIELD model, state, N | ||
| RETURN model, state, N | ||
| ---- | ||
|
|
||
| .And the statistics | ||
| [source,cypher] | ||
| ---- | ||
| CALL regression.linear.stats('airbnb prices') | ||
| YIELD intercept, slope, rSquare, significance | ||
| RETURN intercept, slope, rSquare, significance | ||
| ---- | ||
|
|
||
| .Make sure that before shutting down the database, you store the model in the graph or externally | ||
| [source,cypher] | ||
| ---- | ||
| MERGE (m:ModelNode {model: 'airbnb prices'}) | ||
| SET m.data = regression.linear.serialize('airbnb prices') | ||
| RETURN m | ||
| ---- | ||
|
|
||
| .Delete the model | ||
| [source,cypher] | ||
| ---- | ||
| CALL regression.linear.delete('airbnb prices') | ||
| YIELD model, state, N | ||
| RETURN model, state, N | ||
| ---- | ||
|
|
||
| .And then when you restart the database, load the model from the graph back into the procedure | ||
| [source,cypher] | ||
| ---- | ||
| MATCH (m:ModelNode {model: 'airbnb prices'}) | ||
| CALL regression.linear.load('airbnb prices', m.data) | ||
| ---- | ||
|
|
||
| Now the model is ready for further data changes and predictions! | ||
|
|
||
| == Syntax | ||
|
|
||
| // tag::syntax[] | ||
|
|
||
| If your queries return duplicate values (eg: both directions of the same relationship) then data from the same observation may be added to the model multiple times. This will make your model less accurate. It is recommended that you be careful with queries (eg: specify direction of relationship) or store somewhere in relevant nodes/relationships whether this data has been added to the model. This way you can be sure to select relevant data points which have not yet been added to the model. | ||
|
|
||
| // end::syntax[] | ||
|
|
||
| == References | ||
|
|
||
| // tag::references[] | ||
| * https://priceonomics.com/the-discovery-of-statistical-regression/ | ||
| * https://en.wikipedia.org/wiki/Regression_analysis | ||
| * https://dzone.com/articles/decision-trees-vs-clustering-algorithms-vs-linear | ||
| // end::references[] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| package regression; | ||
|
|
||
| import org.apache.commons.math3.stat.regression.SimpleRegression; | ||
| import org.neo4j.graphdb.Entity; | ||
| import org.neo4j.graphdb.GraphDatabaseService; | ||
| import org.neo4j.graphdb.ResourceIterator; | ||
| import org.neo4j.logging.Log; | ||
| import org.neo4j.procedure.*; | ||
| import org.neo4j.procedure.Mode; | ||
| import org.neo4j.unsafe.impl.batchimport.cache.ByteArray; | ||
| import sun.java2d.pipe.SpanShapeRenderer; | ||
|
|
||
| import java.io.*; | ||
| import java.util.*; | ||
| import java.util.stream.Stream; | ||
|
|
||
| public class LR { | ||
| @Context | ||
| public GraphDatabaseService db; | ||
|
|
||
| @Context | ||
| public Log log; | ||
|
|
||
| @Procedure(value = "regression.linear.create", mode = Mode.READ) | ||
| @Description("Create a simple linear regression named 'model'. Returns a stream containing its name (model), state (state), and " + | ||
| "number of data points (N).") | ||
| public Stream<ModelResult> create(@Name("model") String model) { | ||
| return Stream.of((new LRModel(model)).asResult()); | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.info", mode = Mode.READ) | ||
| @Description("Returns a stream containing the model's name (model), state (state), and number of data points (N).") | ||
| public Stream<ModelResult> info(@Name("model") String model) { | ||
| LRModel lrModel = LRModel.from(model); | ||
| return Stream.of(lrModel.asResult()); | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.stats", mode = Mode.READ) | ||
| @Description("Returns a stream containing the model's intercept (intercept), slope (slope), coefficient of determination " + | ||
| "(rSquare), and significance of the slope (significance).") | ||
| public Stream<StatResult> stat(@Name("model") String model) { | ||
| LRModel lrModel = LRModel.from(model); | ||
| return Stream.of(lrModel.stats()); | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.add", mode = Mode.READ) | ||
| @Description("Void procedure which adds a single data point to 'model'.") | ||
| public void add(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { | ||
| LRModel lrModel = LRModel.from(model); | ||
| lrModel.add(given, expected); | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.addM", mode = Mode.READ) | ||
| @Description("Void procedure which adds multiple data points (given[i], expected[i]) to 'model'.") | ||
| public void addM(@Name("model") String model, @Name("given") List<Double> given, @Name("expected") List<Double> expected) { | ||
| LRModel lrModel = LRModel.from(model); | ||
| if (given.size() != expected.size()) throw new IllegalArgumentException("Lengths of the two data lists are unequal."); | ||
| for (int i = 0; i < given.size(); i++) { | ||
| lrModel.add(given.get(i), expected.get(i)); | ||
| } | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.remove", mode = Mode.READ) | ||
| @Description("Void procedure which removes a single data point from 'model'.") | ||
| public void remove(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { | ||
| LRModel lrModel = LRModel.from(model); | ||
| lrModel.removeData(given, expected); | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.delete", mode = Mode.READ) | ||
| @Description("Deletes 'model' from storage. Returns a stream containing the model's name (model), state (state), and " + | ||
| "number of data points (N).") | ||
| public Stream<ModelResult> delete(@Name("model") String model) { | ||
| return Stream.of(LRModel.removeModel(model)); | ||
| } | ||
|
|
||
| @UserFunction(value = "regression.linear.predict") | ||
| @Description("Function which returns a single double which is 'model' evaluated at the point 'given'.") | ||
| public double predict(@Name("mode") String model, @Name("given") double given) { | ||
| LRModel lrModel = LRModel.from(model); | ||
| return lrModel.predict(given); | ||
| } | ||
|
|
||
| @UserFunction(value = "regression.linear.serialize") | ||
| @Description("Function which serializes the model's Java object and returns the byte[] serialization.") | ||
| public Object serialize(@Name("model") String model) { | ||
| LRModel lrModel = LRModel.from(model); | ||
| return lrModel.serialize(); | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.load", mode = Mode.READ) | ||
| @Description("Loads the model stored in data into the procedure's memory under the name 'model'. 'data' must be a byte array. " + | ||
| "Returns a stream containing the model's name (model), state (state), and number of data points (N).") | ||
| public Stream<ModelResult> load(@Name("model") String model, @Name("data") Object data) { | ||
| SimpleRegression R; | ||
| try { R = (SimpleRegression) convertFromBytes((byte[]) data); } | ||
| catch (Exception e) { | ||
| throw new RuntimeException("invalid data"); | ||
| } | ||
| return Stream.of((new LRModel(model, R)).asResult()); | ||
| } | ||
|
|
||
| public static class ModelResult { | ||
| public final String model; | ||
| public final String state; | ||
| public final double N; | ||
| public final Map<String,Object> info = new HashMap<>(); | ||
|
|
||
| public ModelResult(String model, LRModel.State state, double N) { | ||
| this.model = model; | ||
| this.state = state.name(); | ||
| this.N = N; | ||
| } | ||
|
|
||
| ModelResult withInfo(Object...infos) { | ||
| for (int i = 0; i < infos.length; i+=2) { | ||
| info.put(infos[i].toString(),infos[i+1]); | ||
| } | ||
| return this; | ||
| } | ||
| } | ||
|
|
||
| public static class StatResult { | ||
| public final double intercept; | ||
| public final double slope; | ||
| public final double rSquare; | ||
| public final double significance; | ||
|
|
||
| public StatResult(double intercept, double slope, double rSquare, double significance) { | ||
| this.intercept = intercept; | ||
| this.slope = slope; | ||
| this.rSquare = rSquare; | ||
| this.significance = significance; | ||
| } | ||
| } | ||
|
|
||
| //Serializes the object into a byte array for storage | ||
| public static byte[] convertToBytes(Object object) throws IOException { | ||
| try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); | ||
| ObjectOutput out = new ObjectOutputStream(bos)) { | ||
| out.writeObject(object); | ||
| return bos.toByteArray(); | ||
| } | ||
| } | ||
|
|
||
| //de serializes the byte array and returns the stored object | ||
| public static Object convertFromBytes(byte[] bytes) throws IOException, ClassNotFoundException { | ||
| try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); | ||
| ObjectInput in = new ObjectInputStream(bis)) { | ||
| return in.readObject(); | ||
| } | ||
| } | ||
|
|
||
|
|
||
|
|
||
|
|
||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For future proof it would probably be better at some point to serialize using some explicit data from the model.
For now it's ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I agree, it's just difficult because you cannot create a new SimpleRegression object using the information it stores (because is doesn't actually store individual data points). Maybe I should consider using a different library?