diff --git a/causalpy/data/simulate_data.py b/causalpy/data/simulate_data.py index 715a7b63..b22d51b0 100644 --- a/causalpy/data/simulate_data.py +++ b/causalpy/data/simulate_data.py @@ -15,21 +15,26 @@ Functions that generate data sets used in examples """ +from typing import Any + import numpy as np import pandas as pd from scipy.stats import dirichlet, gamma, norm, uniform from statsmodels.nonparametric.smoothers_lowess import lowess -default_lowess_kwargs = {"frac": 0.2, "it": 0} -RANDOM_SEED = 8927 -rng = np.random.default_rng(RANDOM_SEED) +default_lowess_kwargs: dict[str, float] = {"frac": 0.2, "it": 0} +RANDOM_SEED: int = 8927 +rng: np.random.Generator = np.random.default_rng(RANDOM_SEED) def _smoothed_gaussian_random_walk( - gaussian_random_walk_mu, gaussian_random_walk_sigma, N, lowess_kwargs -): + gaussian_random_walk_mu: float, + gaussian_random_walk_sigma: float, + N: int, + lowess_kwargs: dict[str, Any], +) -> tuple[np.ndarray, np.ndarray]: """ - Generates Gaussian random walk data and applies LOWESS + Generates Gaussian random walk data and applies LOWESS. :param gaussian_random_walk_mu: Mean of the random walk @@ -48,12 +53,12 @@ def _smoothed_gaussian_random_walk( def generate_synthetic_control_data( - N=100, - treatment_time=70, - grw_mu=0.25, - grw_sigma=1, - lowess_kwargs=default_lowess_kwargs, -): + N: int = 100, + treatment_time: int = 70, + grw_mu: float = 0.25, + grw_sigma: float = 1, + lowess_kwargs: dict[str, Any] | None = None, +) -> tuple[pd.DataFrame, np.ndarray]: """ Generates data for synthetic control example. @@ -73,6 +78,8 @@ def generate_synthetic_control_data( >>> from causalpy.data.simulate_data import generate_synthetic_control_data >>> df, weightings_true = generate_synthetic_control_data(treatment_time=70) """ + if lowess_kwargs is None: + lowess_kwargs = default_lowess_kwargs # 1. Generate non-treated variables df = pd.DataFrame( @@ -108,8 +115,12 @@ def generate_synthetic_control_data( def generate_time_series_data( - N=100, treatment_time=70, beta_temp=-1, beta_linear=0.5, beta_intercept=3 -): + N: int = 100, + treatment_time: int = 70, + beta_temp: float = -1, + beta_linear: float = 0.5, + beta_intercept: float = 3, +) -> pd.DataFrame: """ Generates interrupted time series example data @@ -155,7 +166,7 @@ def generate_time_series_data( return df -def generate_time_series_data_seasonal(treatment_time): +def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataFrame: """ Generates 10 years of monthly data with seasonality """ @@ -183,7 +194,9 @@ def generate_time_series_data_seasonal(treatment_time): return df -def generate_time_series_data_simple(treatment_time, slope=0.0): +def generate_time_series_data_simple( + treatment_time: pd.Timestamp, slope: float = 0.0 +) -> pd.DataFrame: """Generate simple interrupted time series data, with no seasonality or temporal structure. """ @@ -205,7 +218,7 @@ def generate_time_series_data_simple(treatment_time, slope=0.0): return df -def generate_did(): +def generate_did() -> pd.DataFrame: """ Generate Difference in Differences data @@ -223,8 +236,14 @@ def generate_did(): # local functions def outcome( - t, control_intercept, treat_intercept_delta, trend, Δ, group, post_treatment - ): + t: np.ndarray, + control_intercept: float, + treat_intercept_delta: float, + trend: float, + Δ: float, + group: np.ndarray, + post_treatment: np.ndarray, + ) -> np.ndarray: """Compute the outcome of each unit""" return ( control_intercept @@ -257,8 +276,8 @@ def outcome( def generate_regression_discontinuity_data( - N=100, true_causal_impact=0.5, true_treatment_threshold=0.0 -): + N: int = 100, true_causal_impact: float = 0.5, true_treatment_threshold: float = 0.0 +) -> pd.DataFrame: """ Generate regression discontinuity example data @@ -272,12 +291,12 @@ def generate_regression_discontinuity_data( ... ) # doctest: +SKIP """ - def is_treated(x): + def is_treated(x: np.ndarray) -> np.ndarray: """Check if x was treated""" return np.greater_equal(x, true_treatment_threshold) - def impact(x): - """Assign true_causal_impact to all treaated entries""" + def impact(x: np.ndarray) -> np.ndarray: + """Assign true_causal_impact to all treated entries""" y = np.zeros(len(x)) y[is_treated(x)] = true_causal_impact return y @@ -289,8 +308,11 @@ def impact(x): def generate_ancova_data( - N=200, pre_treatment_means=np.array([10, 12]), treatment_effect=2, sigma=1 -): + N: int = 200, + pre_treatment_means: np.ndarray = np.array([10, 12]), + treatment_effect: float = 2, + sigma: float = 1, +) -> pd.DataFrame: """ Generate ANCOVA example data @@ -310,7 +332,7 @@ def generate_ancova_data( return df -def generate_geolift_data(): +def generate_geolift_data() -> pd.DataFrame: """Generate synthetic data for a geolift example. This will consists of 6 untreated countries. The treated unit `Denmark` is a weighted combination of the untreated units. We additionally specify a treatment effect which takes effect after the @@ -360,7 +382,7 @@ def generate_geolift_data(): return df -def generate_multicell_geolift_data(): +def generate_multicell_geolift_data() -> pd.DataFrame: """Generate synthetic data for a geolift example. This will consists of 6 untreated countries. The treated unit `Denmark` is a weighted combination of the untreated units. We additionally specify a treatment effect which takes effect after the @@ -422,7 +444,9 @@ def generate_multicell_geolift_data(): # ----------------- -def generate_seasonality(n=12, amplitude=1, length_scale=0.5): +def generate_seasonality( + n: int = 12, amplitude: float = 1, length_scale: float = 0.5 +) -> np.ndarray: """Generate monthly seasonality by sampling from a Gaussian process with a Gaussian kernel, using numpy code""" # Generate the covariance matrix @@ -436,14 +460,26 @@ def generate_seasonality(n=12, amplitude=1, length_scale=0.5): return seasonality -def periodic_kernel(x1, x2, period=1, length_scale=1, amplitude=1): +def periodic_kernel( + x1: np.ndarray, + x2: np.ndarray, + period: float = 1, + length_scale: float = 1, + amplitude: float = 1, +) -> np.ndarray: """Generate a periodic kernel for gaussian process""" return amplitude**2 * np.exp( -2 * np.sin(np.pi * np.abs(x1 - x2) / period) ** 2 / length_scale**2 ) -def create_series(n=52, amplitude=1, length_scale=2, n_years=4, intercept=3): +def create_series( + n: int = 52, + amplitude: float = 1, + length_scale: float = 2, + n_years: int = 4, + intercept: float = 3, +) -> np.ndarray: """ Returns numpy tile with generated seasonality data repeated over multiple years diff --git a/causalpy/tests/test_synthetic_data.py b/causalpy/tests/test_synthetic_data.py index 1093802a..3477ce0f 100644 --- a/causalpy/tests/test_synthetic_data.py +++ b/causalpy/tests/test_synthetic_data.py @@ -39,3 +39,43 @@ def test_generate_geolift_data(): df = generate_geolift_data() assert isinstance(df, pd.DataFrame) assert np.all(df >= 0), "Found negative values in dataset" + + +def test_generate_regression_discontinuity_data(): + """ + Test the generate_regression_discontinuity_data function. + """ + from causalpy.data.simulate_data import generate_regression_discontinuity_data + + df = generate_regression_discontinuity_data() + assert isinstance(df, pd.DataFrame) + assert "x" in df.columns + assert "y" in df.columns + assert "treated" in df.columns + assert len(df) == 100 # default N value + assert df["treated"].dtype == bool or df["treated"].dtype == np.bool_ + + # Test with custom parameters + df_custom = generate_regression_discontinuity_data( + N=50, true_causal_impact=1.0, true_treatment_threshold=0.5 + ) + assert len(df_custom) == 50 + + +def test_generate_synthetic_control_data(): + """ + Test the generate_synthetic_control_data function. + """ + from causalpy.data.simulate_data import generate_synthetic_control_data + + # Test with default parameters (lowess_kwargs=None) + df, weightings = generate_synthetic_control_data() + assert isinstance(df, pd.DataFrame) + assert isinstance(weightings, np.ndarray) + assert len(df) == 100 # default N value + + # Test with explicit lowess_kwargs + df_custom, weightings_custom = generate_synthetic_control_data( + N=50, lowess_kwargs={"frac": 0.3, "it": 5} + ) + assert len(df_custom) == 50 diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 8734d55d..26433625 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 95.8% + interrogate: 95.9% @@ -12,8 +12,8 @@ interrogate interrogate - 95.8% - 95.8% + 95.9% + 95.9%