@@ -32,8 +32,14 @@ def __init__(self, dataset_info, name="historical_data"):
3232 self .dataset_info = dataset_info
3333 self .target_category_columns = dataset_info .target_category_columns
3434 self .target_column_name = dataset_info .target_column
35- self .dt_column_name = dataset_info .datetime_column .name if dataset_info .datetime_column else None
36- self .dt_column_format = dataset_info .datetime_column .format if dataset_info .datetime_column else None
35+ self .dt_column_name = (
36+ dataset_info .datetime_column .name if dataset_info .datetime_column else None
37+ )
38+ self .dt_column_format = (
39+ dataset_info .datetime_column .format
40+ if dataset_info .datetime_column
41+ else None
42+ )
3743 self .preprocessing = dataset_info .preprocessing
3844
3945 def run (self , data ):
@@ -58,6 +64,7 @@ def run(self, data):
5864 if self .dt_column_name :
5965 clean_df = self ._format_datetime_col (clean_df )
6066 clean_df = self ._set_multi_index (clean_df )
67+ clean_df = self ._fill_na (clean_df ) if not self .dt_column_name else clean_df
6168
6269 if self .preprocessing and self .preprocessing .enabled :
6370 if self .name == "historical_data" :
@@ -67,7 +74,9 @@ def run(self, data):
6774 except Exception as e :
6875 logger .debug (f"Missing value imputation failed with { e .args } " )
6976 else :
70- logger .info ("Skipping missing value imputation because it is disabled" )
77+ logger .info (
78+ "Skipping missing value imputation because it is disabled"
79+ )
7180 if self .preprocessing .steps .outlier_treatment :
7281 try :
7382 clean_df = self ._outlier_treatment (clean_df )
@@ -78,7 +87,9 @@ def run(self, data):
7887 elif self .name == "additional_data" :
7988 clean_df = self ._missing_value_imputation_add (clean_df )
8089 else :
81- logger .info ("Skipping all preprocessing steps because preprocessing is disabled" )
90+ logger .info (
91+ "Skipping all preprocessing steps because preprocessing is disabled"
92+ )
8293 return clean_df
8394
8495 def _remove_trailing_whitespace (self , df ):
@@ -96,7 +107,14 @@ def _set_series_id_column(self, df):
96107 merged_values = df [DataColumns .Series ].unique ().tolist ()
97108 if self .target_category_columns :
98109 for value in merged_values :
99- self ._target_category_columns_map [value ] = df [df [DataColumns .Series ] == value ][self .target_category_columns ].drop_duplicates ().iloc [0 ].to_dict ()
110+ self ._target_category_columns_map [value ] = (
111+ df [df [DataColumns .Series ] == value ][
112+ self .target_category_columns
113+ ]
114+ .drop_duplicates ()
115+ .iloc [0 ]
116+ .to_dict ()
117+ )
100118
101119 if self .target_category_columns != [DataColumns .Series ]:
102120 df = df .drop (self .target_category_columns , axis = 1 )
@@ -127,7 +145,9 @@ def _set_multi_index(self, df):
127145 """
128146 if self .dt_column_name :
129147 df = df .set_index ([self .dt_column_name , DataColumns .Series ])
130- return df .sort_values ([self .dt_column_name , DataColumns .Series ], ascending = True )
148+ return df .sort_values (
149+ [self .dt_column_name , DataColumns .Series ], ascending = True
150+ )
131151 return df .set_index ([df .index , DataColumns .Series ])
132152
133153 def _missing_value_imputation_hist (self , df ):
@@ -225,5 +245,10 @@ def _check_historical_dataset(self, df):
225245
226246 }
227247 """
248+
228249 def get_target_category_columns_map (self ):
229- return self ._target_category_columns_map
250+ return self ._target_category_columns_map
251+
252+ def _fill_na (self , df : pd .DataFrame , na_value = 0 ) -> pd .DataFrame :
253+ """Fill nans in dataframe"""
254+ return df .fillna (value = na_value )
0 commit comments