@@ -497,3 +497,114 @@ def fit(self, X, t, coords):
497497 )
498498 )
499499 return self .idata
500+
501+
502+ class InterventionTimeEstimator (PyMCModel ):
503+ r"""
504+ Custom PyMC model to estimate the time an intervetnion took place.
505+
506+ defines the PyMC model :
507+
508+ .. math::
509+ \alpha &\sim \mathrm{Normal}(0, 1) \\
510+ \beta &\sim \mathrm{Normal}(0, 1) \\
511+ s(t) &= \gamma_{i(t)} \quad \textrm{with} \quad \gamma_{k \in [0, ..., n_{seasons}-1]} \sim \mathrm{Normal}(0, 1)\\
512+ base_{\mu}(t) &= \alpha + \beta \cdot t + s_t\\
513+ \\
514+ \tau &\sim \mathrm{Uniform}(0, 1) \\
515+ w(t) &= sigmoid(t-\tau) \\
516+ \\
517+ level &\sim \mathrm{Normal}(0, 1) \\
518+ trend &\sim \mathrm{Normal}(0, 1) \\
519+ A &\sim \mathrm{Normal}(0, 1) \\
520+ \lambda &\sim \mathrm{HalfNormal}(0, 1) \\
521+ impulse(t) &= A \cdot exp(-\lambda \cdot |t-\tau|) \\
522+ intervention(t) &= level + trend \cdot (t-\tau) + impulse_t\\
523+ \\
524+ \sigma &\sim \mathrm{Normal}(0, 1) \\
525+ \mu(t) &= base_{\mu}(t) + w(t) \cdot intervention(t) \\
526+ \\
527+ y(t) &\sim \mathrm{Normal}(\mu (t), \sigma)
528+
529+ Example
530+ --------
531+ >>> import causalpy as cp
532+ >>> import numpy as np
533+ >>> from causalpy.pymc_models import InterventionTimeEstimator
534+ >>> df = cp.load("its")
535+ >>> y = df["y"].values
536+ >>> t = df["t"].values
537+ >>> coords = {"sseasons" = range(12)} # The data is monthly
538+ >>> estimator = InterventionTimeEstimator()
539+ >>> # We are trying to capture an impulse in the number of death per month due to Covid.
540+ >>> estimator.fit(
541+ ... t,
542+ ... y,
543+ ... coords,
544+ ... effect=["impulse"])
545+ Inference data...
546+ """
547+
548+ def build_model (self , t , y , coords , effect , span , grain_season ):
549+ """
550+ Defines the PyMC model
551+
552+ :param t: An array of values representing the time over which y is spread
553+ :param y: An array of values representing our outcome y
554+ :param coords: A dictionary with the coordinate names for our instruments
555+ """
556+
557+ with self :
558+ self .add_coords (coords )
559+
560+ if span is None :
561+ span = (t .min (), t .max ())
562+
563+ # --- Priors ---
564+ switchpoint = pm .Uniform ("switchpoint" , lower = span [0 ], upper = span [1 ])
565+ alpha = pm .Normal (name = "alpha" , mu = 0 , sigma = 10 )
566+ beta = pm .Normal (name = "beta" , mu = 0 , sigma = 10 )
567+ seasons = 0
568+ if "seasons" in coords and len (coords ["seasons" ]) > 0 :
569+ season_idx = np .arange (len (y )) // grain_season % len (coords ["seasons" ])
570+ season_effect = pm .Normal ("season" , mu = 0 , sigma = 1 , dims = "seasons" )
571+ seasons = season_effect [season_idx ]
572+
573+ # --- Intervention effect ---
574+ level = trend = impulse = 0
575+
576+ if "level" in effect :
577+ level = pm .Normal ("level" , mu = 0 , sigma = 10 )
578+
579+ if "trend" in effect :
580+ trend = pm .Normal ("trend" , mu = 0 , sigma = 10 )
581+
582+ if "impulse" in effect :
583+ impulse_amplitude = pm .Normal ("impulse_amplitude" , mu = 0 , sigma = 1 )
584+ decay_rate = pm .HalfNormal ("decay_rate" , sigma = 1 )
585+ impulse = impulse_amplitude * pm .math .exp (
586+ - decay_rate * abs (t - switchpoint )
587+ )
588+
589+ # --- Parameterization ---
590+ weight = pm .math .sigmoid (t - switchpoint )
591+ # Compute and store the modelled time series
592+ mu_ts = pm .Deterministic (name = "mu_ts" , var = alpha + beta * t + seasons )
593+ # Compute and store the modelled intervention effect
594+ mu_in = pm .Deterministic (
595+ name = "mu_in" , var = level + trend * (t - switchpoint ) + impulse
596+ )
597+ # Compute and store the the sum of the intervention and the time series
598+ mu = pm .Deterministic ("mu" , mu_ts + weight * mu_in )
599+
600+ # --- Likelihood ---
601+ pm .Normal ("y_hat" , mu = mu , sigma = 2 , observed = y )
602+
603+ def fit (self , t , y , coords , effect = [], span = None , grain_season = 1 , n = 1000 ):
604+ """
605+ Draw samples from posterior distribution
606+ """
607+ self .build_model (t , y , coords , effect , span , grain_season )
608+ with self :
609+ self .idata = pm .sample (n , ** self .sample_kwargs )
610+ return self .idata
0 commit comments