|
1 | 1 | import numpy as np |
2 | 2 | import pandas as pd |
3 | | -from pandas import Series |
| 3 | +from pandas import DataFrame, Series |
4 | 4 | from pipeline.pipeline import IngestionPipeline |
5 | 5 | from tqdm import tqdm |
6 | 6 |
|
7 | 7 |
|
8 | 8 | class ColumnTypeInterpreter: |
| 9 | + def __int__(self): |
| 10 | + self.df: DataFrame = None |
| 11 | + |
9 | 12 | def apply(self, pipeline: IngestionPipeline): |
10 | 13 | """ |
11 | 14 | This method is responsible for inferring the |
@@ -38,6 +41,8 @@ def analyze_column(self, column: Series): |
38 | 41 |
|
39 | 42 | if self.categorical_test(values): |
40 | 43 | column_type = "categorical" |
| 44 | + elif self.numeric_test(types) and self.id_check(types, values): |
| 45 | + column_type = "id" |
41 | 46 | elif self.numeric_test(types): |
42 | 47 | column_type = "numeric" |
43 | 48 |
|
@@ -77,17 +82,45 @@ def numeric_test(types: list): |
77 | 82 | :param types: list of type objects |
78 | 83 | :return: True if column is numeric, False otherwise |
79 | 84 | """ |
80 | | - return all([item == float or item == int for item in set(types)]) |
| 85 | + return all( |
| 86 | + [item == float or item == int for item in set(types) if item is not None] |
| 87 | + ) |
81 | 88 |
|
82 | 89 | @staticmethod |
83 | 90 | def string_test(types: set): |
84 | 91 | raise NotImplementedError |
85 | 92 |
|
86 | 93 | def datetime_check(self, column: Series): |
87 | | - if self.df[column.name].dtype.type == np.datetime64: |
88 | | - return True |
89 | | - try: |
90 | | - self.df[column.name] = pd.to_datetime(self.df[column.name]) |
| 94 | + """ |
| 95 | +
|
| 96 | + :param column: |
| 97 | + :return: |
| 98 | + """ |
| 99 | + col_name = str(column.name) |
| 100 | + |
| 101 | + # if type of column is actually datetime |
| 102 | + if self.df[col_name].dtype.type == np.datetime64: |
91 | 103 | return True |
92 | | - except Exception as e: # noqa |
93 | | - return False |
| 104 | + |
| 105 | + # if date or time is in column name and can be cast as date |
| 106 | + if "date" in col_name.lower() or "time" in col_name.lower(): |
| 107 | + try: |
| 108 | + self.df[col_name] = pd.to_datetime(self.df[col_name]) |
| 109 | + return True |
| 110 | + except Exception as e: # noqa |
| 111 | + pass |
| 112 | + |
| 113 | + # if format of values look like dates |
| 114 | + |
| 115 | + return False |
| 116 | + |
| 117 | + def id_check(self, types, values): |
| 118 | + """ |
| 119 | +
|
| 120 | + :param types: |
| 121 | + :param values: |
| 122 | + :return: |
| 123 | + """ |
| 124 | + return all([item == int for item in set(types) if item is not None]) and len( |
| 125 | + set(values) |
| 126 | + ) == len(self.df) |
0 commit comments