Skip to content

Commit 523d190

Browse files
Update interpreter_step.py
1 parent e877583 commit 523d190

File tree

1 file changed

+41
-8
lines changed

1 file changed

+41
-8
lines changed

python/src/lazylearn/ingestion/ingestion_pipeline_steps/interpreter_step.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import numpy as np
22
import pandas as pd
3-
from pandas import Series
3+
from pandas import DataFrame, Series
44
from pipeline.pipeline import IngestionPipeline
55
from tqdm import tqdm
66

77

88
class ColumnTypeInterpreter:
9+
def __int__(self):
10+
self.df: DataFrame = None
11+
912
def apply(self, pipeline: IngestionPipeline):
1013
"""
1114
This method is responsible for inferring the
@@ -38,6 +41,8 @@ def analyze_column(self, column: Series):
3841

3942
if self.categorical_test(values):
4043
column_type = "categorical"
44+
elif self.numeric_test(types) and self.id_check(types, values):
45+
column_type = "id"
4146
elif self.numeric_test(types):
4247
column_type = "numeric"
4348

@@ -77,17 +82,45 @@ def numeric_test(types: list):
7782
:param types: list of type objects
7883
:return: True if column is numeric, False otherwise
7984
"""
80-
return all([item == float or item == int for item in set(types)])
85+
return all(
86+
[item == float or item == int for item in set(types) if item is not None]
87+
)
8188

8289
@staticmethod
8390
def string_test(types: set):
8491
raise NotImplementedError
8592

8693
def datetime_check(self, column: Series):
87-
if self.df[column.name].dtype.type == np.datetime64:
88-
return True
89-
try:
90-
self.df[column.name] = pd.to_datetime(self.df[column.name])
94+
"""
95+
96+
:param column:
97+
:return:
98+
"""
99+
col_name = str(column.name)
100+
101+
# if type of column is actually datetime
102+
if self.df[col_name].dtype.type == np.datetime64:
91103
return True
92-
except Exception as e: # noqa
93-
return False
104+
105+
# if date or time is in column name and can be cast as date
106+
if "date" in col_name.lower() or "time" in col_name.lower():
107+
try:
108+
self.df[col_name] = pd.to_datetime(self.df[col_name])
109+
return True
110+
except Exception as e: # noqa
111+
pass
112+
113+
# if format of values look like dates
114+
115+
return False
116+
117+
def id_check(self, types, values):
118+
"""
119+
120+
:param types:
121+
:param values:
122+
:return:
123+
"""
124+
return all([item == int for item in set(types) if item is not None]) and len(
125+
set(values)
126+
) == len(self.df)

0 commit comments

Comments
 (0)