@@ -85,21 +85,26 @@ class CycleComponent(Component):
8585
8686 # Build the structural model
8787 grw = st.LevelTrendComponent(order=1, innovations_order=1)
88- cycle = st.CycleComponent('business_cycle', estimate_cycle_length=True, dampen=False)
88+ cycle = st.CycleComponent(
89+ "business_cycle", cycle_length=12, estimate_cycle_length=False, innovations=True, dampen=True
90+ )
8991 ss_mod = (grw + cycle).build()
9092
9193 # Estimate with PyMC
9294 with pm.Model(coords=ss_mod.coords) as model:
9395 P0 = pm.Deterministic('P0', pt.eye(ss_mod.k_states), dims=ss_mod.param_dims['P0'])
94- intitial_trend = pm.Normal('initial_trend', dims=ss_mod.param_dims['initial_trend'])
95- sigma_trend = pm.HalfNormal('sigma_trend', dims=ss_mod.param_dims['sigma_trend'])
9696
97- cycle_strength = pm.Normal("business_cycle", dims=ss_mod.param_dims["business_cycle"])
98- cycle_length = pm.Uniform('business_cycle_length', lower=6, upper=12)
99- sigma_cycle = pm.HalfNormal('sigma_business_cycle', sigma=1)
97+ initial_level_trend = pm.Normal('initial_level_trend', dims=ss_mod.param_dims['initial_level_trend'])
98+ sigma_level_trend = pm.HalfNormal('sigma_level_trend', dims=ss_mod.param_dims['sigma_level_trend'])
99+
100+ business_cycle = pm.Normal("business_cycle", dims=ss_mod.param_dims["business_cycle"])
101+ dampening = pm.Beta("dampening_factor_business_cycle", 2, 2)
102+ sigma_cycle = pm.HalfNormal("sigma_business_cycle", sigma=1)
100103
101104 ss_mod.build_statespace_graph(data)
102- idata = pm.sample()
105+ idata = pm.sample(
106+ nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "JAX", "gradient_backend": "JAX"}
107+ )
103108
104109 **Multivariate Example:**
105110 Model cycles for multiple economic indicators with variable-specific innovation variances:
@@ -115,26 +120,25 @@ class CycleComponent(Component):
115120 dampen=True,
116121 observed_state_names=['gdp', 'unemployment', 'inflation']
117122 )
118-
119- # Build the model
120123 ss_mod = cycle.build()
121124
122- # In PyMC model:
123125 with pm.Model(coords=ss_mod.coords) as model:
124126 P0 = pm.Deterministic("P0", pt.eye(ss_mod.k_states), dims=ss_mod.param_dims["P0"])
125127 # Initial states: shape (3, 2) for 3 variables, 2 states each
126- cycle_init = pm.Normal('business_cycle', dims=ss_mod.param_dims["business_cycle"])
128+ business_cycle = pm.Normal('business_cycle', dims=ss_mod.param_dims["business_cycle"])
127129
128130 # Dampening factor: scalar (shared across variables)
129- dampening = pm.Beta("business_cycle_dampening_factor ", 2, 2)
131+ dampening = pm.Beta("dampening_factor_business_cycle ", 2, 2)
130132
131133 # Innovation variances: shape (3,) for variable-specific variances
132134 sigma_cycle = pm.HalfNormal(
133135 "sigma_business_cycle", dims=ss_mod.param_dims["sigma_business_cycle"]
134136 )
135137
136138 ss_mod.build_statespace_graph(data)
137- idata = pm.sample()
139+ idata = pm.sample(
140+ nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "JAX", "gradient_backend": "JAX"}
141+ )
138142
139143 References
140144 ----------
0 commit comments