@@ -547,32 +547,29 @@ def _gen_model_t(self):
547547 def _gen_model_final (self ):
548548 return StatsModelsLinearRegression (fit_intercept = False )
549549
550- def _gen_ortho_learner_model_nuisance (self , n_periods ):
550+ def _gen_ortho_learner_model_nuisance (self ):
551551 return _DynamicModelNuisance (
552552 model_t = self ._gen_model_t (),
553553 model_y = self ._gen_model_y (),
554- n_periods = n_periods )
554+ n_periods = self . _n_periods )
555555
556- def _gen_ortho_learner_model_final (self , n_periods ):
556+ def _gen_ortho_learner_model_final (self ):
557557 wrapped_final_model = _DynamicFinalWrapper (
558558 StatsModelsLinearRegression (fit_intercept = False ),
559559 fit_cate_intercept = self .fit_cate_intercept ,
560560 featurizer = self .featurizer ,
561561 use_weight_trick = False )
562- return _LinearDynamicModelFinal (wrapped_final_model , n_periods = n_periods )
562+ return _LinearDynamicModelFinal (wrapped_final_model , n_periods = self . _n_periods )
563563
564564 def _prefit (self , Y , T , * args , groups = None , only_final = False , ** kwargs ):
565+ # we need to set the number of periods before calling super()._prefit, since that will generate the
566+ # final and nuisance models, which need to have self._n_periods set
565567 u_periods = np .unique (np .unique (groups , return_counts = True )[1 ])
566568 if len (u_periods ) > 1 :
567569 raise AttributeError (
568570 "Imbalanced panel. Method currently expects only panels with equal number of periods. Pad your data" )
569571 self ._n_periods = u_periods [0 ]
570- # generate an instance of the final model
571- self ._ortho_learner_model_final = self ._gen_ortho_learner_model_final (self ._n_periods )
572- if not only_final :
573- # generate an instance of the nuisance model
574- self ._ortho_learner_model_nuisance = self ._gen_ortho_learner_model_nuisance (self ._n_periods )
575- TreatmentExpansionMixin ._prefit (self , Y , T , * args , ** kwargs )
572+ super ()._prefit (Y , T , * args , ** kwargs )
576573
577574 def _postfit (self , Y , T , * args , ** kwargs ):
578575 super ()._postfit (Y , T , * args , ** kwargs )
0 commit comments