diff --git a/flink-tensorflow/pom.xml b/flink-tensorflow/pom.xml index be2512b..a81fc90 100644 --- a/flink-tensorflow/pom.xml +++ b/flink-tensorflow/pom.xml @@ -101,6 +101,11 @@ under the License. tensorflow ${tf.version} + + org.tensorflow + proto + ${tf.version} + com.google.protobuf protobuf-java @@ -176,23 +181,6 @@ under the License. - - org.xolstice.maven.plugins - protobuf-maven-plugin - 0.5.0 - - - - - - - compile - test-compile - - - - - net.alchim31.maven scala-maven-plugin diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/example/example.proto b/flink-tensorflow/src/main/proto/tensorflow/core/example/example.proto deleted file mode 100644 index 956a374..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/example/example.proto +++ /dev/null @@ -1,294 +0,0 @@ -// Protocol messages for describing input data Examples for machine learning -// model training or inference. -syntax = "proto3"; - -import "tensorflow/core/example/feature.proto"; -option cc_enable_arenas = true; -option java_outer_classname = "ExampleProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.example"; - -package tensorflow; - -// An Example is a mostly-normalized data format for storing data for -// training and inference. It contains a key-value store (features); where -// each key (string) maps to a Feature message (which is oneof packed BytesList, -// FloatList, or Int64List). This flexible and compact format allows the -// storage of large amounts of typed data, but requires that the data shape -// and use be determined by the configuration files and parsers that are used to -// read and write this format. That is, the Example is mostly *not* a -// self-describing format. In TensorFlow, Examples are read in row-major -// format, so any configuration that describes data with rank-2 or above -// should keep this in mind. For example, to store an M x N matrix of Bytes, -// the BytesList must contain M*N bytes, with M rows of N contiguous values -// each. That is, the BytesList value must store the matrix as: -// .... row 0 .... .... row 1 .... // ........... // ... row M-1 .... -// -// An Example for a movie recommendation application: -// features { -// feature { -// key: "age" -// value { float_list { -// value: 29.0 -// }} -// } -// feature { -// key: "movie" -// value { bytes_list { -// value: "The Shawshank Redemption" -// value: "Fight Club" -// }} -// } -// feature { -// key: "movie_ratings" -// value { float_list { -// value: 9.0 -// value: 9.7 -// }} -// } -// feature { -// key: "suggestion" -// value { bytes_list { -// value: "Inception" -// }} -// } -// # Note that this feature exists to be used as a label in training. -// # E.g., if training a logistic regression model to predict purchase -// # probability in our learning tool we would set the label feature to -// # "suggestion_purchased". -// feature { -// key: "suggestion_purchased" -// value { float_list { -// value: 1.0 -// }} -// } -// # Similar to "suggestion_purchased" above this feature exists to be used -// # as a label in training. -// # E.g., if training a linear regression model to predict purchase -// # price in our learning tool we would set the label feature to -// # "purchase_price". -// feature { -// key: "purchase_price" -// value { float_list { -// value: 9.99 -// }} -// } -// } -// -// A conformant Example data set obeys the following conventions: -// - If a Feature K exists in one example with data type T, it must be of -// type T in all other examples when present. It may be omitted. -// - The number of instances of Feature K list data may vary across examples, -// depending on the requirements of the model. -// - If a Feature K doesn't exist in an example, a K-specific default will be -// used, if configured. -// - If a Feature K exists in an example but contains no items, the intent -// is considered to be an empty tensor and no default will be used. - -message Example { - Features features = 1; -}; - -// A SequenceExample is an Example representing one or more sequences, and -// some context. The context contains features which apply to the entire -// example. The feature_lists contain a key, value map where each key is -// associated with a repeated set of Features (a FeatureList). -// A FeatureList thus represents the values of a feature identified by its key -// over time / frames. -// -// Below is a SequenceExample for a movie recommendation application recording a -// sequence of ratings by a user. The time-independent features ("locale", -// "age", "favorites") describing the user are part of the context. The sequence -// of movies the user rated are part of the feature_lists. For each movie in the -// sequence we have information on its name and actors and the user's rating. -// This information is recorded in three separate feature_list(s). -// In the example below there are only two movies. All three feature_list(s), -// namely "movie_ratings", "movie_names", and "actors" have a feature value for -// both movies. Note, that "actors" is itself a bytes_list with multiple -// strings per movie. -// -// context: { -// feature: { -// key : "locale" -// value: { -// bytes_list: { -// value: [ "pt_BR" ] -// } -// } -// } -// feature: { -// key : "age" -// value: { -// float_list: { -// value: [ 19.0 ] -// } -// } -// } -// feature: { -// key : "favorites" -// value: { -// bytes_list: { -// value: [ "Majesty Rose", "Savannah Outen", "One Direction" ] -// } -// } -// } -// } -// feature_lists: { -// feature_list: { -// key : "movie_ratings" -// value: { -// feature: { -// float_list: { -// value: [ 4.5 ] -// } -// } -// feature: { -// float_list: { -// value: [ 5.0 ] -// } -// } -// } -// } -// feature_list: { -// key : "movie_names" -// value: { -// feature: { -// bytes_list: { -// value: [ "The Shawshank Redemption" ] -// } -// } -// feature: { -// bytes_list: { -// value: [ "Fight Club" ] -// } -// } -// } -// } -// feature_list: { -// key : "actors" -// value: { -// feature: { -// bytes_list: { -// value: [ "Tim Robbins", "Morgan Freeman" ] -// } -// } -// feature: { -// bytes_list: { -// value: [ "Brad Pitt", "Edward Norton", "Helena Bonham Carter" ] -// } -// } -// } -// } -// } -// -// A conformant SequenceExample data set obeys the following conventions: -// -// Context: -// - All conformant context features K must obey the same conventions as -// a conformant Example's features (see above). -// Feature lists: -// - A FeatureList L may be missing in an example; it is up to the -// parser configuration to determine if this is allowed or considered -// an empty list (zero length). -// - If a FeatureList L exists, it may be empty (zero length). -// - If a FeatureList L is non-empty, all features within the FeatureList -// must have the same data type T. Even across SequenceExamples, the type T -// of the FeatureList identified by the same key must be the same. -// - If a FeatureList L is non-empty, it is up to the parser configuration -// to determine if all features within the FeatureList must -// have the same size. The same holds for this FeatureList across multiple -// examples. -// -// Examples of conformant and non-conformant examples' FeatureLists: -// -// Conformant FeatureLists: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// -// Non-conformant FeatureLists (mismatched types): -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { int64_list: { value: [ 5 ] } } } -// } } -// -// Conditionally conformant FeatureLists, the parser configuration determines -// if the feature sizes must match: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0, 6.0 ] } } } -// } } -// -// Conformant pair of SequenceExample -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// and: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } -// feature: { float_list: { value: [ 2.0 ] } } } -// } } -// -// Conformant pair of SequenceExample -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// and: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { } -// } } -// -// Conditionally conformant pair of SequenceExample, the parser configuration -// determines if the second feature_lists is consistent (zero-length) or -// invalid (missing "movie_ratings"): -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// and: -// feature_lists: { } -// -// Non-conformant pair of SequenceExample (mismatched types) -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// and: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { int64_list: { value: [ 4 ] } } -// feature: { int64_list: { value: [ 5 ] } } -// feature: { int64_list: { value: [ 2 ] } } } -// } } -// -// Conditionally conformant pair of SequenceExample; the parser configuration -// determines if the feature sizes must match: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// and: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.0 ] } } -// feature: { float_list: { value: [ 5.0, 3.0 ] } } -// } } - -message SequenceExample { - Features context = 1; - FeatureLists feature_lists = 2; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/example/feature.proto b/flink-tensorflow/src/main/proto/tensorflow/core/example/feature.proto deleted file mode 100644 index da3dc59..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/example/feature.proto +++ /dev/null @@ -1,105 +0,0 @@ -// Protocol messages for describing features for machine learning model -// training or inference. -// -// There are three base Feature types: -// - bytes -// - float -// - int64 -// -// A Feature contains Lists which may hold zero or more values. These -// lists are the base values BytesList, FloatList, Int64List. -// -// Features are organized into categories by name. The Features message -// contains the mapping from name to Feature. -// -// Example Features for a movie recommendation application: -// feature { -// key: "age" -// value { float_list { -// value: 29.0 -// }} -// } -// feature { -// key: "movie" -// value { bytes_list { -// value: "The Shawshank Redemption" -// value: "Fight Club" -// }} -// } -// feature { -// key: "movie_ratings" -// value { float_list { -// value: 9.0 -// value: 9.7 -// }} -// } -// feature { -// key: "suggestion" -// value { bytes_list { -// value: "Inception" -// }} -// } -// feature { -// key: "suggestion_purchased" -// value { int64_list { -// value: 1 -// }} -// } -// feature { -// key: "purchase_price" -// value { float_list { -// value: 9.99 -// }} -// } -// - -syntax = "proto3"; -option cc_enable_arenas = true; -option java_outer_classname = "FeatureProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.example"; - -package tensorflow; - -// Containers to hold repeated fundamental values. -message BytesList { - repeated bytes value = 1; -} -message FloatList { - repeated float value = 1 [packed = true]; -} -message Int64List { - repeated int64 value = 1 [packed = true]; -} - -// Containers for non-sequential data. -message Feature { - // Each feature can be exactly one kind. - oneof kind { - BytesList bytes_list = 1; - FloatList float_list = 2; - Int64List int64_list = 3; - } -}; - -message Features { - // Map from feature name to feature. - map feature = 1; -}; - -// Containers for sequential data. -// -// A FeatureList contains lists of Features. These may hold zero or more -// Feature values. -// -// FeatureLists are organized into categories by name. The FeatureLists message -// contains the mapping from name to FeatureList. -// -message FeatureList { - repeated Feature feature = 1; -}; - -message FeatureLists { - // Map from feature name to feature list. - map feature_list = 1; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/allocation_description.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/allocation_description.proto deleted file mode 100644 index bb1037c..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/allocation_description.proto +++ /dev/null @@ -1,27 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "AllocationDescriptionProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -message AllocationDescription { - // Total number of bytes requested - int64 requested_bytes = 1; - - // Total number of bytes allocated if known - int64 allocated_bytes = 2; - - // Name of the allocator used - string allocator_name = 3; - - // Identifier of the allocated buffer if known - int64 allocation_id = 4; - - // Set if this tensor only has one remaining reference - bool has_single_reference = 5; - - // Address of the allocation. - uint64 ptr = 6; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/attr_value.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/attr_value.proto deleted file mode 100644 index f115329..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/attr_value.proto +++ /dev/null @@ -1,62 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "AttrValueProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/tensor.proto"; -import "tensorflow/core/framework/tensor_shape.proto"; -import "tensorflow/core/framework/types.proto"; - -// Protocol buffer representing the value for an attr used to configure an Op. -// Comment indicates the corresponding attr type. Only the field matching the -// attr type may be filled. -message AttrValue { - // LINT.IfChange - message ListValue { - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated DataType type = 6 [packed = true]; // "list(type)" - repeated TensorShapeProto shape = 7; // "list(shape)" - repeated TensorProto tensor = 8; // "list(tensor)" - // TODO(zhifengc/josh11b): implements list(func) if needed. - } - // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) - - oneof value { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - DataType type = 6; // "type" - TensorShapeProto shape = 7; // "shape" - TensorProto tensor = 8; // "tensor" - ListValue list = 1; // any "list(...)" - - // "func" represents a function. func.name is a function's name or - // a primitive op's name. func.attr.first is the name of an attr - // defined for that function. func.attr.second is the value for - // that attr in the instantiation. - NameAttrList func = 10; - - // This is a placeholder only used in nodes defined inside a - // function. It indicates the attr value will be supplied when - // the function is instantiated. For example, let us suppose a - // node "N" in function "FN". "N" has an attr "A" with value - // placeholder = "foo". When FN is instantiated with attr "foo" - // set to "bar", the instantiated node N's attr A will have been - // given the value "bar". - string placeholder = 9; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NameAttrList { - string name = 1; - map attr = 2; -} diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/cost_graph.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/cost_graph.proto deleted file mode 100644 index 8145486..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/cost_graph.proto +++ /dev/null @@ -1,59 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "CostGraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/tensor_shape.proto"; -import "tensorflow/core/framework/types.proto"; - -message CostGraphDef { - message Node { - // The name of the node. Names are globally unique. - string name = 1; - - // The device of the node. Can be empty if the node is mapped to the - // default partition or partitioning hasn't been run yet. - string device = 2; - - // The id of the node. Node ids are only unique inside a partition. - int32 id = 3; - - // Inputs of this node. They must be executed before this node can be - // executed. An input is a particular output of another node, specified - // by the node id and the output index. - message InputInfo { - int32 preceding_node = 1; - int32 preceding_port = 2; - } - repeated InputInfo input_info = 4; - - // Outputs of this node. - message OutputInfo { - int64 size = 1; - // If >= 0, the output is an alias of an input. Note that an alias input - // may itself be an alias. The algorithm will therefore need to follow - // those pointers. - int64 alias_input_port = 2; - TensorShapeProto shape = 3; - DataType dtype = 4; - } - repeated OutputInfo output_info = 5; - - // Temporary memory used by this node. - int64 temporary_memory_size = 6; - - // Estimate of the computational cost of this node. - int64 compute_cost = 9; - - // If true, the output is permanent: it can't be discarded, because this - // node is part of the "final output". Nodes may depend on final nodes. - bool is_final = 7; - - // Ids of the control inputs for this node. - repeated int32 control_input = 8; - } - repeated Node node = 1; -} diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/device_attributes.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/device_attributes.proto deleted file mode 100644 index 9983bcb..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/device_attributes.proto +++ /dev/null @@ -1,35 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "DeviceAttributesProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -message DeviceLocality { - // Optional bus locality of device. Default value of 0 means - // no specific locality. Specific localities are indexed from 1. - int32 bus_id = 1; -}; - -message DeviceAttributes { - // Fully specified name of the device within a cluster. - string name = 1; - - // String representation of device_type. - string device_type = 2; - - // Memory capacity of device in bytes. - int64 memory_limit = 4; - - // Platform-specific data about device that may be useful - // for supporting efficient data transfers. - DeviceLocality locality = 5; - - // A device is assigned a global unique number each time it is - // initialized. "incarnation" should never be 0. - fixed64 incarnation = 6; - - // String representation of the physical device that this device maps to. - string physical_device_desc = 7; -} diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/function.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/function.proto deleted file mode 100644 index 5a394d6..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/function.proto +++ /dev/null @@ -1,151 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "FunctionProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/attr_value.proto"; -import "tensorflow/core/framework/node_def.proto"; -import "tensorflow/core/framework/op_def.proto"; - -// A library is a set of named functions. -message FunctionDefLibrary { - repeated FunctionDef function = 1; - repeated GradientDef gradient = 2; -} - -// A function can be instantiated when the runtime can bind every attr -// with a value. When a GraphDef has a call to a function, it must -// have binding for every attr defined in the signature. -// -// TODO(zhifengc): -// * device spec, etc. -message FunctionDef { - // The definition of the function's name, arguments, return values, - // attrs etc. - OpDef signature = 1; - - // Attributes specific to this function definition. - map attr = 5; - - // TO BE REPLACED - - // The body of the function. - repeated Node node = 2; // function.node.ret[*] are unique. - - // A node is a multi-value assignment: - // (ret[0], ret[1], ...) = func(arg[0], arg[1], ...) - // - // By convention, "func" is resolved by consulting with a user-defined - // library first. If not resolved, "func" is assumed to be a builtin op. - message Node { - // This node produces multiple outputs. They are named ret[0], - // ret[1], ..., etc. - // - // REQUIRES: function.node.ret[*] are unique across all nodes. - // REQUIRES: ret.size == func/op def's number of output args. - repeated string ret = 1; - - // The op/function name. - string op = 2; - - // Arguments passed to this func/op. - // - // arg[i] must be either one of - // function.signature.input_args[*].name or one of - // function.node[*].ret[*]. - // - // REQUIRES: arg.size == func/op def's number of input args. - repeated string arg = 3; - - // Control dependencies. - // - // dep[i] must be one of function.node[*].ret[*] or one of - // function.signature.input_args[*].name. - repeated string dep = 4; - - // Attrs. - // - // 'attr' maps names defined by 'func's attr defs to attr values. - // attr values may have placeholders which are substituted - // recursively by concrete values when this node is instantiated. - // These placeholders must name an attr listed in the FunctionDef's - // signature. - map attr = 5; - } - - // WILL REPLACE THE ABOVE - - // If node_def is present, and the consumer is at GraphDef version - // >= 12, then these fields are used and `node` is ignored. If the - // consumer's GraphDef version is < 12 or this field is empty, then - // `node` is used. This allows producers to fill both fields to - // remain compatible with old consumers. At some future GraphDef - // version, `node` will be ignored even if `node_def` is empty. - // TODO(josh11b): Finish this transition. - - // In both of the following fields, there is the need to specify an - // output that is used as either the input to another node (in - // `node_def`) or as a return value of the function (in `ret`). - // Unlike the NodeDefs in GraphDef, we need to be able to specify a - // list in some cases (instead of just single outputs). Also, we - // need to be able to deal with lists of unknown length (so the - // output index may not be known at function definition time). So - // we use the following format instead: - // * "fun_in" where "fun_in" is the name of a function input arg in - // the `signature` field above. This represents that input, whether - // it is a single tensor or a list. - // * "fun_in:0" gives the first element of a function input arg (a - // non-list input is considered a list of length 1 for these - // purposes). - // * "node:out" where "node" is the name of a node in `node_def` and - // "out" is the name one of its op's output arguments (the name - // comes from the OpDef of the node's op). This represents that - // node's output, whether it is a single tensor or a list. - // Note: We enforce that an op's output arguments are never - // renamed in the backwards-compatibility test. - // * "node:out:0" gives the first element of a node output arg (a - // non-list output is considered a list of length 1 for these - // purposes). - // - // NOT CURRENTLY SUPPORTED (but may be in the future): - // * "node:out:-1" gives last element in a node output list - // * "node:out:1:" gives a list with all but the first element in a - // node output list - // * "node:out::-1" gives a list with all but the last element in a - // node output list - - // The body of the function. Unlike the NodeDefs in a GraphDef, attrs - // may have values of type `placeholder` and the `input` field uses - // the "output" format above. - repeated NodeDef node_def = 3; - - // A mapping from the output arg names from `signature` to the - // outputs from `node_def` that should be returned by the function. - map ret = 4; -} - -// GradientDef defines the gradient function of a function defined in -// a function library. -// -// A gradient function g (specified by gradient_func) for a function f -// (specified by function_name) must follow the following: -// -// The function 'f' must be a numerical function which takes N inputs -// and produces M outputs. Its gradient function 'g', which is a -// function taking N + M inputs and produces N outputs. -// -// I.e. if we have -// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), -// then, g is -// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, -// dL/dy1, dL/dy2, ..., dL/dy_M), -// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the -// loss function). dL/dx_i is the partial derivative of L with respect -// to x_i. -message GradientDef { - string function_name = 1; // The function name. - string gradient_func = 2; // The gradient function's name. -} diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/graph.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/graph.proto deleted file mode 100644 index 7d6e16d..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/graph.proto +++ /dev/null @@ -1,56 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "GraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/node_def.proto"; -import "tensorflow/core/framework/function.proto"; -import "tensorflow/core/framework/versions.proto"; - -// Represents the graph of operations -message GraphDef { - repeated NodeDef node = 1; - - // Compatibility versions of the graph. See core/public/version.h for version - // history. The GraphDef version is distinct from the TensorFlow version, and - // each release of TensorFlow will support a range of GraphDef versions. - VersionDef versions = 4; - - // Deprecated single version field; use versions above instead. Since all - // GraphDef changes before "versions" was introduced were forward - // compatible, this field is entirely ignored. - int32 version = 3 [deprecated = true]; - - // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. - // - // "library" provides user-defined functions. - // - // Naming: - // * library.function.name are in a flat namespace. - // NOTE: We may need to change it to be hierarchical to support - // different orgs. E.g., - // { "/google/nn", { ... }}, - // { "/google/vision", { ... }} - // { "/org_foo/module_bar", { ... }} - // map named_lib; - // * If node[i].op is the name of one function in "library", - // node[i] is deemed as a function call. Otherwise, node[i].op - // must be a primitive operation supported by the runtime. - // - // - // Function call semantics: - // - // * The callee may start execution as soon as some of its inputs - // are ready. The caller may want to use Tuple() mechanism to - // ensure all inputs are ready in the same time. - // - // * The consumer of return values may start executing as soon as - // the return values the consumer depends on are ready. The - // consumer may want to use Tuple() mechanism to ensure the - // consumer does not start until all return values of the callee - // function are ready. - FunctionDefLibrary library = 2; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/kernel_def.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/kernel_def.proto deleted file mode 100644 index 65e9ef0..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/kernel_def.proto +++ /dev/null @@ -1,36 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "KernelDefProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/attr_value.proto"; - -message KernelDef { - // Must match the name of an Op. - string op = 1; - - // Type of device this kernel runs on. - string device_type = 2; - - message AttrConstraint { - // Name of an attr from the Op. - string name = 1; - - // A list of values that this kernel supports for this attr. - // Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops. - AttrValue allowed_values = 2; - } - repeated AttrConstraint constraint = 3; - - // Names of the Op's input_/output_args that reside in host memory - // instead of device memory. - repeated string host_memory_arg = 4; - - // This allows experimental kernels to be registered for an op that - // won't be used unless the user specifies a "_kernel" attr with - // value matching this. - string label = 5; -} diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/log_memory.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/log_memory.proto deleted file mode 100644 index d1e1263..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/log_memory.proto +++ /dev/null @@ -1,93 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "LogMemoryProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/tensor_description.proto"; - -message MemoryLogStep { - // Process-unique step id. - int64 step_id = 1; - - // Handle describing the feeds and fetches of the step. - string handle = 2; -}; - -message MemoryLogTensorAllocation { - // Process-unique step id. - int64 step_id = 1; - - // Name of the kernel making the allocation as set in GraphDef, - // e.g., "affine2/weights/Assign". - string kernel_name = 2; - - // Allocated tensor details. - TensorDescription tensor = 3; -}; - -message MemoryLogTensorDeallocation { - // Id of the tensor buffer being deallocated, used to match to a - // corresponding allocation. - int64 allocation_id = 1; - - // Name of the allocator used. - string allocator_name = 2; -}; - -message MemoryLogTensorOutput { - // Process-unique step id. - int64 step_id = 1; - - // Name of the kernel producing an output as set in GraphDef, e.g., - // "affine2/weights/Assign". - string kernel_name = 2; - - // Index of the output being set. - int32 index = 3; - - // Output tensor details. - TensorDescription tensor = 4; -} - -message MemoryLogRawAllocation { - // Process-unique step id. - int64 step_id = 1; - - // Name of the operation making the allocation. - string operation = 2; - - // Number of bytes in the allocation. - int64 num_bytes = 3; - - // Address of the allocation. - uint64 ptr = 4; - - // Id of the tensor buffer being allocated, used to match to a - // corresponding deallocation. - int64 allocation_id = 5; - - // Name of the allocator used. - string allocator_name = 6; -}; - -message MemoryLogRawDeallocation { - // Process-unique step id. - int64 step_id = 1; - - // Name of the operation making the deallocation. - string operation = 2; - - // Id of the tensor buffer being deallocated, used to match to a - // corresponding allocation. - int64 allocation_id = 3; - - // Name of the allocator used. - string allocator_name = 4; - - // True if the deallocation is queued and will be performed later, - // e.g. for GPU lazy freeing of buffers. - bool deferred = 5; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/node_def.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/node_def.proto deleted file mode 100644 index 8d38115..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/node_def.proto +++ /dev/null @@ -1,65 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "NodeProto"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/attr_value.proto"; - -message NodeDef { - // The name given to this operator. Used for naming inputs, - // logging, visualization, etc. Unique within a single GraphDef. - // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". - string name = 1; - - // The operation name. There may be custom parameters in attrs. - // Op names starting with an underscore are reserved for internal use. - string op = 2; - - // Each input is "node:src_output" with "node" being a string name and - // "src_output" indicating which output tensor to use from "node". If - // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs - // may optionally be followed by control inputs that have the format - // "^node". - repeated string input = 3; - - // A (possibly partial) specification for the device on which this - // node should be placed. - // The expected syntax for this string is as follows: - // - // DEVICE_SPEC ::= COLOCATED_NODE | PARTIAL_SPEC - // - // COLOCATED_NODE ::= "@" NODE_NAME // See NodeDef.name above. - // PARTIAL_SPEC ::= ("/" CONSTRAINT) * - // CONSTRAINT ::= ("job:" JOB_NAME) - // | ("replica:" [1-9][0-9]*) - // | ("task:" [1-9][0-9]*) - // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) - // - // Valid values for this string include: - // * "@other/node" (colocate with "other/node") - // * "/job:worker/replica:0/task:1/gpu:3" (full specification) - // * "/job:worker/gpu:3" (partial specification) - // * "" (no specification) - // - // If the constraints do not resolve to a single device (or if this - // field is empty or not present), the runtime will attempt to - // choose a device automatically. - string device = 4; - - // Operation-specific graph-construction-time configuration. - // Note that this should include all attrs defined in the - // corresponding OpDef, including those with a value matching - // the default -- this allows the default to change and makes - // NodeDefs easier to interpret on their own. However, if - // an attr with a default is not specified in this list, the - // default will be used. - // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and - // one of the names from the corresponding OpDef's attr field). - // The values must have a type matching the corresponding OpDef - // attr's type field. - // TODO(josh11b): Add some examples here showing best practices. - map attr = 5; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/op_def.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/op_def.proto deleted file mode 100644 index acb480e..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/op_def.proto +++ /dev/null @@ -1,157 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "OpDefProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/attr_value.proto"; -import "tensorflow/core/framework/types.proto"; - -// Defines an operation. A NodeDef in a GraphDef specifies an Op by -// using the "op" field which should match the name of a OpDef. -message OpDef { - // Op names starting with an underscore are reserved for internal use. - // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". - string name = 1; - - // For describing inputs and outputs. - message ArgDef { - // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". - string name = 1; - - // Human readable description. - string description = 2; - - // Describes the type of one or more tensors that are accepted/produced - // by this input/output arg. The only legal combinations are: - // * For a single tensor: either the "type" field is set or the - // "type_attr" field is set to the name of an attr with type "type". - // * For a sequence of tensors with the same type: the "number_attr" - // field will be set to the name of an attr with type "int", and - // either the "type" or "type_attr" field will be set as for - // single tensors. - // * For a sequence of tensors, the "type_list_attr" field will be set - // to the name of an attr with type "list(type)". - DataType type = 3; - string type_attr = 4; // if specified, attr must have type "type" - string number_attr = 5; // if specified, attr must have type "int" - // If specified, attr must have type "list(type)", and none of - // type, type_attr, and number_attr may be specified. - string type_list_attr = 6; - - // For inputs: if true, the inputs are required to be refs. - // By default, inputs can be either refs or non-refs. - // For outputs: if true, outputs are refs, otherwise they are not. - bool is_ref = 16; - }; - - // Description of the input(s). - repeated ArgDef input_arg = 2; - - // Description of the output(s). - repeated ArgDef output_arg = 3; - - // Description of the graph-construction-time configuration of this - // Op. That is to say, this describes the attr fields that will - // be specified in the NodeDef. - message AttrDef { - // A descriptive name for the argument. May be used, e.g. by the - // Python client, as a keyword argument name, and so should match - // the regexp "[a-z][a-z0-9_]+". - string name = 1; - - // One of the type names from attr_value.proto ("string", "list(string)", - // "int", etc.). - string type = 2; - - // A reasonable default for this attribute if the user does not supply - // a value. If not specified, the user must supply a value. - AttrValue default_value = 3; - - // Human-readable description. - string description = 4; - - // TODO(josh11b): bool is_optional? - - // --- Constraints --- - // These constraints are only in effect if specified. Default is no - // constraints. - - // For type == "int", this is a minimum value. For "list(___)" - // types, this is the minimum length. - bool has_minimum = 5; - int64 minimum = 6; - - // The set of allowed values. Has type that is the "list" version - // of the "type" field above (uses the "list" field of AttrValue). - // If type == "type" or "list(type)" above, then the "type" field - // of "allowed_values.list" has the set of allowed DataTypes. - // If type == "string" or "list(string)", then the "s" field of - // "allowed_values.list" has the set of allowed strings. - AttrValue allowed_values = 7; - } - repeated AttrDef attr = 4; - - // Optional deprecation based on GraphDef versions. - OpDeprecation deprecation = 8; - - // One-line human-readable description of what the Op does. - string summary = 5; - - // Additional, longer human-readable description of what the Op does. - string description = 6; - - // ------------------------------------------------------------------------- - // Which optimizations this operation can participate in. - - // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) - bool is_commutative = 18; - - // If is_aggregate is true, then this operation accepts N >= 2 - // inputs and produces 1 output all of the same type. Should be - // associative and commutative, and produce output with the same - // shape as the input. The optimizer may replace an aggregate op - // taking input from multiple devices with a tree of aggregate ops - // that aggregate locally within each device (and possibly within - // groups of nearby devices) before communicating. - // TODO(josh11b): Implement that optimization. - bool is_aggregate = 16; // for things like add - - // Other optimizations go here, like - // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. - - // ------------------------------------------------------------------------- - // Optimization constraints. - - // By default Ops may be moved between devices. Stateful ops should - // either not be moved, or should only be moved if that state can also - // be moved (e.g. via some sort of save / restore). - // Stateful ops are guaranteed to never be optimized away by Common - // Subexpression Elimination (CSE). - bool is_stateful = 17; // for things like variables, queue - - // ------------------------------------------------------------------------- - // Non-standard options. - - // By default, all inputs to an Op must be initialized Tensors. Ops - // that may initialize tensors for the first time should set this - // field to true, to allow the Op to take an uninitialized Tensor as - // input. - bool allows_uninitialized_input = 19; // for Assign, etc. -}; - -// Information about version-dependent deprecation of an op -message OpDeprecation { - // First GraphDef version at which the op is disallowed. - int32 version = 1; - - // Explanation of why it was deprecated and what to use instead. - string explanation = 2; -}; - -// A collection of OpDefs -message OpList { - repeated OpDef op = 1; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/resource_handle.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/resource_handle.proto deleted file mode 100644 index f9f19ca..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/resource_handle.proto +++ /dev/null @@ -1,29 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "ResourceHandleProto"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Protocol buffer representing a handle to a tensorflow resource. Handles are -// not valid across executions, but can be serialized back and forth from within -// a single run. -message ResourceHandle { - // Unique name for the device containing the resource. - string device = 1; - - // Container in which this resource is placed. - string container = 2; - - // Unique name of this resource. - string name = 3; - - // Hash code for the type of the resource. Is only valid in the same device - // and in the same execution. - uint64 hash_code = 4; - - // For debug-only, the name of the type pointed to by this handle, if - // available. - string maybe_type_name = 5; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/step_stats.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/step_stats.proto deleted file mode 100644 index 4488f98..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/step_stats.proto +++ /dev/null @@ -1,54 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "StepStatsProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/allocation_description.proto"; -import "tensorflow/core/framework/tensor_description.proto"; - -// TODO(tucker): The next 4 message defs are very similar to -// the *LogEntry messages in profile.proto. They should be -// unified in one place. - -message AllocatorMemoryUsed { - string allocator_name = 1; - int64 total_bytes = 2; - int64 peak_bytes = 3; -} - -// Output sizes recorded for a single execution of a graph node. -message NodeOutput { - int32 slot = 1; - TensorDescription tensor_description = 3; -}; - -// Time/size stats recorded for a single execution of a graph node. -message NodeExecStats { - // TODO(tucker): Use some more compact form of node identity than - // the full string name. Either all processes should agree on a - // global id (cost_id?) for each node, or we should use a hash of - // the name. - string node_name = 1; - int64 all_start_micros = 2; - int64 op_start_rel_micros = 3; - int64 op_end_rel_micros = 4; - int64 all_end_rel_micros = 5; - repeated AllocatorMemoryUsed memory = 6; - repeated NodeOutput output = 7; - string timeline_label = 8; - int64 scheduled_micros = 9; - uint32 thread_id = 10; - repeated AllocationDescription referenced_tensor = 11; -}; - -message DeviceStepStats { - string device = 1; - repeated NodeExecStats node_stats = 2; -} - -message StepStats { - repeated DeviceStepStats dev_stats = 1; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/summary.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/summary.proto deleted file mode 100644 index 3560b96..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/summary.proto +++ /dev/null @@ -1,103 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "SummaryProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/tensor.proto"; - -// Metadata associated with a series of Summary data -message SummaryDescription { - // Hint on how plugins should process the data in this series. - // Supported values include "scalar", "histogram", "image", "audio" - string type_hint = 1; -} - -// Serialization format for histogram module in -// core/lib/histogram/histogram.h -message HistogramProto { - double min = 1; - double max = 2; - double num = 3; - double sum = 4; - double sum_squares = 5; - - // Parallel arrays encoding the bucket boundaries and the bucket values. - // bucket(i) is the count for the bucket i. The range for - // a bucket is: - // i == 0: -DBL_MAX .. bucket_limit(0) - // i != 0: bucket_limit(i-1) .. bucket_limit(i) - repeated double bucket_limit = 6 [packed = true]; - repeated double bucket = 7 [packed = true]; -}; - -// A Summary is a set of named values to be displayed by the -// visualizer. -// -// Summaries are produced regularly during training, as controlled by -// the "summary_interval_secs" attribute of the training operation. -// Summaries are also produced at the end of an evaluation. -message Summary { - message Image { - // Dimensions of the image. - int32 height = 1; - int32 width = 2; - // Valid colorspace values are - // 1 - grayscale - // 2 - grayscale + alpha - // 3 - RGB - // 4 - RGBA - // 5 - DIGITAL_YUV - // 6 - BGRA - int32 colorspace = 3; - // Image data in encoded format. All image formats supported by - // image_codec::CoderUtil can be stored here. - bytes encoded_image_string = 4; - } - - message Audio { - // Sample rate of the audio in Hz. - float sample_rate = 1; - // Number of channels of audio. - int64 num_channels = 2; - // Length of the audio in frames (samples per channel). - int64 length_frames = 3; - // Encoded audio data and its associated RFC 2045 content type (e.g. - // "audio/wav"). - bytes encoded_audio_string = 4; - string content_type = 5; - } - - message Value { - // Name of the node that output this summary; in general, the name of a - // TensorSummary node. If the node in question has multiple outputs, then - // a ":\d+" suffix will be appended, like "some_op:13". - // Might not be set for legacy summaries (i.e. those not using the tensor - // value field) - string node_name = 7; - - // Tag name for the data. Will only be used by legacy summaries - // (ie. those not using the tensor value field) - // For legacy summaries, will be used as the title of the graph - // in the visualizer. - // - // Tag is usually "op_name:value_name", where "op_name" itself can have - // structure to indicate grouping. - string tag = 1; - - // Value associated with the tag. - oneof value { - float simple_value = 2; - bytes obsolete_old_style_histogram = 3; - Image image = 4; - HistogramProto histo = 5; - Audio audio = 6; - TensorProto tensor = 8; - } - } - - // Set of values for the summary. - repeated Value value = 1; -} diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor.proto deleted file mode 100644 index 5d383bc..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor.proto +++ /dev/null @@ -1,75 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TensorProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/resource_handle.proto"; -import "tensorflow/core/framework/tensor_shape.proto"; -import "tensorflow/core/framework/types.proto"; - -// Protocol buffer representing a tensor. -message TensorProto { - DataType dtype = 1; - - // Shape of the tensor. TODO(touts): sort out the 0-rank issues. - TensorShapeProto tensor_shape = 2; - - // Only one of the representations below is set, one of "tensor_contents" and - // the "xxx_val" attributes. We are not using oneof because as oneofs cannot - // contain repeated fields it would require another extra set of messages. - - // Version number. - // - // In version 0, if the "repeated xxx" representations contain only one - // element, that element is repeated to fill the shape. This makes it easy - // to represent a constant Tensor with a single value. - int32 version_number = 3; - - // Serialized raw tensor content from either Tensor::AsProtoTensorContent or - // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation - // can be used for all tensor types. The purpose of this representation is to - // reduce serialization overhead during RPC call by avoiding serialization of - // many repeated small items. - bytes tensor_content = 4; - - // Type specific representations that make it easy to create tensor protos in - // all languages. Only the representation corresponding to "dtype" can - // be set. The values hold the flattened representation of the tensor in - // row major order. - - // DT_HALF. Note that since protobuf has no int16 type, we'll have some - // pointless zero padding for each value here. - repeated int32 half_val = 13 [packed = true]; - - // DT_FLOAT. - repeated float float_val = 5 [packed = true]; - - // DT_DOUBLE. - repeated double double_val = 6 [packed = true]; - - // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. - repeated int32 int_val = 7 [packed = true]; - - // DT_STRING - repeated bytes string_val = 8; - - // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real - // and imaginary parts of i-th single precision complex. - repeated float scomplex_val = 9 [packed = true]; - - // DT_INT64 - repeated int64 int64_val = 10 [packed = true]; - - // DT_BOOL - repeated bool bool_val = 11 [packed = true]; - - // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real - // and imaginary parts of i-th double precision complex. - repeated double dcomplex_val = 12 [packed = true]; - - // DT_RESOURCE - repeated ResourceHandle resource_handle_val = 14; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor_description.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor_description.proto deleted file mode 100644 index 6ac3c1b..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor_description.proto +++ /dev/null @@ -1,22 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TensorDescriptionProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/framework/types.proto"; -import "tensorflow/core/framework/tensor_shape.proto"; -import "tensorflow/core/framework/allocation_description.proto"; - -message TensorDescription { - // Data type of tensor elements - DataType dtype = 1; - - // Shape of the tensor. - TensorShapeProto shape = 2; - - // Information about the size and allocator used for the data - AllocationDescription allocation_description = 4; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor_shape.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor_shape.proto deleted file mode 100644 index 1ec3c53..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor_shape.proto +++ /dev/null @@ -1,45 +0,0 @@ -// Protocol buffer representing the shape of tensors. - -syntax = "proto3"; -option cc_enable_arenas = true; -option java_outer_classname = "TensorShapeProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -package tensorflow; - -// Dimensions of a tensor. -message TensorShapeProto { - // One dimension of the tensor. - message Dim { - // Size of the tensor in that dimension. - // This value must be >= -1, but values of -1 are reserved for "unknown" - // shapes (values of -1 mean "unknown" dimension). Certain wrappers - // that work with TensorShapeProto may fail at runtime when deserializing - // a TensorShapeProto containing a dim value of -1. - int64 size = 1; - - // Optional name of the tensor dimension. - string name = 2; - }; - - // Dimensions of the tensor, such as {"input", 30}, {"output", 40} - // for a 30 x 40 2D tensor. If an entry has size -1, this - // corresponds to a dimension of unknown size. The names are - // optional. - // - // The order of entries in "dim" matters: It indicates the layout of the - // values in the tensor in-memory representation. - // - // The first entry in "dim" is the outermost dimension used to layout the - // values, the last entry is the innermost dimension. This matches the - // in-memory layout of RowMajor Eigen tensors. - // - // If "dim.size()" > 0, "unknown_rank" must be false. - repeated Dim dim = 2; - - // If true, the number of dimensions in the shape is unknown. - // - // If true, "dim.size()" must be 0. - bool unknown_rank = 3; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor_slice.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor_slice.proto deleted file mode 100644 index 24b0166..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/tensor_slice.proto +++ /dev/null @@ -1,37 +0,0 @@ -// Protocol buffer representing slices of a tensor - -syntax = "proto3"; -option cc_enable_arenas = true; -option java_outer_classname = "TensorSliceProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -package tensorflow; - -// Can only be interpreted if you know the corresponding TensorShape. -message TensorSliceProto { - // Extent of the slice in one dimension. - message Extent { - // Either both or no attributes must be set. When no attribute is set - // means: All data in that dimension. - - // Start index of the slice, starting at 0. - int64 start = 1; - - // Length of the slice: if the length is missing or -1 we will - // interpret this as "everything in this dimension". We use - // "oneof" to preserve information about whether the length is - // present without changing the serialization format from the - // prior proto2 version of this proto. - oneof has_length { - int64 length = 2; - } - }; - - // Extent of the slice in all tensor dimensions. - // - // Must have one entry for each of the dimension of the tensor that this - // slice belongs to. The order of sizes is the same as the order of - // dimensions in the TensorShape. - repeated Extent extent = 1; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/types.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/types.proto deleted file mode 100644 index b80e2b3..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/types.proto +++ /dev/null @@ -1,64 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TypesProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// LINT.IfChange -enum DataType { - // Not a legal value for DataType. Used to indicate a DataType field - // has not been set. - DT_INVALID = 0; - - // Data types that all computation devices are expected to be - // capable to support. - DT_FLOAT = 1; - DT_DOUBLE = 2; - DT_INT32 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_INT8 = 6; - DT_STRING = 7; - DT_COMPLEX64 = 8; // Single-precision complex - DT_INT64 = 9; - DT_BOOL = 10; - DT_QINT8 = 11; // Quantized int8 - DT_QUINT8 = 12; // Quantized uint8 - DT_QINT32 = 13; // Quantized int32 - DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. - DT_QINT16 = 15; // Quantized int16 - DT_QUINT16 = 16; // Quantized uint16 - DT_UINT16 = 17; - DT_COMPLEX128 = 18; // Double-precision complex - DT_HALF = 19; - DT_RESOURCE = 20; - - // TODO(josh11b): DT_GENERIC_PROTO = ??; - // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? - - // Do not use! These are only for parameters. Every enum above - // should have a corresponding value below (verified by types_test). - DT_FLOAT_REF = 101; - DT_DOUBLE_REF = 102; - DT_INT32_REF = 103; - DT_UINT8_REF = 104; - DT_INT16_REF = 105; - DT_INT8_REF = 106; - DT_STRING_REF = 107; - DT_COMPLEX64_REF = 108; - DT_INT64_REF = 109; - DT_BOOL_REF = 110; - DT_QINT8_REF = 111; - DT_QUINT8_REF = 112; - DT_QINT32_REF = 113; - DT_BFLOAT16_REF = 114; - DT_QINT16_REF = 115; - DT_QUINT16_REF = 116; - DT_UINT16_REF = 117; - DT_COMPLEX128_REF = 118; - DT_HALF_REF = 119; - DT_RESOURCE_REF = 120; -} -// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/variable.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/variable.proto deleted file mode 100644 index e793f5a..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/variable.proto +++ /dev/null @@ -1,33 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "VariableProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Protocol buffer representing a Variable. -message VariableDef { - // Name of the variable tensor. - string variable_name = 1; - - // Name of the initializer op. - string initializer_name = 2; - - // Name of the snapshot tensor. - string snapshot_name = 3; - - // Support for saving variables as slices of a larger variable. - SaveSliceInfoDef save_slice_info_def = 4; -} - -message SaveSliceInfoDef { - // Name of the full variable of which this is a slice. - string full_name = 1; - // Shape of the full variable. - repeated int64 full_shape = 2; - // Offset of this variable into the full variable. - repeated int64 var_offset = 3; - // Shape of this variable. - repeated int64 var_shape = 4; -} diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/framework/versions.proto b/flink-tensorflow/src/main/proto/tensorflow/core/framework/versions.proto deleted file mode 100644 index 7d5e58a..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/framework/versions.proto +++ /dev/null @@ -1,31 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "VersionsProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Version information for a piece of serialized data -// -// There are different types of versions for each type of data -// (GraphDef, etc.), but they all have the same common shape -// described here. -// -// Each consumer has "consumer" and "min_producer" versions (specified -// elsewhere). A consumer is allowed to consume this data if -// -// producer >= min_producer -// consumer >= min_consumer -// consumer not in bad_consumers -// -message VersionDef { - // The version of the code that produced this data. - int32 producer = 1; - - // Any consumer below this version is not allowed to consume this data. - int32 min_consumer = 2; - - // Specific consumer versions which are disallowed (e.g. due to bugs). - repeated int32 bad_consumers = 3; -}; diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/protobuf/meta_graph.proto b/flink-tensorflow/src/main/proto/tensorflow/core/protobuf/meta_graph.proto deleted file mode 100644 index 5b20223..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/protobuf/meta_graph.proto +++ /dev/null @@ -1,292 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "MetaGraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "google/protobuf/any.proto"; - -import "tensorflow/core/framework/graph.proto"; -import "tensorflow/core/framework/op_def.proto"; -import "tensorflow/core/framework/tensor_shape.proto"; -import "tensorflow/core/framework/types.proto"; -import "tensorflow/core/protobuf/saver.proto"; - -// NOTE: This protocol buffer is evolving, and will go through revisions in the -// coming months. -// -// Protocol buffer containing the following which are necessary to restart -// training, run inference. It can be used to serialize/de-serialize memory -// objects necessary for running computation in a graph when crossing the -// process boundary. It can be used for long term storage of graphs, -// cross-language execution of graphs, etc. -// MetaInfoDef -// GraphDef -// SaverDef -// CollectionDef -// TensorInfo -// SignatureDef -message MetaGraphDef { - // Meta information regarding the graph to be exported. To be used by users - // of this protocol buffer to encode information regarding their meta graph. - message MetaInfoDef { - // User specified Version string. Can be the name of the model and revision, - // steps this model has been trained to, etc. - string meta_graph_version = 1; - - // A copy of the OpDefs used by the producer of this graph_def. - // Descriptions and Ops not used in graph_def are stripped out. - OpList stripped_op_list = 2; - - // A serialized protobuf. Can be the time this meta graph is created, or - // modified, or name of the model. - google.protobuf.Any any_info = 3; - - // User supplied tag(s) on the meta_graph and included graph_def. - // - // MetaGraphDefs should be tagged with their capabilities or use-cases. - // Examples: "train", "serve", "gpu", "tpu", etc. - // These tags enable loaders to access the MetaGraph(s) appropriate for a - // specific use-case or runtime environment. - repeated string tags = 4; - - // The __version__ string of the tensorflow build used to write this graph. - // This will be populated by the framework, which will overwrite any user - // supplied value. - string tensorflow_version = 5; - - // The __git_version__ string of the tensorflow build used to write this - // graph. This will be populated by the framework, which will overwrite any - // user supplied value. - string tensorflow_git_version = 6; - } - MetaInfoDef meta_info_def = 1; - - // GraphDef. - GraphDef graph_def = 2; - - // SaverDef. - SaverDef saver_def = 3; - - // collection_def: Map from collection name to collections. - // See CollectionDef section for details. - map collection_def = 4; - - // signature_def: Map from user supplied key for a signature to a single - // SignatureDef. - map signature_def = 5; - - // Asset file def to be used with the defined graph. - repeated AssetFileDef asset_file_def = 6; -} - -// CollectionDef should cover most collections. -// To add a user-defined collection, do one of the following: -// 1. For simple data types, such as string, int, float: -// tf.add_to_collection("your_collection_name", your_simple_value) -// strings will be stored as bytes_list. -// -// 2. For Protobuf types, there are three ways to add them: -// 1) tf.add_to_collection("your_collection_name", -// your_proto.SerializeToString()) -// -// collection_def { -// key: "user_defined_bytes_collection" -// value { -// bytes_list { -// value: "queue_name: \"test_queue\"\n" -// } -// } -// } -// -// or -// -// 2) tf.add_to_collection("your_collection_name", str(your_proto)) -// -// collection_def { -// key: "user_defined_string_collection" -// value { -// bytes_list { -// value: "\n\ntest_queue" -// } -// } -// } -// -// or -// -// 3) any_buf = any_pb2.Any() -// tf.add_to_collection("your_collection_name", -// any_buf.Pack(your_proto)) -// -// collection_def { -// key: "user_defined_any_collection" -// value { -// any_list { -// value { -// type_url: "type.googleapis.com/tensorflow.QueueRunnerDef" -// value: "\n\ntest_queue" -// } -// } -// } -// } -// -// 3. For Python objects, implement to_proto() and from_proto(), and register -// them in the following manner: -// ops.register_proto_function("your_collection_name", -// proto_type, -// to_proto=YourPythonObject.to_proto, -// from_proto=YourPythonObject.from_proto) -// These functions will be invoked to serialize and de-serialize the -// collection. For example, -// ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES, -// proto_type=variable_pb2.VariableDef, -// to_proto=Variable.to_proto, -// from_proto=Variable.from_proto) -message CollectionDef { - // NodeList is used for collecting nodes in graph. For example - // collection_def { - // key: "summaries" - // value { - // node_list { - // value: "input_producer/ScalarSummary:0" - // value: "shuffle_batch/ScalarSummary:0" - // value: "ImageSummary:0" - // } - // } - message NodeList { - repeated string value = 1; - } - - // BytesList is used for collecting strings and serialized protobufs. For - // example: - // collection_def { - // key: "trainable_variables" - // value { - // bytes_list { - // value: "\n\017conv1/weights:0\022\024conv1/weights/Assign - // \032\024conv1/weights/read:0" - // value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032 - // \023conv1/biases/read:0" - // } - // } - // } - message BytesList { - repeated bytes value = 1; - } - - // Int64List is used for collecting int, int64 and long values. - message Int64List { - repeated int64 value = 1 [packed = true]; - } - - // FloatList is used for collecting float values. - message FloatList { - repeated float value = 1 [packed = true]; - } - - // AnyList is used for collecting Any protos. - message AnyList { - repeated google.protobuf.Any value = 1; - } - - oneof kind { - NodeList node_list = 1; - BytesList bytes_list = 2; - Int64List int64_list = 3; - FloatList float_list = 4; - AnyList any_list = 5; - } -} - -// Information about a Tensor necessary for feeding or retrieval. -message TensorInfo { - string name = 1; - DataType dtype = 2; - TensorShapeProto tensor_shape = 3; -} - -// SignatureDef defines the signature of a computation supported by a TensorFlow -// graph. -// -// For example, a model with two loss computations, sharing a single input, -// might have the following signature_def map. -// -// Note that across the two SignatureDefs "loss_A" and "loss_B", the input key, -// output key, and method_name are identical, and will be used by system(s) that -// implement or rely upon this particular loss method. The output tensor names -// differ, demonstrating how different outputs can exist for the same method. -// -// signature_def { -// key: "loss_A" -// value { -// inputs { -// key: "input" -// value { -// name: "input:0" -// dtype: DT_STRING -// tensor_shape: ... -// } -// } -// outputs { -// key: "loss_output" -// value { -// name: "loss_output_A:0" -// dtype: DT_FLOAT -// tensor_shape: ... -// } -// } -// } -// ... -// method_name: "some/package/compute_loss" -// } -// signature_def { -// key: "loss_B" -// value { -// inputs { -// key: "input" -// value { -// name: "input:0" -// dtype: DT_STRING -// tensor_shape: ... -// } -// } -// outputs { -// key: "loss_output" -// value { -// name: "loss_output_B:0" -// dtype: DT_FLOAT -// tensor_shape: ... -// } -// } -// } -// ... -// method_name: "some/package/compute_loss" -// } -message SignatureDef { - // Named input parameters. - map inputs = 1; - // Named output parameters. - map outputs = 2; - // Extensible method_name information enabling third-party users to mark a - // SignatureDef as supporting a particular method. This enables producers and - // consumers of SignatureDefs, e.g. a model definition library and a serving - // library to have a clear hand-off regarding the semantics of a computation. - // - // Note that multiple SignatureDefs in a single MetaGraphDef may have the same - // method_name. This is commonly used to support multi-headed computation, - // where a single graph computation may return multiple results. - string method_name = 3; -} - -// An asset file def for a single file or a set of sharded files with the same -// name. -message AssetFileDef { - // The tensor to bind the asset filename to. - TensorInfo tensor_info = 1; - // The filename within an assets directory. Note: does not include the path - // prefix, i.e. directories. For an asset at /tmp/path/vocab.txt, the filename - // would be "vocab.txt". - string filename = 2; -} diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/protobuf/saved_model.proto b/flink-tensorflow/src/main/proto/tensorflow/core/protobuf/saved_model.proto deleted file mode 100644 index c2595dd..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/protobuf/saved_model.proto +++ /dev/null @@ -1,21 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "SavedModelProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensorflow/core/protobuf/meta_graph.proto"; - -// SavedModel is the high level serialization format for TensorFlow Models. -// See [todo: doc links, similar to session_bundle] for more information. -message SavedModel { - // The schema version of the SavedModel instance. Used for versioning when - // making future changes to the specification/implementation. Initial value - // at release will be 1. - int64 saved_model_schema_version = 1; - - // One or more MetaGraphs. - repeated MetaGraphDef meta_graphs = 2; -} diff --git a/flink-tensorflow/src/main/proto/tensorflow/core/protobuf/saver.proto b/flink-tensorflow/src/main/proto/tensorflow/core/protobuf/saver.proto deleted file mode 100644 index 65fe9c4..0000000 --- a/flink-tensorflow/src/main/proto/tensorflow/core/protobuf/saver.proto +++ /dev/null @@ -1,46 +0,0 @@ -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "SaverProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.util"; - -// Protocol buffer representing the configuration of a Saver. -message SaverDef { - // The name of the tensor in which to specify the filename when saving or - // restoring a model checkpoint. - string filename_tensor_name = 1; - - // The operation to run when saving a model checkpoint. - string save_tensor_name = 2; - - // The operation to run when restoring a model checkpoint. - string restore_op_name = 3; - - // Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. - int32 max_to_keep = 4; - - // Shard the save files, one per device that has Variable nodes. - bool sharded = 5; - - // How often to keep an additional checkpoint. If not specified, only the last - // "max_to_keep" checkpoints are kept; if specified, in addition to keeping - // the last "max_to_keep" checkpoints, an additional checkpoint will be kept - // for every n hours of training. - float keep_checkpoint_every_n_hours = 6; - - // A version number that identifies a different on-disk checkpoint format. - // Usually, each subclass of BaseSaverBuilder works with a particular - // version/format. However, it is possible that the same builder may be - // upgraded to support a newer checkpoint format in the future. - enum CheckpointFormatVersion { - // Internal legacy format. - LEGACY = 0; - // Current format: tf.Saver() which works with tensorflow::table::Table. - V1 = 1; - // Experimental format under development. - V2 = 2; - } - CheckpointFormatVersion version = 7; -} diff --git a/flink-tensorflow/src/main/scala/org/apache/flink/contrib/tensorflow/ml/signatures/PredictionMethod.scala b/flink-tensorflow/src/main/scala/org/apache/flink/contrib/tensorflow/ml/signatures/PredictionMethod.scala new file mode 100644 index 0000000..2a78894 --- /dev/null +++ b/flink-tensorflow/src/main/scala/org/apache/flink/contrib/tensorflow/ml/signatures/PredictionMethod.scala @@ -0,0 +1,29 @@ +package org.apache.flink.contrib.tensorflow.ml.signatures + +import org.apache.flink.contrib.tensorflow.graphs.GraphMethod +import org.apache.flink.contrib.tensorflow.models.savedmodel.SignatureConstants._ +import org.tensorflow.Tensor + +/** + * The standard prediction signature. + * + * See https://github.com/tensorflow/serving/blob/master/tensorflow_serving/servables/tensorflow/predict_impl.cc + */ +sealed trait PredictionMethod extends GraphMethod { + val name = PREDICT_METHOD_NAME + override type Input = Tensor + override type Output = Tensor +} + +object PredictionMethod { + + /** + * For each input tensor, output the score for each possible classification + */ + implicit val impl = new PredictionMethod { + type Result = Tensor + def inputs(i: Tensor): Map[String, Tensor] = Map(PREDICT_INPUTS -> i) + def outputs(o: Map[String, Tensor]): Tensor = o(PREDICT_OUTPUTS) + } +} + diff --git a/flink-tensorflow/src/main/scala/org/apache/flink/contrib/tensorflow/util/RegistrationUtils.java b/flink-tensorflow/src/main/scala/org/apache/flink/contrib/tensorflow/util/RegistrationUtils.java index 77e3b25..987b781 100644 --- a/flink-tensorflow/src/main/scala/org/apache/flink/contrib/tensorflow/util/RegistrationUtils.java +++ b/flink-tensorflow/src/main/scala/org/apache/flink/contrib/tensorflow/util/RegistrationUtils.java @@ -70,7 +70,7 @@ public static void registerTypes(ExecutionConfig config) { config.registerTypeWithKryoSerializer(MemoryLogRawAllocation.class, ProtobufSerializer.class); config.registerTypeWithKryoSerializer(SummaryDescription.class, ProtobufSerializer.class); config.registerTypeWithKryoSerializer(Summary.Value.class, ProtobufSerializer.class); - config.registerTypeWithKryoSerializer(FunctionDef.Node.class, ProtobufSerializer.class); +// config.registerTypeWithKryoSerializer(FunctionDef.Node.class, ProtobufSerializer.class); config.registerTypeWithKryoSerializer(DeviceStepStats.class, ProtobufSerializer.class); config.registerTypeWithKryoSerializer(MemoryLogTensorOutput.class, ProtobufSerializer.class); config.registerTypeWithKryoSerializer(AttrValue.class, ProtobufSerializer.class); diff --git a/flink-tensorflow/src/test/scala/org/apache/flink/contrib/tensorflow/ml/MNIST_dense.scala b/flink-tensorflow/src/test/scala/org/apache/flink/contrib/tensorflow/ml/MNIST_dense.scala new file mode 100644 index 0000000..f144b9d --- /dev/null +++ b/flink-tensorflow/src/test/scala/org/apache/flink/contrib/tensorflow/ml/MNIST_dense.scala @@ -0,0 +1,23 @@ +package org.apache.flink.contrib.tensorflow.ml + +import org.apache.flink.contrib.tensorflow.ml.signatures.PredictionMethod +import org.apache.flink.contrib.tensorflow.models.ModelFunction +import org.apache.flink.contrib.tensorflow.models.savedmodel.TensorFlowModel +import org.apache.flink.contrib.tensorflow.models.savedmodel.TensorFlowModel._ +import org.apache.flink.core.fs.Path + +/** + * The MNIST_dense model. + * + * @param modelPath path to the saved model data. + */ +@SerialVersionUID(1L) +class MNIST_dense(modelPath: Path) extends TensorFlowModel[MNIST_dense] { + + override protected val loader = load(modelPath, "serve") + + /** + * Prediction returns scores for all classes + */ + def predict = ModelFunction[PredictionMethod](session(), signatureDef("predict_images").get) +} diff --git a/flink-tensorflow/src/test/scala/org/apache/flink/contrib/tensorflow/ml/PredictITCase.scala b/flink-tensorflow/src/test/scala/org/apache/flink/contrib/tensorflow/ml/PredictITCase.scala new file mode 100644 index 0000000..fdb3eb6 --- /dev/null +++ b/flink-tensorflow/src/test/scala/org/apache/flink/contrib/tensorflow/ml/PredictITCase.scala @@ -0,0 +1,97 @@ +package org.apache.flink.contrib.tensorflow.ml + +import com.twitter.bijection.Conversion._ +import org.apache.flink.api.common.functions.RichFlatMapFunction +import org.apache.flink.api.scala._ +import org.apache.flink.configuration.Configuration +import org.apache.flink.contrib.tensorflow.types.{TensorValue, TensorValueBuilder} +import org.apache.flink.contrib.tensorflow.util.{FlinkTestBase, RegistrationUtils} +import org.apache.flink.core.fs.Path +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.util.Collector +import org.apache.flink.util.Preconditions.checkState +import org.junit.runner.RunWith +import org.scalatest.junit.JUnitRunner +import org.scalatest.{Matchers, WordSpecLike} +import org.tensorflow.contrib.scala.Arrays._ +import org.tensorflow.contrib.scala.Rank._ +import org.tensorflow.contrib.scala._ + +import resource.{managed, _} +import org.tensorflow.DataType + +@RunWith(classOf[JUnitRunner]) +class PredictITCase extends WordSpecLike + with Matchers + with FlinkTestBase { + + override val parallelism = 1 + + type ImageTensorValue = TensorValue[`2D`, Float] + type LabeledImageTensorValue = (ImageTensorValue, Int) + + + def examples(): Seq[LabeledImageTensorValue] = { + val num_samples = 5 + val samples : Array[Array[Float]] = new Array[Array[Float]](num_samples) + val labels = new Array[Int](num_samples) + labels(0) = 7 + labels(1) = 2 + labels(2) = 1 + labels(3) = 0 + labels(4) = 4 + samples(0) = Array(0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,84f,185f,159f,151f,60f,36f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,222f,254f,254f,254f,254f,241f,198f,198f,198f,198f,198f,198f,198f,198f,170f,52f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,67f,114f,72f,114f,163f,227f,254f,225f,254f,254f,254f,250f,229f,254f,254f,140f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,17f,66f,14f,67f,67f,67f,59f,21f,236f,254f,106f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,83f,253f,209f,18f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,22f,233f,255f,83f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,129f,254f,238f,44f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,59f,249f,254f,62f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,133f,254f,187f,5f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,9f,205f,248f,58f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,126f,254f,182f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,75f,251f,240f,57f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,19f,221f,254f,166f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,3f,203f,254f,219f,35f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,38f,254f,254f,77f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,31f,224f,254f,115f,1f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,133f,254f,254f,52f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,61f,242f,254f,254f,52f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,121f,254f,254f,219f,40f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,121f,254f,207f,18f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f) + samples(1) = Array(0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,116f,125f,171f,255f,255f,150f,93f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,169f,253f,253f,253f,253f,253f,253f,218f,30f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,169f,253f,253f,253f,213f,142f,176f,253f,253f,122f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,52f,250f,253f,210f,32f,12f,0f,6f,206f,253f,140f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,77f,251f,210f,25f,0f,0f,0f,122f,248f,253f,65f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,31f,18f,0f,0f,0f,0f,209f,253f,253f,65f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,117f,247f,253f,198f,10f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,76f,247f,253f,231f,63f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,128f,253f,253f,144f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,176f,246f,253f,159f,12f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,25f,234f,253f,233f,35f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,198f,253f,253f,141f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,78f,248f,253f,189f,12f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,19f,200f,253f,253f,141f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,134f,253f,253f,173f,12f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,248f,253f,253f,25f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,248f,253f,253f,43f,20f,20f,20f,20f,5f,0f,5f,20f,20f,37f,150f,150f,150f,147f,10f,0f,0f,0f,0f,0f,0f,0f,0f,0f,248f,253f,253f,253f,253f,253f,253f,253f,168f,143f,166f,253f,253f,253f,253f,253f,253f,253f,123f,0f,0f,0f,0f,0f,0f,0f,0f,0f,174f,253f,253f,253f,253f,253f,253f,253f,253f,253f,253f,253f,249f,247f,247f,169f,117f,117f,57f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,118f,123f,123f,123f,166f,253f,253f,253f,155f,123f,123f,41f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f) + samples(2) = Array(0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,38f,254f,109f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,87f,252f,82f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,135f,241f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,45f,244f,150f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,84f,254f,63f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,202f,223f,11f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,32f,254f,216f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,95f,254f,195f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,140f,254f,77f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,57f,237f,205f,8f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,124f,255f,165f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,171f,254f,81f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,24f,232f,215f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,120f,254f,159f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,151f,254f,142f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,228f,254f,66f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,61f,251f,254f,66f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,141f,254f,205f,3f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,10f,215f,254f,121f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,5f,198f,176f,10f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f) + samples(3) = Array(0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,11f,150f,253f,202f,31f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,37f,251f,251f,253f,107f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,21f,197f,251f,251f,253f,107f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,110f,190f,251f,251f,251f,253f,169f,109f,62f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,253f,251f,251f,251f,251f,253f,251f,251f,220f,51f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,182f,255f,253f,253f,253f,253f,234f,222f,253f,253f,253f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,63f,221f,253f,251f,251f,251f,147f,77f,62f,128f,251f,251f,105f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,32f,231f,251f,253f,251f,220f,137f,10f,0f,0f,31f,230f,251f,243f,113f,5f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,37f,251f,251f,253f,188f,20f,0f,0f,0f,0f,0f,109f,251f,253f,251f,35f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,37f,251f,251f,201f,30f,0f,0f,0f,0f,0f,0f,31f,200f,253f,251f,35f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,37f,253f,253f,0f,0f,0f,0f,0f,0f,0f,0f,32f,202f,255f,253f,164f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,140f,251f,251f,0f,0f,0f,0f,0f,0f,0f,0f,109f,251f,253f,251f,35f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,217f,251f,251f,0f,0f,0f,0f,0f,0f,21f,63f,231f,251f,253f,230f,30f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,217f,251f,251f,0f,0f,0f,0f,0f,0f,144f,251f,251f,251f,221f,61f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,217f,251f,251f,0f,0f,0f,0f,0f,182f,221f,251f,251f,251f,180f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,218f,253f,253f,73f,73f,228f,253f,253f,255f,253f,253f,253f,253f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,113f,251f,251f,253f,251f,251f,251f,251f,253f,251f,251f,251f,147f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,31f,230f,251f,253f,251f,251f,251f,251f,253f,230f,189f,35f,10f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,62f,142f,253f,251f,251f,251f,251f,253f,107f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,72f,174f,251f,173f,71f,72f,30f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f) + samples(4) = Array(0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,50f,224f,0f,0f,0f,0f,0f,0f,0f,70f,29f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,121f,231f,0f,0f,0f,0f,0f,0f,0f,148f,168f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,4f,195f,231f,0f,0f,0f,0f,0f,0f,0f,96f,210f,11f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,69f,252f,134f,0f,0f,0f,0f,0f,0f,0f,114f,252f,21f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,45f,236f,217f,12f,0f,0f,0f,0f,0f,0f,0f,192f,252f,21f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,168f,247f,53f,0f,0f,0f,0f,0f,0f,0f,18f,255f,253f,21f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,84f,242f,211f,0f,0f,0f,0f,0f,0f,0f,0f,141f,253f,189f,5f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,169f,252f,106f,0f,0f,0f,0f,0f,0f,0f,32f,232f,250f,66f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,15f,225f,252f,0f,0f,0f,0f,0f,0f,0f,0f,134f,252f,211f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,22f,252f,164f,0f,0f,0f,0f,0f,0f,0f,0f,169f,252f,167f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,9f,204f,209f,18f,0f,0f,0f,0f,0f,0f,22f,253f,253f,107f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,169f,252f,199f,85f,85f,85f,85f,129f,164f,195f,252f,252f,106f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,41f,170f,245f,252f,252f,252f,252f,232f,231f,251f,252f,252f,9f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,49f,84f,84f,84f,84f,0f,0f,161f,252f,252f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,127f,252f,252f,45f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,128f,253f,253f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,127f,252f,252f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,135f,252f,244f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,232f,236f,111f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,179f,66f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f) + + val shapeArray = Array(1L, 784L) + + val tvb = TensorValueBuilder.newBuilder() + tvb.shape(shapeArray) + tvb.dataType(DataType.FLOAT) + + for (i <- 0 to (num_samples - 1)) + yield (tvb.data(samples(i)).build().asInstanceOf[ImageTensorValue], labels(i)) + } + + + "A PredictFunction" should { + "process elements" in { + val env = StreamExecutionEnvironment.getExecutionEnvironment + RegistrationUtils.registerTypes(env.getConfig) + + val model = new MNIST_dense(new Path("../models/mnist_dense")) + + + val outputs = env + .fromCollection(examples()) + .flatMap(new RichFlatMapFunction[LabeledImageTensorValue, Int] { + override def open(parameters: Configuration): Unit = model.open() + override def close(): Unit = model.close() + + override def flatMap(value: LabeledImageTensorValue, out: Collector[Int]): Unit = { + + for { + x <- managed(value._1.toTensor()) + y <- model.predict(x) + } { + // cast as a 1D tensor to use the available conversion + val o = y.taggedAs[TypedTensor[`1D`,Float]].as[Array[Float]] + + // prediction is the index of the max value returned + val max_index = o.zipWithIndex.maxBy(_._1)._2 + + checkState(max_index == value._2) + + out.collect(max_index) + } + } + }) + .print() + + env.execute() + } + } +} diff --git a/models/mnist_dense/saved_model.pb b/models/mnist_dense/saved_model.pb new file mode 100644 index 0000000..09cc1c4 Binary files /dev/null and b/models/mnist_dense/saved_model.pb differ diff --git a/models/mnist_dense/variables/variables.data-00000-of-00001 b/models/mnist_dense/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000..d34bc9f Binary files /dev/null and b/models/mnist_dense/variables/variables.data-00000-of-00001 differ diff --git a/models/mnist_dense/variables/variables.index b/models/mnist_dense/variables/variables.index new file mode 100644 index 0000000..a5136fe Binary files /dev/null and b/models/mnist_dense/variables/variables.index differ