|
| 1 | +# Copyright 2025 - 2025 The PyMC Labs Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +""" |
| 15 | +Tests for plot utility functions |
| 16 | +""" |
| 17 | + |
| 18 | +import numpy as np |
| 19 | +import pandas as pd |
| 20 | +import pytest |
| 21 | +import xarray as xr |
| 22 | + |
| 23 | +from causalpy.plot_utils import get_hdi_to_df |
| 24 | + |
| 25 | + |
| 26 | +@pytest.mark.integration |
| 27 | +def test_get_hdi_to_df_with_coordinate_dimensions(): |
| 28 | + """ |
| 29 | + Regression test for bug where get_hdi_to_df returned string coordinate values |
| 30 | + instead of numeric HDI values when xarray had named coordinate dimensions. |
| 31 | +
|
| 32 | + This bug manifested in multi-cell synthetic control experiments where columns |
| 33 | + like 'pred_hdi_upper_94' contained the string "treated_agg" instead of |
| 34 | + numeric upper bound values. |
| 35 | +
|
| 36 | + See: https://github.com/pymc-labs/CausalPy/issues/532 |
| 37 | + """ |
| 38 | + # Create a mock xarray DataArray similar to what's produced in synthetic control |
| 39 | + # with a coordinate dimension like 'treated_units' |
| 40 | + np.random.seed(42) |
| 41 | + n_chains = 2 |
| 42 | + n_draws = 100 |
| 43 | + n_obs = 10 |
| 44 | + |
| 45 | + # Simulate posterior samples with a named coordinate |
| 46 | + data = np.random.normal(loc=5.0, scale=0.5, size=(n_chains, n_draws, n_obs)) |
| 47 | + |
| 48 | + xr_data = xr.DataArray( |
| 49 | + data, |
| 50 | + dims=["chain", "draw", "obs_ind"], |
| 51 | + coords={ |
| 52 | + "chain": np.arange(n_chains), |
| 53 | + "draw": np.arange(n_draws), |
| 54 | + "obs_ind": np.arange(n_obs), |
| 55 | + "treated_units": "treated_agg", # This coordinate caused the bug |
| 56 | + }, |
| 57 | + ) |
| 58 | + |
| 59 | + # Call get_hdi_to_df |
| 60 | + result = get_hdi_to_df(xr_data, hdi_prob=0.94) |
| 61 | + |
| 62 | + # Assertions to verify the bug is fixed |
| 63 | + assert isinstance(result, pd.DataFrame), "Result should be a DataFrame" |
| 64 | + |
| 65 | + # Check that we have exactly 2 columns (lower and higher) |
| 66 | + assert result.shape[1] == 2, f"Expected 2 columns, got {result.shape[1]}" |
| 67 | + |
| 68 | + # Check column names |
| 69 | + assert "lower" in result.columns, "Should have 'lower' column" |
| 70 | + assert "higher" in result.columns, "Should have 'higher' column" |
| 71 | + |
| 72 | + # CRITICAL: Check that columns contain numeric data, not strings |
| 73 | + assert result["lower"].dtype in [ |
| 74 | + np.float64, |
| 75 | + np.float32, |
| 76 | + ], f"'lower' column should be numeric, got {result['lower'].dtype}" |
| 77 | + assert result["higher"].dtype in [ |
| 78 | + np.float64, |
| 79 | + np.float32, |
| 80 | + ], f"'higher' column should be numeric, got {result['higher'].dtype}" |
| 81 | + |
| 82 | + # Check that no string values like 'treated_agg' appear in the data |
| 83 | + assert not (result["lower"].astype(str).str.contains("treated_agg").any()), ( |
| 84 | + "'lower' column should not contain coordinate string values" |
| 85 | + ) |
| 86 | + assert not (result["higher"].astype(str).str.contains("treated_agg").any()), ( |
| 87 | + "'higher' column should not contain coordinate string values" |
| 88 | + ) |
| 89 | + |
| 90 | + # Verify HDI ordering |
| 91 | + assert (result["lower"] <= result["higher"]).all(), ( |
| 92 | + "'lower' should be <= 'higher' for all rows" |
| 93 | + ) |
| 94 | + |
| 95 | + # Verify reasonable HDI values (should be around the mean of 5.0) |
| 96 | + assert result["lower"].min() > 3.0, "HDI lower bounds should be reasonable" |
| 97 | + assert result["higher"].max() < 7.0, "HDI upper bounds should be reasonable" |
0 commit comments