-
Notifications
You must be signed in to change notification settings - Fork 13
added linear regression #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| package regression; | ||
|
|
||
| import org.apache.commons.math3.stat.regression.SimpleRegression; | ||
| import org.neo4j.graphdb.Entity; | ||
| import org.neo4j.graphdb.GraphDatabaseService; | ||
| import org.neo4j.graphdb.ResourceIterator; | ||
| import org.neo4j.logging.Log; | ||
| import org.neo4j.procedure.*; | ||
| import org.neo4j.procedure.Mode; | ||
|
|
||
| import java.io.*; | ||
| import java.util.*; | ||
| import java.util.stream.Stream; | ||
|
|
||
| public class LR { | ||
| @Context | ||
| public GraphDatabaseService db; | ||
|
|
||
| @Context | ||
| public Log log; | ||
|
|
||
| @Procedure(value = "regression.linear.create", mode = Mode.READ) | ||
| public Stream<ModelResult> create(@Name("model") String model) { | ||
| return Stream.of((new LRModel(model)).asResult()); | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an info/details function or procedure that just returns data on the model by name. |
||
|
|
||
| @Procedure(value = "regression.linear.addData", mode = Mode.READ) | ||
| public Stream<ModelResult> addData(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { | ||
|
||
| LRModel lrModel = LRModel.from(model); | ||
| lrModel.add(given, expected); | ||
| return Stream.of(lrModel.asResult()); | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.removeData", mode = Mode.READ) | ||
| public Stream<ModelResult> removeData(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) { | ||
|
||
| LRModel lrModel = LRModel.from(model); | ||
| lrModel.removeData(given, expected); | ||
| return Stream.of(lrModel.asResult()); | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.removeModel", mode = Mode.READ) | ||
| public Stream<ModelResult> removeModel(@Name("model") String model) { | ||
|
||
| return Stream.of(LRModel.removeModel(model)); | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.predict", mode = Mode.READ) | ||
| public Stream<PredictResult> predict(@Name("mode") String model, @Name("given") double given) { | ||
| LRModel lrModel = LRModel.from(model); | ||
| return Stream.of(lrModel.predict(given)); | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.storeModel", mode = Mode.WRITE) | ||
| public Stream<ModelResult> storeModel(@Name("model") String model) { | ||
|
||
| LRModel lrModel = LRModel.from(model); | ||
| lrModel.store(db); | ||
| return Stream.of(lrModel.asResult()); | ||
| } | ||
|
|
||
| @Procedure(value = "regression.linear.createFromStorage", mode = Mode.READ) | ||
| public Stream<ModelResult> createFromStorage(@Name("model") String model) { | ||
|
||
| Map<String, Object> parameters = new HashMap<>(); | ||
| parameters.put("name", model); | ||
| Entity modelNode; | ||
| SimpleRegression R; | ||
| try { | ||
| ResourceIterator<Entity> n = db.execute("MATCH (n:LRModel {name:$name}) RETURN " + | ||
| "n", parameters).columnAs("n"); | ||
| modelNode = n.next(); | ||
| byte[] m = (byte[]) modelNode.getProperty("serializedModel"); | ||
| R = (SimpleRegression) convertFromBytes(m); | ||
| } catch (Exception e) { | ||
| throw new RuntimeException("no existing model for specified independent and dependent variables and model ID"); | ||
| } | ||
| return Stream.of(new LRModel(model, R, (String) modelNode.getProperty("state")).asResult()); | ||
| } | ||
|
|
||
| public static class ModelResult { | ||
| public final String model; | ||
| public final String state; | ||
| public final double N; | ||
| public final Map<String,Object> info = new HashMap<>(); | ||
|
|
||
| public ModelResult(String model, LRModel.State state, double N) { | ||
| this.model = model; | ||
| this.state = state.name(); | ||
| this.N = N; | ||
| } | ||
|
|
||
| ModelResult withInfo(Object...infos) { | ||
| for (int i = 0; i < infos.length; i+=2) { | ||
| info.put(infos[i].toString(),infos[i+1]); | ||
| } | ||
| return this; | ||
| } | ||
| } | ||
|
|
||
| public static class PredictResult { | ||
| public final double prediction; | ||
| public PredictResult(double p) { | ||
| this.prediction = p; | ||
| } | ||
| } | ||
|
|
||
| //Serializes the object into a byte array for storage | ||
| public static byte[] convertToBytes(Object object) throws IOException { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For future proof it would probably be better at some point to serialize using some explicit data from the model.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea I agree, it's just difficult because you cannot create a new SimpleRegression object using the information it stores (because is doesn't actually store individual data points). Maybe I should consider using a different library? |
||
| try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); | ||
| ObjectOutput out = new ObjectOutputStream(bos)) { | ||
| out.writeObject(object); | ||
| return bos.toByteArray(); | ||
| } | ||
| } | ||
|
|
||
| //de serializes the byte array and returns the stored object | ||
| private static Object convertFromBytes(byte[] bytes) throws IOException, ClassNotFoundException { | ||
| try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); | ||
| ObjectInput in = new ObjectInputStream(bis)) { | ||
| return in.readObject(); | ||
| } | ||
| } | ||
|
|
||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| package regression; | ||
| import java.util.concurrent.ConcurrentHashMap; | ||
| import java.util.*; | ||
| import org.apache.commons.math3.stat.regression.SimpleRegression; | ||
| import org.neo4j.graphdb.GraphDatabaseService; | ||
| import org.neo4j.graphdb.*; | ||
| import org.neo4j.procedure.UserAggregationUpdate; | ||
| import java.io.*; | ||
|
|
||
| public class LRModel { | ||
|
|
||
| private static ConcurrentHashMap<String, LRModel> models = new ConcurrentHashMap<>(); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm torn about having too many "model" storage maps in the plugin, perhaps we should unify this into one.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But we can leave that for later. |
||
| private final String name; | ||
| private State state; | ||
| SimpleRegression R; | ||
|
|
||
|
|
||
| public LRModel(String model) { | ||
| if (models.containsKey(model)) | ||
| throw new IllegalArgumentException("Model " + model + " already exists, please remove it first"); | ||
| this.name = model; | ||
| this.state = State.created; | ||
| this.R = new SimpleRegression(); | ||
| models.put(name, this); | ||
| } | ||
|
|
||
| public LRModel(String model, SimpleRegression R, String state) { | ||
| if (models.containsKey(model)) | ||
| throw new IllegalArgumentException("Model " + model + " already exists, please remove it first"); | ||
| this.name = model; | ||
| switch(state) { | ||
| case "created": this.state = State.created; | ||
| case "ready": this.state = State.ready; | ||
| case "removed": this.state = State.removed; | ||
| case "unknown": this.state = State.unknown; | ||
| } | ||
| this.R = R; | ||
| models.put(name, this); | ||
| } | ||
|
|
||
| public static LRModel from(String name) { | ||
| LRModel model = models.get(name); | ||
| if (model != null) return model; | ||
| throw new IllegalArgumentException("No valid LR-Model " + name); | ||
| } | ||
|
|
||
| public void add(double given, double expected) { | ||
| R.addData(given, expected); | ||
| if (R.getN() > 1) { | ||
| this.state = State.ready; | ||
| } | ||
| } | ||
|
|
||
| public LR.PredictResult predict(double given) { | ||
| if (this.state == State.ready) | ||
| return new LR.PredictResult(R.predict(given)); | ||
| throw new IllegalArgumentException("Not enough data in model to predict yet"); | ||
| } | ||
|
|
||
|
|
||
| public void removeData(double given, double expected) { | ||
| R.removeData(given, expected); | ||
| if (R.getN() < 2) { | ||
| this.state = State.created; | ||
| } | ||
| } | ||
|
|
||
| public static LR.ModelResult removeModel(String model) { | ||
| LRModel existing = models.remove(model); | ||
| return new LR.ModelResult(model, existing == null ? State.unknown : State.removed, 0); | ||
| } | ||
|
|
||
| public void store(GraphDatabaseService db) { | ||
| try{ byte[] serializedR = LR.convertToBytes(R); | ||
| Map<String, Object> parameters = new HashMap<>(); | ||
| parameters.put("name", name); | ||
| ResourceIterator<Entity> n = db.execute("MERGE (n:LRModel {name:$name}) RETURN n", parameters).columnAs("n"); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change to embedded API or just return byte[] |
||
| Entity modelNode = n.next(); | ||
| modelNode.setProperty("state", state.name()); | ||
| modelNode.setProperty("serializedModel", serializedR);} | ||
| catch (IOException e) { throw new RuntimeException(name + " cannot be serialized."); } | ||
| } | ||
|
|
||
| /*protected void initTypes(Map<String, String> types, String output) { | ||
| if (!types.containsKey(output)) throw new IllegalArgumentException("Outputs not defined: " + output); | ||
| int i = 0; | ||
| for (Map.Entry<String, String> entry : types.entrySet()) { | ||
| String key = entry.getKey(); | ||
| this.types.put(key, DataType.from(entry.getValue())); | ||
| if (!key.equals(output)) this.offsets.put(key, i++); | ||
| } | ||
| this.offsets.put(output, i); | ||
| }*/ | ||
|
|
||
|
|
||
| /*enum DataType { | ||
| _class, _float, _order; | ||
| public static DataType from(String type) { | ||
| switch (type.toUpperCase()) { | ||
| case "CLASS": | ||
| return DataType._class; | ||
| case "FLOAT": | ||
| return DataType._float; | ||
| case "ORDER": | ||
| return DataType._order; | ||
| default: | ||
| throw new IllegalArgumentException("Unknown type: " + type); | ||
| } | ||
| } | ||
| }*/ | ||
|
|
||
| public enum State {created, ready, removed, unknown} | ||
|
|
||
| public LR.ModelResult asResult() { | ||
| LR.ModelResult result = new LR.ModelResult(this.name, this.state, R.getN()); | ||
| return result; | ||
| } | ||
|
|
||
|
|
||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not do this but use embedded testing for the library, it's faster than going through server + driver