Skip to content

Commit fb1282d

Browse files
authored
support partially-known and unknown shape specification in decode_json (#996)
This PR tries to address the isseu in 918 by adding support for partially-known and unknown shape specification in decode_json. In this PR, the shapes are not passed in C++ kernels anymore. Instead, the C++ kernels will always render 1-D output tensor first with a follow-up tf.reshape to fixup the shape later. This PR fixes 918. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
1 parent 717a0d6 commit fb1282d

File tree

4 files changed

+43
-40
lines changed

4 files changed

+43
-40
lines changed

tensorflow_io/core/kernels/serialization_kernels.cc

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "tensorflow/core/framework/resource_mgr.h"
17-
#include "tensorflow/core/framework/resource_op_kernel.h"
18-
19-
#include "rapidjson/document.h"
20-
#include "rapidjson/pointer.h"
21-
2216
#include "api/Compiler.hh"
2317
#include "api/DataFile.hh"
2418
#include "api/Generic.hh"
2519
#include "api/Stream.hh"
2620
#include "api/Validator.hh"
21+
#include "rapidjson/document.h"
22+
#include "rapidjson/pointer.h"
23+
#include "tensorflow/core/framework/op_kernel.h"
2724

2825
namespace tensorflow {
2926
namespace data {
@@ -33,7 +30,6 @@ class DecodeJSONOp : public OpKernel {
3330
public:
3431
explicit DecodeJSONOp(OpKernelConstruction* context) : OpKernel(context) {
3532
env_ = context->env();
36-
OP_REQUIRES_OK(context, context->GetAttr("shapes", &shapes_));
3733
}
3834

3935
void Compute(OpKernelContext* context) override {
@@ -45,28 +41,26 @@ class DecodeJSONOp : public OpKernel {
4541
const Tensor* names_tensor;
4642
OP_REQUIRES_OK(context, context->input("names", &names_tensor));
4743

48-
OP_REQUIRES(context, (names_tensor->NumElements() == shapes_.size()),
49-
errors::InvalidArgument(
50-
"shapes and names should have same number: ",
51-
shapes_.size(), " vs. ", names_tensor->NumElements()));
44+
OP_REQUIRES(
45+
context, (names_tensor->NumElements() == context->num_outputs()),
46+
errors::InvalidArgument("names should have same number as outputs: ",
47+
names_tensor->NumElements(), " vs. ",
48+
context->num_outputs()));
5249
rapidjson::Document d;
5350
d.Parse(input.c_str());
5451
OP_REQUIRES(context, d.IsObject(),
5552
errors::InvalidArgument("not a valid JSON object"));
56-
for (size_t i = 0; i < shapes_.size(); i++) {
57-
Tensor* value_tensor;
58-
OP_REQUIRES_OK(context,
59-
context->allocate_output(i, shapes_[i], &value_tensor));
53+
for (size_t i = 0; i < names_tensor->NumElements(); i++) {
6054
rapidjson::Value* entry =
6155
rapidjson::Pointer(names_tensor->flat<tstring>()(i).c_str()).Get(d);
6256
OP_REQUIRES(context, (entry != nullptr),
6357
errors::InvalidArgument("no value for ",
6458
names_tensor->flat<tstring>()(i)));
59+
Tensor* value_tensor;
6560
if (entry->IsArray()) {
66-
OP_REQUIRES(context, entry->Size() == value_tensor->NumElements(),
67-
errors::InvalidArgument(
68-
"number of elements in JSON does not match spec: ",
69-
entry->Size(), " vs. ", value_tensor->NumElements()));
61+
OP_REQUIRES_OK(context,
62+
context->allocate_output(i, TensorShape({entry->Size()}),
63+
&value_tensor));
7064

7165
switch (value_tensor->dtype()) {
7266
case DT_INT32:
@@ -103,21 +97,23 @@ class DecodeJSONOp : public OpKernel {
10397
}
10498

10599
} else {
100+
OP_REQUIRES_OK(context, context->allocate_output(i, TensorShape({1}),
101+
&value_tensor));
106102
switch (value_tensor->dtype()) {
107103
case DT_INT32:
108-
value_tensor->scalar<int32>()() = entry->GetInt();
104+
value_tensor->flat<int32>()(0) = entry->GetInt();
109105
break;
110106
case DT_INT64:
111-
value_tensor->scalar<int64>()() = entry->GetInt64();
107+
value_tensor->flat<int64>()(0) = entry->GetInt64();
112108
break;
113109
case DT_FLOAT:
114-
value_tensor->scalar<float>()() = entry->GetDouble();
110+
value_tensor->flat<float>()(0) = entry->GetDouble();
115111
break;
116112
case DT_DOUBLE:
117-
value_tensor->scalar<double>()() = entry->GetDouble();
113+
value_tensor->flat<double>()(0) = entry->GetDouble();
118114
break;
119115
case DT_STRING:
120-
value_tensor->scalar<tstring>()() = entry->GetString();
116+
value_tensor->flat<tstring>()(0) = entry->GetString();
121117
break;
122118
default:
123119
OP_REQUIRES(
@@ -133,7 +129,6 @@ class DecodeJSONOp : public OpKernel {
133129
private:
134130
mutable mutex mu_;
135131
Env* env_ TF_GUARDED_BY(mu_);
136-
std::vector<TensorShape> shapes_ TF_GUARDED_BY(mu_);
137132
};
138133

139134
class DecodeAvroOp : public OpKernel {

tensorflow_io/core/ops/serialization_ops.cc

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,13 @@ REGISTER_OP("IO>DecodeJSON")
2525
.Input("input: string")
2626
.Input("names: string")
2727
.Output("value: dtypes")
28-
.Attr("shapes: list(shape)")
2928
.Attr("dtypes: list(type)")
3029
.SetShapeFn([](shape_inference::InferenceContext* c) {
3130
// TODO: support batch (1-D) input
3231
shape_inference::ShapeHandle unused;
3332
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &unused));
34-
std::vector<TensorShape> shapes;
35-
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
36-
if (shapes.size() != c->num_outputs()) {
37-
return errors::InvalidArgument(
38-
"shapes and types should be the same: ", shapes.size(), " vs. ",
39-
c->num_outputs());
40-
}
41-
for (size_t i = 0; i < shapes.size(); ++i) {
42-
shape_inference::ShapeHandle shape;
43-
TF_RETURN_IF_ERROR(
44-
c->MakeShapeFromPartialTensorShape(shapes[i], &shape));
45-
c->set_output(static_cast<int64>(i), shape);
33+
for (size_t i = 0; i < c->num_outputs(); ++i) {
34+
c->set_output(static_cast<int64>(i), c->MakeShape({c->UnknownDim()}));
4635
}
4736
return Status::OK();
4837
});

tensorflow_io/core/python/experimental/serialization_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,14 @@ def decode_json(data, specs, name=None):
6767
named_spec(named)
6868
named = tf.nest.flatten(named)
6969
names = [e.named() for e in named]
70-
shapes = [e.shape for e in named]
70+
shapes = [
71+
tf.constant([-1 if d is None else d for d in e.shape.as_list()], tf.int32)
72+
for e in named
73+
]
7174
dtypes = [e.dtype for e in named]
7275

73-
values = core_ops.io_decode_json(data, names, shapes, dtypes, name=name)
76+
values = core_ops.io_decode_json(data, names, dtypes, name=name)
77+
values = [tf.reshape(value, shape) for value, shape in zip(values, shapes)]
7478
return tf.nest.pack_sequence_as(specs, values)
7579

7680

tests/test_serialization_eager.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Test Serialization"""
1616

1717
import os
18+
import json
1819
import numpy as np
1920

2021
import pytest
@@ -185,3 +186,17 @@ def test_serialization_decode_in_dataset(
185186
for v, r in zip(tf.nest.flatten(value), tf.nest.flatten(returned))
186187
]
187188
)
189+
190+
191+
def test_json_partial_shape():
192+
"""Test case for partial shape GitHub 918."""
193+
r = json.dumps({"foo": [1, 2, 3, 4, 5]})
194+
195+
@tf.function(autograph=False)
196+
def parse_json(json_text):
197+
specs = {"foo": tf.TensorSpec(tf.TensorShape([None]), tf.int32)}
198+
parsed = tfio.experimental.serialization.decode_json(json_text, specs)
199+
return parsed["foo"]
200+
201+
v = parse_json(r)
202+
assert np.array_equal(v, [1, 2, 3, 4, 5])

0 commit comments

Comments
 (0)