@@ -1656,25 +1656,57 @@ def plot_ATE(self, idata=None, method=None, prop_draws=100, ate_draws=300):
16561656 if method is None :
16571657 method = self .weighting_scheme
16581658
1659- def plot_weights (bins , top0 , top1 , ax ):
1659+ def plot_weights (bins , top0 , top1 , ax , color = "population" ):
1660+ colors_dict = {
1661+ "population" : ["red" , "blue" , 0.9 ],
1662+ "pseudo_population" : ["purple" , "purple" , 0.1 ],
1663+ }
1664+
16601665 ax .axhline (0 , c = "gray" , linewidth = 1 )
16611666 bars0 = ax .bar (
1662- bins [:- 1 ] + 0.025 , top0 , width = 0.04 , facecolor = "red" , alpha = 0.3
1667+ bins [:- 1 ] + 0.025 ,
1668+ top0 ,
1669+ width = 0.04 ,
1670+ facecolor = colors_dict [color ][0 ],
1671+ alpha = colors_dict [color ][2 ],
16631672 )
16641673 bars1 = ax .bar (
1665- bins [:- 1 ] + 0.025 , - top1 , width = 0.04 , facecolor = "blue" , alpha = 0.3
1674+ bins [:- 1 ] + 0.025 ,
1675+ - top1 ,
1676+ width = 0.04 ,
1677+ facecolor = colors_dict [color ][1 ],
1678+ alpha = colors_dict [color ][2 ],
16661679 )
16671680
16681681 for bars in (bars0 , bars1 ):
16691682 for bar in bars :
16701683 bar .set_edgecolor ("black" )
16711684
1672- def make_hists (idata , i , axs ):
1685+ def make_hists (idata , i , axs , method = method ):
16731686 p_i = az .extract (idata )["p" ][:, i ].values
1687+ if method == "raw" :
1688+ weight0 = 1 / (1 - p_i [self .t .flatten () == 0 ])
1689+ weight1 = 1 / (p_i [self .t .flatten () == 1 ])
1690+ elif method == "overlap" :
1691+ t = self .t .flatten ()
1692+ weight1 = (1 - p_i [t == 1 ]) * t [t == 1 ]
1693+ weight0 = p_i [t == 0 ] * (1 - t [t == 0 ])
1694+ else :
1695+ t = self .t .flatten ()
1696+ p_of_t = np .mean (t )
1697+ weight1 = p_of_t / p_i [t == 1 ]
1698+ weight0 = (1 - p_of_t ) / (1 - p_i [t == 0 ])
16741699 bins = np .arange (0.025 , 0.99 , 0.005 )
16751700 top0 , _ = np .histogram (p_i [self .t .flatten () == 0 ], bins = bins )
16761701 top1 , _ = np .histogram (p_i [self .t .flatten () == 1 ], bins = bins )
16771702 plot_weights (bins , top0 , top1 , axs [0 ])
1703+ top0 , _ = np .histogram (
1704+ p_i [self .t .flatten () == 0 ], bins = bins , weights = weight0
1705+ )
1706+ top1 , _ = np .histogram (
1707+ p_i [self .t .flatten () == 1 ], bins = bins , weights = weight1
1708+ )
1709+ plot_weights (bins , top0 , top1 , axs [0 ], color = "pseudo_population" )
16781710
16791711 mosaic = """AAAAAA
16801712 BBBBCC"""
@@ -1690,6 +1722,7 @@ def make_hists(idata, i, axs):
16901722 axs [0 ].set_title (
16911723 "Draws from the Posterior \n Propensity Scores Distribution" , fontsize = 20
16921724 )
1725+ axs [0 ].legend ()
16931726
16941727 [make_hists (idata , i , axs ) for i in range (prop_draws )]
16951728 ate_df = pd .DataFrame (
0 commit comments