2222import xarray as xr
2323from patsy import build_design_matrices , dmatrices
2424from sklearn .linear_model import LinearRegression as sk_lin_reg
25+ from matplotlib .lines import Line2D
26+
2527
2628from causalpy .data_validation import (
2729 PrePostFitDataValidator ,
@@ -1658,7 +1660,7 @@ def plot_ATE(self, idata=None, method=None, prop_draws=100, ate_draws=300):
16581660
16591661 def plot_weights (bins , top0 , top1 , ax , color = "population" ):
16601662 colors_dict = {
1661- "population" : ["red " , "blue " , 0.9 ],
1663+ "population" : ["lightcoral " , "skyblue " , 0.6 ],
16621664 "pseudo_population" : ["purple" , "purple" , 0.1 ],
16631665 }
16641666
@@ -1722,7 +1724,17 @@ def make_hists(idata, i, axs, method=method):
17221724 axs [0 ].set_title (
17231725 "Draws from the Posterior \n Propensity Scores Distribution" , fontsize = 20
17241726 )
1725- axs [0 ].legend ()
1727+ custom_lines = [
1728+ Line2D ([0 ], [0 ], color = "skyblue" , lw = 2 ),
1729+ Line2D ([0 ], [0 ], color = "lightcoral" , lw = 2 ),
1730+ Line2D ([0 ], [0 ], color = "purple" , lw = 2 ),
1731+ Line2D ([0 ], [0 ], color = "black" , lw = 2 , linestyle = "--" ),
1732+ ]
1733+
1734+ axs [0 ].legend (
1735+ custom_lines ,
1736+ ["Control PS" , "Treatment PS" , "Weighted Pseudo Population" , "Extreme PS" ],
1737+ )
17261738
17271739 [make_hists (idata , i , axs ) for i in range (prop_draws )]
17281740 ate_df = pd .DataFrame (
@@ -1734,11 +1746,16 @@ def make_hists(idata, i, axs, method=method):
17341746 label = "E(Y(1))" ,
17351747 ec = "black" ,
17361748 bins = 10 ,
1737- alpha = 0.8 ,
1738- color = "blue " ,
1749+ alpha = 0.6 ,
1750+ color = "skyblue " ,
17391751 )
17401752 axs [1 ].hist (
1741- ate_df ["Y(0)" ], label = "E(Y(0))" , ec = "black" , bins = 10 , alpha = 0.8 , color = "red"
1753+ ate_df ["Y(0)" ],
1754+ label = "E(Y(0))" ,
1755+ ec = "black" ,
1756+ bins = 10 ,
1757+ alpha = 0.6 ,
1758+ color = "lightcoral" ,
17421759 )
17431760 axs [1 ].legend ()
17441761 axs [1 ].set_title (
@@ -1811,17 +1828,24 @@ def plot_balance_ecdf(self, covariate, idata=None, weighting_scheme=None):
18111828 self .weighted_percentile (X [t == 0 ][covariate ].values , w0 , p )
18121829 for p in np .linspace (0 , 1 , 1000 )
18131830 ]
1814- axs [0 ].plot (np .linspace (0 , 1 , 1000 ), raw_trt , color = "blue" , label = "Raw Treated" )
1815- axs [0 ].plot (np .linspace (0 , 1 , 1000 ), raw_ntrt , color = "red" , label = "Raw Control" )
1831+ axs [0 ].plot (
1832+ np .linspace (0 , 1 , 1000 ), raw_trt , color = "skyblue" , label = "Raw Treated"
1833+ )
1834+ axs [0 ].plot (
1835+ np .linspace (0 , 1 , 1000 ), raw_ntrt , color = "lightcoral" , label = "Raw Control"
1836+ )
18161837 axs [0 ].set_title (f"ECDF \n Raw: { covariate } " )
18171838 axs [1 ].set_title (
18181839 f"ECDF \n Weighted { weighting_scheme } adjustment for { covariate } "
18191840 )
18201841 axs [1 ].plot (
1821- np .linspace (0 , 1 , 1000 ), w_trt , color = "blue " , label = "Reweighted Treated"
1842+ np .linspace (0 , 1 , 1000 ), w_trt , color = "skyblue " , label = "Reweighted Treated"
18221843 )
18231844 axs [1 ].plot (
1824- np .linspace (0 , 1 , 1000 ), w_ntrt , color = "red" , label = "Reweighted Control"
1845+ np .linspace (0 , 1 , 1000 ),
1846+ w_ntrt ,
1847+ color = "lightcoral" ,
1848+ label = "Reweighted Control" ,
18251849 )
18261850 axs [1 ].set_xlabel ("Quantiles" )
18271851 axs [0 ].set_xlabel ("Quantiles" )
0 commit comments