66
77from typing import List
88
9+ import numpy as np
10+ import pandas as pd
11+
912from ads .common .decorator .runtime_dependency import OptionalDependency
13+ from ads .feature_store .common .enums import FeatureType
1014
1115try :
1216 from pyspark .sql .types import *
@@ -25,65 +29,153 @@ def map_spark_type_to_feature_type(spark_type):
2529 :return:
2630 """
2731 spark_type_to_feature_type = {
28- StringType (): "string" ,
29- IntegerType (): "integer" ,
30- FloatType (): "float" ,
31- DoubleType (): "double" ,
32- BooleanType (): "boolean" ,
33- DateType (): "date" ,
34- TimestampType (): "timestamp" ,
35- DecimalType (): "decimal" ,
36- BinaryType (): "binary" ,
37- ArrayType (StringType ()): "array" ,
38- MapType (StringType (), StringType ()): "map" ,
39- StructType (): "struct" ,
40- ByteType (): "byte" ,
41- ShortType (): "short" ,
42- LongType (): "long" ,
32+ StringType (): FeatureType .STRING ,
33+ IntegerType (): FeatureType .INTEGER ,
34+ ShortType (): FeatureType .SHORT ,
35+ LongType (): FeatureType .LONG ,
36+ FloatType (): FeatureType .FLOAT ,
37+ DoubleType (): FeatureType .DOUBLE ,
38+ BooleanType (): FeatureType .BOOLEAN ,
39+ DateType (): FeatureType .DATE ,
40+ TimestampType (): FeatureType .TIMESTAMP ,
41+ BinaryType (): FeatureType .BINARY ,
42+ ByteType (): FeatureType .BYTE ,
43+ ArrayType (StringType ()): FeatureType .STRING_ARRAY ,
44+ ArrayType (IntegerType ()): FeatureType .INTEGER_ARRAY ,
45+ ArrayType (LongType ()): FeatureType .LONG_ARRAY ,
46+ ArrayType (FloatType ()): FeatureType .FLOAT_ARRAY ,
47+ ArrayType (DoubleType ()): FeatureType .DOUBLE_ARRAY ,
48+ ArrayType (BinaryType ()): FeatureType .BINARY_ARRAY ,
49+ ArrayType (DateType ()): FeatureType .DATE_ARRAY ,
50+ ArrayType (TimestampType ()): FeatureType .TIMESTAMP_ARRAY ,
51+ ArrayType (ByteType ()): FeatureType .BYTE_ARRAY ,
52+ ArrayType (BooleanType ()): FeatureType .BOOLEAN_ARRAY ,
53+ ArrayType (ShortType ()): FeatureType .SHORT_ARRAY ,
54+ MapType (StringType (), StringType ()): FeatureType .STRING_STRING_MAP ,
55+ MapType (StringType (), IntegerType ()): FeatureType .STRING_INTEGER_MAP ,
56+ MapType (StringType (), ShortType ()): FeatureType .STRING_SHORT_MAP ,
57+ MapType (StringType (), LongType ()): FeatureType .STRING_LONG_MAP ,
58+ MapType (StringType (), FloatType ()): FeatureType .STRING_FLOAT_MAP ,
59+ MapType (StringType (), DoubleType ()): FeatureType .STRING_DOUBLE_MAP ,
60+ MapType (StringType (), TimestampType ()): FeatureType .STRING_TIMESTAMP_MAP ,
61+ MapType (StringType (), DateType ()): FeatureType .STRING_DATE_MAP ,
62+ MapType (StringType (), BinaryType ()): FeatureType .STRING_BINARY_MAP ,
63+ MapType (StringType (), ByteType ()): FeatureType .STRING_BYTE_MAP ,
64+ MapType (StringType (), BooleanType ()): FeatureType .STRING_BOOLEAN_MAP ,
4365 }
44-
45- return spark_type_to_feature_type .get (spark_type ).upper ()
66+ if spark_type in spark_type_to_feature_type :
67+ return spark_type_to_feature_type .get (spark_type )
68+ else :
69+ return FeatureType .UNKNOWN
70+
71+
72+ def map_pandas_type_to_feature_type (feature_name , values ):
73+ pandas_type = str (values .dtype )
74+ inferred_dtype = FeatureType .UNKNOWN
75+ if pandas_type is "object" :
76+ for row in values :
77+ if isinstance (row , (list , np .ndarray )):
78+ raise TypeError (f"object of type { type (row )} not supported" )
79+ pandas_basic_type = type (row ).__name__
80+ current_dtype = map_pandas_basic_type_to_feature_type (pandas_basic_type )
81+ if inferred_dtype is FeatureType .UNKNOWN :
82+ inferred_dtype = current_dtype
83+ else :
84+ if (
85+ current_dtype != inferred_dtype
86+ and current_dtype is not FeatureType .UNKNOWN
87+ ):
88+ raise TypeError (
89+ f"Input feature '{ feature_name } ' has mixed types, { current_dtype } and { inferred_dtype } . "
90+ f"That is not allowed. "
91+ )
92+ else :
93+ inferred_dtype = map_pandas_basic_type_to_feature_type (pandas_type )
94+ if inferred_dtype is FeatureType .UNKNOWN :
95+ raise TypeError (
96+ f"Input feature '{ feature_name } ' has type { str (pandas_type )} which is not supported"
97+ )
98+ else :
99+ return inferred_dtype
46100
47101
48- def map_pandas_type_to_feature_type (pandas_type ):
102+ def map_pandas_basic_type_to_feature_type (pandas_type ):
49103 """Returns the feature type corresponding to pandas_type
50104 :param pandas_type:
51105 :return:
52106 """
107+ # TODO uint64 with bigger number cant be mapped to LongType
53108 pandas_type_to_feature_type = {
54- "object" : "string" ,
55- "int64" : "integer" ,
56- "float64" : "float" ,
57- "bool" : "boolean" ,
109+ "str" : FeatureType .STRING ,
110+ "string" : FeatureType .STRING ,
111+ "int" : FeatureType .INTEGER ,
112+ "int8" : FeatureType .INTEGER ,
113+ "int16" : FeatureType .INTEGER ,
114+ "int32" : FeatureType .LONG ,
115+ "int64" : FeatureType .LONG ,
116+ "uint8" : FeatureType .INTEGER ,
117+ "uint16" : FeatureType .INTEGER ,
118+ "uint32" : FeatureType .LONG ,
119+ "uint64" : FeatureType .LONG ,
120+ "float" : FeatureType .FLOAT ,
121+ "float16" : FeatureType .FLOAT ,
122+ "float32" : FeatureType .DOUBLE ,
123+ "float64" : FeatureType .DOUBLE ,
124+ "datetime64[ns]" : FeatureType .TIMESTAMP ,
125+ "datetime64[ns, UTC]" : FeatureType .TIMESTAMP ,
126+ "timedelta64[ns]" : FeatureType .LONG ,
127+ "bool" : FeatureType .BOOLEAN ,
128+ "Decimal" : FeatureType .DECIMAL ,
129+ "date" : FeatureType .DATE ,
58130 }
59-
60- return pandas_type_to_feature_type .get (pandas_type ).upper ()
131+ if pandas_type in pandas_type_to_feature_type :
132+ return pandas_type_to_feature_type .get (pandas_type )
133+ return FeatureType .UNKNOWN
61134
62135
63136def map_feature_type_to_spark_type (feature_type ):
64137 """Returns the Spark Type for a particular feature type.
65138 :param feature_type:
66139 :return: Spark Type
67140 """
141+ feature_type_in = FeatureType (feature_type )
68142 spark_types = {
69- "string" : StringType (),
70- "integer" : IntegerType (),
71- "float" : FloatType (),
72- "double" : DoubleType (),
73- "boolean" : BooleanType (),
74- "date" : DateType (),
75- "timestamp" : TimestampType (),
76- "decimal" : DecimalType (),
77- "binary" : BinaryType (),
78- "array" : ArrayType (StringType ()),
79- "map" : MapType (StringType (), StringType ()),
80- "struct" : StructType (),
81- "byte" : ByteType (),
82- "short" : ShortType (),
83- "long" : LongType (),
143+ FeatureType .STRING : StringType (),
144+ FeatureType .SHORT : ShortType (),
145+ FeatureType .INTEGER : IntegerType (),
146+ FeatureType .LONG : LongType (),
147+ FeatureType .FLOAT : FloatType (),
148+ FeatureType .DOUBLE : DoubleType (),
149+ FeatureType .BOOLEAN : BooleanType (),
150+ FeatureType .DATE : DateType (),
151+ FeatureType .TIMESTAMP : TimestampType (),
152+ FeatureType .DECIMAL : DecimalType (),
153+ FeatureType .BINARY : BinaryType (),
154+ FeatureType .STRING_ARRAY : ArrayType (StringType ()),
155+ FeatureType .INTEGER_ARRAY : ArrayType (IntegerType ()),
156+ FeatureType .SHORT_ARRAY : ArrayType (ShortType ()),
157+ FeatureType .LONG_ARRAY : ArrayType (LongType ()),
158+ FeatureType .FLOAT_ARRAY : ArrayType (FloatType ()),
159+ FeatureType .DOUBLE_ARRAY : ArrayType (DoubleType ()),
160+ FeatureType .BINARY_ARRAY : ArrayType (BinaryType ()),
161+ FeatureType .DATE_ARRAY : ArrayType (DateType ()),
162+ FeatureType .BOOLEAN_ARRAY : ArrayType (BooleanType ()),
163+ FeatureType .TIMESTAMP_ARRAY : ArrayType (TimestampType ()),
164+ FeatureType .STRING_STRING_MAP : MapType (StringType (), StringType ()),
165+ FeatureType .STRING_INTEGER_MAP : MapType (StringType (), IntegerType ()),
166+ FeatureType .STRING_SHORT_MAP : MapType (StringType (), ShortType ()),
167+ FeatureType .STRING_LONG_MAP : MapType (StringType (), LongType ()),
168+ FeatureType .STRING_FLOAT_MAP : MapType (StringType (), FloatType ()),
169+ FeatureType .STRING_DOUBLE_MAP : MapType (StringType (), DoubleType ()),
170+ FeatureType .STRING_DATE_MAP : MapType (StringType (), DateType ()),
171+ FeatureType .STRING_TIMESTAMP_MAP : MapType (StringType (), TimestampType ()),
172+ FeatureType .STRING_BOOLEAN_MAP : MapType (StringType (), BooleanType ()),
173+ FeatureType .BYTE : ByteType (),
84174 }
85-
86- return spark_types .get (feature_type .lower (), None )
175+ if feature_type_in in spark_types :
176+ return spark_types .get (feature_type_in )
177+ else :
178+ return "UNKNOWN"
87179
88180
89181def get_raw_data_source_schema (raw_feature_details : List [dict ]):
@@ -94,6 +186,7 @@ def get_raw_data_source_schema(raw_feature_details: List[dict]):
94186
95187 Returns:
96188 StructType: Spark schema.
189+ :param raw_feature_details:
97190 """
98191 # Initialize the schema
99192 features_schema = StructType ()
@@ -113,3 +206,40 @@ def get_raw_data_source_schema(raw_feature_details: List[dict]):
113206 features_schema .add (feature_name , feature_type , is_nullable )
114207
115208 return features_schema
209+
210+
211+ def map_feature_type_to_pandas (feature_type ):
212+ feature_type_in = FeatureType (feature_type )
213+ supported_feature_type = {
214+ FeatureType .STRING : str ,
215+ FeatureType .LONG : "int64" ,
216+ FeatureType .DOUBLE : "float64" ,
217+ FeatureType .TIMESTAMP : "datetime64[ns]" ,
218+ FeatureType .BOOLEAN : "bool" ,
219+ FeatureType .FLOAT : "float32" ,
220+ FeatureType .INTEGER : "int32" ,
221+ FeatureType .DECIMAL : "object" ,
222+ FeatureType .DATE : "object" ,
223+ }
224+ if feature_type_in in supported_feature_type :
225+ return supported_feature_type .get (feature_type_in )
226+ else :
227+ raise TypeError (f"Feature Type { feature_type } is not supported for pandas" )
228+
229+
230+ def convert_pandas_datatype_with_schema (
231+ raw_feature_details : List [dict ], input_df : pd .DataFrame
232+ ):
233+ feature_detail_map = {}
234+ for feature_details in raw_feature_details :
235+ feature_detail_map [feature_details .get ("name" )] = feature_details
236+ for column in input_df .columns :
237+ if column in feature_detail_map .keys ():
238+ feature_details = feature_detail_map [column ]
239+ feature_type = feature_details .get ("featureType" )
240+ pandas_type = map_feature_type_to_pandas (feature_type )
241+ input_df [column ] = (
242+ input_df [column ]
243+ .astype (pandas_type )
244+ .where (pd .notnull (input_df [column ]), None )
245+ )
0 commit comments