1818from pymc_experimental .statespace .utils .constants import (
1919 ALL_STATE_AUX_DIM ,
2020 ALL_STATE_DIM ,
21+ AR_PARAM_DIM ,
2122 LONG_MATRIX_NAMES ,
2223 OBS_STATE_DIM ,
2324 POSITION_DERIVATIVE_NAMES ,
25+ TIME_DIM ,
2426)
2527
2628_log = logging .getLogger ("pymc.experimental.statespace" )
@@ -786,7 +788,7 @@ def populate_component_properties(self):
786788 self .state_names = [name for name , mask in zip (name_slice , self ._order_mask ) if mask ]
787789 self .param_dims = {"initial_trend" : ("trend_state" ,)}
788790 self .coords = {"trend_state" : self .state_names }
789- self .param_info = {"initial_trend" : {"shape" : (self .k_states ,), "constraints" : " None" }}
791+ self .param_info = {"initial_trend" : {"shape" : (self .k_states ,), "constraints" : None }}
790792
791793 if self .k_posdef > 0 :
792794 self .param_names += ["sigma_trend" ]
@@ -871,7 +873,11 @@ def populate_component_properties(self):
871873 self .param_names = [f"sigma_{ self .name } " ]
872874 self .param_dims = {f"sigma_{ self .name } " : (OBS_STATE_DIM ,)}
873875 self .param_info = {
874- f"sigma_{ self .name } " : {"shape" : (1 ,), "constraints" : "Positive" , "dims" : "None" }
876+ f"sigma_{ self .name } " : {
877+ "shape" : (1 ,),
878+ "constraints" : "Positive" ,
879+ "dims" : (OBS_STATE_DIM ,),
880+ }
875881 }
876882
877883 def make_symbolic_graph (self ) -> None :
@@ -959,11 +965,15 @@ def populate_component_properties(self):
959965 self .state_names = [f"L{ i + 1 } .data" for i in range (self .k_states )]
960966 self .shock_names = [f"{ self .name } _innovation" ]
961967 self .param_names = ["ar_params" , "sigma_ar" ]
962- self .param_dims = {"ar_params" : ("ar_lags" ,)}
963- self .coords = {"ar_lags" : self .ar_lags }
968+ self .param_dims = {"ar_params" : (AR_PARAM_DIM ,)}
969+ self .coords = {AR_PARAM_DIM : self .ar_lags . tolist () }
964970
965971 self .param_info = {
966- "ar_params" : {"shape" : (self .k_states ,), "constraints" : "None" , "dims" : "(ar_lags, )" },
972+ "ar_params" : {
973+ "shape" : (self .k_states ,),
974+ "constraints" : None ,
975+ "dims" : (AR_PARAM_DIM ,),
976+ },
967977 "sigma_ar" : {"shape" : (1 ,), "constraints" : "Positive" , "dims" : None },
968978 }
969979
@@ -1133,19 +1143,19 @@ def populate_component_properties(self):
11331143 self .param_info = {
11341144 f"{ self .name } _coefs" : {
11351145 "shape" : (self .k_states ,),
1136- "constraints" : " None" ,
1137- "dims" : f"( { self .name } _state, )" ,
1146+ "constraints" : None ,
1147+ "dims" : ( f" { self .name } _state" ,) ,
11381148 }
11391149 }
1140- self .param_dims = {f"{ self .name } _coefs" : (f"{ self .name } _periods " ,)}
1150+ self .param_dims = {f"{ self .name } _coefs" : (f"{ self .name } _state " ,)}
11411151 self .coords = {f"{ self .name } _state" : self .state_names }
11421152
11431153 if self .innovations :
11441154 self .param_names += [f"sigma_{ self .name } " ]
11451155 self .param_info [f"sigma_{ self .name } " ] = {
11461156 "shape" : (1 ,),
11471157 "constraints" : "Positive" ,
1148- "dims" : " None" ,
1158+ "dims" : None ,
11491159 }
11501160 self .shock_names = [f"{ self .name } " ]
11511161
@@ -1270,27 +1280,27 @@ def populate_component_properties(self):
12701280 self .state_names = [f"{ self .name } _{ f } _{ i } " for i in range (self .n ) for f in ["Cos" , "Sin" ]]
12711281 self .param_names = [f"{ self .name } " ]
12721282
1273- self .param_dims = {self .name : (f"{ self .name } _initial_state " ,)}
1283+ self .param_dims = {self .name : (f"{ self .name } _state " ,)}
12741284 self .param_info = {
12751285 f"{ self .name } " : {
12761286 "shape" : (self .k_states - int (self .last_state_not_identified ),),
1277- "constraints" : " None" ,
1278- "dims" : f"( { self .name } _initial_state, )" ,
1287+ "constraints" : None ,
1288+ "dims" : ( f" { self .name } _state" ,) ,
12791289 }
12801290 }
12811291
12821292 init_state_idx = np .arange (self .k_states , dtype = int )
12831293 if self .last_state_not_identified :
12841294 init_state_idx = init_state_idx [:- 1 ]
1285- self .coords = {f"{ self .name } _initial_state " : [self .state_names [i ] for i in init_state_idx ]}
1295+ self .coords = {f"{ self .name } _state " : [self .state_names [i ] for i in init_state_idx ]}
12861296
12871297 if self .innovations :
12881298 self .shock_names = self .state_names .copy ()
12891299 self .param_names += [f"sigma_{ self .name } " ]
12901300 self .param_info [f"sigma_{ self .name } " ] = {
12911301 "shape" : (1 ,),
12921302 "constraints" : "Positive" ,
1293- "dims" : " None" ,
1303+ "dims" : None ,
12941304 }
12951305
12961306
@@ -1421,10 +1431,12 @@ def __init__(
14211431 def make_symbolic_graph (self ) -> None :
14221432 self .ssm ["design" , 0 , slice (0 , self .k_states , 2 )] = 1
14231433 self .ssm ["selection" , :, :] = np .eye (self .k_states )
1434+ self .param_dims = {self .name : (f"{ self .name } _state" ,)}
1435+ self .coords = {f"{ self .name } _state" : self .state_names }
14241436
1425- init_state = self .make_and_register_variable (f"{ self .name } " , shape = (1 ,))
1437+ init_state = self .make_and_register_variable (f"{ self .name } " , shape = (self . k_states ,))
14261438
1427- self .ssm ["initial_state" , 0 ] = init_state
1439+ self .ssm ["initial_state" , : ] = init_state
14281440
14291441 if self .estimate_cycle_length :
14301442 lamb = self .make_and_register_variable (f"{ self .name } _length" , shape = (1 ,))
@@ -1440,18 +1452,18 @@ def make_symbolic_graph(self) -> None:
14401452 self .ssm ["transition" , :, :] = T
14411453
14421454 if self .innovations :
1443- sigma_season = self .make_and_register_variable (f"sigma_{ self .name } " , shape = (1 ,))
1444- self .ssm ["state_cov" , :, :] = pt .eye (self .k_posdef ) * sigma_season
1455+ sigma_cycle = self .make_and_register_variable (f"sigma_{ self .name } " , shape = (1 ,))
1456+ self .ssm ["state_cov" , :, :] = pt .eye (self .k_posdef ) * sigma_cycle
14451457
14461458 def populate_component_properties (self ):
14471459 self .state_names = [f"{ self .name } _{ f } " for f in ["Sin" , "Cos" ]]
14481460 self .param_names = [f"{ self .name } " ]
14491461
14501462 self .param_info = {
14511463 f"{ self .name } " : {
1452- "shape" : (1 ,),
1453- "constraints" : " None" ,
1454- "dims" : None ,
1464+ "shape" : (2 ,),
1465+ "constraints" : None ,
1466+ "dims" : ( f" { self . name } _state" ,) ,
14551467 }
14561468 }
14571469
@@ -1476,7 +1488,7 @@ def populate_component_properties(self):
14761488 self .param_info [f"sigma_{ self .name } " ] = {
14771489 "shape" : (1 ,),
14781490 "constraints" : "Positive" ,
1479- "dims" : " None" ,
1491+ "dims" : None ,
14801492 }
14811493 self .shock_names = self .state_names .copy ()
14821494
@@ -1551,15 +1563,16 @@ def populate_component_properties(self) -> None:
15511563
15521564 self .param_names = [f"beta_{ self .name } " , f"data_{ self .name } " ]
15531565 self .param_dims = {
1554- f"beta_{ self .name } " : "exog_state" ,
1555- f"data_{ self .name } " : ("time" , "exog_state" ),
1566+ f"beta_{ self .name } " : ( "exog_state" ,) ,
1567+ f"data_{ self .name } " : (TIME_DIM , "exog_state" ),
15561568 }
1569+
15571570 self .param_info = {
1558- f"beta_{ self .name } " : {"shape" : (1 ,), "constraints" : " None" , "dims" : ("exog_state" ,)},
1571+ f"beta_{ self .name } " : {"shape" : (1 ,), "constraints" : None , "dims" : ("exog_state" ,)},
15591572 f"data_{ self .name } " : {
15601573 "shape" : (None , self .k_states ),
1561- "constraints" : " None" ,
1562- "dims" : ("time" , "exog_state" ),
1574+ "constraints" : None ,
1575+ "dims" : (TIME_DIM , "exog_state" ),
15631576 },
15641577 }
15651578 self .coords = {f"exog_state" : self .state_names }
0 commit comments