3030 RegressionKinkDataValidator ,
3131 PrePostNEGDDataValidator ,
3232 IVDataValidator ,
33- PropensityDataValidator
33+ PropensityDataValidator ,
3434)
3535from causalpy .plot_utils import plot_xY
3636from causalpy .utils import round_num
@@ -1491,7 +1491,7 @@ class InversePropensityWeighting(ExperimentalDesign, PropensityDataValidator):
14911491 :param outcome_variable
14921492 A string denoting the outcome variable in datq to be reweighted
14931493 :param weighting_scheme:
1494- A string denoting which weighting scheme to use among: 'raw', 'robust',
1494+ A string denoting which weighting scheme to use among: 'raw', 'robust',
14951495 'doubly robust'
14961496 :param model:
14971497 A PyMC model
@@ -1543,17 +1543,15 @@ def __init__(
15431543
15441544 COORDS = {"obs_ind" : list (range (self .X .shape [0 ])), "coeffs" : self .labels }
15451545 self .coords = COORDS
1546- self .model .fit (
1547- X = self .X , t = self .t , coords = COORDS
1548- )
1546+ self .model .fit (X = self .X , t = self .t , coords = COORDS )
15491547
15501548 def make_robust_adjustments (self , ps ):
15511549 X = pd .DataFrame (self .X , columns = self .labels )
1552- X ['ps' ] = ps
1550+ X ["ps" ] = ps
15531551 X [self .outcome_variable ] = self .y
15541552 t = self .t .flatten ()
15551553 p_of_t = np .mean (t )
1556- X ["i_ps" ] = np .where (t == 1 , (p_of_t / X ["ps" ]), (1 - p_of_t ) / (1 - X ["ps" ]))
1554+ X ["i_ps" ] = np .where (t == 1 , (p_of_t / X ["ps" ]), (1 - p_of_t ) / (1 - X ["ps" ]))
15571555 n_ntrt = X [t == 0 ].shape [0 ]
15581556 n_trt = X [t == 1 ].shape [0 ]
15591557 outcome_trt = X [t == 1 ][self .outcome_variable ]
@@ -1564,10 +1562,9 @@ def make_robust_adjustments(self, ps):
15641562 weighted_outcome0 = outcome_ntrt * i_propensity0
15651563 return weighted_outcome0 , weighted_outcome1 , n_ntrt , n_trt
15661564
1567-
15681565 def make_raw_adjustments (self , ps ):
15691566 X = pd .DataFrame (self .X , columns = self .labels )
1570- X ['ps' ] = ps
1567+ X ["ps" ] = ps
15711568 X [self .outcome_variable ] = self .y
15721569 t = self .t .flatten ()
15731570 X ["ps" ] = np .where (t , X ["ps" ], 1 - X ["ps" ])
@@ -1580,28 +1577,26 @@ def make_raw_adjustments(self, ps):
15801577 weighted_outcome1 = outcome_trt * i_propensity1
15811578 weighted_outcome0 = outcome_ntrt * i_propensity0
15821579 return weighted_outcome0 , weighted_outcome1 , n_ntrt , n_trt
1583-
15841580
15851581 def make_overlap_adjustments (self , ps ):
15861582 X = pd .DataFrame (self .X , columns = self .labels )
1587- X ['ps' ] = ps
1583+ X ["ps" ] = ps
15881584 X [self .outcome_variable ] = self .y
15891585 t = self .t .flatten ()
1590- X ["i_ps" ] = np .where (t , (1 - X ["ps" ])* t , X ["ps" ]* ( 1 - t ))
1591- n_ntrt = (1 - t [t == 0 ])* X [t == 0 ][' i_ps' ]
1592- n_trt = t [t == 1 ]* X [t == 1 ][' i_ps' ]
1586+ X ["i_ps" ] = np .where (t , (1 - X ["ps" ]) * t , X ["ps" ] * ( 1 - t ))
1587+ n_ntrt = (1 - t [t == 0 ]) * X [t == 0 ][" i_ps" ]
1588+ n_trt = t [t == 1 ] * X [t == 1 ][" i_ps" ]
15931589 outcome_trt = X [t == 1 ][self .outcome_variable ]
15941590 outcome_ntrt = X [t == 0 ][self .outcome_variable ]
15951591 i_propensity0 = X [t == 0 ]["i_ps" ]
15961592 i_propensity1 = X [t == 1 ]["i_ps" ]
1597- weighted_outcome1 = t [t == 1 ]* outcome_trt * i_propensity1
1598- weighted_outcome0 = (1 - t [t == 0 ])* outcome_ntrt * i_propensity0
1593+ weighted_outcome1 = t [t == 1 ] * outcome_trt * i_propensity1
1594+ weighted_outcome0 = (1 - t [t == 0 ]) * outcome_ntrt * i_propensity0
15991595 return weighted_outcome0 , weighted_outcome1 , n_ntrt , n_trt
16001596
1601-
16021597 def make_doubly_robust_adjustment (self , ps ):
16031598 X = pd .DataFrame (self .X , columns = self .labels )
1604- X ['ps' ] = ps
1599+ X ["ps" ] = ps
16051600 t = self .t .flatten ()
16061601 m0 = sk_lin_reg ().fit (X [t == 0 ].astype (float ), self .y [t == 0 ])
16071602 m1 = sk_lin_reg ().fit (X [t == 1 ].astype (float ), self .y [t == 1 ])
@@ -1611,50 +1606,72 @@ def make_doubly_robust_adjustment(self, ps):
16111606 weighted_outcome0 = (1 - t ) * (self .y - m0_pred ) / (1 - X ["ps" ]) + m0_pred
16121607 weighted_outcome1 = t * (self .y - m1_pred ) / X ["ps" ] + m1_pred
16131608 return weighted_outcome0 , weighted_outcome1 , None , None
1614-
1609+
16151610 def get_ate (self , i , idata , method = "doubly_robust" ):
16161611 ### Post processing the sample posterior distribution for propensity scores
16171612 ### One sample at a time.
16181613 ps = idata ["posterior" ]["p" ].stack (z = ("chain" , "draw" ))[:, i ].values
16191614 if method == "robust" :
1620- weighted_outcome_ntrt , weighted_outcome_trt , n_ntrt , n_trt = self .make_robust_adjustments (ps )
1615+ (
1616+ weighted_outcome_ntrt ,
1617+ weighted_outcome_trt ,
1618+ n_ntrt ,
1619+ n_trt ,
1620+ ) = self .make_robust_adjustments (ps )
16211621 ntrt = weighted_outcome_ntrt .sum () / n_ntrt
16221622 trt = weighted_outcome_trt .sum () / n_trt
16231623 elif method == "raw" :
1624- weighted_outcome_ntrt , weighted_outcome_trt , n_ntrt , n_trt = self .make_raw_adjustments (ps )
1624+ (
1625+ weighted_outcome_ntrt ,
1626+ weighted_outcome_trt ,
1627+ n_ntrt ,
1628+ n_trt ,
1629+ ) = self .make_raw_adjustments (ps )
16251630 ntrt = weighted_outcome_ntrt .sum () / n_ntrt
16261631 trt = weighted_outcome_trt .sum () / n_trt
16271632 elif method == "overlap" :
1628- weighted_outcome_ntrt , weighted_outcome_trt , n_ntrt , n_trt = self .make_overlap_adjustments (ps )
1629- ntrt = np .sum (weighted_outcome_ntrt ) / np .sum (n_ntrt )
1633+ (
1634+ weighted_outcome_ntrt ,
1635+ weighted_outcome_trt ,
1636+ n_ntrt ,
1637+ n_trt ,
1638+ ) = self .make_overlap_adjustments (ps )
1639+ ntrt = np .sum (weighted_outcome_ntrt ) / np .sum (n_ntrt )
16301640 trt = np .sum (weighted_outcome_trt ) / np .sum (n_trt )
16311641 else :
1632- weighted_outcome_ntrt , weighted_outcome_trt , n_ntrt , n_trt = self .make_doubly_robust_adjustment (
1633- ps
1634- )
1642+ (
1643+ weighted_outcome_ntrt ,
1644+ weighted_outcome_trt ,
1645+ n_ntrt ,
1646+ n_trt ,
1647+ ) = self .make_doubly_robust_adjustment (ps )
16351648 trt = np .mean (weighted_outcome_trt )
16361649 ntrt = np .mean (weighted_outcome_ntrt )
16371650 ate = trt - ntrt
16381651 return [ate , trt , ntrt ]
1639-
1652+
16401653 def plot_ATE (self , idata = None , method = None , prop_draws = 100 , ate_draws = 300 ):
16411654 if idata is None :
16421655 idata = self .idata
1643- if method is None :
1656+ if method is None :
16441657 method = self .weighting_scheme
1645-
1658+
16461659 def plot_weights (bins , top0 , top1 , ax ):
16471660 ax .axhline (0 , c = "gray" , linewidth = 1 )
1648- bars0 = ax .bar (bins [:- 1 ] + 0.025 , top0 , width = 0.04 , facecolor = "red" , alpha = 0.3 )
1649- bars1 = ax .bar (bins [:- 1 ] + 0.025 , - top1 , width = 0.04 , facecolor = "blue" , alpha = 0.3 )
1661+ bars0 = ax .bar (
1662+ bins [:- 1 ] + 0.025 , top0 , width = 0.04 , facecolor = "red" , alpha = 0.3
1663+ )
1664+ bars1 = ax .bar (
1665+ bins [:- 1 ] + 0.025 , - top1 , width = 0.04 , facecolor = "blue" , alpha = 0.3
1666+ )
16501667
16511668 for bars in (bars0 , bars1 ):
16521669 for bar in bars :
16531670 bar .set_edgecolor ("black" )
16541671
16551672 def make_hists (idata , i , axs ):
1656- p_i = az .extract (idata )['p' ][:, i ].values
1657- bins = np .arange (0.025 , .99 , 0.005 )
1673+ p_i = az .extract (idata )["p" ][:, i ].values
1674+ bins = np .arange (0.025 , 0 .99 , 0.005 )
16581675 top0 , _ = np .histogram (p_i [self .t .flatten () == 0 ], bins = bins )
16591676 top1 , _ = np .histogram (p_i [self .t .flatten () == 1 ], bins = bins )
16601677 plot_weights (bins , top0 , top1 , axs [0 ])
@@ -1664,68 +1681,114 @@ def make_hists(idata, i, axs):
16641681
16651682 fig , axs = plt .subplot_mosaic (mosaic , figsize = (20 , 13 ))
16661683 axs = [axs [k ] for k in axs .keys ()]
1667- axs [0 ].axvline (0.1 , linestyle = '--' , label = 'Low Extreme Propensity Scores' , color = 'black' )
1668- axs [0 ].axvline (0.9 , linestyle = '--' , label = 'Hi Extreme Propensity Scores' , color = 'black' )
1669- axs [0 ].set_title ("Draws from the Posterior \n Propensity Scores Distribution" , fontsize = 20 )
1670-
1671- [make_hists (idata , i , axs ) for i in range (prop_draws )];
1672- ate_df = pd .DataFrame ([self .get_ate (i , idata , method = method ) for i in range (ate_draws )], columns = ['ATE' , 'Y(1)' , 'Y(0)' ])
1673- axs [1 ].hist (ate_df ['Y(1)' ], label = 'E(Y(1))' , ec = 'black' , bins = 10 , alpha = 0.8 , color = 'blue' );
1674- axs [1 ].hist (ate_df ['Y(0)' ], label = 'E(Y(0))' , ec = 'black' , bins = 10 , alpha = 0.8 , color = 'red' );
1684+ axs [0 ].axvline (
1685+ 0.1 , linestyle = "--" , label = "Low Extreme Propensity Scores" , color = "black"
1686+ )
1687+ axs [0 ].axvline (
1688+ 0.9 , linestyle = "--" , label = "Hi Extreme Propensity Scores" , color = "black"
1689+ )
1690+ axs [0 ].set_title (
1691+ "Draws from the Posterior \n Propensity Scores Distribution" , fontsize = 20
1692+ )
1693+
1694+ [make_hists (idata , i , axs ) for i in range (prop_draws )]
1695+ ate_df = pd .DataFrame (
1696+ [self .get_ate (i , idata , method = method ) for i in range (ate_draws )],
1697+ columns = ["ATE" , "Y(1)" , "Y(0)" ],
1698+ )
1699+ axs [1 ].hist (
1700+ ate_df ["Y(1)" ],
1701+ label = "E(Y(1))" ,
1702+ ec = "black" ,
1703+ bins = 10 ,
1704+ alpha = 0.8 ,
1705+ color = "blue" ,
1706+ )
1707+ axs [1 ].hist (
1708+ ate_df ["Y(0)" ], label = "E(Y(0))" , ec = "black" , bins = 10 , alpha = 0.8 , color = "red"
1709+ )
16751710 axs [1 ].legend ()
1676- axs [1 ].set_title (f'The Outcomes \n Under the { method } re-weighting scheme' , fontsize = 20 )
1677- axs [2 ].hist (ate_df ['ATE' ], label = 'ATE' , ec = 'black' , bins = 10 , color = 'slateblue' , alpha = 0.6 );
1678- axs [2 ].axvline (ate_df ['ATE' ].mean (), label = 'E(ATE)' )
1711+ axs [1 ].set_title (
1712+ f"The Outcomes \n Under the { method } re-weighting scheme" , fontsize = 20
1713+ )
1714+ axs [2 ].hist (
1715+ ate_df ["ATE" ],
1716+ label = "ATE" ,
1717+ ec = "black" ,
1718+ bins = 10 ,
1719+ color = "slateblue" ,
1720+ alpha = 0.6 ,
1721+ )
1722+ axs [2 ].axvline (ate_df ["ATE" ].mean (), label = "E(ATE)" )
16791723 axs [2 ].legend ()
1680- axs [2 ].set_title ("Average Treatment Effect" , fontsize = 20 );
1681-
1724+ axs [2 ].set_title ("Average Treatment Effect" , fontsize = 20 )
16821725
16831726 def weighted_percentile (self , data , weights , perc ):
16841727 """
16851728 perc : percentile in [0-1]!
16861729 """
16871730 ix = np .argsort (data )
1688- data = data [ix ] # sort data
1689- weights = weights [ix ] # sort weights
1690- cdf = (np .cumsum (weights ) - 0.5 * weights ) / np .sum (weights ) # 'like' a CDF function
1731+ data = data [ix ] # sort data
1732+ weights = weights [ix ] # sort weights
1733+ cdf = (np .cumsum (weights ) - 0.5 * weights ) / np .sum (
1734+ weights
1735+ ) # 'like' a CDF function
16911736 return np .interp (perc , cdf , data )
1692-
1737+
16931738 def plot_balance_ecdf (self , covariate , idata = None , weighting_scheme = None ):
16941739 if idata is None :
16951740 idata = self .idata
1696- if weighting_scheme is None :
1741+ if weighting_scheme is None :
16971742 weighting_scheme = self .weighting_scheme
1698-
1699- ps = az .extract (idata )['p' ].mean (dim = ' sample' ).values
1743+
1744+ ps = az .extract (idata )["p" ].mean (dim = " sample" ).values
17001745 X = pd .DataFrame (self .X , columns = self .labels )
1701- X ['ps' ] = ps
1746+ X ["ps" ] = ps
17021747 t = self .t .flatten ()
1703- if weighting_scheme == ' raw' :
1748+ if weighting_scheme == " raw" :
17041749 w1 = 1 / ps [t == 1 ]
1705- w0 = 1 / (1 - ps [t == 0 ])
1706- elif weighting_scheme == ' robust' :
1750+ w0 = 1 / (1 - ps [t == 0 ])
1751+ elif weighting_scheme == " robust" :
17071752 p_of_t = np .mean (t )
1708- w1 = p_of_t / (ps [t == 1 ])
1753+ w1 = p_of_t / (ps [t == 1 ])
17091754 w0 = (1 - p_of_t ) / (1 - ps [t == 0 ])
17101755 else :
1711- w1 = (1 - ps [t == 1 ])* t [t == 1 ]
1712- w0 = ( ps [t == 0 ]* ( 1 - t [t == 0 ]) )
1756+ w1 = (1 - ps [t == 1 ]) * t [t == 1 ]
1757+ w0 = ps [t == 0 ] * ( 1 - t [t == 0 ] )
17131758 fig , axs = plt .subplots (1 , 2 , figsize = (20 , 6 ))
1714- raw_trt = [self .weighted_percentile (X [t == 1 ][covariate ].values , np .ones (len (X [t == 1 ])), p ) for p in np .linspace (0 , 1 , 1000 )]
1715- raw_ntrt = [self .weighted_percentile (X [t == 0 ][covariate ].values , np .ones (len (X [t == 0 ])), p ) for p in np .linspace (0 , 1 , 1000 )]
1716- w_trt = [self .weighted_percentile (X [t == 1 ][covariate ].values , w1 , p ) for p in np .linspace (0 , 1 , 1000 )]
1717- w_ntrt = [self .weighted_percentile (X [t == 0 ][covariate ].values , w0 , p ) for p in np .linspace (0 , 1 , 1000 )]
1718- axs [0 ].plot (np .linspace (0 , 1 , 1000 ), raw_trt , color = 'blue' , label = 'Raw Treated' )
1719- axs [0 ].plot (np .linspace (0 , 1 , 1000 ), raw_ntrt , color = 'red' , label = 'Raw Control' )
1759+ raw_trt = [
1760+ self .weighted_percentile (
1761+ X [t == 1 ][covariate ].values , np .ones (len (X [t == 1 ])), p
1762+ )
1763+ for p in np .linspace (0 , 1 , 1000 )
1764+ ]
1765+ raw_ntrt = [
1766+ self .weighted_percentile (
1767+ X [t == 0 ][covariate ].values , np .ones (len (X [t == 0 ])), p
1768+ )
1769+ for p in np .linspace (0 , 1 , 1000 )
1770+ ]
1771+ w_trt = [
1772+ self .weighted_percentile (X [t == 1 ][covariate ].values , w1 , p )
1773+ for p in np .linspace (0 , 1 , 1000 )
1774+ ]
1775+ w_ntrt = [
1776+ self .weighted_percentile (X [t == 0 ][covariate ].values , w0 , p )
1777+ for p in np .linspace (0 , 1 , 1000 )
1778+ ]
1779+ axs [0 ].plot (np .linspace (0 , 1 , 1000 ), raw_trt , color = "blue" , label = "Raw Treated" )
1780+ axs [0 ].plot (np .linspace (0 , 1 , 1000 ), raw_ntrt , color = "red" , label = "Raw Control" )
17201781 axs [0 ].set_title (f"ECDF \n Raw: { covariate } " )
1721- axs [1 ].set_title (f"ECDF \n Weighted { weighting_scheme } adjustment for { covariate } " )
1722- axs [1 ].plot (np .linspace (0 , 1 , 1000 ), w_trt , color = 'blue' , label = 'Reweighted Treated' )
1723- axs [1 ].plot (np .linspace (0 , 1 , 1000 ), w_ntrt , color = 'red' , label = 'Reweighted Control' )
1782+ axs [1 ].set_title (
1783+ f"ECDF \n Weighted { weighting_scheme } adjustment for { covariate } "
1784+ )
1785+ axs [1 ].plot (
1786+ np .linspace (0 , 1 , 1000 ), w_trt , color = "blue" , label = "Reweighted Treated"
1787+ )
1788+ axs [1 ].plot (
1789+ np .linspace (0 , 1 , 1000 ), w_ntrt , color = "red" , label = "Reweighted Control"
1790+ )
17241791 axs [1 ].set_xlabel ("Quantiles" )
17251792 axs [0 ].set_xlabel ("Quantiles" )
17261793 axs [1 ].legend ()
17271794 axs [0 ].legend ()
1728-
1729-
1730-
1731-
0 commit comments