11"""
22This module implements the visualization for the plot_diff function.
33""" # pylint: disable=too-many-lines
4+ from turtle import color
45from typing import Any , Dict , List , Tuple , Optional
5-
6+ from sklearn . preprocessing import MinMaxScaler
67import math
78import numpy as np
89import pandas as pd
10+ import dask .array as da
11+ import matplotlib .pyplot as plt
912from bokeh .models import (
1013 HoverTool ,
1114 Panel ,
1215 FactorRange ,
1316)
14- from bokeh .plotting import Figure , figure
17+ from bokeh .plotting import Figure , figure , show
1518from bokeh .transform import dodge
1619from bokeh .layouts import row
20+ from bokeh .models .ranges import Range1d
21+ from bokeh .models import LinearAxis
1722
1823from ..configs import Config
1924from ..dtypes import Continuous , DateTime , Nominal , is_dtype
@@ -78,6 +83,8 @@ def bar_viz(
7883 orig : List [str ],
7984 df_labels : List [str ],
8085 baseline : int ,
86+ target : Optional [str ] = None ,
87+ df_list : Optional [List [pd .DataFrame ]] = None
8188) -> Figure :
8289 """
8390 Render a bar chart
@@ -94,6 +101,12 @@ def bar_viz(
94101 ("Source" , "@orig" ),
95102 ]
96103
104+ col1_min = df [0 ][col ].min ()
105+ col2_min = df [1 ][col ].min ()
106+ col1_max = df [0 ][col ].max ()
107+ col2_max = df [1 ][col ].max ()
108+ y_inc = 0.05
109+
97110 if show_yticks :
98111 if len (df [baseline ]) > 10 :
99112 plot_width = 28 * len (df [baseline ])
@@ -106,12 +119,15 @@ def bar_viz(
106119 tools = "hover" ,
107120 x_range = list (df [baseline ].index ),
108121 y_axis_type = yscale ,
122+ y_range = (min (col1_min , col2_min ) * (1 - y_inc ), max (col1_max , col2_max ) * (1 + y_inc ))
109123 )
110-
124+ row_names = None
111125 offset = np .linspace (- 0.08 * len (df ), 0.08 * len (df ), len (df )) if len (df ) > 1 else [0 ]
112126 for i , (nrow , data ) in enumerate (zip (nrows , df )):
113127 data ["pct" ] = data [col ] / nrow * 100
114128 data .index = [str (val ) for val in data .index ]
129+ if row_names is None :
130+ row_names = data .index
115131 data ["orig" ] = orig [i ]
116132
117133 fig .vbar (
@@ -126,7 +142,6 @@ def bar_viz(
126142 tweak_figure (fig , "bar" , show_yticks )
127143
128144 fig .yaxis .axis_label = "Count"
129-
130145 x_axis_label = ""
131146 if ttl_grps > len (df [baseline ]):
132147 x_axis_label += f"Top { len (df [baseline ])} of { ttl_grps } { col } "
@@ -142,6 +157,21 @@ def bar_viz(
142157
143158 if show_yticks and yscale == "linear" :
144159 _format_axis (fig , 0 , df [baseline ].max (), "y" )
160+
161+ df1 , df2 = df_list [0 ], df_list [1 ]
162+ if target != col and target and col in df1 .columns and col in df2 .columns :
163+ col1 , col2 = df_list [0 ][col ], df_list [1 ][col ]
164+ row_avgs_1 = []
165+ row_avgs_2 = []
166+ for names in row_names :
167+ row_avgs_1 .append (df_list [0 ][target ][col1 == names ].mean ())
168+ row_avgs_2 .append (df_list [1 ][target ][col2 == names ].mean ())
169+
170+ row_avgs_1 = [0 if math .isnan (x ) else x for x in row_avgs_1 ]
171+ row_avgs_2 = [0 if math .isnan (x ) else x for x in row_avgs_2 ]
172+ fig .extra_y_ranges = {"Averages" : Range1d (start = min (row_avgs_1 + row_avgs_2 ) * (1 - y_inc ), end = max (row_avgs_1 + row_avgs_2 ) * (1 + y_inc ))}
173+ fig .multi_line ([row_names , row_names ], [row_avgs_1 , row_avgs_2 ], color = ['navy' , 'firebrick' ], y_range_name = "Averages" , line_width = 4 )
174+ fig .add_layout (LinearAxis (y_range_name = "Averages" ), 'right' )
145175 return fig
146176
147177
@@ -155,28 +185,56 @@ def hist_viz(
155185 show_yticks : bool ,
156186 df_labels : List [str ],
157187 orig : Optional [List [str ]] = None ,
188+ target : Optional [str ] = None ,
189+ df_list : Optional [List [pd .DataFrame ]] = None
158190) -> Figure :
159191 """
160192 Render a histogram
161193 """
162194 # pylint: disable=too-many-arguments,too-many-locals
163-
164195 tooltips = [
165196 ("Bin" , "@intvl" ),
166197 ("Frequency" , "@freq" ),
167198 ("Percent" , "@pct{0.2f}%" ),
168199 ("Source" , "@orig" ),
169200 ]
201+ df1 , df2 = df_list [0 ], df_list [1 ]
202+ y_inc = 0.05
203+ tooltips = [
204+ ("Bin" , "@intvl" ),
205+ ("Frequency" , "@freq" ),
206+ ("Percent" , "@pct{0.2f}%" ),
207+ ("Source" , "@orig" ),
208+ ]
209+ fig = None
210+
211+ y_start , y_end = None , None
212+ counts_list = []
213+ if target and target != col and col in df1 .columns and col in df2 .columns :
214+ for hst in hist :
215+ counts , bins = hst
216+ counts_list .append (counts )
217+
218+ counts_min_1 = min (counts_list [0 ])
219+ counts_min_2 = min (counts_list [1 ])
220+
221+ counts_max_1 = max (counts_list [0 ])
222+ counts_max_2 = max (counts_list [1 ])
223+
224+ y_start , y_end = min (counts_min_1 , counts_min_2 ), max (counts_max_1 , counts_max_2 )
225+
226+
170227 fig = Figure (
171228 plot_height = plot_height ,
172229 plot_width = plot_width ,
173230 title = col ,
174231 toolbar_location = None ,
175- y_axis_type = yscale ,
232+ y_axis_type = yscale
176233 )
177-
234+ bins_list = []
178235 for i , hst in enumerate (hist ):
179236 counts , bins = hst
237+ bins_list .append (bins )
180238 if sum (counts ) == 0 :
181239 fig .rect (x = 0 , y = 0 , width = 0 , height = 0 )
182240 continue
@@ -192,16 +250,34 @@ def hist_viz(
192250 }
193251 )
194252 bottom = 0 if yscale == "linear" or df .empty else counts .min () / 2
195- fig .quad (
196- source = df ,
197- left = "left" ,
198- right = "right" ,
199- bottom = bottom ,
200- alpha = 0.5 ,
201- top = "freq" ,
202- fill_color = CATEGORY10 [i ],
203- line_color = CATEGORY10 [i ],
204- )
253+ if y_start is not None and y_end is not None :
254+ # fig.y_range = (y_start * (1 - y_inc), y_end * (1 + y_inc))
255+ fig .extra_y_ranges = {"Counts" : Range1d (start = y_start * (1 - y_inc ), end = y_end * (1 + y_inc ))}
256+ fig .quad (
257+ source = df ,
258+ left = "left" ,
259+ right = "right" ,
260+ bottom = bottom ,
261+ alpha = 0.5 ,
262+ top = "freq" ,
263+ fill_color = CATEGORY10 [i ],
264+ line_color = CATEGORY10 [i ],
265+ y_range_name = "Counts"
266+ )
267+ else :
268+ fig .quad (
269+ source = df ,
270+ left = "left" ,
271+ right = "right" ,
272+ bottom = bottom ,
273+ alpha = 0.5 ,
274+ top = "freq" ,
275+ fill_color = CATEGORY10 [i ],
276+ line_color = CATEGORY10 [i ]
277+ )
278+ # if col == 'LotFrontage':
279+ # breakpoint()
280+
205281 hover = HoverTool (tooltips = tooltips , attachment = "vertical" , mode = "vline" )
206282 fig .add_tools (hover )
207283
@@ -224,6 +300,34 @@ def hist_viz(
224300 fig .xaxis .axis_label = x_axis_label
225301 fig .xaxis .axis_label_standoff = 0
226302
303+ if target and target != col and col in df1 .columns and col in df2 .columns :
304+ col1 , col2 = df1 [col ], df2 [col ]
305+ source1 , source2 = col1 , col2
306+ col1 = col1 [~ np .isnan (col1 )]
307+ col2 = col2 [~ np .isnan (col2 )]
308+ num_bins1 = len (bins_list [0 ]) - 1
309+ num_bins2 = len (bins_list [1 ]) - 1
310+ bins_1 , bins_2 = bins_list [0 ], bins_list [1 ]
311+
312+ df1_source_bins_series = pd .cut (source1 , bins = bins_1 , labels = False )
313+ df1_bin_averages = [None ] * num_bins1
314+
315+ df2_source_bins_series = pd .cut (source2 , bins = bins_2 , labels = False )
316+ df2_bin_averages = [None ] * num_bins2
317+
318+ for b in range (num_bins1 ):
319+ df1_bin_averages [b ] = df1 [target ][df1_source_bins_series == b ].mean ()
320+ for b in range (num_bins2 ):
321+ df2_bin_averages [b ] = df2 [target ][df2_source_bins_series == b ].mean ()
322+
323+ df1_bin_averages = [0 if math .isnan (x ) else x for x in df1_bin_averages ]
324+ df2_bin_averages = [0 if math .isnan (x ) else x for x in df2_bin_averages ]
325+ max_range = max (df1_bin_averages + df2_bin_averages )
326+ min_range = min (df1_bin_averages + df2_bin_averages )
327+
328+ fig .extra_y_ranges ['Averages' ] = Range1d (start = min_range * (1 - y_inc ), end = max_range * (1 + y_inc ))
329+ fig .multi_line ([bins_1 , bins_2 ], [df1_bin_averages , df2_bin_averages ], color = ['navy' , 'firebrick' ], y_range_name = "Averages" , line_width = 4 )
330+ fig .add_layout (LinearAxis (y_range_name = "Averages" , axis_label = 'Bin Averages' ), 'right' )
227331 return fig
228332
229333
@@ -610,6 +714,9 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
610714 nrows = itmdt ["stats" ]["nrows" ]
611715 titles : List [str ] = []
612716
717+ df_list = itmdt .df_list
718+ target = itmdt .target
719+
613720 for col , dtp , data , orig in itmdt ["data" ]:
614721 fig = None
615722 if is_dtype (dtp , Nominal ()):
@@ -626,6 +733,8 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
626733 orig ,
627734 df_labels ,
628735 baseline if len (df ) > 1 else 0 ,
736+ target ,
737+ df_list
629738 )
630739 elif is_dtype (dtp , Continuous ()):
631740 if cfg .diff .density :
@@ -643,6 +752,8 @@ def render_comparison_grid(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
643752 False ,
644753 df_labels ,
645754 orig ,
755+ target ,
756+ df_list
646757 )
647758 elif is_dtype (dtp , DateTime ()):
648759 df , timeunit = data
@@ -760,7 +871,6 @@ def render_diff(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
760871 cfg
761872 Config instance
762873 """
763-
764874 if itmdt .visual_type == "comparison_grid" :
765875 visual_elem = render_comparison_grid (itmdt , cfg )
766876 if itmdt .visual_type == "comparison_continuous" :
0 commit comments