2626from causalpy .custom_exceptions import BadIndexException # NOQA
2727from causalpy .custom_exceptions import DataException , FormulaException
2828from causalpy .plot_utils import plot_xY
29- from causalpy .utils import _is_variable_dummy_coded , _series_has_2_levels
29+ from causalpy .utils import _is_variable_dummy_coded
3030
3131LEGEND_FONT_SIZE = 12
3232az .style .use ("arviz-darkgrid" )
@@ -978,7 +978,8 @@ class PrePostNEGD(ExperimentalDesign):
978978 :param formula:
979979 A statistical model formula
980980 :param group_variable_name:
981- Name of the column in data for the group variable
981+ Name of the column in data for the group variable, should be either
982+ binary or boolean
982983 :param pretreatment_variable_name:
983984 Name of the column in data for the pretreatment variable
984985 :param model:
@@ -1058,17 +1059,19 @@ def __init__(
10581059 self .group_variable_name : np .zeros (self .pred_xi .shape ),
10591060 }
10601061 )
1061- (new_x ,) = build_design_matrices ([self ._x_design_info ], x_pred_untreated )
1062- self .pred_untreated = self .model .predict (X = np .asarray (new_x ))
1062+ (new_x_untreated ,) = build_design_matrices (
1063+ [self ._x_design_info ], x_pred_untreated
1064+ )
1065+ self .pred_untreated = self .model .predict (X = np .asarray (new_x_untreated ))
10631066 # treated
1064- x_pred_untreated = pd .DataFrame (
1067+ x_pred_treated = pd .DataFrame (
10651068 {
10661069 self .pretreatment_variable_name : self .pred_xi ,
10671070 self .group_variable_name : np .ones (self .pred_xi .shape ),
10681071 }
10691072 )
1070- (new_x ,) = build_design_matrices ([self ._x_design_info ], x_pred_untreated )
1071- self .pred_treated = self .model .predict (X = np .asarray (new_x ))
1073+ (new_x_treated ,) = build_design_matrices ([self ._x_design_info ], x_pred_treated )
1074+ self .pred_treated = self .model .predict (X = np .asarray (new_x_treated ))
10721075
10731076 # Evaluate causal impact as equal to the trestment effect
10741077 self .causal_impact = self .idata .posterior ["beta" ].sel (
@@ -1079,7 +1082,7 @@ def __init__(
10791082
10801083 def _input_validation (self ) -> None :
10811084 """Validate the input data and model formula for correctness"""
1082- if not _series_has_2_levels (self .data [self .group_variable_name ]):
1085+ if not _is_variable_dummy_coded (self .data [self .group_variable_name ]):
10831086 raise DataException (
10841087 f"""
10851088 There must be 2 levels of the grouping variable
@@ -1165,7 +1168,7 @@ def _get_treatment_effect_coeff(self) -> str:
11651168 then we want `C(group)[T.1]`.
11661169 """
11671170 for label in self .labels :
1168- if ("group" in label ) & (":" not in label ):
1171+ if (self . group_variable_name in label ) & (":" not in label ):
11691172 return label
11701173
11711174 raise NameError ("Unable to find coefficient name for the treatment effect" )
0 commit comments