@@ -74,7 +74,7 @@ def __init__(
7474 # cumulative impact post
7575 self .post_impact_cumulative = np .cumsum (self .post_impact )
7676
77- def plot (self ):
77+ def plot (self , counterfactual_label = "Counterfactual" , ** kwargs ):
7878 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
7979
8080 ax [0 ].plot (self .datapre .index , self .pre_y , "k." )
@@ -84,7 +84,7 @@ def plot(self):
8484 ax [0 ].plot (
8585 self .datapost .index ,
8686 self .post_pred ,
87- label = "counterfactual" ,
87+ label = counterfactual_label ,
8888 ls = ":" ,
8989 c = "k" ,
9090 )
@@ -95,7 +95,7 @@ def plot(self):
9595 self .datapost .index ,
9696 self .post_impact ,
9797 "k." ,
98- label = "counterfactual" ,
98+ label = counterfactual_label ,
9999 )
100100 ax [1 ].axhline (y = 0 , c = "k" )
101101 ax [1 ].set (title = "Causal Impact" )
@@ -151,12 +151,18 @@ def plot_coeffs(self):
151151 )
152152
153153
154+ class InterruptedTimeSeries (PrePostFit ):
155+ """Interrupted time series analysis"""
156+
157+ expt_type = "Interrupted Time Series"
158+
159+
154160class SyntheticControl (PrePostFit ):
155161 """A wrapper around the PrePostFit class"""
156162
157- def plot (self , plot_predictors = False ):
163+ def plot (self , plot_predictors = False , ** kwargs ):
158164 """Plot the results"""
159- fig , ax = super ().plot ()
165+ fig , ax = super ().plot (counterfactual_label = "Synthetic control" , ** kwargs )
160166 if plot_predictors :
161167 # plot control units as well
162168 ax [0 ].plot (self .datapre .index , self .pre_X , "-" , c = [0.8 , 0.8 , 0.8 ], zorder = 1 )
0 commit comments