@@ -84,7 +84,7 @@ def bar_viz(
8484 df_labels : List [str ],
8585 baseline : int ,
8686 target : Optional [str ] = None ,
87- df_list : Optional [List [pd .DataFrame ]] = None
87+ df_list : Optional [List [pd .DataFrame ]] = None ,
8888) -> Figure :
8989 """
9090 Render a bar chart
@@ -119,7 +119,7 @@ def bar_viz(
119119 tools = "hover" ,
120120 x_range = list (df [baseline ].index ),
121121 y_axis_type = yscale ,
122- y_range = (min (col1_min , col2_min ) * (1 - y_inc ), max (col1_max , col2_max ) * (1 + y_inc ))
122+ y_range = (min (col1_min , col2_min ) * (1 - y_inc ), max (col1_max , col2_max ) * (1 + y_inc )),
123123 )
124124 row_names = None
125125 offset = np .linspace (- 0.08 * len (df ), 0.08 * len (df ), len (df )) if len (df ) > 1 else [0 ]
@@ -157,7 +157,7 @@ def bar_viz(
157157
158158 if show_yticks and yscale == "linear" :
159159 _format_axis (fig , 0 , df [baseline ].max (), "y" )
160-
160+
161161 df1 , df2 = df_list [0 ], df_list [1 ]
162162 if target != col and target and col in df1 .columns and col in df2 .columns :
163163 col1 , col2 = df_list [0 ][col ], df_list [1 ][col ]
@@ -166,12 +166,23 @@ def bar_viz(
166166 for names in row_names :
167167 row_avgs_1 .append (df_list [0 ][target ][col1 == names ].mean ())
168168 row_avgs_2 .append (df_list [1 ][target ][col2 == names ].mean ())
169-
169+
170170 row_avgs_1 = [0 if math .isnan (x ) else x for x in row_avgs_1 ]
171171 row_avgs_2 = [0 if math .isnan (x ) else x for x in row_avgs_2 ]
172- fig .extra_y_ranges = {"Averages" : Range1d (start = min (row_avgs_1 + row_avgs_2 ) * (1 - y_inc ), end = max (row_avgs_1 + row_avgs_2 ) * (1 + y_inc ))}
173- fig .multi_line ([row_names , row_names ], [row_avgs_1 , row_avgs_2 ], color = ['navy' , 'firebrick' ], y_range_name = "Averages" , line_width = 4 )
174- fig .add_layout (LinearAxis (y_range_name = "Averages" ), 'right' )
172+ fig .extra_y_ranges = {
173+ "Averages" : Range1d (
174+ start = min (row_avgs_1 + row_avgs_2 ) * (1 - y_inc ),
175+ end = max (row_avgs_1 + row_avgs_2 ) * (1 + y_inc ),
176+ )
177+ }
178+ fig .multi_line (
179+ [row_names , row_names ],
180+ [row_avgs_1 , row_avgs_2 ],
181+ color = ["navy" , "firebrick" ],
182+ y_range_name = "Averages" ,
183+ line_width = 4 ,
184+ )
185+ fig .add_layout (LinearAxis (y_range_name = "Averages" ), "right" )
175186 return fig
176187
177188
@@ -186,7 +197,7 @@ def hist_viz(
186197 df_labels : List [str ],
187198 orig : Optional [List [str ]] = None ,
188199 target : Optional [str ] = None ,
189- df_list : Optional [List [pd .DataFrame ]] = None
200+ df_list : Optional [List [pd .DataFrame ]] = None ,
190201) -> Figure :
191202 """
192203 Render a histogram
@@ -222,14 +233,13 @@ def hist_viz(
222233 counts_max_2 = max (counts_list [1 ])
223234
224235 y_start , y_end = min (counts_min_1 , counts_min_2 ), max (counts_max_1 , counts_max_2 )
225-
226236
227237 fig = Figure (
228238 plot_height = plot_height ,
229239 plot_width = plot_width ,
230240 title = col ,
231241 toolbar_location = None ,
232- y_axis_type = yscale
242+ y_axis_type = yscale ,
233243 )
234244 bins_list = []
235245 for i , hst in enumerate (hist ):
@@ -252,7 +262,9 @@ def hist_viz(
252262 bottom = 0 if yscale == "linear" or df .empty else counts .min () / 2
253263 if y_start is not None and y_end is not None :
254264 # fig.y_range = (y_start * (1 - y_inc), y_end * (1 + y_inc))
255- fig .extra_y_ranges = {"Counts" : Range1d (start = y_start * (1 - y_inc ), end = y_end * (1 + y_inc ))}
265+ fig .extra_y_ranges = {
266+ "Counts" : Range1d (start = y_start * (1 - y_inc ), end = y_end * (1 + y_inc ))
267+ }
256268 fig .quad (
257269 source = df ,
258270 left = "left" ,
@@ -262,7 +274,7 @@ def hist_viz(
262274 top = "freq" ,
263275 fill_color = CATEGORY10 [i ],
264276 line_color = CATEGORY10 [i ],
265- y_range_name = "Counts"
277+ y_range_name = "Counts" ,
266278 )
267279 else :
268280 fig .quad (
@@ -273,11 +285,11 @@ def hist_viz(
273285 alpha = 0.5 ,
274286 top = "freq" ,
275287 fill_color = CATEGORY10 [i ],
276- line_color = CATEGORY10 [i ]
288+ line_color = CATEGORY10 [i ],
277289 )
278290 # if col == 'LotFrontage':
279- # breakpoint()
280-
291+ # breakpoint()
292+
281293 hover = HoverTool (tooltips = tooltips , attachment = "vertical" , mode = "vline" )
282294 fig .add_tools (hover )
283295
@@ -325,9 +337,17 @@ def hist_viz(
325337 max_range = max (df1_bin_averages + df2_bin_averages )
326338 min_range = min (df1_bin_averages + df2_bin_averages )
327339
328- fig .extra_y_ranges ['Averages' ] = Range1d (start = min_range * (1 - y_inc ), end = max_range * (1 + y_inc ))
329- fig .multi_line ([bins_1 , bins_2 ], [df1_bin_averages , df2_bin_averages ], color = ['navy' , 'firebrick' ], y_range_name = "Averages" , line_width = 4 )
330- fig .add_layout (LinearAxis (y_range_name = "Averages" , axis_label = 'Bin Averages' ), 'right' )
340+ fig .extra_y_ranges ["Averages" ] = Range1d (
341+ start = min_range * (1 - y_inc ), end = max_range * (1 + y_inc )
342+ )
343+ fig .multi_line (
344+ [bins_1 , bins_2 ],
345+ [df1_bin_averages , df2_bin_averages ],
346+ color = ["navy" , "firebrick" ],
347+ y_range_name = "Averages" ,
348+ line_width = 4 ,
349+ )
350+ fig .add_layout (LinearAxis (y_range_name = "Averages" , axis_label = "Bin Averages" ), "right" )
331351 return fig
332352
333353
@@ -678,7 +698,7 @@ def format_num_stats(data: Dict[str, List[Any]]) -> Dict[str, Dict[str, List[Any
678698 descriptive = {
679699 "Mean" : data ["mean" ],
680700 "Standard Deviation" : data ["std" ],
681- "Variance" : [std ** 2 for std in data ["std" ]],
701+ "Variance" : [std ** 2 for std in data ["std" ]],
682702 "Sum" : [mean * npres for mean , npres in zip (data ["mean" ], data ["npres" ])],
683703 "Skewness" : [float (skew ) for skew in data ["skew" ]],
684704 "Kurtosis" : [float (kurt ) for kurt in data ["kurt" ]],
@@ -734,7 +754,7 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
734754 df_labels ,
735755 baseline if len (df ) > 1 else 0 ,
736756 target ,
737- df_list
757+ df_list ,
738758 )
739759 elif is_dtype (dtp , Continuous ()):
740760 if cfg .diff .density :
@@ -753,7 +773,7 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
753773 df_labels ,
754774 orig ,
755775 target ,
756- df_list
776+ df_list ,
757777 )
758778 elif is_dtype (dtp , DateTime ()):
759779 df , timeunit = data
0 commit comments