5252from tensorflow_datasets .core .features import features_dict
5353from tensorflow_datasets .core .features import image_feature
5454from tensorflow_datasets .core .features import sequence_feature
55+ from tensorflow_datasets .core .features import tensor_feature
5556from tensorflow_datasets .core .features import text_feature
5657from tensorflow_datasets .core .utils import conversion_utils
5758from tensorflow_datasets .core .utils import croissant_utils
@@ -75,6 +76,51 @@ def _strip_record_set_prefix(
7576 }
7677
7778
79+ def array_datatype_converter (
80+ feature : type_utils .TfdsDType | feature_lib .FeatureConnector | None ,
81+ field : mlc .Field ,
82+ int_dtype : type_utils .TfdsDType = np .int64 ,
83+ float_dtype : type_utils .TfdsDType = np .float32 ,
84+ ):
85+ """Includes the given feature in a sequence or tensor feature.
86+
87+ Single-dimensional arrays are converted to sequences. Multi-dimensional arrays
88+ with unknown dimensions, or with non-native dtypes are converted to sequences
89+ of sequences. Otherwise, they are converted to tensors.
90+
91+ Args:
92+ feature: The inner feature to include in a sequence or tensor feature.
93+ field: The mlc.Field object.
94+ int_dtype: The dtype to use for TFDS integer features. Defaults to np.int64.
95+ float_dtype: The dtype to use for TFDS float features. Defaults to
96+ np.float32.
97+
98+ Returns:
99+ A sequence or tensor feature including the inner feature.
100+ """
101+ dtype_mapping = {
102+ int : int_dtype ,
103+ float : float_dtype ,
104+ bool : np .bool_ ,
105+ bytes : np .str_ ,
106+ }
107+ dtype = dtype_mapping .get (field .data_type , None )
108+ if len (field .array_shape_tuple ) == 1 :
109+ return sequence_feature .Sequence (feature , doc = field .description )
110+ elif (- 1 in field .array_shape_tuple ) or (
111+ field .data_type not in dtype_mapping
112+ ):
113+ for _ in range (len (field .array_shape_tuple )):
114+ feature = sequence_feature .Sequence (feature , doc = field .description )
115+ return feature
116+ else :
117+ return tensor_feature .Tensor (
118+ shape = field .array_shape_tuple ,
119+ dtype = dtype ,
120+ doc = field .description ,
121+ )
122+
123+
78124def datatype_converter (
79125 field : mlc .Field ,
80126 int_dtype : type_utils .TfdsDType = np .int64 ,
@@ -133,6 +179,16 @@ def datatype_converter(
133179 else :
134180 raise ValueError (f'Unknown data type: { field_data_type } .' )
135181
182+ if feature and field .is_array :
183+ feature = array_datatype_converter (
184+ feature = feature ,
185+ field = field ,
186+ int_dtype = int_dtype ,
187+ float_dtype = float_dtype ,
188+ )
189+ # If the field is repeated, we return a sequence feature. `field.repeated` is
190+ # deprecated starting from Croissant 1.1, but we still support it for
191+ # backwards compatibility.
136192 if feature and field .repeated :
137193 feature = sequence_feature .Sequence (feature , doc = field .description )
138194 return feature
0 commit comments