|
1 | | -# Copyright 2024 The PyMC Labs Developers |
| 1 | +# Copyright 2025 The PyMC Labs Developers |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
25 | 25 | from sklearn.base import RegressorMixin |
26 | 26 |
|
27 | 27 | from causalpy.custom_exceptions import BadIndexException |
28 | | -from causalpy.plot_utils import plot_xY, get_hdi_to_df |
| 28 | +from causalpy.plot_utils import get_hdi_to_df, plot_xY |
29 | 29 | from causalpy.pymc_models import PyMCModel |
30 | 30 | from causalpy.utils import round_num |
31 | 31 |
|
@@ -321,13 +321,21 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame: |
321 | 321 | .mean("sample") |
322 | 322 | .values |
323 | 323 | ) |
324 | | - pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob).set_index(pre_data.index) |
325 | | - post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df(self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob).set_index(post_data.index) |
| 324 | + pre_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df( |
| 325 | + self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob |
| 326 | + ).set_index(pre_data.index) |
| 327 | + post_data[["pred_hdi_lower", "pred_hdi_upper"]] = get_hdi_to_df( |
| 328 | + self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob |
| 329 | + ).set_index(post_data.index) |
326 | 330 |
|
327 | 331 | pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values |
328 | 332 | post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values |
329 | | - pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob).set_index(pre_data.index) |
330 | | - post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df(self.post_impact, hdi_prob=hdi_prob).set_index(post_data.index) |
| 333 | + pre_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df( |
| 334 | + self.pre_impact, hdi_prob=hdi_prob |
| 335 | + ).set_index(pre_data.index) |
| 336 | + post_data[["impact_hdi_lower", "impact_hdi_upper"]] = get_hdi_to_df( |
| 337 | + self.post_impact, hdi_prob=hdi_prob |
| 338 | + ).set_index(post_data.index) |
331 | 339 |
|
332 | 340 | self.plot_data = pd.concat([pre_data, post_data]) |
333 | 341 |
|
|
0 commit comments