Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
<project.reporting.outputEncoding>${encoding}</project.reporting.outputEncoding>
<maven.compiler.source>${java.version}</maven.compiler.source>
<maven.compiler.target>${java.version}</maven.compiler.target>
<neo4j.version>3.2.2</neo4j.version>
<neo4j.version>3.3.4</neo4j.version>
</properties>

<dependencies>
Expand Down Expand Up @@ -54,7 +54,7 @@
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.4.1</version>
<version>3.6.1</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
Expand Down Expand Up @@ -103,6 +103,21 @@
<scope>test</scope>
</dependency>

<dependency>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not do this but use embedded testing for the library, it's faster than going through server + driver

<groupId>org.neo4j.test</groupId>
<artifactId>neo4j-harness</artifactId>
<version>${neo4j.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<!-- Used to send cypher statements to our procedure. -->
<groupId>org.neo4j.driver</groupId>
<artifactId>neo4j-java-driver</artifactId>
<version>1.5.0</version>
<scope>test</scope>
</dependency>

</dependencies>

<build>
Expand All @@ -119,6 +134,16 @@
</execution>
</executions>
</plugin>

<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.1</version>
<configuration>
<!-- Neo4j Procedures require Java 8 -->
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
122 changes: 122 additions & 0 deletions src/main/java/regression/LR.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package regression;

import org.apache.commons.math3.stat.regression.SimpleRegression;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.ResourceIterator;
import org.neo4j.logging.Log;
import org.neo4j.procedure.*;
import org.neo4j.procedure.Mode;

import java.io.*;
import java.util.*;
import java.util.stream.Stream;

public class LR {
@Context
public GraphDatabaseService db;

@Context
public Log log;

@Procedure(value = "regression.linear.create", mode = Mode.READ)
public Stream<ModelResult> create(@Name("model") String model) {
return Stream.of((new LRModel(model)).asResult());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an info/details function or procedure that just returns data on the model by name.


@Procedure(value = "regression.linear.addData", mode = Mode.READ)
public Stream<ModelResult> addData(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should have this as a void procedure (or have a variant that's void for perf-reasons)
Perhaps also a variant that takes a batch, i.e. two double[] or lists

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree. I will make the remove data procedure void also. If the user desires information about the model, they can call the info procedure

LRModel lrModel = LRModel.from(model);
lrModel.add(given, expected);
return Stream.of(lrModel.asResult());
}

@Procedure(value = "regression.linear.removeData", mode = Mode.READ)
public Stream<ModelResult> removeData(@Name("model") String model, @Name("given") double given, @Name("expected") double expected) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

LRModel lrModel = LRModel.from(model);
lrModel.removeData(given, expected);
return Stream.of(lrModel.asResult());
}

@Procedure(value = "regression.linear.removeModel", mode = Mode.READ)
public Stream<ModelResult> removeModel(@Name("model") String model) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

return Stream.of(LRModel.removeModel(model));
}

@Procedure(value = "regression.linear.predict", mode = Mode.READ)
public Stream<PredictResult> predict(@Name("mode") String model, @Name("given") double given) {
LRModel lrModel = LRModel.from(model);
return Stream.of(lrModel.predict(given));
}

@Procedure(value = "regression.linear.storeModel", mode = Mode.WRITE)
public Stream<ModelResult> storeModel(@Name("model") String model) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it's easier to just have a function instead that turns the model into a byte[] so the user can decide themselves what to do with it? i.e. where to store it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree. This was my original idea, but I ran into issues because I was trying to return byte[] as part of a stream (this is silly because I should have just written a user function). I will fix this, and also convert predict into a function that returns a single value rather than a stream.

LRModel lrModel = LRModel.from(model);
lrModel.store(db);
return Stream.of(lrModel.asResult());
}

@Procedure(value = "regression.linear.createFromStorage", mode = Mode.READ)
public Stream<ModelResult> createFromStorage(@Name("model") String model) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just consume byte[] and turn into Model

Map<String, Object> parameters = new HashMap<>();
parameters.put("name", model);
Entity modelNode;
SimpleRegression R;
try {
ResourceIterator<Entity> n = db.execute("MATCH (n:LRModel {name:$name}) RETURN " +
"n", parameters).columnAs("n");
modelNode = n.next();
byte[] m = (byte[]) modelNode.getProperty("serializedModel");
R = (SimpleRegression) convertFromBytes(m);
} catch (Exception e) {
throw new RuntimeException("no existing model for specified independent and dependent variables and model ID");
}
return Stream.of(new LRModel(model, R, (String) modelNode.getProperty("state")).asResult());
}

public static class ModelResult {
public final String model;
public final String state;
public final double N;
public final Map<String,Object> info = new HashMap<>();

public ModelResult(String model, LRModel.State state, double N) {
this.model = model;
this.state = state.name();
this.N = N;
}

ModelResult withInfo(Object...infos) {
for (int i = 0; i < infos.length; i+=2) {
info.put(infos[i].toString(),infos[i+1]);
}
return this;
}
}

public static class PredictResult {
public final double prediction;
public PredictResult(double p) {
this.prediction = p;
}
}

//Serializes the object into a byte array for storage
public static byte[] convertToBytes(Object object) throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future proof it would probably be better at some point to serialize using some explicit data from the model.
For now it's ok

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I agree, it's just difficult because you cannot create a new SimpleRegression object using the information it stores (because is doesn't actually store individual data points). Maybe I should consider using a different library?

try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutput out = new ObjectOutputStream(bos)) {
out.writeObject(object);
return bos.toByteArray();
}
}

//de serializes the byte array and returns the stored object
private static Object convertFromBytes(byte[] bytes) throws IOException, ClassNotFoundException {
try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
ObjectInput in = new ObjectInputStream(bis)) {
return in.readObject();
}
}


}
122 changes: 122 additions & 0 deletions src/main/java/regression/LRModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package regression;
import java.util.concurrent.ConcurrentHashMap;
import java.util.*;
import org.apache.commons.math3.stat.regression.SimpleRegression;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.*;
import org.neo4j.procedure.UserAggregationUpdate;
import java.io.*;

public class LRModel {

private static ConcurrentHashMap<String, LRModel> models = new ConcurrentHashMap<>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm torn about having too many "model" storage maps in the plugin, perhaps we should unify this into one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we can leave that for later.

private final String name;
private State state;
SimpleRegression R;


public LRModel(String model) {
if (models.containsKey(model))
throw new IllegalArgumentException("Model " + model + " already exists, please remove it first");
this.name = model;
this.state = State.created;
this.R = new SimpleRegression();
models.put(name, this);
}

public LRModel(String model, SimpleRegression R, String state) {
if (models.containsKey(model))
throw new IllegalArgumentException("Model " + model + " already exists, please remove it first");
this.name = model;
switch(state) {
case "created": this.state = State.created;
case "ready": this.state = State.ready;
case "removed": this.state = State.removed;
case "unknown": this.state = State.unknown;
}
this.R = R;
models.put(name, this);
}

public static LRModel from(String name) {
LRModel model = models.get(name);
if (model != null) return model;
throw new IllegalArgumentException("No valid LR-Model " + name);
}

public void add(double given, double expected) {
R.addData(given, expected);
if (R.getN() > 1) {
this.state = State.ready;
}
}

public LR.PredictResult predict(double given) {
if (this.state == State.ready)
return new LR.PredictResult(R.predict(given));
throw new IllegalArgumentException("Not enough data in model to predict yet");
}


public void removeData(double given, double expected) {
R.removeData(given, expected);
if (R.getN() < 2) {
this.state = State.created;
}
}

public static LR.ModelResult removeModel(String model) {
LRModel existing = models.remove(model);
return new LR.ModelResult(model, existing == null ? State.unknown : State.removed, 0);
}

public void store(GraphDatabaseService db) {
try{ byte[] serializedR = LR.convertToBytes(R);
Map<String, Object> parameters = new HashMap<>();
parameters.put("name", name);
ResourceIterator<Entity> n = db.execute("MERGE (n:LRModel {name:$name}) RETURN n", parameters).columnAs("n");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to embedded API or just return byte[]

Entity modelNode = n.next();
modelNode.setProperty("state", state.name());
modelNode.setProperty("serializedModel", serializedR);}
catch (IOException e) { throw new RuntimeException(name + " cannot be serialized."); }
}

/*protected void initTypes(Map<String, String> types, String output) {
if (!types.containsKey(output)) throw new IllegalArgumentException("Outputs not defined: " + output);
int i = 0;
for (Map.Entry<String, String> entry : types.entrySet()) {
String key = entry.getKey();
this.types.put(key, DataType.from(entry.getValue()));
if (!key.equals(output)) this.offsets.put(key, i++);
}
this.offsets.put(output, i);
}*/


/*enum DataType {
_class, _float, _order;
public static DataType from(String type) {
switch (type.toUpperCase()) {
case "CLASS":
return DataType._class;
case "FLOAT":
return DataType._float;
case "ORDER":
return DataType._order;
default:
throw new IllegalArgumentException("Unknown type: " + type);
}
}
}*/

public enum State {created, ready, removed, unknown}

public LR.ModelResult asResult() {
LR.ModelResult result = new LR.ModelResult(this.name, this.state, R.getN());
return result;
}



}
Loading