4040import json
4141from typing import Any
4242
43+ from etils import enp
4344from etils import epath
4445import numpy as np
4546from tensorflow_datasets .core import dataset_builder
@@ -79,8 +80,7 @@ def _strip_record_set_prefix(
7980def array_datatype_converter (
8081 feature : type_utils .TfdsDType | feature_lib .FeatureConnector | None ,
8182 field : mlc .Field ,
82- int_dtype : type_utils .TfdsDType = np .int64 ,
83- float_dtype : type_utils .TfdsDType = np .float32 ,
83+ dtype_mapping : Mapping [type_utils .TfdsDType , type_utils .TfdsDType ],
8484):
8585 """Includes the given feature in a sequence or tensor feature.
8686
@@ -91,32 +91,28 @@ def array_datatype_converter(
9191 Args:
9292 feature: The inner feature to include in a sequence or tensor feature.
9393 field: The mlc.Field object.
94- int_dtype: The dtype to use for TFDS integer features. Defaults to np.int64.
95- float_dtype: The dtype to use for TFDS float features. Defaults to
96- np.float32.
94+ dtype_mapping: A mapping of dtypes to the corresponding dtypes that will be
95+ used in TFDS.
9796
9897 Returns:
9998 A sequence or tensor feature including the inner feature.
10099 """
101- dtype_mapping = {
102- int : int_dtype ,
103- float : float_dtype ,
104- bool : np .bool_ ,
105- bytes : np .str_ ,
106- }
107- dtype = dtype_mapping .get (field .data_type , None )
100+ field_dtype = None
101+ if field .data_type in dtype_mapping :
102+ field_dtype = dtype_mapping [field .data_type ]
103+ elif enp .lazy .is_np_dtype (field .data_type ):
104+ field_dtype = field .data_type
105+
108106 if len (field .array_shape_tuple ) == 1 :
109107 return sequence_feature .Sequence (feature , doc = field .description )
110- elif (- 1 in field .array_shape_tuple ) or (
111- field .data_type not in dtype_mapping
112- ):
108+ elif (- 1 in field .array_shape_tuple ) or (field_dtype is None ):
113109 for _ in range (len (field .array_shape_tuple )):
114110 feature = sequence_feature .Sequence (feature , doc = field .description )
115111 return feature
116112 else :
117113 return tensor_feature .Tensor (
118114 shape = field .array_shape_tuple ,
119- dtype = dtype ,
115+ dtype = field_dtype ,
120116 doc = field .description ,
121117 )
122118
@@ -142,8 +138,15 @@ def datatype_converter(
142138 """
143139 if field .is_enumeration :
144140 raise NotImplementedError ('Not implemented yet.' )
141+ dtype_mapping = {
142+ bool : np .bool_ ,
143+ bytes : np .str_ ,
144+ float : float_dtype ,
145+ int : int_dtype ,
146+ }
145147
146148 field_data_type = field .data_type
149+
147150 if not field_data_type :
148151 # Fields with sub fields are of type None
149152 if field .sub_fields :
@@ -158,14 +161,12 @@ def datatype_converter(
158161 )
159162 else :
160163 feature = None
161- elif field_data_type == int :
162- feature = int_dtype
163- elif field_data_type == float :
164- feature = float_dtype
165- elif field_data_type == bool :
166- feature = np .bool_
167164 elif field_data_type == bytes :
168165 feature = text_feature .Text (doc = field .description )
166+ elif field_data_type in dtype_mapping :
167+ feature = dtype_mapping [field_data_type ]
168+ elif enp .lazy .is_np_dtype (field_data_type ):
169+ feature = field_data_type
169170 # We return a text feature for mlc.DataType.DATE features.
170171 elif field_data_type == pd .Timestamp :
171172 feature = text_feature .Text (doc = field .description )
@@ -183,8 +184,7 @@ def datatype_converter(
183184 feature = array_datatype_converter (
184185 feature = feature ,
185186 field = field ,
186- int_dtype = int_dtype ,
187- float_dtype = float_dtype ,
187+ dtype_mapping = dtype_mapping ,
188188 )
189189 # If the field is repeated, we return a sequence feature. `field.repeated` is
190190 # deprecated starting from Croissant 1.1, but we still support it for
0 commit comments