@@ -134,7 +134,7 @@ def _input_validation(self, data, treatment_time):
134134 "If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
135135 )
136136
137- def plot (self ):
137+ def plot (self , counterfactual_label = "Counterfactual" , ** kwargs ):
138138 """Plot the results"""
139139 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
140140
@@ -161,7 +161,7 @@ def plot(self):
161161 plot_hdi_kwargs = {"color" : "C1" },
162162 )
163163 handles .append ((h_line , h_patch ))
164- labels .append ("Synthetic control" )
164+ labels .append (counterfactual_label )
165165
166166 ax [0 ].plot (self .datapost .index , self .post_y , "k." )
167167 # Shaded causal effect
@@ -243,14 +243,20 @@ def summary(self):
243243 self .print_coefficients ()
244244
245245
246+ class InterruptedTimeSeries (PrePostFit ):
247+ """Interrupted time series analysis"""
248+
249+ expt_type = "Interrupted Time Series"
250+
251+
246252class SyntheticControl (PrePostFit ):
247253 """A wrapper around the PrePostFit class"""
248254
249255 expt_type = "Synthetic Control"
250256
251- def plot (self , plot_predictors = False ):
257+ def plot (self , plot_predictors = False , ** kwargs ):
252258 """Plot the results"""
253- fig , ax = super ().plot ()
259+ fig , ax = super ().plot (counterfactual_label = "Synthetic control" , ** kwargs )
254260 if plot_predictors :
255261 # plot control units as well
256262 ax [0 ].plot (self .datapre .index , self .pre_X , "-" , c = [0.8 , 0.8 , 0.8 ], zorder = 1 )
0 commit comments