Skip to content

Commit 2d06be5

Browse files
authored
Fix bug in SyntheticControl.get_plot_data (#540)
* initial bug fix + regression test * massive simplification of the fix
1 parent a6ddda9 commit 2d06be5

File tree

3 files changed

+112
-8
lines changed

3 files changed

+112
-8
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ doctest:
1616
pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
1717

1818
test:
19-
pytest
19+
python -m pytest
2020

2121
uml:
2222
pyreverse -o png causalpy --output-directory docs/source/_static --ignore tests

causalpy/plot_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,17 @@ def get_hdi_to_df(
9393
:param hdi_prob:
9494
The size of the HDI, default is 0.94
9595
"""
96-
hdi = (
97-
az.hdi(x, hdi_prob=hdi_prob)
98-
.to_dataframe()
99-
.unstack(level="hdi")
100-
.droplevel(0, axis=1)
101-
)
102-
return hdi
96+
hdi_result = az.hdi(x, hdi_prob=hdi_prob)
97+
98+
# Get the data variable name (typically 'mu' or 'x')
99+
# We select only the data variable column to exclude coordinates like 'treated_units'
100+
data_var = list(hdi_result.data_vars)[0]
101+
102+
# Convert to DataFrame, select only the data variable column, then unstack
103+
# This prevents coordinate values (like 'treated_agg') from appearing as columns
104+
hdi_df = hdi_result[data_var].to_dataframe()[[data_var]].unstack(level="hdi")
105+
106+
# Remove the top level of column MultiIndex to get just 'lower' and 'higher'
107+
hdi_df.columns = hdi_df.columns.droplevel(0)
108+
109+
return hdi_df

causalpy/tests/test_plot_utils.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2025 - 2025 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Tests for plot utility functions
16+
"""
17+
18+
import numpy as np
19+
import pandas as pd
20+
import pytest
21+
import xarray as xr
22+
23+
from causalpy.plot_utils import get_hdi_to_df
24+
25+
26+
@pytest.mark.integration
27+
def test_get_hdi_to_df_with_coordinate_dimensions():
28+
"""
29+
Regression test for bug where get_hdi_to_df returned string coordinate values
30+
instead of numeric HDI values when xarray had named coordinate dimensions.
31+
32+
This bug manifested in multi-cell synthetic control experiments where columns
33+
like 'pred_hdi_upper_94' contained the string "treated_agg" instead of
34+
numeric upper bound values.
35+
36+
See: https://github.com/pymc-labs/CausalPy/issues/532
37+
"""
38+
# Create a mock xarray DataArray similar to what's produced in synthetic control
39+
# with a coordinate dimension like 'treated_units'
40+
np.random.seed(42)
41+
n_chains = 2
42+
n_draws = 100
43+
n_obs = 10
44+
45+
# Simulate posterior samples with a named coordinate
46+
data = np.random.normal(loc=5.0, scale=0.5, size=(n_chains, n_draws, n_obs))
47+
48+
xr_data = xr.DataArray(
49+
data,
50+
dims=["chain", "draw", "obs_ind"],
51+
coords={
52+
"chain": np.arange(n_chains),
53+
"draw": np.arange(n_draws),
54+
"obs_ind": np.arange(n_obs),
55+
"treated_units": "treated_agg", # This coordinate caused the bug
56+
},
57+
)
58+
59+
# Call get_hdi_to_df
60+
result = get_hdi_to_df(xr_data, hdi_prob=0.94)
61+
62+
# Assertions to verify the bug is fixed
63+
assert isinstance(result, pd.DataFrame), "Result should be a DataFrame"
64+
65+
# Check that we have exactly 2 columns (lower and higher)
66+
assert result.shape[1] == 2, f"Expected 2 columns, got {result.shape[1]}"
67+
68+
# Check column names
69+
assert "lower" in result.columns, "Should have 'lower' column"
70+
assert "higher" in result.columns, "Should have 'higher' column"
71+
72+
# CRITICAL: Check that columns contain numeric data, not strings
73+
assert result["lower"].dtype in [
74+
np.float64,
75+
np.float32,
76+
], f"'lower' column should be numeric, got {result['lower'].dtype}"
77+
assert result["higher"].dtype in [
78+
np.float64,
79+
np.float32,
80+
], f"'higher' column should be numeric, got {result['higher'].dtype}"
81+
82+
# Check that no string values like 'treated_agg' appear in the data
83+
assert not (result["lower"].astype(str).str.contains("treated_agg").any()), (
84+
"'lower' column should not contain coordinate string values"
85+
)
86+
assert not (result["higher"].astype(str).str.contains("treated_agg").any()), (
87+
"'higher' column should not contain coordinate string values"
88+
)
89+
90+
# Verify HDI ordering
91+
assert (result["lower"] <= result["higher"]).all(), (
92+
"'lower' should be <= 'higher' for all rows"
93+
)
94+
95+
# Verify reasonable HDI values (should be around the mean of 5.0)
96+
assert result["lower"].min() > 3.0, "HDI lower bounds should be reasonable"
97+
assert result["higher"].max() < 7.0, "HDI upper bounds should be reasonable"

0 commit comments

Comments
 (0)