Skip to content

Commit a5c014f

Browse files
authored
[IO] Fix tensor shape meta-data bug for DataFrame Value. (#958)
* Revert "[IO] Add tensor shape meta-data support for ParquetDataset. (#849)" * [IO] Fix tensor shape meta-data bug for DataFrame Value. Signed-off-by: chenbangduo.cbd <chenbangduo.cbd@alibaba-inc.com>
1 parent 7ce8477 commit a5c014f

File tree

2 files changed

+28
-26
lines changed

2 files changed

+28
-26
lines changed

tensorflow/python/data/experimental/ops/dataframe.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,14 @@ def __init__(self, name, dtype=None, ragged_rank=None, shape=None):
5959
self._ragged_rank = ragged_rank
6060
if shape:
6161
shape = tensor_shape.TensorShape(shape)
62-
shape_rank = 0
63-
for _ in shape:
64-
shape_rank += 1
65-
if ragged_rank is not None and ragged_rank != shape_rank:
62+
for d in shape:
63+
if d.value is None:
64+
raise ValueError(
65+
f'Field {name} has incomplete shape: {shape}')
66+
if ragged_rank is not None and ragged_rank > 1:
6667
raise ValueError(
6768
f'Field {name} is a nested list ({ragged_rank}) '
6869
f'with shape {shape}')
69-
self._ragged_rank = shape_rank
70-
elif ragged_rank is not None:
71-
shape = tensor_shape.TensorShape([None for _ in xrange(ragged_rank)])
72-
7370
self._shape = shape
7471

7572
@property
@@ -134,16 +131,17 @@ def output_classes(self):
134131
def output_types(self):
135132
return self.map(lambda i: self._dtype if i == 0 else dtypes.int32)
136133

137-
def output_shapes(self, batch_size=None):
134+
@property
135+
def output_shapes(self):
138136
if self._shape is None:
139-
return self.map(lambda i: tensor_shape.vector(batch_size) if i == 0
140-
else tensor_shape.vector(None))
137+
return self.map(lambda _: tensor_shape.vector(None))
141138
return self.map(
142-
lambda i: tensor_shape.vector(batch_size).concatenate(self._shape) if i == 0
139+
lambda i: tensor_shape.vector(None).concatenate(self._shape) if i == 0
143140
else tensor_shape.vector(None))
144141

145-
def output_specs(self, batch_size=None):
146-
shape = tensor_shape.vector(batch_size)
142+
@property
143+
def output_specs(self):
144+
shape = tensor_shape.vector(None)
147145
if self._shape is not None:
148146
shape = shape.concatenate(self._shape)
149147
specs = [tensor_spec.TensorSpec(shape, dtype=self._dtype)]

tensorflow/python/data/experimental/ops/parquet_dataset_ops.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorflow.python.data.ops import readers
2323
from tensorflow.python.framework import dtypes
2424
from tensorflow.python.framework import ops
25+
from tensorflow.python.framework import tensor_shape
2526
from tensorflow.python.framework import tensor_spec
2627
from tensorflow.python.framework import type_spec
2728
from tensorflow.python.util import nest
@@ -38,25 +39,23 @@ class DataFrameValueSpec(type_spec.BatchableTypeSpec):
3839
def value_type(self):
3940
return DataFrame.Value if self._ragged_rank > 0 else ops.Tensor
4041

41-
def __init__(self, field, batch_size=None):
42+
def __init__(self, field):
4243
"""Constructs a type specification for a `tf.RaggedTensor`.
4344
4445
Args:
4546
field: The field definition.
46-
batch_size: The batch_size of DataFrame.
4747
"""
4848
if field.incomplete:
4949
raise ValueError(
5050
f'Field {field} is incomplete, please specify dtype and ragged_rank')
5151
self._field = field
52-
self._batch_size = batch_size
5352

5453
def _serialize(self):
5554
return (self._field.dtype, self._field.ragged_rank)
5655

5756
@property
5857
def _component_specs(self):
59-
return self._field.output_specs(self._batch_size)
58+
return self._field.output_specs
6059

6160
def _to_components(self, value):
6261
if isinstance(value, DataFrame.Value):
@@ -80,7 +79,7 @@ def _to_legacy_output_types(self):
8079
return self._field.output_types
8180

8281
def _to_legacy_output_shapes(self):
83-
return self._field.output_shapes(self._batch_size)
82+
return self._field.output_shapes
8483

8584
def _to_legacy_output_classes(self):
8685
return self._field.output_classes
@@ -110,13 +109,18 @@ def __init__(
110109
self._batch_size = ops.convert_to_tensor(
111110
batch_size, dtype=dtypes.int64, name='batch_size')
112111
self._fields = fields
113-
self._output_specs = {
114-
f.name: (
115-
DataFrameValueSpec(f, batch_size if drop_remainder else None)
116-
if f.ragged_rank > 0
117-
else tensor_spec.TensorSpec(
118-
shape=[batch_size if drop_remainder else None], dtype=f.dtype))
119-
for f in self._fields}
112+
self._output_specs = {}
113+
for f in self._fields:
114+
item = None
115+
if f.ragged_rank > 0:
116+
item = DataFrameValueSpec(f)
117+
else:
118+
shape = tensor_shape.vector(batch_size if drop_remainder else None)
119+
if f.shape:
120+
shape = shape.concatenate(f.shape)
121+
item = tensor_spec.TensorSpec(shape=shape, dtype=f.dtype)
122+
self._output_specs[f.name] = item
123+
120124
self._field_names = nest.flatten({f.name: f.name for f in self._fields})
121125
self._field_dtypes = nest.flatten({f.name: f.dtype for f in self._fields})
122126
self._field_ragged_ranks = nest.flatten(

0 commit comments

Comments
 (0)