2222from tensorflow .python .data .ops import readers
2323from tensorflow .python .framework import dtypes
2424from tensorflow .python .framework import ops
25+ from tensorflow .python .framework import tensor_shape
2526from tensorflow .python .framework import tensor_spec
2627from tensorflow .python .framework import type_spec
2728from tensorflow .python .util import nest
@@ -38,25 +39,23 @@ class DataFrameValueSpec(type_spec.BatchableTypeSpec):
3839 def value_type (self ):
3940 return DataFrame .Value if self ._ragged_rank > 0 else ops .Tensor
4041
41- def __init__ (self , field , batch_size = None ):
42+ def __init__ (self , field ):
4243 """Constructs a type specification for a `tf.RaggedTensor`.
4344
4445 Args:
4546 field: The field definition.
46- batch_size: The batch_size of DataFrame.
4747 """
4848 if field .incomplete :
4949 raise ValueError (
5050 f'Field { field } is incomplete, please specify dtype and ragged_rank' )
5151 self ._field = field
52- self ._batch_size = batch_size
5352
5453 def _serialize (self ):
5554 return (self ._field .dtype , self ._field .ragged_rank )
5655
5756 @property
5857 def _component_specs (self ):
59- return self ._field .output_specs ( self . _batch_size )
58+ return self ._field .output_specs
6059
6160 def _to_components (self , value ):
6261 if isinstance (value , DataFrame .Value ):
@@ -80,7 +79,7 @@ def _to_legacy_output_types(self):
8079 return self ._field .output_types
8180
8281 def _to_legacy_output_shapes (self ):
83- return self ._field .output_shapes ( self . _batch_size )
82+ return self ._field .output_shapes
8483
8584 def _to_legacy_output_classes (self ):
8685 return self ._field .output_classes
@@ -110,13 +109,18 @@ def __init__(
110109 self ._batch_size = ops .convert_to_tensor (
111110 batch_size , dtype = dtypes .int64 , name = 'batch_size' )
112111 self ._fields = fields
113- self ._output_specs = {
114- f .name : (
115- DataFrameValueSpec (f , batch_size if drop_remainder else None )
116- if f .ragged_rank > 0
117- else tensor_spec .TensorSpec (
118- shape = [batch_size if drop_remainder else None ], dtype = f .dtype ))
119- for f in self ._fields }
112+ self ._output_specs = {}
113+ for f in self ._fields :
114+ item = None
115+ if f .ragged_rank > 0 :
116+ item = DataFrameValueSpec (f )
117+ else :
118+ shape = tensor_shape .vector (batch_size if drop_remainder else None )
119+ if f .shape :
120+ shape = shape .concatenate (f .shape )
121+ item = tensor_spec .TensorSpec (shape = shape , dtype = f .dtype )
122+ self ._output_specs [f .name ] = item
123+
120124 self ._field_names = nest .flatten ({f .name : f .name for f in self ._fields })
121125 self ._field_dtypes = nest .flatten ({f .name : f .dtype for f in self ._fields })
122126 self ._field_ragged_ranks = nest .flatten (
0 commit comments