@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- #include " tensorflow/core/framework/resource_mgr.h"
17- #include " tensorflow/core/framework/resource_op_kernel.h"
18-
19- #include " rapidjson/document.h"
20- #include " rapidjson/pointer.h"
21-
2216#include " api/Compiler.hh"
2317#include " api/DataFile.hh"
2418#include " api/Generic.hh"
2519#include " api/Stream.hh"
2620#include " api/Validator.hh"
21+ #include " rapidjson/document.h"
22+ #include " rapidjson/pointer.h"
23+ #include " tensorflow/core/framework/op_kernel.h"
2724
2825namespace tensorflow {
2926namespace data {
@@ -33,7 +30,6 @@ class DecodeJSONOp : public OpKernel {
3330 public:
3431 explicit DecodeJSONOp (OpKernelConstruction* context) : OpKernel(context) {
3532 env_ = context->env ();
36- OP_REQUIRES_OK (context, context->GetAttr (" shapes" , &shapes_));
3733 }
3834
3935 void Compute (OpKernelContext* context) override {
@@ -45,28 +41,26 @@ class DecodeJSONOp : public OpKernel {
4541 const Tensor* names_tensor;
4642 OP_REQUIRES_OK (context, context->input (" names" , &names_tensor));
4743
48- OP_REQUIRES (context, (names_tensor->NumElements () == shapes_.size ()),
49- errors::InvalidArgument (
50- " shapes and names should have same number: " ,
51- shapes_.size (), " vs. " , names_tensor->NumElements ()));
44+ OP_REQUIRES (
45+ context, (names_tensor->NumElements () == context->num_outputs ()),
46+ errors::InvalidArgument (" names should have same number as outputs: " ,
47+ names_tensor->NumElements (), " vs. " ,
48+ context->num_outputs ()));
5249 rapidjson::Document d;
5350 d.Parse (input.c_str ());
5451 OP_REQUIRES (context, d.IsObject (),
5552 errors::InvalidArgument (" not a valid JSON object" ));
56- for (size_t i = 0 ; i < shapes_.size (); i++) {
57- Tensor* value_tensor;
58- OP_REQUIRES_OK (context,
59- context->allocate_output (i, shapes_[i], &value_tensor));
53+ for (size_t i = 0 ; i < names_tensor->NumElements (); i++) {
6054 rapidjson::Value* entry =
6155 rapidjson::Pointer (names_tensor->flat <tstring>()(i).c_str ()).Get (d);
6256 OP_REQUIRES (context, (entry != nullptr ),
6357 errors::InvalidArgument (" no value for " ,
6458 names_tensor->flat <tstring>()(i)));
59+ Tensor* value_tensor;
6560 if (entry->IsArray ()) {
66- OP_REQUIRES (context, entry->Size () == value_tensor->NumElements (),
67- errors::InvalidArgument (
68- " number of elements in JSON does not match spec: " ,
69- entry->Size (), " vs. " , value_tensor->NumElements ()));
61+ OP_REQUIRES_OK (context,
62+ context->allocate_output (i, TensorShape ({entry->Size ()}),
63+ &value_tensor));
7064
7165 switch (value_tensor->dtype ()) {
7266 case DT_INT32:
@@ -103,21 +97,23 @@ class DecodeJSONOp : public OpKernel {
10397 }
10498
10599 } else {
100+ OP_REQUIRES_OK (context, context->allocate_output (i, TensorShape ({1 }),
101+ &value_tensor));
106102 switch (value_tensor->dtype ()) {
107103 case DT_INT32:
108- value_tensor->scalar <int32>()() = entry->GetInt ();
104+ value_tensor->flat <int32>()(0 ) = entry->GetInt ();
109105 break ;
110106 case DT_INT64:
111- value_tensor->scalar <int64>()() = entry->GetInt64 ();
107+ value_tensor->flat <int64>()(0 ) = entry->GetInt64 ();
112108 break ;
113109 case DT_FLOAT:
114- value_tensor->scalar <float >()() = entry->GetDouble ();
110+ value_tensor->flat <float >()(0 ) = entry->GetDouble ();
115111 break ;
116112 case DT_DOUBLE:
117- value_tensor->scalar <double >()() = entry->GetDouble ();
113+ value_tensor->flat <double >()(0 ) = entry->GetDouble ();
118114 break ;
119115 case DT_STRING:
120- value_tensor->scalar <tstring>()() = entry->GetString ();
116+ value_tensor->flat <tstring>()(0 ) = entry->GetString ();
121117 break ;
122118 default :
123119 OP_REQUIRES (
@@ -133,7 +129,6 @@ class DecodeJSONOp : public OpKernel {
133129 private:
134130 mutable mutex mu_;
135131 Env* env_ TF_GUARDED_BY (mu_);
136- std::vector<TensorShape> shapes_ TF_GUARDED_BY (mu_);
137132};
138133
139134class DecodeAvroOp : public OpKernel {
0 commit comments