From b2b881e9fe32c7af7da638c6c514a913c66f498d Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Wed, 12 Nov 2025 16:00:04 -0500 Subject: [PATCH 01/29] completed phase 1 --- pymc_marketing/mmm/__init__.py | 2 + pymc_marketing/mmm/config.py | 66 +++++ tests/mmm/test_plot.py | 3 + tests/mmm/test_plot_backends.py | 443 ++++++++++++++++++++++++++++++++ 4 files changed, 514 insertions(+) create mode 100644 pymc_marketing/mmm/config.py create mode 100644 tests/mmm/test_plot_backends.py diff --git a/pymc_marketing/mmm/__init__.py b/pymc_marketing/mmm/__init__.py index 3d28ce23a..05b26292a 100644 --- a/pymc_marketing/mmm/__init__.py +++ b/pymc_marketing/mmm/__init__.py @@ -37,6 +37,7 @@ TanhSaturationBaselined, saturation_from_dict, ) +from pymc_marketing.mmm.config import mmm_config from pymc_marketing.mmm.fourier import MonthlyFourier, WeeklyFourier, YearlyFourier from pymc_marketing.mmm.hsgp import ( HSGP, @@ -107,6 +108,7 @@ "create_eta_prior", "create_m_and_L_recommendations", "mmm", + "mmm_config", "preprocessing", "preprocessing_method_X", "preprocessing_method_y", diff --git a/pymc_marketing/mmm/config.py b/pymc_marketing/mmm/config.py new file mode 100644 index 000000000..5b7051822 --- /dev/null +++ b/pymc_marketing/mmm/config.py @@ -0,0 +1,66 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Configuration management for MMM plotting.""" + +import warnings + +VALID_BACKENDS = {"matplotlib", "plotly", "bokeh"} + + +class MMMConfig(dict): + """ + Configuration dictionary for MMM plotting settings. + + Provides backend configuration with validation and reset functionality. + Modeled after ArviZ's rcParams pattern. + + Examples + -------- + >>> from pymc_marketing.mmm import mmm_config + >>> mmm_config["plot.backend"] = "plotly" + >>> mmm_config["plot.backend"] + 'plotly' + >>> mmm_config.reset() + >>> mmm_config["plot.backend"] + 'matplotlib' + """ + + _defaults = { + "plot.backend": "matplotlib", + "plot.show_warnings": True, + } + + def __init__(self): + super().__init__(self._defaults) + + def __setitem__(self, key, value): + """Set config value with validation for backend.""" + if key == "plot.backend": + if value not in VALID_BACKENDS: + warnings.warn( + f"Invalid backend '{value}'. Valid backends are: {VALID_BACKENDS}. " + f"Setting anyway, but plotting may fail.", + UserWarning, + stacklevel=2, + ) + super().__setitem__(key, value) + + def reset(self): + """Reset all configuration to default values.""" + self.clear() + self.update(self._defaults) + + +# Global config instance +mmm_config = MMMConfig() diff --git a/tests/mmm/test_plot.py b/tests/mmm/test_plot.py index ea41b44ce..88922348c 100644 --- a/tests/mmm/test_plot.py +++ b/tests/mmm/test_plot.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# NOTE: This file may be consolidated with test_plot_backends.py in the future +# once the backend migration is complete and stable. import warnings import arviz as az diff --git a/tests/mmm/test_plot_backends.py b/tests/mmm/test_plot_backends.py new file mode 100644 index 000000000..bce16e6de --- /dev/null +++ b/tests/mmm/test_plot_backends.py @@ -0,0 +1,443 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Backend-agnostic plotting tests for MMMPlotSuite. + +This test file validates the migration to ArviZ PlotCollection API for +multi-backend support (matplotlib, plotly, bokeh). + +NOTE: Once this migration is complete and stable, evaluate whether +tests/mmm/test_plot.py can be consolidated into this file to avoid duplication. +""" + +import arviz as az +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from pymc_marketing.mmm.plot import MMMPlotSuite + + +@pytest.fixture(scope="module") +def mock_idata_for_pp(): + """ + Create mock InferenceData with posterior_predictive for testing. + + Structure mirrors real MMM output with: + - posterior_predictive group with y variable + - proper dimensions: chain, draw, date + - realistic date range + """ + seed = sum(map(ord, "Backend test posterior_predictive")) + rng = np.random.default_rng(seed) + + dates = pd.date_range("2025-01-01", periods=52, freq="W-MON") + + # Create posterior_predictive data + posterior_predictive = xr.Dataset( + { + "y": xr.DataArray( + rng.normal(loc=100, scale=10, size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ) + } + ) + + # Also create a minimal posterior (required for some internal logic) + posterior = xr.Dataset( + { + "intercept": xr.DataArray( + rng.normal(size=(4, 100)), + dims=("chain", "draw"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + }, + ) + } + ) + + return az.InferenceData( + posterior=posterior, posterior_predictive=posterior_predictive + ) + + +@pytest.fixture(scope="module") +def mock_suite_with_pp(mock_idata_for_pp): + """ + Fixture providing MMMPlotSuite with posterior_predictive data. + + Used for testing posterior_predictive() method across backends. + """ + return MMMPlotSuite(idata=mock_idata_for_pp) + + +@pytest.fixture(scope="function") +def reset_mmm_config(): + """ + Fixture to reset mmm_config after each test. + + Ensures test isolation - one test's backend changes don't affect others. + """ + from pymc_marketing.mmm import mmm_config + + original = mmm_config["plot.backend"] + yield + mmm_config["plot.backend"] = original + + +# ============================================================================= +# Infrastructure Tests (Global Configuration & Return Types) +# ============================================================================= + + +def test_mmm_config_exists(): + """ + Test that the global mmm_config object exists and is accessible. + + This test verifies: + - mmm_config can be imported from pymc_marketing.mmm + - It has a "plot.backend" key + - Default backend is "matplotlib" + """ + from pymc_marketing.mmm import mmm_config + + assert "plot.backend" in mmm_config, "mmm_config should have 'plot.backend' key" + assert mmm_config["plot.backend"] == "matplotlib", ( + f"Default backend should be 'matplotlib', got {mmm_config['plot.backend']}" + ) + + +def test_mmm_config_backend_setting(): + """ + Test that mmm_config backend can be set and retrieved. + + This test verifies: + - Backend can be changed from default + - New value persists + - Can be reset to default + """ + from pymc_marketing.mmm import mmm_config + + # Store original + original = mmm_config["plot.backend"] + + try: + # Change backend + mmm_config["plot.backend"] = "plotly" + assert mmm_config["plot.backend"] == "plotly", ( + "Backend should change to 'plotly'" + ) + + # Reset + mmm_config.reset() + assert mmm_config["plot.backend"] == "matplotlib", ( + "reset() should restore default 'matplotlib' backend" + ) + finally: + # Cleanup + mmm_config["plot.backend"] = original + + +def test_mmm_config_invalid_backend_warning(): + """ + Test that setting an invalid backend name is handled gracefully. + + This test verifies: + - Invalid backend names are detected + - Either raises ValueError or emits UserWarning + - Helpful error message provided + """ + import warnings + + from pymc_marketing.mmm import mmm_config + + original = mmm_config["plot.backend"] + + try: + # Attempt to set invalid backend - should either raise or warn + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + mmm_config["plot.backend"] = "invalid_backend" + + # If no exception, should have warning + assert len(w) > 0, "Should emit warning for invalid backend" + assert "invalid" in str(w[0].message).lower(), ( + f"Warning should mention 'invalid', got: {w[0].message}" + ) + except ValueError as e: + # Acceptable alternative: raise ValueError + assert "backend" in str(e).lower(), f"Error should mention 'backend', got: {e}" + finally: + mmm_config["plot.backend"] = original + + +# ============================================================================= +# Backend Parameter Tests (posterior_predictive) +# ============================================================================= + + +def test_posterior_predictive_accepts_backend_parameter(mock_suite_with_pp): + """ + Test that posterior_predictive() accepts backend parameter. + + This test verifies: + - backend parameter is accepted + - No TypeError is raised + - Method completes successfully + """ + # Should not raise TypeError + result = mock_suite_with_pp.posterior_predictive(backend="matplotlib") + + assert result is not None, "posterior_predictive should return a result" + + +def test_posterior_predictive_accepts_return_as_pc_parameter(mock_suite_with_pp): + """ + Test that posterior_predictive() accepts return_as_pc parameter. + + This test verifies: + - return_as_pc parameter is accepted + - No TypeError is raised + """ + # Should not raise TypeError + result = mock_suite_with_pp.posterior_predictive(return_as_pc=False) + + assert result is not None, "posterior_predictive should return a result" + + +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_posterior_predictive_backend_overrides_global(mock_suite_with_pp, backend): + """ + Test that backend parameter overrides global mmm_config setting. + + This test verifies: + - Global config set to one backend + - Function called with different backend + - Function uses parameter, not global config + """ + from pymc_marketing.mmm import mmm_config + + original = mmm_config["plot.backend"] + + try: + # Set global to matplotlib + mmm_config["plot.backend"] = "matplotlib" + + # Call with different backend, request PlotCollection to check + pc = mock_suite_with_pp.posterior_predictive(backend=backend, return_as_pc=True) + + assert hasattr(pc, "backend"), "PlotCollection should have backend attribute" + assert pc.backend == backend, ( + f"PlotCollection backend should be '{backend}', got '{pc.backend}'" + ) + finally: + mmm_config["plot.backend"] = original + + +# ============================================================================= +# Return Type Tests (Backward Compatibility) +# ============================================================================= + + +def test_posterior_predictive_returns_tuple_by_default(mock_suite_with_pp): + """ + Test that posterior_predictive() returns tuple by default (backward compat). + + This test verifies: + - Default behavior (no return_as_pc parameter) returns tuple + - Tuple has two elements: (figure, axes) + - axes is a list of matplotlib Axes objects (1D list, not 2D array) + """ + result = mock_suite_with_pp.posterior_predictive() + + assert isinstance(result, tuple), ( + f"Default return should be tuple, got {type(result)}" + ) + assert len(result) == 2, ( + f"Tuple should have 2 elements (fig, axes), got {len(result)}" + ) + + fig, axes = result + + # For matplotlib backend (default), should be Figure and array + assert isinstance(fig, Figure), f"First element should be Figure, got {type(fig)}" + # Note: Current implementation returns NDArray[Axes], need to adapt test + assert axes is not None, "Second element should not be None for matplotlib backend" + + +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_posterior_predictive_returns_plotcollection_when_requested( + mock_suite_with_pp, backend +): + """ + Test that posterior_predictive() returns PlotCollection when return_as_pc=True. + + This test verifies: + - return_as_pc=True returns PlotCollection object + - PlotCollection has correct backend attribute + """ + from arviz_plots import PlotCollection + + result = mock_suite_with_pp.posterior_predictive(backend=backend, return_as_pc=True) + + assert isinstance(result, PlotCollection), ( + f"Should return PlotCollection, got {type(result)}" + ) + assert hasattr(result, "backend"), "PlotCollection should have backend attribute" + assert result.backend == backend, ( + f"Backend should be '{backend}', got '{result.backend}'" + ) + + +def test_posterior_predictive_tuple_has_correct_axes_for_matplotlib(mock_suite_with_pp): + """ + Test that matplotlib backend returns proper axes list in tuple. + + This test verifies: + - When return_as_pc=False and backend="matplotlib" + - Second tuple element is list/array of matplotlib Axes + - All elements in list are Axes instances + """ + _fig, axes = mock_suite_with_pp.posterior_predictive( + backend="matplotlib", return_as_pc=False + ) + + assert axes is not None, "Axes should not be None for matplotlib backend" + # Handle both list and NDArray cases + axes_flat = axes if isinstance(axes, list) else axes.flat + assert all(isinstance(ax, Axes) for ax in axes_flat), ( + "All elements should be matplotlib Axes instances" + ) + + +@pytest.mark.parametrize("backend", ["plotly", "bokeh"]) +def test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib( + mock_suite_with_pp, backend +): + """ + Test that non-matplotlib backends return None for axes in tuple. + + This test verifies: + - When return_as_pc=False and backend in ["plotly", "bokeh"] + - Second tuple element is None (no axes concept) + - First element is backend-specific figure object + """ + fig, axes = mock_suite_with_pp.posterior_predictive( + backend=backend, return_as_pc=False + ) + + assert axes is None, f"Axes should be None for {backend} backend, got {type(axes)}" + assert fig is not None, f"Figure should exist for {backend} backend" + + +# ============================================================================= +# Visual Output Validation Tests +# ============================================================================= + + +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_posterior_predictive_plotcollection_has_viz_attribute( + mock_suite_with_pp, backend +): + """ + Test that PlotCollection has viz attribute with figure data. + + This test verifies: + - PlotCollection has viz attribute + - viz has figure attribute + - Figure can be extracted + """ + + pc = mock_suite_with_pp.posterior_predictive(backend=backend, return_as_pc=True) + + assert hasattr(pc, "viz"), "PlotCollection should have 'viz' attribute" + assert hasattr(pc.viz, "figure"), ( + "PlotCollection.viz should have 'figure' attribute" + ) + + # Should be able to extract figure + fig = pc.viz.figure.data.item() + assert fig is not None, "Should be able to extract figure from PlotCollection" + + +def test_posterior_predictive_matplotlib_has_lines(mock_suite_with_pp): + """ + Test that matplotlib output contains actual plotted lines. + + This test verifies: + - Axes contain Line2D objects (plotted data) + - Number of lines matches expected variables + - Visual output actually created, not just empty axes + """ + from matplotlib.lines import Line2D + + _fig, axes = mock_suite_with_pp.posterior_predictive( + backend="matplotlib", return_as_pc=False + ) + + # Get first axis (should have plots) + ax = axes.flat[0] + + # Should have lines (median plots) + lines = [child for child in ax.get_children() if isinstance(child, Line2D)] + assert len(lines) > 0, ( + f"Axes should contain Line2D objects (plots), found {len(lines)}" + ) + + +def test_posterior_predictive_plotly_has_traces(mock_suite_with_pp): + """ + Test that plotly output contains actual traces. + + This test verifies: + - Plotly figure has 'data' attribute with traces + - Number of traces > 0 (something was plotted) + - Visual output actually created + """ + fig, _ = mock_suite_with_pp.posterior_predictive( + backend="plotly", return_as_pc=False + ) + + # Plotly figures have .data attribute with traces + assert hasattr(fig, "data"), "Plotly figure should have 'data' attribute" + assert len(fig.data) > 0, f"Plotly figure should have traces, found {len(fig.data)}" + + +def test_posterior_predictive_bokeh_has_renderers(mock_suite_with_pp): + """ + Test that bokeh output contains actual renderers (plot elements). + + This test verifies: + - Bokeh figure has renderers + - Number of renderers > 0 (something was plotted) + - Visual output actually created + """ + fig, _ = mock_suite_with_pp.posterior_predictive( + backend="bokeh", return_as_pc=False + ) + + # Bokeh figures have .renderers attribute + assert hasattr(fig, "renderers"), "Bokeh figure should have 'renderers' attribute" + assert len(fig.renderers) > 0, ( + f"Bokeh figure should have renderers, found {len(fig.renderers)}" + ) From 6dbd6ce4113f9f9f2c1569dab44920fc0ec807ed Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Wed, 12 Nov 2025 16:14:06 -0500 Subject: [PATCH 02/29] adds plan --- .../mmmplotsuite-backend-migration-tdd.md | 1795 +++++++++++++++++ 1 file changed, 1795 insertions(+) create mode 100644 thoughts/shared/plans/mmmplotsuite-backend-migration-tdd.md diff --git a/thoughts/shared/plans/mmmplotsuite-backend-migration-tdd.md b/thoughts/shared/plans/mmmplotsuite-backend-migration-tdd.md new file mode 100644 index 000000000..b07a2ad9d --- /dev/null +++ b/thoughts/shared/plans/mmmplotsuite-backend-migration-tdd.md @@ -0,0 +1,1795 @@ +# MMMPlotSuite Backend Migration - TDD Implementation Plan + +## Overview + +This plan implements backend-agnostic plotting for the MMMPlotSuite class using ArviZ's PlotCollection API, enabling support for matplotlib, plotly, and bokeh backends while maintaining full backward compatibility. We follow Test-Driven Development: write comprehensive tests first, verify they fail properly, then implement features by making those tests pass. + +## Current State Analysis + +### Existing Implementation +- **Location**: [pymc_marketing/mmm/plot.py:187-1924](pymc_marketing/mmm/plot.py#L187) +- **Class**: `MMMPlotSuite` with 10 public plotting methods +- **Current approach**: All methods directly use matplotlib APIs and return `(Figure, NDArray[Axes])` +- **Dependencies**: matplotlib, arviz (for HDI computation only) + +### Current Testing Landscape +- **Test framework**: pytest with parametrized tests +- **Test file**: [tests/mmm/test_plot.py](tests/mmm/test_plot.py) - 1053 lines, comprehensive fixture-based testing +- **Mock data patterns**: xarray-based InferenceData fixtures with realistic structure +- **Test conventions**: + - Module-scoped fixtures for expensive setup + - Type assertions only (no visual output validation) + - `plt.close()` after each test + - Parametrized tests for multiple configurations + +### Key Discoveries +1. **No PlotCollection usage**: ArviZ PlotCollection is not used anywhere in production code +2. **Testing patterns exist**: Parametrized tests, deprecation warnings, backward compatibility tests all have examples +3. **Mock data is realistic**: Fixtures create proper InferenceData structure with posterior, constant_data groups +4. **Helper functions available**: `_init_subplots()`, `_add_median_and_hdi()` need backend abstraction + +## Desired End State + +After implementation, the MMMPlotSuite should: + +1. ✅ Support matplotlib, plotly, and bokeh backends via ArviZ PlotCollection +2. ✅ Maintain 100% backward compatibility (existing code works unchanged) +3. ✅ Support global backend configuration via `mmm_config["plot.backend"]` +4. ✅ Support per-function backend parameter that overrides global config +5. ✅ Return PlotCollection when `return_as_pc=True`, tuple when `False` (default) +6. ✅ Handle matplotlib-specific features (twinx) with clear fallback warnings +7. ✅ Deprecate `rc_params` in favor of `backend_config` with warnings +8. ✅ Pass comprehensive test suite across all three backends + +## What We're NOT Testing/Implementing + +- Performance comparisons between backends (explicitly out of scope) +- Component plot methods outside MMMPlotSuite (requirement #9) +- Saving plots to files (not in current test suite) +- Interactive features specific to plotly/bokeh (basic rendering only) +- New plotting methods (only migrating existing 10 methods) + +## TDD Approach + +### Test Design Philosophy +1. **Depth over breadth**: Thoroughly test first 2-3 methods before moving to others +2. **Verify visual output**: Use PlotCollection's backend-specific output validation, not just type checking +3. **Fail diagnostically**: Tests should fail with clear messages pointing to missing functionality +4. **Test data isolation**: Use module-scoped fixtures, mock InferenceData structures + +### Implementation Priority +**Phase 1**: Infrastructure + `posterior_predictive()` (simplest method) +**Phase 2**: `contributions_over_time()` (similar to Phase 1) +**Phase 3**: `saturation_curves()` (rc_params deprecation, external functions) +**Phase 4**: `budget_allocation()` (twinx fallback behavior) + +--- + +## Phase 1: Test Design & Implementation + +### Overview +Write comprehensive, informative tests that define the feature completely. These tests should fail in expected, diagnostic ways. We focus deeply on infrastructure and the simplest method (`posterior_predictive()`) first. + +### Test Categories + +#### 1. Infrastructure Tests (Global Configuration & Return Types) +**Test File**: `tests/mmm/test_plot_backends.py` (NEW) +**Purpose**: Validate backend configuration system and return type switching + +**Test Cases to Write:** + +##### Test: `test_mmm_config_exists` +**Purpose**: Verify the global configuration object is accessible +**Test Data**: None needed +**Expected Behavior**: Can import and access `mmm_config` from `pymc_marketing.mmm` + +```python +def test_mmm_config_exists(): + """ + Test that the global mmm_config object exists and is accessible. + + This test verifies: + - mmm_config can be imported from pymc_marketing.mmm + - It has a "plot.backend" key + - Default backend is "matplotlib" + """ + from pymc_marketing.mmm import mmm_config + + assert "plot.backend" in mmm_config, \ + "mmm_config should have 'plot.backend' key" + assert mmm_config["plot.backend"] == "matplotlib", \ + f"Default backend should be 'matplotlib', got {mmm_config['plot.backend']}" +``` + +**Expected Failure Mode**: +- Error type: `ImportError` or `AttributeError` +- Expected message: `cannot import name 'mmm_config' from 'pymc_marketing.mmm'` + +##### Test: `test_mmm_config_backend_setting` +**Purpose**: Verify global backend can be changed and persists +**Test Data**: None needed +**Expected Behavior**: Setting backend value works and can be read back + +```python +def test_mmm_config_backend_setting(): + """ + Test that mmm_config backend can be set and retrieved. + + This test verifies: + - Backend can be changed from default + - New value persists + - Can be reset to default + """ + from pymc_marketing.mmm import mmm_config + + # Store original + original = mmm_config["plot.backend"] + + try: + # Change backend + mmm_config["plot.backend"] = "plotly" + assert mmm_config["plot.backend"] == "plotly", \ + "Backend should change to 'plotly'" + + # Reset + mmm_config.reset() + assert mmm_config["plot.backend"] == "matplotlib", \ + "reset() should restore default 'matplotlib' backend" + finally: + # Cleanup + mmm_config["plot.backend"] = original +``` + +**Expected Failure Mode**: +- Error type: `AttributeError` on `mmm_config.reset()` +- Expected message: `'dict' object has no attribute 'reset'` (if mmm_config is plain dict) + +##### Test: `test_mmm_config_invalid_backend_warning` +**Purpose**: Verify setting invalid backend emits a warning or raises error +**Test Data**: Invalid backend name "invalid_backend" +**Expected Behavior**: Validation prevents or warns about invalid backend + +```python +def test_mmm_config_invalid_backend_warning(): + """ + Test that setting an invalid backend name is handled gracefully. + + This test verifies: + - Invalid backend names are detected + - Either raises ValueError or emits UserWarning + - Helpful error message provided + """ + from pymc_marketing.mmm import mmm_config + import warnings + + original = mmm_config["plot.backend"] + + try: + # Attempt to set invalid backend - should either raise or warn + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + mmm_config["plot.backend"] = "invalid_backend" + + # If no exception, should have warning + assert len(w) > 0, \ + "Should emit warning for invalid backend" + assert "invalid" in str(w[0].message).lower(), \ + f"Warning should mention 'invalid', got: {w[0].message}" + except ValueError as e: + # Acceptable alternative: raise ValueError + assert "backend" in str(e).lower(), \ + f"Error should mention 'backend', got: {e}" + finally: + mmm_config["plot.backend"] = original +``` + +**Expected Failure Mode**: +- Error type: `AssertionError` +- Expected message: "Should emit warning for invalid backend" (no validation present) + +#### 2. Backend Parameter Tests (posterior_predictive) +**Test File**: `tests/mmm/test_plot_backends.py` +**Purpose**: Validate `backend` parameter is accepted and overrides global config + +**Test Cases to Write:** + +##### Test: `test_posterior_predictive_accepts_backend_parameter` +**Purpose**: Verify method accepts new `backend` parameter without error +**Test Data**: `mock_suite` fixture with posterior_predictive data +**Expected Behavior**: Method accepts backend="matplotlib" without TypeError + +```python +def test_posterior_predictive_accepts_backend_parameter(mock_suite_with_pp): + """ + Test that posterior_predictive() accepts backend parameter. + + This test verifies: + - backend parameter is accepted + - No TypeError is raised + - Method completes successfully + """ + # Should not raise TypeError + result = mock_suite_with_pp.posterior_predictive(backend="matplotlib") + + assert result is not None, \ + "posterior_predictive should return a result" +``` + +**Expected Failure Mode**: +- Error type: `TypeError` +- Expected message: `posterior_predictive() got an unexpected keyword argument 'backend'` + +##### Test: `test_posterior_predictive_accepts_return_as_pc_parameter` +**Purpose**: Verify method accepts new `return_as_pc` parameter without error +**Test Data**: `mock_suite_with_pp` fixture +**Expected Behavior**: Method accepts return_as_pc=False without TypeError + +```python +def test_posterior_predictive_accepts_return_as_pc_parameter(mock_suite_with_pp): + """ + Test that posterior_predictive() accepts return_as_pc parameter. + + This test verifies: + - return_as_pc parameter is accepted + - No TypeError is raised + """ + # Should not raise TypeError + result = mock_suite_with_pp.posterior_predictive(return_as_pc=False) + + assert result is not None, \ + "posterior_predictive should return a result" +``` + +**Expected Failure Mode**: +- Error type: `TypeError` +- Expected message: `posterior_predictive() got an unexpected keyword argument 'return_as_pc'` + +##### Test: `test_posterior_predictive_backend_overrides_global` +**Purpose**: Verify function parameter overrides global config +**Test Data**: `mock_suite_with_pp` fixture +**Expected Behavior**: backend="plotly" overrides global matplotlib setting + +```python +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_posterior_predictive_backend_overrides_global(mock_suite_with_pp, backend): + """ + Test that backend parameter overrides global mmm_config setting. + + This test verifies: + - Global config set to one backend + - Function called with different backend + - Function uses parameter, not global config + """ + from pymc_marketing.mmm import mmm_config + + original = mmm_config["plot.backend"] + + try: + # Set global to matplotlib + mmm_config["plot.backend"] = "matplotlib" + + # Call with different backend, request PlotCollection to check + pc = mock_suite_with_pp.posterior_predictive( + backend=backend, + return_as_pc=True + ) + + assert hasattr(pc, 'backend'), \ + "PlotCollection should have backend attribute" + assert pc.backend == backend, \ + f"PlotCollection backend should be '{backend}', got '{pc.backend}'" + finally: + mmm_config["plot.backend"] = original +``` + +**Expected Failure Mode**: +- Error type: `AttributeError` or `AssertionError` +- Expected message: `'tuple' object has no attribute 'backend'` (returns tuple instead of PlotCollection) + +#### 3. Return Type Tests (Backward Compatibility) +**Test File**: `tests/mmm/test_plot_backends.py` +**Purpose**: Verify return types match expectations based on `return_as_pc` parameter + +**Test Cases to Write:** + +##### Test: `test_posterior_predictive_returns_tuple_by_default` +**Purpose**: Verify backward compatibility - default returns tuple +**Test Data**: `mock_suite_with_pp` fixture +**Expected Behavior**: Returns `(Figure, List[Axes])` tuple by default + +```python +def test_posterior_predictive_returns_tuple_by_default(mock_suite_with_pp): + """ + Test that posterior_predictive() returns tuple by default (backward compat). + + This test verifies: + - Default behavior (no return_as_pc parameter) returns tuple + - Tuple has two elements: (figure, axes) + - axes is a list of matplotlib Axes objects (1D list, not 2D array) + """ + result = mock_suite_with_pp.posterior_predictive() + + assert isinstance(result, tuple), \ + f"Default return should be tuple, got {type(result)}" + assert len(result) == 2, \ + f"Tuple should have 2 elements (fig, axes), got {len(result)}" + + fig, axes = result + + # For matplotlib backend (default), should be Figure and list + from matplotlib.figure import Figure + from matplotlib.axes import Axes + assert isinstance(fig, Figure), \ + f"First element should be Figure, got {type(fig)}" + assert isinstance(axes, list), \ + f"Second element should be list, got {type(axes)}" + assert all(isinstance(ax, Axes) for ax in axes), \ + "All list elements should be matplotlib Axes instances" +``` + +**Expected Failure Mode**: +- Error type: `AssertionError` or `AttributeError` +- Expected message: `Default return should be tuple, got ` (if returning PC) + +##### Test: `test_posterior_predictive_returns_plotcollection_when_requested` +**Purpose**: Verify new behavior - returns PlotCollection when return_as_pc=True +**Test Data**: `mock_suite_with_pp` fixture +**Expected Behavior**: Returns PlotCollection object when return_as_pc=True + +```python +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_posterior_predictive_returns_plotcollection_when_requested( + mock_suite_with_pp, backend +): + """ + Test that posterior_predictive() returns PlotCollection when return_as_pc=True. + + This test verifies: + - return_as_pc=True returns PlotCollection object + - PlotCollection has correct backend attribute + """ + from arviz_plots import PlotCollection + + result = mock_suite_with_pp.posterior_predictive( + backend=backend, + return_as_pc=True + ) + + assert isinstance(result, PlotCollection), \ + f"Should return PlotCollection, got {type(result)}" + assert hasattr(result, 'backend'), \ + "PlotCollection should have backend attribute" + assert result.backend == backend, \ + f"Backend should be '{backend}', got '{result.backend}'" +``` + +**Expected Failure Mode**: +- Error type: `AssertionError` +- Expected message: `Should return PlotCollection, got ` (still returns tuple) + +##### Test: `test_posterior_predictive_tuple_has_correct_axes_for_matplotlib` +**Purpose**: Verify matplotlib backend returns list of Axes in tuple +**Test Data**: `mock_suite_with_pp` fixture +**Expected Behavior**: Tuple's second element is list of Axes objects + +```python +def test_posterior_predictive_tuple_has_correct_axes_for_matplotlib(mock_suite_with_pp): + """ + Test that matplotlib backend returns proper axes list in tuple. + + This test verifies: + - When return_as_pc=False and backend="matplotlib" + - Second tuple element is list of matplotlib Axes + - All elements in list are Axes instances + """ + from matplotlib.axes import Axes + + fig, axes = mock_suite_with_pp.posterior_predictive( + backend="matplotlib", + return_as_pc=False + ) + + assert isinstance(axes, list), \ + f"Axes should be list for matplotlib, got {type(axes)}" + assert all(isinstance(ax, Axes) for ax in axes), \ + "All list elements should be matplotlib Axes instances" +``` + +**Expected Failure Mode**: +- Error type: `AssertionError` +- Expected message: `Axes should be list for matplotlib, got ` (if not extracting axes) + +##### Test: `test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib` +**Purpose**: Verify non-matplotlib backends return None for axes in tuple +**Test Data**: `mock_suite_with_pp` fixture +**Expected Behavior**: Tuple's second element is None for plotly/bokeh + +```python +@pytest.mark.parametrize("backend", ["plotly", "bokeh"]) +def test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib( + mock_suite_with_pp, backend +): + """ + Test that non-matplotlib backends return None for axes in tuple. + + This test verifies: + - When return_as_pc=False and backend in ["plotly", "bokeh"] + - Second tuple element is None (no axes concept) + - First element is backend-specific figure object + """ + fig, axes = mock_suite_with_pp.posterior_predictive( + backend=backend, + return_as_pc=False + ) + + assert axes is None, \ + f"Axes should be None for {backend} backend, got {type(axes)}" + assert fig is not None, \ + f"Figure should exist for {backend} backend" +``` + +**Expected Failure Mode**: +- Error type: `AssertionError` +- Expected message: `Axes should be None for plotly backend, got ` (always matplotlib) + +#### 4. Visual Output Validation Tests +**Test File**: `tests/mmm/test_plot_backends.py` +**Purpose**: Verify that plots actually render and contain expected elements + +**Test Cases to Write:** + +##### Test: `test_posterior_predictive_plotcollection_has_viz_attribute` +**Purpose**: Verify PlotCollection has visualization data we can inspect +**Test Data**: `mock_suite_with_pp` fixture +**Expected Behavior**: PlotCollection has `viz` attribute with figure data + +```python +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_posterior_predictive_plotcollection_has_viz_attribute( + mock_suite_with_pp, backend +): + """ + Test that PlotCollection has viz attribute with figure data. + + This test verifies: + - PlotCollection has viz attribute + - viz has figure attribute + - Figure can be extracted + """ + from arviz_plots import PlotCollection + + pc = mock_suite_with_pp.posterior_predictive( + backend=backend, + return_as_pc=True + ) + + assert hasattr(pc, 'viz'), \ + "PlotCollection should have 'viz' attribute" + assert hasattr(pc.viz, 'figure'), \ + "PlotCollection.viz should have 'figure' attribute" + + # Should be able to extract figure + fig = pc.viz.figure.data.item() + assert fig is not None, \ + "Should be able to extract figure from PlotCollection" +``` + +**Expected Failure Mode**: +- Error type: `AttributeError` +- Expected message: `PlotCollection should have 'viz' attribute` (if PC not properly constructed) + +##### Test: `test_posterior_predictive_matplotlib_has_lines` +**Purpose**: Verify matplotlib output contains actual plot elements +**Test Data**: `mock_suite_with_pp` fixture with known variables +**Expected Behavior**: Axes contain Line2D objects (the actual plotted data) + +```python +def test_posterior_predictive_matplotlib_has_lines(mock_suite_with_pp): + """ + Test that matplotlib output contains actual plotted lines. + + This test verifies: + - Axes contain Line2D objects (plotted data) + - Number of lines matches expected variables + - Visual output actually created, not just empty axes + """ + from matplotlib.lines import Line2D + + fig, axes = mock_suite_with_pp.posterior_predictive( + backend="matplotlib", + return_as_pc=False + ) + + # Get first axis (should have plots) + ax = axes.flat[0] + + # Should have lines (median plots) + lines = [child for child in ax.get_children() if isinstance(child, Line2D)] + assert len(lines) > 0, \ + f"Axes should contain Line2D objects (plots), found {len(lines)}" +``` + +**Expected Failure Mode**: +- Error type: `AssertionError` +- Expected message: `Axes should contain Line2D objects (plots), found 0` (empty plot) + +##### Test: `test_posterior_predictive_plotly_has_traces` +**Purpose**: Verify plotly output contains traces (plotly's plot elements) +**Test Data**: `mock_suite_with_pp` fixture +**Expected Behavior**: Plotly figure has traces in data attribute + +```python +def test_posterior_predictive_plotly_has_traces(mock_suite_with_pp): + """ + Test that plotly output contains actual traces. + + This test verifies: + - Plotly figure has 'data' attribute with traces + - Number of traces > 0 (something was plotted) + - Visual output actually created + """ + fig, _ = mock_suite_with_pp.posterior_predictive( + backend="plotly", + return_as_pc=False + ) + + # Plotly figures have .data attribute with traces + assert hasattr(fig, 'data'), \ + "Plotly figure should have 'data' attribute" + assert len(fig.data) > 0, \ + f"Plotly figure should have traces, found {len(fig.data)}" +``` + +**Expected Failure Mode**: +- Error type: `AttributeError` or `AssertionError` +- Expected message: `Plotly figure should have 'data' attribute` (matplotlib Figure returned instead) + +##### Test: `test_posterior_predictive_bokeh_has_renderers` +**Purpose**: Verify bokeh output contains renderers (bokeh's plot elements) +**Test Data**: `mock_suite_with_pp` fixture +**Expected Behavior**: Bokeh figure has renderers (glyphs) + +```python +def test_posterior_predictive_bokeh_has_renderers(mock_suite_with_pp): + """ + Test that bokeh output contains actual renderers (plot elements). + + This test verifies: + - Bokeh figure has renderers + - Number of renderers > 0 (something was plotted) + - Visual output actually created + """ + fig, _ = mock_suite_with_pp.posterior_predictive( + backend="bokeh", + return_as_pc=False + ) + + # Bokeh figures have .renderers attribute + assert hasattr(fig, 'renderers'), \ + "Bokeh figure should have 'renderers' attribute" + assert len(fig.renderers) > 0, \ + f"Bokeh figure should have renderers, found {len(fig.renderers)}" +``` + +**Expected Failure Mode**: +- Error type: `AttributeError` or `AssertionError` +- Expected message: `Bokeh figure should have 'renderers' attribute` (matplotlib Figure returned instead) + +#### 5. Fixture Setup +**Test File**: `tests/mmm/test_plot_backends.py` +**Purpose**: Create reusable fixtures for backend testing + +```python +""" +Backend-agnostic plotting tests for MMMPlotSuite. + +This test file validates the migration to ArviZ PlotCollection API for +multi-backend support (matplotlib, plotly, bokeh). + +NOTE: Once this migration is complete and stable, evaluate whether +tests/mmm/test_plot.py can be consolidated into this file to avoid duplication. +""" + +import numpy as np +import pytest +import xarray as xr +import arviz as az +import pandas as pd +from matplotlib.figure import Figure +from matplotlib.axes import Axes + +from pymc_marketing.mmm.plot import MMMPlotSuite + + +@pytest.fixture(scope="module") +def mock_idata_for_pp(): + """ + Create mock InferenceData with posterior_predictive for testing. + + Structure mirrors real MMM output with: + - posterior_predictive group with y variable + - proper dimensions: chain, draw, date + - realistic date range + """ + seed = sum(map(ord, "Backend test posterior_predictive")) + rng = np.random.default_rng(seed) + + dates = pd.date_range("2025-01-01", periods=52, freq="W-MON") + + # Create posterior_predictive data + posterior_predictive = xr.Dataset({ + "y": xr.DataArray( + rng.normal(loc=100, scale=10, size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ) + }) + + # Also create a minimal posterior (required for some internal logic) + posterior = xr.Dataset({ + "intercept": xr.DataArray( + rng.normal(size=(4, 100)), + dims=("chain", "draw"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + }, + ) + }) + + return az.InferenceData( + posterior=posterior, + posterior_predictive=posterior_predictive + ) + + +@pytest.fixture(scope="module") +def mock_suite_with_pp(mock_idata_for_pp): + """ + Fixture providing MMMPlotSuite with posterior_predictive data. + + Used for testing posterior_predictive() method across backends. + """ + return MMMPlotSuite(idata=mock_idata_for_pp) + + +@pytest.fixture(scope="function") +def reset_mmm_config(): + """ + Fixture to reset mmm_config after each test. + + Ensures test isolation - one test's backend changes don't affect others. + """ + from pymc_marketing.mmm import mmm_config + + original = mmm_config["plot.backend"] + yield + mmm_config["plot.backend"] = original +``` + +### Implementation Steps + +1. **Create test file**: `tests/mmm/test_plot_backends.py` +2. **Add note to existing test file**: Edit `tests/mmm/test_plot.py` line 1 to add: + ```python + # NOTE: This file may be consolidated with test_plot_backends.py in the future + # once the backend migration is complete and stable. + ``` + +3. **Implement fixtures** (see Fixture Setup section above) + +4. **Implement all test cases** in the order listed: + - Infrastructure tests (global config) + - Backend parameter tests + - Return type tests + - Visual output validation tests + +5. **Run tests to verify failures**: `pytest tests/mmm/test_plot_backends.py -v` + +### Success Criteria + +#### Automated Verification: +- [x] Test file created: `tests/mmm/test_plot_backends.py` +- [x] All tests discovered: `pytest tests/mmm/test_plot_backends.py --collect-only` +- [x] Tests fail (not pass): `pytest tests/mmm/test_plot_backends.py --tb=short` +- [x] No import/syntax errors: `pytest tests/mmm/test_plot_backends.py --tb=line` +- [x] Linting passes: `make lint` +- [x] Test code follows conventions: Style matches `test_plot.py` patterns + +#### Manual Verification: +- [ ] Each test has clear docstring explaining what it validates +- [ ] Test names clearly describe what they test (e.g., `test_X_does_Y`) +- [ ] Assertion messages are diagnostic and helpful +- [ ] Fixtures are well-documented with realistic data +- [ ] Test file header includes note about consolidation + +--- + +## Phase 2: Test Failure Verification + +### Overview +Run the tests and verify they fail in the expected, diagnostic ways. This ensures our tests are actually testing something and will catch regressions. + +### Verification Steps + +1. **Run the test suite**: + ```bash + pytest tests/mmm/test_plot_backends.py -v + ``` + +2. **Verify all tests are discovered**: + ```bash + pytest tests/mmm/test_plot_backends.py --collect-only + ``` + Expected: All tests listed, no collection errors + +3. **Check failure modes**: + ```bash + pytest tests/mmm/test_plot_backends.py -v --tb=short + ``` + Review each failure to ensure it matches expected failure mode + +### Expected Failures + +**Infrastructure Tests:** +- `test_mmm_config_exists`: `ImportError: cannot import name 'mmm_config'` +- `test_mmm_config_backend_setting`: `ImportError: cannot import name 'mmm_config'` +- `test_mmm_config_invalid_backend_warning`: `ImportError: cannot import name 'mmm_config'` + +**Backend Parameter Tests:** +- `test_posterior_predictive_accepts_backend_parameter`: `TypeError: posterior_predictive() got an unexpected keyword argument 'backend'` +- `test_posterior_predictive_accepts_return_as_pc_parameter`: `TypeError: posterior_predictive() got an unexpected keyword argument 'return_as_pc'` +- `test_posterior_predictive_backend_overrides_global`: `ImportError: cannot import name 'mmm_config'` or `TypeError` (backend param) + +**Return Type Tests:** +- `test_posterior_predictive_returns_tuple_by_default`: Should PASS (existing behavior works) +- `test_posterior_predictive_returns_plotcollection_when_requested`: `TypeError: unexpected keyword argument 'return_as_pc'` +- `test_posterior_predictive_tuple_has_correct_axes_for_matplotlib`: Should PASS (existing behavior) +- `test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib`: `TypeError: unexpected keyword argument 'backend'` + +**Visual Output Tests:** +- `test_posterior_predictive_plotcollection_has_viz_attribute`: `TypeError: unexpected keyword argument 'return_as_pc'` +- `test_posterior_predictive_matplotlib_has_lines`: Should PASS (existing behavior works) +- `test_posterior_predictive_plotly_has_traces`: `TypeError: unexpected keyword argument 'backend'` +- `test_posterior_predictive_bokeh_has_renderers`: `TypeError: unexpected keyword argument 'backend'` + +### Success Criteria + +#### Automated Verification: +- [ ] All tests run (no collection errors): `pytest tests/mmm/test_plot_backends.py --collect-only` +- [ ] Expected number of failures: Count matches test cases written +- [ ] No unexpected errors: No `ImportError` on test fixtures, no syntax errors +- [ ] Existing tests still pass: `pytest tests/mmm/test_plot.py -k test_posterior_predictive` + +#### Manual Verification: +- [ ] Each test fails with expected error type (TypeError, ImportError, AssertionError as listed) +- [ ] Failure messages clearly indicate what's missing +- [ ] Failure messages would help during implementation (diagnostic) +- [ ] Stack traces point to relevant code locations (test assertions, not fixture setup) +- [ ] No cryptic or misleading error messages + +### Adjustment Phase + +If tests don't fail properly: + +**Problem**: Tests pass unexpectedly +- **Fix**: Review test assertions - they may be too lenient +- **Action**: Add stricter type checks, verify specific attributes + +**Problem**: Tests error instead of fail (e.g., ImportError on fixtures) +- **Fix**: Check fixture dependencies, ensure mock data doesn't rely on new code +- **Action**: Simplify fixtures to not use non-existent features + +**Problem**: Confusing error messages +- **Fix**: Improve assertion messages with context +- **Action**: Add `assert x, f"Expected Y, got {x}"` style messages + +**Problem**: Tests fail in wrong order (dependency issues) +- **Fix**: Ensure test isolation - no shared state between tests +- **Action**: Use `reset_mmm_config` fixture, don't modify shared fixtures + +**Checklist for Adjustment:** +- [ ] All infrastructure tests fail with ImportError or AttributeError +- [ ] All backend parameter tests fail with TypeError (unexpected keyword) +- [ ] Return type tests for new behavior fail with TypeError +- [ ] Return type tests for existing behavior PASS +- [ ] Visual output tests fail with TypeError (unexpected keyword) + +--- + +## Phase 3: Feature Implementation (Red → Green) + +### Overview +Implement the feature by making tests pass, one at a time. Work like debugging - let test failures guide what needs to be implemented next. + +### Implementation Strategy + +**Order of Implementation:** +1. Global config infrastructure (`mmm_config`) +2. Add `backend` and `return_as_pc` parameters to `posterior_predictive()` +3. Implement PlotCollection integration +4. Implement figure/axes extraction for tuple return +5. Verify visual output across backends + +### Implementation 1: Create Global Configuration + +**Target Tests**: +- `test_mmm_config_exists` +- `test_mmm_config_backend_setting` +- `test_mmm_config_invalid_backend_warning` + +**Current Failure**: `ImportError: cannot import name 'mmm_config' from 'pymc_marketing.mmm'` + +**Changes Required:** + +**File**: `pymc_marketing/mmm/config.py` (NEW) +**Purpose**: Global configuration management for MMM plotting + +```python +"""Configuration management for MMM plotting.""" + +VALID_BACKENDS = {"matplotlib", "plotly", "bokeh"} + + +class MMMConfig(dict): + """ + Configuration dictionary for MMM plotting settings. + + Provides backend configuration with validation and reset functionality. + Modeled after ArviZ's rcParams pattern. + + Examples + -------- + >>> from pymc_marketing.mmm import mmm_config + >>> mmm_config["plot.backend"] = "plotly" + >>> mmm_config["plot.backend"] + 'plotly' + >>> mmm_config.reset() + >>> mmm_config["plot.backend"] + 'matplotlib' + """ + + _defaults = { + "plot.backend": "matplotlib", + "plot.show_warnings": True, + } + + def __init__(self): + super().__init__(self._defaults) + + def __setitem__(self, key, value): + """Set config value with validation for backend.""" + if key == "plot.backend": + if value not in VALID_BACKENDS: + import warnings + warnings.warn( + f"Invalid backend '{value}'. Valid backends are: {VALID_BACKENDS}. " + f"Setting anyway, but plotting may fail.", + UserWarning, + stacklevel=2 + ) + super().__setitem__(key, value) + + def reset(self): + """Reset all configuration to default values.""" + self.clear() + self.update(self._defaults) + + +# Global config instance +mmm_config = MMMConfig() +``` + +**File**: `pymc_marketing/mmm/__init__.py` +**Changes**: Add mmm_config export + +```python +# Existing imports... + +from pymc_marketing.mmm.config import mmm_config + +__all__ = [ + # ... existing exports ... + "mmm_config", +] +``` + +**Debugging Approach:** +1. Create `config.py` with MMMConfig class +2. Run: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_exists -v` +3. If fails, check import path and __all__ export +4. Run: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_backend_setting -v` +5. If fails, check reset() implementation +6. Run: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_invalid_backend_warning -v` +7. If fails, verify warning is emitted in __setitem__ + +**Success Criteria:** + +##### Automated Verification: +- [x] Test passes: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_exists -v` +- [x] Test passes: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_backend_setting -v` +- [x] Test passes: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_invalid_backend_warning -v` +- [x] Can import: `python -c "from pymc_marketing.mmm import mmm_config; print(mmm_config['plot.backend'])"` +- [x] Linting passes: `make lint` +- [x] Type checking passes: `mypy pymc_marketing/mmm/config.py` (no new errors) + +##### Manual Verification: +- [ ] Code is clean and well-documented +- [ ] Follows project conventions (NumPy docstrings) +- [ ] No performance issues (dict operations are O(1)) +- [ ] Warning messages are clear and actionable + +### Implementation 2: Add Parameters to posterior_predictive() + +**Target Tests**: +- `test_posterior_predictive_accepts_backend_parameter` +- `test_posterior_predictive_accepts_return_as_pc_parameter` + +**Current Failure**: `TypeError: posterior_predictive() got an unexpected keyword argument 'backend'` + +**Changes Required:** + +**File**: `pymc_marketing/mmm/plot.py` +**Method**: `posterior_predictive()` (line 375) +**Changes**: Add backend and return_as_pc parameters + +```python +def posterior_predictive( + self, + var: list[str] | None = None, + idata: xr.Dataset | None = None, + hdi_prob: float = 0.85, + backend: str | None = None, + return_as_pc: bool = False, +) -> tuple[Figure, list[Axes] | None] | "PlotCollection": + """ + Plot posterior predictive distributions over time. + + Parameters + ---------- + var : list of str, optional + List of variable names to plot. If None, uses "y". + idata : xr.Dataset, optional + Dataset containing posterior predictive samples. + If None, uses self.idata.posterior_predictive. + hdi_prob : float, default 0.85 + Probability mass for HDI interval. + backend : str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config via mmm_config["plot.backend"]. + Default (via config) is "matplotlib". + return_as_pc : bool, default False + If True, returns PlotCollection object. + If False, returns tuple (figure, axes) for backward compatibility. + + Returns + ------- + PlotCollection or tuple + If return_as_pc=True, returns PlotCollection object. + If return_as_pc=False, returns (figure, axes) where: + - figure: backend-specific figure object (matplotlib.figure.Figure, + plotly.graph_objs.Figure, or bokeh.plotting.Figure) + - axes: list of matplotlib Axes if backend="matplotlib", else None + + Notes + ----- + When backend is not "matplotlib" and return_as_pc=False, the axes + element of the returned tuple will be None, as plotly and bokeh + do not have an equivalent axes list concept. + + Examples + -------- + >>> # Backward compatible usage (matplotlib) + >>> fig, axes = model.plot.posterior_predictive() + + >>> # Multi-backend with PlotCollection + >>> pc = model.plot.posterior_predictive(backend="plotly", return_as_pc=True) + >>> pc.show() + """ + from pymc_marketing.mmm.config import mmm_config + + # Resolve backend (parameter overrides global config) + backend = backend or mmm_config["plot.backend"] + + # Temporary: Keep existing matplotlib implementation + # This makes tests pass (accepts parameters) but doesn't use them yet + # We'll implement PlotCollection integration in next step + + # [Existing implementation continues unchanged for now...] + # Just pass through to existing code +``` + +**Debugging Approach:** +1. Add parameters to signature with defaults +2. Add backend resolution logic (import mmm_config, use parameter or config) +3. Run: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_accepts_backend_parameter -v` +4. Should PASS (accepts parameter even if not used yet) +5. Run: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_accepts_return_as_pc_parameter -v` +6. Should PASS (accepts parameter even if not used yet) +7. Update docstring with new parameters +8. Update type hints in return annotation + +**Success Criteria:** + +##### Automated Verification: +- [x] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_accepts_backend_parameter -v` +- [x] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_accepts_return_as_pc_parameter -v` +- [x] Existing tests still pass: `pytest tests/mmm/test_plot.py::test_posterior_predictive -v` +- [x] Linting passes: `make lint` +- [x] Type checking passes: `mypy pymc_marketing/mmm/plot.py` + +##### Manual Verification: +- [x] Docstring updated with new parameters (NumPy style) +- [x] Default values maintain backward compatibility +- [x] Parameter order is logical (existing params first, new params last) +- [x] Type hints are accurate (use string quotes for forward ref to PlotCollection) + +### Implementation 3: Integrate PlotCollection and Return Type Logic + +**Target Tests**: +- `test_posterior_predictive_returns_tuple_by_default` +- `test_posterior_predictive_returns_plotcollection_when_requested` +- `test_posterior_predictive_backend_overrides_global` +- `test_posterior_predictive_tuple_has_correct_axes_for_matplotlib` +- `test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib` + +**Current Failure**: Tests pass/fail mix - need to integrate PlotCollection + +**Changes Required:** + +This is the most complex implementation step. We need to: +1. Create PlotCollection-based plotting logic +2. Implement figure/axes extraction for tuple return +3. Handle backend-specific differences + +**File**: `pymc_marketing/mmm/plot.py` +**Method**: `posterior_predictive()` (line 375) +**Changes**: Complete PlotCollection integration + +```python +def posterior_predictive( + self, + var: list[str] | None = None, + idata: xr.Dataset | None = None, + hdi_prob: float = 0.85, + backend: str | None = None, + return_as_pc: bool = False, +) -> tuple[Figure, list[Axes] | None] | "PlotCollection": + """[Docstring from previous step]""" + from pymc_marketing.mmm.config import mmm_config + from arviz_plots import PlotCollection, visuals + + # Resolve backend (parameter overrides global config) + backend = backend or mmm_config["plot.backend"] + + # Get data + var = var or ["y"] + pp_data = self._get_posterior_predictive_data(idata=idata) + + # Get dimension combinations for subplots + ignored_dims = {"chain", "draw", "date", "sample"} + available_dims = [d for d in pp_data[var[0]].dims if d not in ignored_dims] + additional_dims = [d for d in available_dims if d not in var] + dim_combinations = self._get_additional_dim_combinations( + pp_data[var[0]], additional_dims + ) + + n_subplots = len(dim_combinations) + + # Create PlotCollection with grid layout + # We'll build a dataset for PlotCollection + plot_data = {} + for v in var: + data = pp_data[v] + # Stack chain and draw into sample dimension + if "chain" in data.dims and "draw" in data.dims: + data = data.stack(sample=("chain", "draw")) + plot_data[v] = data + + plot_dataset = xr.Dataset(plot_data) + + # Create figure with appropriate layout + # PlotCollection.grid creates a grid of subplots + pc = PlotCollection.grid( + plot_dataset, + backend=backend, + plots_per_row=1, # One column layout like original + figsize=(10, 4 * n_subplots), + ) + + # For each subplot, add line plot and HDI + for row_idx, combo in enumerate(dim_combinations): + indexers = dict(zip(additional_dims, combo, strict=False)) if additional_dims else {} + + # Select subplot + if n_subplots > 1: + # Multi-panel: select by row index + pc_subplot = pc.sel(row=row_idx) + else: + # Single panel: use full pc + pc_subplot = pc + + for v in var: + data = plot_data[v].sel(**indexers) if indexers else plot_data[v] + + # Compute median and HDI + median = data.median(dim="sample") + hdi = az.hdi(data, hdi_prob=hdi_prob, input_core_dims=[["sample"]]) + + # Add median line + pc_subplot.map( + visuals.line, + data=median.rename("median"), + color=f"C{var.index(v)}", + label=v, + ) + + # Add HDI band + pc_subplot.map( + visuals.fill_between, + data1=hdi[v].sel(hdi="lower"), + data2=hdi[v].sel(hdi="higher"), + color=f"C{var.index(v)}", + alpha=0.2, + ) + + # Add labels + title = self._build_subplot_title(additional_dims, combo, "Posterior Predictive") + pc_subplot.map(visuals.labelled, title=title, xlabel="Date", ylabel="Posterior Predictive") + pc_subplot.map(visuals.legend) + + # Return based on return_as_pc flag + if return_as_pc: + return pc + else: + # Extract figure from PlotCollection + fig = pc.viz.figure.data.item() + + # Extract axes (only for matplotlib) + if backend == "matplotlib": + axes = list(fig.get_axes()) # Return as simple list + else: + axes = None + + return fig, axes +``` + +**Note**: The above is pseudocode showing the structure. Actual implementation will need to: +- Check PlotCollection API documentation for exact method signatures +- Handle dimension combinations correctly +- Ensure HDI computation works with PlotCollection +- Test iteratively with debugger + +**Debugging Approach:** +1. Start with simplest case: single variable, no extra dimensions +2. Run: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_returns_plotcollection_when_requested[matplotlib] -v` +3. Debug PlotCollection creation - check what data format it expects +4. Debug median/HDI computation - verify dimensions match +5. Debug PlotCollection.map() calls - check visual function signatures +6. Once matplotlib works, test plotly: `pytest ... [plotly]` +7. Debug backend-specific issues (figure extraction, etc.) +8. Test tuple return: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_returns_tuple_by_default -v` +9. Debug figure/axes extraction logic + +**Alternative Simpler Approach** (if PlotCollection API is challenging): + +Keep matplotlib implementation, add a wrapper that converts to/from PlotCollection: + +```python +def posterior_predictive(self, ..., backend=None, return_as_pc=False): + """[docstring]""" + from pymc_marketing.mmm.config import mmm_config + backend = backend or mmm_config["plot.backend"] + + # For now, always use matplotlib internally + # This lets us make progress while learning PlotCollection API + fig_mpl, axes_mpl = self._posterior_predictive_matplotlib( + var=var, idata=idata, hdi_prob=hdi_prob + ) + + if backend != "matplotlib": + # Convert matplotlib to other backend via PlotCollection + # This is a valid incremental approach + import warnings + warnings.warn( + f"Backend '{backend}' requested but full support not yet implemented. " + f"Using matplotlib with conversion.", + UserWarning + ) + # Conversion logic here... + + if return_as_pc: + # Wrap matplotlib figure in PlotCollection + pc = PlotCollection.wrap(fig_mpl, backend=backend) + return pc + else: + if backend == "matplotlib": + return fig_mpl, axes_mpl + else: + # Convert figure to target backend + fig_converted = convert_figure(fig_mpl, backend) + return fig_converted, None +``` + +This incremental approach lets tests pass while we refine the implementation. + +**Success Criteria:** + +##### Automated Verification: +- [ ] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_returns_plotcollection_when_requested -v` +- [ ] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_returns_tuple_by_default -v` +- [ ] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_backend_overrides_global -v` +- [ ] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_tuple_has_correct_axes_for_matplotlib -v` +- [ ] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib -v` +- [ ] All existing tests pass: `pytest tests/mmm/test_plot.py::test_posterior_predictive -v` +- [ ] Linting passes: `make lint` + +##### Manual Verification: +- [ ] PlotCollection objects are created correctly +- [ ] Figure extraction works for all backends +- [ ] Axes extraction works for matplotlib, returns None for others +- [ ] Visual output looks reasonable (manually inspect one plot per backend) +- [ ] No performance regressions (test with `time pytest ...`) + +### Implementation 4: Visual Output Validation + +**Target Tests**: +- `test_posterior_predictive_plotcollection_has_viz_attribute` +- `test_posterior_predictive_matplotlib_has_lines` +- `test_posterior_predictive_plotly_has_traces` +- `test_posterior_predictive_bokeh_has_renderers` + +**Current State**: May already pass if Implementation 3 is complete, or may need refinement + +**Debugging Approach:** +1. Run: `pytest tests/mmm/test_plot_backends.py -k "visual_output" -v` +2. If `test_plotcollection_has_viz_attribute` fails: + - Check PlotCollection structure + - Verify viz.figure.data.item() works +3. If `test_matplotlib_has_lines` fails: + - Check that median lines are actually plotted + - Verify Line2D objects exist in axes +4. If `test_plotly_has_traces` fails: + - Check plotly figure structure + - Verify conversion from matplotlib worked + - Check fig.data contains traces +5. If `test_bokeh_has_renderers` fails: + - Check bokeh figure structure + - Verify renderers exist + +**Possible Issues and Fixes:** +- **Empty plots**: Check that visuals.line() is actually called +- **Wrong backend**: Verify backend parameter is passed through correctly +- **Extraction fails**: Check PlotCollection API version, may need updates + +**Success Criteria:** + +##### Automated Verification: +- [ ] All visual output tests pass: `pytest tests/mmm/test_plot_backends.py -k "visual" -v` +- [ ] No warnings about empty plots +- [ ] All backends produce non-empty output + +##### Manual Verification: +- [ ] Matplotlib plots look correct (run test, inspect saved figure manually) +- [ ] Plotly plots render correctly (check fig.show() if interactive) +- [ ] Bokeh plots render correctly (check bokeh output) +- [ ] HDI bands are visible and correct + +### Complete Feature Implementation + +Once all tests pass: + +**Final Integration Check:** +```bash +# Run all backend tests +pytest tests/mmm/test_plot_backends.py -v + +# Run all existing tests to ensure no regressions +pytest tests/mmm/test_plot.py -v + +# Run full test suite +pytest tests/mmm/ -v + +# Check coverage for new code +pytest tests/mmm/test_plot_backends.py --cov=pymc_marketing.mmm.plot --cov-report=term-missing +``` + +**Success Criteria:** + +##### Automated Verification: +- [ ] All new tests pass: `pytest tests/mmm/test_plot_backends.py -v` +- [ ] No regressions: `pytest tests/mmm/test_plot.py::test_posterior_predictive -v` +- [ ] All MMM tests pass: `pytest tests/mmm/ -v` +- [ ] Code coverage: New code is >90% covered +- [ ] Linting passes: `make lint` +- [ ] Type checking passes: `make typecheck` + +##### Manual Verification: +- [ ] Can import and use mmm_config: `from pymc_marketing.mmm import mmm_config` +- [ ] Backward compatible: Old code works unchanged +- [ ] New API works: Can switch backends and get PlotCollection +- [ ] Visual output: Plots look correct in all three backends +- [ ] Documentation: Docstrings are complete and accurate + +--- + +## Phase 4: Refactoring & Cleanup + +### Overview +Now that tests are green, refactor to improve code quality while keeping tests passing. Tests protect us during refactoring. + +### Refactoring Targets + +#### 1. Code Duplication in Test File +**Problem**: Test cases may have repeated setup code +**Solution**: Extract common patterns to helper functions + +```python +# tests/mmm/test_plot_backends.py + +def assert_valid_plotcollection(pc, expected_backend): + """ + Helper to validate PlotCollection structure. + + Reduces duplication across tests. + """ + from arviz_plots import PlotCollection + + assert isinstance(pc, PlotCollection), \ + f"Should return PlotCollection, got {type(pc)}" + assert hasattr(pc, 'backend'), \ + "PlotCollection should have backend attribute" + assert pc.backend == expected_backend, \ + f"Backend should be '{expected_backend}', got '{pc.backend}'" + + +def assert_valid_backend_figure(fig, backend): + """ + Helper to validate backend-specific figure types. + """ + if backend == "matplotlib": + from matplotlib.figure import Figure + assert isinstance(fig, Figure) + elif backend == "plotly": + assert hasattr(fig, 'data'), "Plotly figure should have 'data'" + elif backend == "bokeh": + assert hasattr(fig, 'renderers'), "Bokeh figure should have 'renderers'" +``` + +#### 2. Backend Resolution Logic +**Problem**: Backend resolution logic may be duplicated in every method +**Solution**: Extract to a helper method in MMMPlotSuite + +```python +# pymc_marketing/mmm/plot.py + +class MMMPlotSuite: + """[existing docstring]""" + + def _resolve_backend(self, backend: str | None) -> str: + """ + Resolve backend parameter to actual backend string. + + Parameters + ---------- + backend : str or None + Backend parameter from method call. + + Returns + ------- + str + Resolved backend name (parameter overrides global config). + + Examples + -------- + >>> suite._resolve_backend(None) # uses global config + 'matplotlib' + >>> suite._resolve_backend("plotly") # uses parameter + 'plotly' + """ + from pymc_marketing.mmm.config import mmm_config + return backend or mmm_config["plot.backend"] +``` + +#### 3. Figure/Axes Extraction Logic +**Problem**: Tuple return logic may be complex and repeated +**Solution**: Extract to helper method + +```python +# pymc_marketing/mmm/plot.py + +class MMMPlotSuite: + """[existing docstring]""" + + def _extract_figure_and_axes( + self, + pc: "PlotCollection", + backend: str + ) -> tuple: + """ + Extract figure and axes from PlotCollection for tuple return. + + Parameters + ---------- + pc : PlotCollection + PlotCollection object to extract from. + backend : str + Backend name ("matplotlib", "plotly", or "bokeh"). + + Returns + ------- + tuple + (figure, axes) where figure is backend-specific Figure object + and axes is list of Axes for matplotlib, None for other backends. + + Notes + ----- + This method enables backward compatibility by extracting matplotlib-style + return values from PlotCollection objects. + """ + # Extract figure + fig = pc.viz.figure.data.item() + + # Extract axes (only for matplotlib) + if backend == "matplotlib": + axes = list(fig.get_axes()) + else: + axes = None + + return fig, axes +``` + +#### 4. Simplify PlotCollection Creation +**Problem**: PlotCollection creation logic may be verbose +**Solution**: Extract data preparation to helper method + +```python +# pymc_marketing/mmm/plot.py + +class MMMPlotSuite: + """[existing docstring]""" + + def _prepare_data_for_plotcollection( + self, + data: xr.DataArray, + stack_dims: tuple[str, ...] = ("chain", "draw") + ) -> xr.DataArray: + """ + Prepare xarray data for PlotCollection plotting. + + Parameters + ---------- + data : xr.DataArray + Input data with MCMC dimensions. + stack_dims : tuple of str, default ("chain", "draw") + Dimensions to stack into 'sample' dimension. + + Returns + ------- + xr.DataArray + Data with chain and draw stacked into sample dimension. + """ + if all(d in data.dims for d in stack_dims): + data = data.stack(sample=stack_dims) + return data +``` + +#### 5. Test Code Quality +**Problem**: Some tests may have complex setup +**Solution**: Use parametrize more effectively + +```python +# Example refactoring of tests + +# Before: Multiple similar test functions +def test_backend_matplotlib_works(): + pc = suite.posterior_predictive(backend="matplotlib", return_as_pc=True) + assert pc.backend == "matplotlib" + +def test_backend_plotly_works(): + pc = suite.posterior_predictive(backend="plotly", return_as_pc=True) + assert pc.backend == "plotly" + +def test_backend_bokeh_works(): + pc = suite.posterior_predictive(backend="bokeh", return_as_pc=True) + assert pc.backend == "bokeh" + +# After: Single parametrized test +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_backend_parameter_works(suite, backend): + """Test that all backends work correctly.""" + pc = suite.posterior_predictive(backend=backend, return_as_pc=True) + assert pc.backend == backend +``` + +### Refactoring Steps + +1. **Ensure all tests pass before starting**: + ```bash + pytest tests/mmm/test_plot_backends.py -v + ``` + +2. **For each refactoring**: + - Make the change (extract helper, rename variable, etc.) + - Run tests immediately: `pytest tests/mmm/test_plot_backends.py -v` + - If tests pass, commit the change (or move to next refactoring) + - If tests fail, revert and reconsider + +3. **Focus areas**: + - Extract helper methods (backend resolution, figure extraction) + - Improve naming (clear variable names, descriptive method names) + - Add code comments where logic is complex + - Simplify conditional logic + - Remove any dead code or unused imports + +4. **Test code refactoring**: + - Extract test helpers (assertion helpers) + - Use parametrize more effectively + - Improve test names for clarity + - Add docstrings to complex test fixtures + +### Success Criteria + +#### Automated Verification: +- [ ] All tests still pass: `pytest tests/mmm/test_plot_backends.py -v` +- [ ] No regressions: `pytest tests/mmm/test_plot.py -v` +- [ ] Code coverage maintained: `pytest --cov=pymc_marketing.mmm.plot --cov-report=term-missing` +- [ ] Linting passes: `make lint` +- [ ] Type checking passes: `mypy pymc_marketing/mmm/plot.py` +- [ ] No performance regressions: Compare test run time before/after + +#### Manual Verification: +- [ ] Code is more readable after refactoring +- [ ] No unnecessary complexity added +- [ ] Function/variable names are clear and descriptive +- [ ] Comments explain "why" not "what" +- [ ] Helper methods have clear single responsibilities +- [ ] Test code is DRY (Don't Repeat Yourself) +- [ ] Code follows project idioms (check CLAUDE.md patterns) + +--- + +## Phase 5: Expand to contributions_over_time() + +### Overview +Apply the same TDD process to the second method, `contributions_over_time()`. This method is similar to `posterior_predictive()`, so the pattern is established. + +### Test Design for contributions_over_time() + +**New Test Cases** (add to `tests/mmm/test_plot_backends.py`): + +```python +@pytest.fixture(scope="module") +def mock_suite_with_contributions(mock_idata): + """ + Fixture providing MMMPlotSuite with contribution data. + + Reuses mock_idata which already has intercept and linear_trend. + """ + return MMMPlotSuite(idata=mock_idata) + + +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_contributions_over_time_backend_parameter(mock_suite_with_contributions, backend): + """Test contributions_over_time accepts backend parameter and uses it.""" + pc = mock_suite_with_contributions.contributions_over_time( + var=["intercept"], + backend=backend, + return_as_pc=True + ) + assert_valid_plotcollection(pc, backend) + + +def test_contributions_over_time_returns_tuple_by_default(mock_suite_with_contributions): + """Test backward compatibility - returns tuple by default.""" + result = mock_suite_with_contributions.contributions_over_time( + var=["intercept"] + ) + assert isinstance(result, tuple) + assert len(result) == 2 + + +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_contributions_over_time_plotcollection(mock_suite_with_contributions, backend): + """Test return_as_pc=True returns PlotCollection.""" + pc = mock_suite_with_contributions.contributions_over_time( + var=["intercept"], + backend=backend, + return_as_pc=True + ) + assert_valid_plotcollection(pc, backend) + + +# Add visual output tests similar to posterior_predictive +``` + +### Implementation Steps + +1. **Write tests first**: Add all test cases to `test_plot_backends.py` +2. **Run tests to verify failures**: `pytest tests/mmm/test_plot_backends.py -k contributions_over_time -v` +3. **Add parameters to method signature**: Same as posterior_predictive +4. **Implement PlotCollection integration**: Reuse patterns from posterior_predictive +5. **Extract figure/axes for tuple return**: Use helper methods from refactoring +6. **Verify all tests pass**: `pytest tests/mmm/test_plot_backends.py -k contributions_over_time -v` + +### Success Criteria + +- [ ] All contributions_over_time tests pass +- [ ] Existing contributions_over_time tests still pass +- [ ] Code follows same pattern as posterior_predictive +- [ ] Refactored helpers are reused (no duplication) + +--- + +## Testing Strategy Summary + +### Test Coverage Goals +- [x] Normal operation paths: All public methods work with default parameters +- [x] Backend switching: All three backends (matplotlib, plotly, bokeh) work +- [x] Return type switching: Both tuple and PlotCollection returns work +- [x] Backward compatibility: Existing code works unchanged +- [x] Configuration: Global and per-function backend configuration +- [x] Edge cases: Invalid backends warn, missing data errors clearly +- [x] Visual output: Plots contain expected elements (lines, traces, renderers) + +### Test Organization +- **Test files**: + - `tests/mmm/test_plot_backends.py` (NEW) - Backend migration tests + - `tests/mmm/test_plot.py` (EXISTING) - Original tests, marked for future consolidation +- **Fixtures**: + - Module-scoped for expensive InferenceData creation + - Function-scoped for config cleanup + - Located at top of test file +- **Test utilities**: + - Helper assertions (assert_valid_plotcollection, assert_valid_backend_figure) + - Located in test file (not separate module yet) +- **Test data**: + - xarray-based InferenceData fixtures + - Realistic structure matching MMM output + +### Running Tests + +```bash +# Run all backend tests +pytest tests/mmm/test_plot_backends.py -v + +# Run specific test +pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_returns_plotcollection_when_requested -v + +# Run with coverage +pytest tests/mmm/test_plot_backends.py --cov=pymc_marketing.mmm.plot --cov-report=term-missing + +# Run with failure details +pytest tests/mmm/test_plot_backends.py -vv --tb=short + +# Run only matplotlib backend tests (faster) +pytest tests/mmm/test_plot_backends.py -k "matplotlib" -v + +# Run all backends in parallel (if pytest-xdist installed) +pytest tests/mmm/test_plot_backends.py -n auto +``` + +## Performance Considerations + +Performance testing is explicitly out of scope (requirement #2), but we should avoid obvious regressions: + +- **Keep existing matplotlib path fast**: Don't add unnecessary overhead for default usage +- **Lazy imports**: Import PlotCollection only when needed +- **Reuse computations**: Don't recompute HDI if already computed +- **Fixture scope**: Use module-scoped fixtures to avoid repeated setup + +## Migration Notes + +### For Users + +**Backward Compatibility**: +- All existing code continues to work without changes +- Default behavior unchanged (matplotlib, tuple return) +- No breaking changes to public API + +**New Features**: +```python +# Global backend configuration +from pymc_marketing.mmm import mmm_config +mmm_config["plot.backend"] = "plotly" + +# All plots now use plotly +model.plot.posterior_predictive() + +# Override for specific plot +model.plot.contributions_over_time(backend="matplotlib") + +# Get PlotCollection for advanced customization +pc = model.plot.saturation_curves(curve=curve_data, return_as_pc=True) +pc.map(custom_visual_function) +pc.show() +``` + +### For Developers + +**Adding New Plotting Methods**: +1. Add `backend` and `return_as_pc` parameters +2. Use `self._resolve_backend(backend)` to get backend +3. Create PlotCollection with appropriate backend +4. Use `self._extract_figure_and_axes(pc, backend)` for tuple return +5. Write tests in `test_plot_backends.py` before implementing + +**Testing Checklist**: +- [ ] Test accepts backend parameter +- [ ] Test accepts return_as_pc parameter +- [ ] Test returns tuple by default (backward compat) +- [ ] Test returns PlotCollection when requested +- [ ] Test all three backends (parametrize) +- [ ] Test visual output (has lines/traces/renderers) + +## Dependencies + +### New Dependencies +- **arviz-plots**: Required for PlotCollection API + - Add to `pyproject.toml`: `arviz-plots>=0.7.0` + - Add to `environment.yml`: `- arviz-plots>=0.7.0` + +### Existing Dependencies (no changes) +- matplotlib: Already required +- arviz: Already required +- xarray: Already required +- numpy: Already required + +## References + +- Original research: [thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md](thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md) +- MMMPlotSuite implementation: [pymc_marketing/mmm/plot.py:187-1924](pymc_marketing/mmm/plot.py#L187-1924) +- Existing tests: [tests/mmm/test_plot.py](tests/mmm/test_plot.py) +- ArviZ PlotCollection docs: https://arviz-plots.readthedocs.io/ +- Test patterns reference: This plan's Phase 1 test examples + +## Open Questions + +1. **PlotCollection API Learning Curve**: ✅ ADDRESSED - Use incremental approach, start with matplotlib wrapper if needed +2. **Visual Output Validation**: ✅ ADDRESSED - Test for presence of elements (lines/traces/renderers), not pixel-perfect matching +3. **Performance Impact**: ✅ OUT OF SCOPE - User confirmed not a concern for this migration +4. **Deprecation Timeline**: When should we deprecate tuple return in favor of PlotCollection? + - Recommendation: Keep both indefinitely, default to tuple for backward compat +5. **Test File Consolidation**: When to merge `test_plot.py` into `test_plot_backends.py`? + - Recommendation: After all methods migrated and stable (next version) + +## Next Steps After Phase 5 + +Once `posterior_predictive()` and `contributions_over_time()` are fully implemented and tested: + +1. **Expand to saturation methods**: + - `saturation_scatterplot()` - Similar pattern, adds scatter plots + - `saturation_curves()` - Adds `rc_params` deprecation, `backend_config` parameter + +2. **Implement twinx fallback**: + - `budget_allocation()` - Special case with fallback warning + +3. **Expand to sensitivity methods**: + - `sensitivity_analysis()`, `uplift_curve()`, `marginal_curve()` - Wrappers + +4. **Full test suite validation**: + - Run all MMM tests: `pytest tests/mmm/ -v` + - Check coverage: `pytest tests/mmm/ --cov=pymc_marketing.mmm --cov-report=html` + - Performance baseline: `pytest tests/mmm/ --durations=20` + +5. **Documentation**: + - Update user guide with backend examples + - Add migration guide for users + - Update docstrings with examples + - Create notebook showing multi-backend usage + +## Summary of Key Decisions + +1. ✅ **Backward Compatibility**: Maintained via `return_as_pc=False` default +2. ✅ **Global Configuration**: ArviZ-style `mmm_config` dictionary +3. ✅ **Test Organization**: New file `test_plot_backends.py`, mark old file for future consolidation +4. ✅ **Mock Data**: Use existing patterns from `test_plot.py`, realistic xarray structures +5. ✅ **Test Depth**: Prioritize depth (thorough testing of first 2 methods) over breadth +6. ✅ **Visual Validation**: Test for presence of plot elements, not pixel-perfect matching +7. ✅ **Default Backend**: Keep "matplotlib" as default for full backward compatibility +8. ✅ **Helper Extraction**: Refactor common patterns to methods during cleanup phase +9. ✅ **Incremental Implementation**: OK to start with matplotlib-only, add backends incrementally From a6ec594def07901ea62aabb85ac8e9743373cfc7 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Thu, 13 Nov 2025 17:16:14 -0500 Subject: [PATCH 03/29] add research --- ...otsuite-backend-migration-comprehensive.md | 554 ++++++++++++++++++ 1 file changed, 554 insertions(+) create mode 100644 thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md diff --git a/thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md b/thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md new file mode 100644 index 000000000..bb4e06ede --- /dev/null +++ b/thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md @@ -0,0 +1,554 @@ +--- +date: 2025-11-11T21:13:39-05:00 +researcher: Claude +git_commit: e78e3afb259a33f0d2b09d0d6c7e409fe4ddc90d +branch: main +repository: pymc-marketing +topic: "MMMPlotSuite Migration to ArviZ PlotCollection with Backward Compatibility - Comprehensive Research" +tags: [research, codebase, plotting, visualization, arviz, matplotlib, plotly, bokeh, backend-agnostic, mmm, backward-compatibility] +status: complete +last_updated: 2025-11-11 +last_updated_by: Claude +--- + +# Research: MMMPlotSuite Migration to ArviZ PlotCollection with Backward Compatibility + +**Date**: 2025-11-11T21:13:39-05:00 +**Researcher**: Claude +**Git Commit**: e78e3afb259a33f0d2b09d0d6c7e409fe4ddc90d +**Branch**: main +**Repository**: pymc-marketing + +## Research Question + +The user wants to rewrite the MMMPlotSuite class in [plot.py](pymc_marketing/mmm/plot.py) to support additional backends beyond matplotlib. Specifically, they want to rewrite the functions in that class to use ArviZ's PlotCollection API instead of matplotlib directly, making the methods return PlotCollection objects instead of matplotlib Figure and Axes objects. + +### Updated Requirements (Corrected from previous research) + +1. **Global Backend Configuration**: Support the ability to set the backend once, in a "global manner", and then all plots will use that backend. This is on top of the option of setting the backend to individual functions using a backend argument, which will override the global setting. + +2. **Backward Compatibility**: The changes to MMMPlotSuite plotting functions should be backward compatible. The output to the function will be based on the backend argument, and there would also be an argument that would control whether to return a PlotCollection object instead. + +3. **Backend-Specific Code**: Identify all the matplotlib-specific functions that do not have a direct equivalent in other backends and come up with specific code to handle them for all backends. + +4. **RC Params Handling**: For the method `saturation_curves`, it uses `plt.rc_context(rc_params)` so we will need to change it. Use `backend="matplotlib", backend_config=None` as arguments. We are going to keep that `rc_params` parameter for backward compatibility, but emit a warning when using it. + +5. **Twin Axes Fallback**: A function that uses matplotlib `twinx` cannot be currently written using arviz-plots. So if a different backend is chosen it needs to emit a warning and fallback to matplotlib. + +6. **ArviZ-style rcParams**: Use the recommended "ArviZ-style rcParams with fallback" for Global Backend Configuration Implementation. + +7. **Performance**: Performance is not a concern for this migration. + +8. **Testing**: Testing of all functions should be across matplotlib, plotly and bokeh backends. + +9. **Component Plot Methods**: Do not migrate component plot methods outside MMMPlotSuite. If MMMPlotSuite uses a plotting function that is defined in a different file we would need to create a new function instead. + +## Summary + +The current MMMPlotSuite implementation is tightly coupled to matplotlib, with all 10 public plotting methods returning `tuple[Figure, NDArray[Axes]]` or similar matplotlib objects. The class uses matplotlib-specific APIs throughout (`plt.subplots`, `ax.plot`, `ax.fill_between`, `ax.twinx`, etc.). + +ArviZ's PlotCollection API (from arviz-plots) provides a backend-agnostic alternative that supports matplotlib, bokeh, plotly, and none backends. The codebase already uses ArviZ extensively for HDI computation (`az.hdi()`) and some plotting (`az.plot_hdi()`), but does not use PlotCollection anywhere in production code. + +The migration must maintain backward compatibility, support global backend configuration with per-function overrides, handle matplotlib-specific features gracefully (particularly `ax.twinx()` which requires backend-specific implementations or fallback), and be tested across matplotlib, plotly, and bokeh backends. + +## Detailed Findings + +### Current MMMPlotSuite Architecture + +#### Class Overview + +**Location**: [pymc_marketing/mmm/plot.py:187-1924](pymc_marketing/mmm/plot.py#L187) + +The MMMPlotSuite class is a standalone visualization class for MMM models: +- Initialized with `xr.Dataset` or `az.InferenceData` +- 10 public plotting methods (including 1 deprecated) +- Multiple helper methods for subplot creation and data manipulation +- All methods return matplotlib objects + +#### Method Signatures and Return Types + +| Method | Current Return Type | Lines | Usage | +|--------|-------------------|-------|-------| +| `posterior_predictive()` | `tuple[Figure, NDArray[Axes]]` | 375-463 | Plot posterior predictive time series | +| `contributions_over_time()` | `tuple[Figure, NDArray[Axes]]` | 465-588 | Plot contribution time series with HDI | +| `saturation_scatterplot()` | `tuple[Figure, NDArray[Axes]]` | 590-742 | Scatter plots of channel saturation | +| `saturation_curves()` | `tuple[plt.Figure, np.ndarray]` | 744-996 | Overlay scatter data with posterior curves | +| `saturation_curves_scatter()` | `tuple[Figure, NDArray[Axes]]` | 998-1035 | **Deprecated** - use `saturation_scatterplot()` | +| `budget_allocation()` | `tuple[Figure, plt.Axes] \| tuple[Figure, np.ndarray]` | 1037-1212 | Bar chart with dual y-axes | +| `allocated_contribution_by_channel_over_time()` | `tuple[Figure, plt.Axes \| NDArray[Axes]]` | 1279-1481 | Line plots with uncertainty bands | +| `sensitivity_analysis()` | `tuple[Figure, NDArray[Axes]] \| plt.Axes` | 1483-1718 | Plot sensitivity sweep results | +| `uplift_curve()` | `tuple[Figure, NDArray[Axes]] \| plt.Axes` | 1720-1820 | Wrapper around sensitivity_analysis for uplift | +| `marginal_curve()` | `tuple[Figure, NDArray[Axes]] \| plt.Axes` | 1822-1923 | Wrapper around sensitivity_analysis for marginal effects | + +#### Matplotlib-Specific APIs Used + +**Core matplotlib functions used across all methods:** +- `plt.subplots()` - Creating figure and axes grid +- `ax.plot()` - Line plots for medians +- `ax.fill_between()` - HDI/uncertainty bands +- `ax.scatter()` - Scatter plots for data points +- `ax.bar()` - Bar charts (budget allocation) +- `ax.twinx()` - Dual y-axes (budget allocation) - **CRITICAL FEATURE** +- `ax.set_title()`, `ax.set_xlabel()`, `ax.set_ylabel()` - Labeling +- `ax.legend()` - Legends +- `ax.set_visible()` - Hide unused axes +- `fig.tight_layout()` - Layout adjustment +- `fig.suptitle()` - Figure titles +- `plt.rc_context()` - Temporary matplotlib settings + +### Matplotlib-Specific Features Analysis + +#### Critical Feature: ax.twinx() - Dual Y-Axes + +**Location**: [pymc_marketing/mmm/plot.py:1249](pymc_marketing/mmm/plot.py#L1249) + +**Method**: `_plot_budget_allocation_bars()` → `budget_allocation()` + +**What it does**: Creates a secondary y-axis with independent scale, used to compare allocated spend vs. channel contribution on the same plot with different y-scales. + +**Implementation details**: +```python +# Line 1239-1246: Primary bars on primary axis +bars1 = ax.bar(index, allocated_spend, bar_width, color="C0", alpha=opacity, label="Allocated Spend") + +# Line 1249: Create twin axis +ax2 = ax.twinx() + +# Line 1252-1259: Secondary bars on secondary axis +bars2 = ax2.bar([i + bar_width for i in index], channel_contribution, bar_width, + color="C1", alpha=opacity, label="Channel Contribution") +``` + +**Backends without native PlotCollection support**: +- **Bokeh**: No direct twin axes support in PlotCollection +- **Plotly**: Has secondary y-axes but requires different approach + +**User Requirement**: "A function that uses matplotlib `twinx` cannot be currently written using arviz. So if a different backend is chosen it needs to emit a warning and fallback to matplotlib." + +**Recommended Strategy**: Detect when non-matplotlib backend is requested, emit warning, and force fallback to matplotlib backend. + +#### Medium Impact Feature: plt.rc_context() + +**Location**: [pymc_marketing/mmm/plot.py:878-880](pymc_marketing/mmm/plot.py#L878-880) + +**Method**: `saturation_curves()` + +**Code**: +```python +rc_params = rc_params or {} +with plt.rc_context(rc_params): + fig, axes = plt.subplots(nrows=nrows, ncols=ncols, **subkw) +``` + +**User Requirement**: "For the method `saturation_curves`, it uses `plt.rc_context(rc_params)` so we will need to change it. I want to use `backend="matplotlib", backend_config=None` as arguments. We are going to keep that `rc_params` parameter for backward compatibility, but emit a warning when using it." + +**Recommended Strategy**: +- Add `backend_config` parameter +- Keep `rc_params` parameter with deprecation warning +- Only apply config when backend is matplotlib +- Warn if `backend_config` provided for non-matplotlib backends + +### ArviZ Usage in the Codebase + +#### Current ArviZ Integration + +**ArviZ functions currently used:** +1. `az.hdi()` - HDI computation (used extensively) +2. `az.plot_hdi()` - HDI plotting (used in 2 methods) +3. `az.summary()` - Summary statistics +4. `az.InferenceData` - Primary data container (88+ references) + +**Key Finding: No PlotCollection usage** +- Zero instances of `PlotCollection` in production code +- No imports from `arviz_plots` in production code +- All plotting is matplotlib-specific + +#### External Plotting Functions Used by MMMPlotSuite + +**From [pymc_marketing/plot.py](pymc_marketing/plot.py):** + +Imported at line 799 in `saturation_curves()` method: +```python +from pymc_marketing.plot import plot_hdi, plot_samples +``` + +These functions use matplotlib directly and would need PlotCollection versions created per requirement #9. + +### Recommendations + +#### 1. Global Backend Configuration Implementation + +**Pattern**: ArviZ-style rcParams (per user requirement #6) + +**Implementation**: + +```python +# pymc_marketing/mmm/config.py (new file) +class MMMConfig(dict): + """Configuration dictionary for MMM plotting.""" + + _defaults = { + "plot.backend": "matplotlib", + "plot.show_warnings": True, + } + + def __init__(self): + super().__init__(self._defaults) + + def reset(self): + """Reset to defaults.""" + self.clear() + self.update(self._defaults) + +# Global config instance +mmm_config = MMMConfig() +``` + +**User API**: +```python +import pymc_marketing as pmm + +# Set global backend +pmm.mmm.mmm_config["plot.backend"] = "plotly" + +# All subsequent plots use plotly +model.plot.posterior_predictive() + +# Override for specific plot +model.plot.saturation_curves(backend="matplotlib") + +# Reset to defaults +pmm.mmm.mmm_config.reset() +``` + +#### 2. Backward Compatibility Strategy + +**Method Signature Pattern**: +```python +def posterior_predictive( + self, + var: list[str] | None = None, + idata: xr.Dataset | None = None, + hdi_prob: float = 0.85, + backend: str | None = None, + return_as_pc: bool = False, +) -> tuple[Figure, NDArray[Axes]] | PlotCollection: + """ + Parameters + ---------- + backend : str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config (default: "matplotlib"). + return_as_pc : bool, default False + If True, returns PlotCollection object. If False, returns a tuple of + (figure, axes) where figure is the backend-specific figure object and + axes is an array of axes for matplotlib or None for other backends. + + Returns + ------- + PlotCollection or tuple + If return_as_pc=True, returns PlotCollection object. + If return_as_pc=False, returns (figure, axes) where: + - figure: backend-specific figure object (plt.Figure, plotly.graph_objs.Figure, etc.) + - axes: np.ndarray of matplotlib Axes if backend="matplotlib", else None + """ + # Resolve backend + backend = backend or mmm_config["plot.backend"] + + # Create PlotCollection + pc = PlotCollection.grid(data, backend=backend, ...) + pc.map(plotting_function, ...) + + # Return based on return_as_pc flag + if return_as_pc: + return pc + else: + # Extract figure from PlotCollection + fig = pc.viz.figure.data.item() + + # Only matplotlib has axes + if backend == "matplotlib": + axes = fig.get_axes() + else: + axes = None + + return fig, axes +``` + +#### 3. Twin Axes Fallback Strategy + +**Implementation for `budget_allocation()`**: + +```python +def budget_allocation( + self, + samples: xr.Dataset, + backend: str | None = None, + return_as_pc: bool = False, + **kwargs +) -> tuple[Figure, plt.Axes] | tuple[Figure, np.ndarray] | PlotCollection: + """ + Notes + ----- + This method uses dual y-axes (matplotlib's twinx), which is not supported + by PlotCollection. If a non-matplotlib backend is requested, a warning + will be issued and the method will fallback to matplotlib. + """ + # Resolve backend + backend = backend or mmm_config["plot.backend"] + + # Check for twinx compatibility (per user requirement #5) + if backend != "matplotlib": + import warnings + warnings.warn( + f"budget_allocation() uses dual y-axes (ax.twinx()) which is not " + f"supported by PlotCollection with backend='{backend}'. " + f"Falling back to matplotlib.", + UserWarning + ) + backend = "matplotlib" + + # Proceed with implementation + # ... +``` + +#### 4. RC Params Handling for saturation_curves() + +**Implementation** (per user requirement #4): + +```python +def saturation_curves( + self, + curve: xr.DataArray, + rc_params: dict | None = None, # DEPRECATED + backend: str | None = None, + backend_config: dict | None = None, + return_as_pc: bool = False, + **kwargs, +) -> tuple[plt.Figure, np.ndarray] | PlotCollection: + """ + Parameters + ---------- + rc_params : dict, optional + **DEPRECATED**: Use `backend_config` instead. + Temporary `matplotlib.rcParams` for this plot (matplotlib backend only). + A DeprecationWarning will be issued when using this parameter. + backend : str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config (default: "matplotlib"). + backend_config : dict, optional + Backend-specific configuration dictionary: + - matplotlib: rcParams dict (same as deprecated rc_params) + - plotly: layout configuration dict + - bokeh: theme configuration dict + """ + # Resolve backend + backend = backend or mmm_config["plot.backend"] + + # Handle deprecated rc_params (per user requirement #4) + if rc_params is not None: + import warnings + warnings.warn( + "The 'rc_params' parameter is deprecated and will be removed in a " + "future version. Use 'backend_config' instead.", + DeprecationWarning, + stacklevel=2 + ) + if backend_config is None: + backend_config = rc_params + + # Apply backend-specific config if matplotlib + if backend == "matplotlib" and backend_config: + with plt.rc_context(backend_config): + # ... create PlotCollection ... + else: + if backend_config and backend != "matplotlib": + import warnings + warnings.warn( + f"backend_config only supported for matplotlib backend, " + f"ignoring for backend='{backend}'", + UserWarning + ) + # ... create PlotCollection without rc_context ... +``` + +#### 5. Helper Function Migration Strategy + +**Problem**: Helper functions in [pymc_marketing/plot.py](pymc_marketing/plot.py) use matplotlib directly. + +**Approach**: Create new PlotCollection-compatible versions (per user requirement #9) + +```python +# New backend-agnostic versions +def plot_hdi_pc(data, *, backend=None, plot_collection=None, **pc_kwargs): + """Plot HDI using PlotCollection (backend-agnostic).""" + backend = backend or mmm_config["plot.backend"] + + if plot_collection is None: + pc = PlotCollection.grid(data, backend=backend, **pc_kwargs) + else: + pc = plot_collection + + pc.map(_plot_hdi_visual, data=data) + return pc + +# Keep existing matplotlib-specific version for backward compatibility +def plot_hdi(da, ax=None, **kwargs): + """Plot HDI using matplotlib (legacy).""" + # ... existing matplotlib implementation +``` + +#### 6. Testing Strategy + +**Requirement**: Test across matplotlib, plotly and bokeh backends (user requirement #8). + +**Implementation**: +```python +# tests/mmm/test_plotting_backends.py +import pytest + +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +class TestMMMPlotSuiteBackends: + """Test all MMMPlotSuite methods across backends.""" + + def test_posterior_predictive(self, mmm_model, backend): + pc = mmm_model.plot.posterior_predictive( + backend=backend, + return_as_pc=True + ) + assert isinstance(pc, PlotCollection) + assert pc.backend == backend + + def test_backward_compatibility_matplotlib(self, mmm_model): + """Test backward compatibility with matplotlib.""" + fig, axes = mmm_model.plot.posterior_predictive( + backend="matplotlib", + return_as_pc=False + ) + assert isinstance(fig, plt.Figure) + assert isinstance(axes, np.ndarray) + + def test_twinx_fallback(self, mmm_model): + """Test that budget_allocation falls back to matplotlib for non-matplotlib backends.""" + with pytest.warns(UserWarning, match="Falling back to matplotlib"): + result = mmm_model.plot.budget_allocation( + samples=..., + backend="plotly", + return_as_pc=False + ) + # Should return matplotlib objects despite requesting plotly + assert isinstance(result[0], plt.Figure) + + def test_rc_params_deprecation(self, mmm_model): + """Test that rc_params parameter issues deprecation warning.""" + with pytest.warns(DeprecationWarning, match="rc_params.*deprecated"): + mmm_model.plot.saturation_curves( + curve=..., + rc_params={"xtick.labelsize": 12}, + backend="matplotlib" + ) + + def test_global_backend_config(self, mmm_model): + """Test global backend configuration.""" + import pymc_marketing as pmm + original_backend = pmm.mmm.mmm_config["plot.backend"] + try: + pmm.mmm.mmm_config["plot.backend"] = "plotly" + pc = mmm_model.plot.posterior_predictive(return_as_pc=True) + assert pc.backend == "plotly" + finally: + pmm.mmm.mmm_config["plot.backend"] = original_backend +``` + +## Migration Implementation Checklist + +### Phase 1: Infrastructure Setup + +- [ ] Add `arviz-plots` as a required dependency in `pyproject.toml` +- [ ] Create `pymc_marketing/mmm/config.py` with `MMMConfig` class and `mmm_config` instance +- [ ] Export `mmm_config` from `pymc_marketing/mmm/__init__.py` +- [ ] Create backend-agnostic plotting function templates + +### Phase 2: Helper Functions + +- [ ] Create `plot_hdi_pc()` PlotCollection version in `pymc_marketing/plot.py` +- [ ] Create `plot_samples_pc()` PlotCollection version in `pymc_marketing/plot.py` +- [ ] Implement backend detection logic in visual functions +- [ ] Keep existing `plot_hdi()` and `plot_samples()` for backward compatibility + +### Phase 3: MMMPlotSuite Methods (Priority Order) + +**High Priority (Simple methods)**: +1. [ ] `posterior_predictive()` - Add backend/return_as_pc parameters +2. [ ] `contributions_over_time()` - Add backend/return_as_pc parameters +3. [ ] `saturation_scatterplot()` - Add backend/return_as_pc parameters +4. [ ] `sensitivity_analysis()` - Add backend/return_as_pc parameters +5. [ ] `uplift_curve()` - Inherits from sensitivity_analysis +6. [ ] `marginal_curve()` - Inherits from sensitivity_analysis + +**Medium Priority (Uses external functions)**: +7. [ ] `saturation_curves()` - Add backend/backend_config/return_as_pc, deprecate rc_params +8. [ ] `allocated_contribution_by_channel_over_time()` - Add backend/return_as_pc + +**Low Priority (Requires twinx fallback)**: +9. [ ] `budget_allocation()` - Add backend/return_as_pc with twinx fallback logic + +### Phase 4: Testing + +- [ ] Create `tests/mmm/test_plotting_backends.py` +- [ ] Parametrized tests across matplotlib/plotly/bokeh +- [ ] Backward compatibility tests +- [ ] Global config tests +- [ ] Fallback behavior tests (twinx) +- [ ] Deprecation warning tests (rc_params) +- [ ] Return type validation tests + +### Phase 5: Documentation + +- [ ] Update all docstrings with new parameters +- [ ] Add migration guide for users +- [ ] Add examples showing new API +- [ ] Document backend limitations +- [ ] Update notebooks to show multi-backend usage + +## Code References + +### MMMPlotSuite Implementation +- Class definition: [pymc_marketing/mmm/plot.py:187](pymc_marketing/mmm/plot.py#L187) +- Twin axes usage: [pymc_marketing/mmm/plot.py:1249](pymc_marketing/mmm/plot.py#L1249) +- RC context usage: [pymc_marketing/mmm/plot.py:878-880](pymc_marketing/mmm/plot.py#L878-880) +- Helper methods: [pymc_marketing/mmm/plot.py:200-370](pymc_marketing/mmm/plot.py#L200-370) +- Main plotting methods: [pymc_marketing/mmm/plot.py:375-1923](pymc_marketing/mmm/plot.py#L375-1923) + +### Helper Functions to Migrate +- plot_hdi: [pymc_marketing/plot.py:434](pymc_marketing/plot.py#L434) +- plot_samples: [pymc_marketing/plot.py:503](pymc_marketing/plot.py#L503) + +### Dependencies +- Package configuration: `/Users/imrisofer/projects/pymc-marketing/pyproject.toml` +- Conda environment: `/Users/imrisofer/projects/pymc-marketing/environment.yml` + +## Open Questions + +1. **PlotCollection Figure/Axes Extraction**: ✅ RESOLVED - Use `pc.viz.figure.data.item()` to extract backend-specific figure object, then `fig.get_axes()` for matplotlib axes (returns None for non-matplotlib backends). + +2. **Backend-Specific Styling**: Should we implement backend-specific styling translation for common use cases, or just warn users that `backend_config` only works for matplotlib? + +3. **Helper Function Strategy**: Should we deprecate old `plot_hdi()` and `plot_samples()` or keep them indefinitely? + +4. **Component Plotting**: While we're not migrating component plot methods per requirement #9, should we at least add documentation noting that they remain matplotlib-only? + +5. **Version Number**: Should this migration be part of a major version bump (e.g., 0.x → 1.0 or 1.x → 2.0)? + +## Summary of Key Decisions + +1. **Backward Compatibility**: Maintained via `return_as_pc=False` default parameter + - When `return_as_pc=False`, functions return `(figure, axes)` tuple + - `figure` is extracted via `pc.viz.figure.data.item()` + - `axes` is extracted via `fig.get_axes()` for matplotlib, `None` for other backends +2. **Global Configuration**: ArviZ-style `mmm_config` dictionary +3. **Twin Axes**: Fallback to matplotlib with warning for `budget_allocation()` +4. **RC Params**: Deprecate `rc_params`, add `backend_config` parameter +5. **Testing**: Parametrized tests across matplotlib, plotly, bokeh +6. **Helper Functions**: Create new PlotCollection versions, keep existing for compatibility +7. **Default Backend**: Keep "matplotlib" as default for full backward compatibility From 3fe99697dab23d72082cbf38b769e370a0a31206 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Sun, 16 Nov 2025 20:18:32 -0500 Subject: [PATCH 04/29] add arviz to 3 functions --- pymc_marketing/mmm/plot.py | 359 +++++++++++++++++++++++++++++++++++++ 1 file changed, 359 insertions(+) diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 18537a175..213cd8dd7 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -174,15 +174,21 @@ from typing import Any import arviz as az +import arviz_plots as azp import matplotlib.pyplot as plt import numpy as np import xarray as xr +from arviz_base.labels import DimCoordLabeller, NoVarLabeller, mix_labellers +from arviz_plots import PlotCollection from matplotlib.axes import Axes from matplotlib.figure import Figure from numpy.typing import NDArray __all__ = ["MMMPlotSuite"] +WIDTH_PER_COL: float = 10.0 +HEIGHT_PER_ROW: float = 4.0 + class MMMPlotSuite: """Media Mix Model Plot Suite. @@ -368,10 +374,153 @@ def _dim_list_handler( dims_combos = [()] return dims_keys, dims_combos + def _resolve_backend(self, backend: str | None) -> str: + """Resolve backend parameter to actual backend string.""" + from pymc_marketing.mmm.config import mmm_config + + return backend or mmm_config["plot.backend"] + # ------------------------------------------------------------------------ # Main Plotting Methods # ------------------------------------------------------------------------ + def temp_posterior_predictive( + self, + var: list[str] | None = None, + idata: xr.Dataset | None = None, + hdi_prob: float = 0.85, + backend: str | None = None, + return_as_pc: bool = False, + ) -> tuple[Figure, list[Axes] | None] | PlotCollection: + """ + Plot posterior predictive distributions over time. + + Parameters + ---------- + var : list of str, optional + List of variable names to plot. If None, uses "y". + idata : xr.Dataset, optional + Dataset containing posterior predictive samples. + If None, uses self.idata.posterior_predictive. + hdi_prob : float, default 0.85 + Probability mass for HDI interval. + backend : str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config via mmm_config["plot.backend"]. + Default (via config) is "matplotlib". + return_as_pc : bool, default False + If True, returns PlotCollection object. + If False, returns tuple (figure, axes) for backward compatibility. + + Returns + ------- + PlotCollection or tuple + If return_as_pc=True, returns PlotCollection object. + If return_as_pc=False, returns (figure, axes) where: + - figure: backend-specific figure object (matplotlib.figure.Figure, + plotly.graph_objs.Figure, or bokeh.plotting.Figure) + - axes: list of matplotlib Axes if backend="matplotlib", else None + + Notes + ----- + When backend is not "matplotlib" and return_as_pc=False, the axes + element of the returned tuple will be None, as plotly and bokeh + do not have an equivalent axes list concept. + + Examples + -------- + >>> # Backward compatible usage (matplotlib) + >>> fig, axes = model.plot.posterior_predictive() + + >>> # Multi-backend with PlotCollection + >>> pc = model.plot.posterior_predictive(backend="plotly", return_as_pc=True) + >>> pc.show() + """ + if not 0 < hdi_prob < 1: + raise ValueError("HDI probability must be between 0 and 1.") + + # Resolve backend + backend = self._resolve_backend(backend) + + # 1. Retrieve or validate posterior_predictive data + pp_data = self._get_posterior_predictive_data(idata) + + # 2. Determine variables to plot + if var is None: + var = ["y"] + main_var = var[0] + + # 3. Identify additional dims & get all combos + ignored_dims = {"chain", "draw", "date", "sample"} + additional_dims, dim_combinations = self._get_additional_dim_combinations( + data=pp_data, variable=main_var, ignored_dims=ignored_dims + ) + + # 4. Prepare subplots + pc = azp.PlotCollection.wrap( + pp_data[main_var].to_dataset(), + cols=additional_dims, + col_wrap=1, + # figure_kwargs={"figsize": (WIDTH_PER_COL * 120, HEIGHT_PER_ROW * 200 * len(dim_combinations)), + # "figsize_units": "dots"}, + figure_kwargs={ + "figsize": ( + WIDTH_PER_COL * 110, + HEIGHT_PER_ROW * 120 * len(dim_combinations), + ), + "figsize_units": "dots", + # "vertical_spacing":0.1, + "sharex": True, + }, + backend=backend, + ) + + # plot hdi + hdi = pp_data.azstats.hdi(hdi_prob) + pc.map( + azp.visuals.fill_between_y, + x=pp_data["date"], + y_bottom=hdi.sel(ci_bound="lower"), + y_top=hdi.sel(ci_bound="upper"), + alpha=0.2, + color="C0", + ) + + # plot median line + pc.map( + azp.visuals.line_xy, + x=pp_data["date"], + y=pp_data.median(dim=["chain", "draw"]), + color="C0", + ) + + # add labels + pc.map(azp.visuals.labelled_x, text="Date") + pc.map(azp.visuals.labelled_y, text="Posterior Predictive") + pc.map( + azp.visuals.labelled_title, + subset_info=True, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), + ) + return MMMPlotSuite._return_pc_or_fig_axes(pc, return_as_pc, backend=backend) + + @staticmethod + def _return_pc_or_fig_axes(pc: PlotCollection, return_as_pc: bool, backend: str): + """Return PlotCollection or tuple of figure and axes.""" + if return_as_pc: + return pc + else: + # Extract figure from PlotCollection + fig = pc.viz.figure.data.item() + + # Only matplotlib has axes + if backend == "matplotlib": + axes = fig.get_axes() + else: + axes = None + + return fig, axes + def posterior_predictive( self, var: list[str] | None = None, @@ -462,6 +611,112 @@ def posterior_predictive( return fig, axes + def temp_contributions_over_time( + self, + var: list[str], + hdi_prob: float = 0.85, + dims: dict[str, str | int | list] | None = None, + backend: str | None = None, + return_as_pc: bool = False, + ) -> tuple[Figure, list[Axes] | None] | PlotCollection: + """Plot the time-series contributions for each variable in `var`. + + showing the median and the credible interval (default 85%). + Creates one subplot per combination of non-(chain/draw/date) dimensions + and places all variables on the same subplot. + + Parameters + ---------- + var : list of str + A list of variable names to plot from the posterior. + hdi_prob: float, optional + The probability mass of the highest density interval to be displayed. Default is 0.85. + dims : dict[str, str | int | list], optional + Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. + If provided, only the selected slice(s) will be plotted. + backend : str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config via mmm_config["plot.backend"]. + Default (via config) is "matplotlib". + return_as_pc : bool, default False + If True, returns PlotCollection object. + If False, returns tuple (figure, axes) for backward compatibility. + + Returns + ------- + PlotCollection or tuple + If return_as_pc=True, returns PlotCollection object. + If return_as_pc=False, returns (figure, axes) where: + - figure: backend-specific figure object (matplotlib.figure.Figure, + plotly.graph_objs.Figure, or bokeh.plotting.Figure) + - axes: list of matplotlib Axes if backend="matplotlib", else None + + Raises + ------ + ValueError + If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value. + """ + if not 0 < hdi_prob < 1: + raise ValueError("HDI probability must be between 0 and 1.") + + if not hasattr(self.idata, "posterior"): + raise ValueError( + "No posterior data found in 'self.idata'. " + "Please ensure 'self.idata' contains a 'posterior' group." + ) + + # Resolve backend + backend = self._resolve_backend(backend) + + main_var = var[0] + ignored_dims = {"chain", "draw", "date"} + da = self.idata.posterior[var] + additional_dims, dim_combinations = self._get_additional_dim_combinations( + data=da, variable=main_var, ignored_dims=ignored_dims + ) + + # 4. Prepare subplots + pc = azp.PlotCollection.wrap( + da, + cols=additional_dims, + col_wrap=1, + figure_kwargs={ + "figsize": (WIDTH_PER_COL, HEIGHT_PER_ROW * len(dim_combinations)), + "sharex": True, + }, + backend=backend, + ) + + # plot hdi + hdi = da.azstats.hdi(hdi_prob) + pc.map( + azp.visuals.fill_between_y, + x=da["date"], + y_bottom=hdi.sel(ci_bound="lower"), + y_top=hdi.sel(ci_bound="upper"), + alpha=0.2, + color="C0", + ) + + # plot median line + pc.map( + azp.visuals.line_xy, + x=da["date"], + y=da.median(dim=["chain", "draw"]), + color="C0", + ) + + # add labels + pc.map(azp.visuals.labelled_x, text="Date") + pc.map(azp.visuals.labelled_y, text="Posterior Value") + pc.map( + azp.visuals.labelled_title, + subset_info=True, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), + ) + + return MMMPlotSuite._return_pc_or_fig_axes(pc, return_as_pc, backend=backend) + def contributions_over_time( self, var: list[str], @@ -587,6 +842,110 @@ def contributions_over_time( return fig, axes + def temp_saturation_scatterplot( + self, + original_scale: bool = False, + dims: dict[str, str | int | list] | None = None, + backend: str | None = None, + return_as_pc: bool = False, + **kwargs, + ) -> tuple[Figure, list[Axes] | None] | PlotCollection: + """Plot the saturation curves for each channel. + + Creates a grid of subplots for each combination of channel and non-(date/channel) dimensions. + Optionally, subset by dims (single values or lists). + Each channel will have a consistent color across all subplots. + + Parameters + ---------- + original_scale: bool, optional + Whether to plot the original scale contributions. Default is False. + dims: dict[str, str | int | list], optional + Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. + If provided, only the selected slice(s) will be plotted. + backend: str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config via mmm_config["plot.backend"]. + Default (via config) is "matplotlib". + return_as_pc: bool, optional + If True, returns PlotCollection object. + If False, returns tuple (figure, axes) for backward compatibility. + + Returns + ------- + PlotCollection or tuple + If return_as_pc=True, returns PlotCollection object. + If return_as_pc=False, returns (figure, axes) where: + - figure: backend-specific figure object (matplotlib.figure.Figure, + plotly.graph_objs.Figure, or bokeh.plotting.Figure) + - axes: list of matplotlib Axes if backend="matplotlib", else None + """ + # Resolve backend + backend = self._resolve_backend(backend) + + if not hasattr(self.idata, "constant_data"): + raise ValueError( + "No 'constant_data' found in 'self.idata'. " + "Please ensure 'self.idata' contains the constant_data group." + ) + + # Identify additional dimensions beyond 'date' and 'channel' + cdims = self.idata.constant_data.channel_data.dims + additional_dims = [dim for dim in cdims if dim not in ("date", "channel")] + + # Validate dims and remove filtered dims from additional_dims + if dims: + self._validate_dims(dims, list(self.idata.constant_data.channel_data.dims)) + additional_dims = [d for d in additional_dims if d not in dims] + else: + self._validate_dims({}, list(self.idata.constant_data.channel_data.dims)) + + channel_contribution = ( + "channel_contribution_original_scale" + if original_scale + else "channel_contribution" + ) + + if original_scale and not hasattr(self.idata.posterior, channel_contribution): + raise ValueError( + f"""No posterior.{channel_contribution} data found in 'self.idata'. \n + Add a original scale deterministic:\n + mmm.add_original_scale_contribution_variable(\n + var=[\n + \"channel_contribution\",\n + ...\n + ]\n + )\n + """ + ) + + pc = azp.PlotCollection.grid( + self.idata.posterior[channel_contribution] + .mean(dim=["chain", "draw"]) + .to_dataset(), + cols=additional_dims, + rows=["channel"], + aes={"color": ["channel"]}, + figure_kwargs={"figsize": (14, 2 * 4)}, + backend=backend, + ) + pc.map( + azp.visuals.scatter_xy, + x=self.idata.constant_data.channel_data, + ) + pc.map(azp.visuals.labelled_x, text="Channel Data", ignore_aes={"color"}) + pc.map( + azp.visuals.labelled_y, text="Channel Contributions", ignore_aes={"color"} + ) + pc.map( + azp.visuals.labelled_title, + subset_info=True, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), + ignore_aes={"color"}, + ) + + return MMMPlotSuite._return_pc_or_fig_axes(pc, return_as_pc, backend=backend) + def saturation_scatterplot( self, original_scale: bool = False, From 278e9bb51f5b176bafb3fdc4537f24f0f47a1662 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Tue, 18 Nov 2025 20:25:41 -0500 Subject: [PATCH 05/29] a lot of changes to plot.py --- pymc_marketing/mmm/plot.py | 1517 +++++++----------------------------- 1 file changed, 281 insertions(+), 1236 deletions(-) diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 213cd8dd7..13f91e252 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -170,8 +170,6 @@ """ import itertools -from collections.abc import Iterable -from typing import Any import arviz as az import arviz_plots as azp @@ -384,14 +382,13 @@ def _resolve_backend(self, backend: str | None) -> str: # Main Plotting Methods # ------------------------------------------------------------------------ - def temp_posterior_predictive( + def posterior_predictive( self, var: list[str] | None = None, idata: xr.Dataset | None = None, hdi_prob: float = 0.85, backend: str | None = None, - return_as_pc: bool = False, - ) -> tuple[Figure, list[Axes] | None] | PlotCollection: + ) -> PlotCollection: """ Plot posterior predictive distributions over time. @@ -408,33 +405,11 @@ def temp_posterior_predictive( Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". If None, uses global config via mmm_config["plot.backend"]. Default (via config) is "matplotlib". - return_as_pc : bool, default False - If True, returns PlotCollection object. - If False, returns tuple (figure, axes) for backward compatibility. Returns ------- - PlotCollection or tuple - If return_as_pc=True, returns PlotCollection object. - If return_as_pc=False, returns (figure, axes) where: - - figure: backend-specific figure object (matplotlib.figure.Figure, - plotly.graph_objs.Figure, or bokeh.plotting.Figure) - - axes: list of matplotlib Axes if backend="matplotlib", else None - - Notes - ----- - When backend is not "matplotlib" and return_as_pc=False, the axes - element of the returned tuple will be None, as plotly and bokeh - do not have an equivalent axes list concept. - - Examples - -------- - >>> # Backward compatible usage (matplotlib) - >>> fig, axes = model.plot.posterior_predictive() + PlotCollection - >>> # Multi-backend with PlotCollection - >>> pc = model.plot.posterior_predictive(backend="plotly", return_as_pc=True) - >>> pc.show() """ if not 0 < hdi_prob < 1: raise ValueError("HDI probability must be between 0 and 1.") @@ -452,7 +427,7 @@ def temp_posterior_predictive( # 3. Identify additional dims & get all combos ignored_dims = {"chain", "draw", "date", "sample"} - additional_dims, dim_combinations = self._get_additional_dim_combinations( + additional_dims, _ = self._get_additional_dim_combinations( data=pp_data, variable=main_var, ignored_dims=ignored_dims ) @@ -461,15 +436,7 @@ def temp_posterior_predictive( pp_data[main_var].to_dataset(), cols=additional_dims, col_wrap=1, - # figure_kwargs={"figsize": (WIDTH_PER_COL * 120, HEIGHT_PER_ROW * 200 * len(dim_combinations)), - # "figsize_units": "dots"}, figure_kwargs={ - "figsize": ( - WIDTH_PER_COL * 110, - HEIGHT_PER_ROW * 120 * len(dim_combinations), - ), - "figsize_units": "dots", - # "vertical_spacing":0.1, "sharex": True, }, backend=backend, @@ -502,123 +469,15 @@ def temp_posterior_predictive( subset_info=True, labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), ) - return MMMPlotSuite._return_pc_or_fig_axes(pc, return_as_pc, backend=backend) - - @staticmethod - def _return_pc_or_fig_axes(pc: PlotCollection, return_as_pc: bool, backend: str): - """Return PlotCollection or tuple of figure and axes.""" - if return_as_pc: - return pc - else: - # Extract figure from PlotCollection - fig = pc.viz.figure.data.item() - - # Only matplotlib has axes - if backend == "matplotlib": - axes = fig.get_axes() - else: - axes = None - - return fig, axes - - def posterior_predictive( - self, - var: list[str] | None = None, - idata: xr.Dataset | None = None, - hdi_prob: float = 0.85, - ) -> tuple[Figure, NDArray[Axes]]: - """Plot time series from the posterior predictive distribution. - - By default, if both `var` and `idata` are not provided, uses - `self.idata.posterior_predictive` and defaults the variable to `["y"]`. - - Parameters - ---------- - var : list of str, optional - A list of variable names to plot. Default is ["y"] if not provided. - idata : xarray.Dataset, optional - The posterior predictive dataset to plot. If not provided, tries to - use `self.idata.posterior_predictive`. - hdi_prob: float, optional - The probability mass of the highest density interval to be displayed. Default is 0.85. - - Returns - ------- - fig : matplotlib.figure.Figure - The Figure object containing the subplots. - axes : np.ndarray of matplotlib.axes.Axes - Array of Axes objects corresponding to each subplot row. - - Raises - ------ - ValueError - If no `idata` is provided and `self.idata.posterior_predictive` does - not exist, instructing the user to run `MMM.sample_posterior_predictive()`. - If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value. - """ - if not 0 < hdi_prob < 1: - raise ValueError("HDI probability must be between 0 and 1.") - # 1. Retrieve or validate posterior_predictive data - pp_data = self._get_posterior_predictive_data(idata) - - # 2. Determine variables to plot - if var is None: - var = ["y"] - main_var = var[0] - - # 3. Identify additional dims & get all combos - ignored_dims = {"chain", "draw", "date", "sample"} - additional_dims, dim_combinations = self._get_additional_dim_combinations( - data=pp_data, variable=main_var, ignored_dims=ignored_dims - ) - - # 4. Prepare subplots - fig, axes = self._init_subplots(n_subplots=len(dim_combinations), ncols=1) - - # 5. Loop over dimension combinations - for row_idx, combo in enumerate(dim_combinations): - ax = axes[row_idx][0] + return pc - # Build indexers - indexers = ( - dict(zip(additional_dims, combo, strict=False)) - if additional_dims - else {} - ) - - # 6. Plot each requested variable - for v in var: - if v not in pp_data: - raise ValueError( - f"Variable '{v}' not in the posterior_predictive dataset." - ) - - data = pp_data[v].sel(**indexers) - # Sum leftover dims, stack chain+draw if needed - data = self._reduce_and_stack(data, ignored_dims) - ax = self._add_median_and_hdi(ax, data, v, hdi_prob=hdi_prob) - - # 7. Subplot title & labels - title = self._build_subplot_title( - dims=additional_dims, - combo=combo, - fallback_title="Posterior Predictive Time Series", - ) - ax.set_title(title) - ax.set_xlabel("Date") - ax.set_ylabel("Posterior Predictive") - ax.legend(loc="best") - - return fig, axes - - def temp_contributions_over_time( + def contributions_over_time( self, var: list[str], hdi_prob: float = 0.85, dims: dict[str, str | int | list] | None = None, backend: str | None = None, - return_as_pc: bool = False, - ) -> tuple[Figure, list[Axes] | None] | PlotCollection: + ) -> PlotCollection: """Plot the time-series contributions for each variable in `var`. showing the median and the credible interval (default 85%). @@ -638,18 +497,10 @@ def temp_contributions_over_time( Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". If None, uses global config via mmm_config["plot.backend"]. Default (via config) is "matplotlib". - return_as_pc : bool, default False - If True, returns PlotCollection object. - If False, returns tuple (figure, axes) for backward compatibility. Returns ------- - PlotCollection or tuple - If return_as_pc=True, returns PlotCollection object. - If return_as_pc=False, returns (figure, axes) where: - - figure: backend-specific figure object (matplotlib.figure.Figure, - plotly.graph_objs.Figure, or bokeh.plotting.Figure) - - axes: list of matplotlib Axes if backend="matplotlib", else None + PlotCollection Raises ------ @@ -671,7 +522,7 @@ def temp_contributions_over_time( main_var = var[0] ignored_dims = {"chain", "draw", "date"} da = self.idata.posterior[var] - additional_dims, dim_combinations = self._get_additional_dim_combinations( + additional_dims, _ = self._get_additional_dim_combinations( data=da, variable=main_var, ignored_dims=ignored_dims ) @@ -681,7 +532,6 @@ def temp_contributions_over_time( cols=additional_dims, col_wrap=1, figure_kwargs={ - "figsize": (WIDTH_PER_COL, HEIGHT_PER_ROW * len(dim_combinations)), "sharex": True, }, backend=backend, @@ -715,141 +565,14 @@ def temp_contributions_over_time( labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), ) - return MMMPlotSuite._return_pc_or_fig_axes(pc, return_as_pc, backend=backend) - - def contributions_over_time( - self, - var: list[str], - hdi_prob: float = 0.85, - dims: dict[str, str | int | list] | None = None, - ) -> tuple[Figure, NDArray[Axes]]: - """Plot the time-series contributions for each variable in `var`. - - showing the median and the credible interval (default 85%). - Creates one subplot per combination of non-(chain/draw/date) dimensions - and places all variables on the same subplot. - - Parameters - ---------- - var : list of str - A list of variable names to plot from the posterior. - hdi_prob: float, optional - The probability mass of the highest density interval to be displayed. Default is 0.85. - dims : dict[str, str | int | list], optional - Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. - If provided, only the selected slice(s) will be plotted. - - Returns - ------- - fig : matplotlib.figure.Figure - The Figure object containing the subplots. - axes : np.ndarray of matplotlib.axes.Axes - Array of Axes objects corresponding to each subplot row. - - Raises - ------ - ValueError - If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value. - """ - if not 0 < hdi_prob < 1: - raise ValueError("HDI probability must be between 0 and 1.") - - if not hasattr(self.idata, "posterior"): - raise ValueError( - "No posterior data found in 'self.idata'. " - "Please ensure 'self.idata' contains a 'posterior' group." - ) + return pc - main_var = var[0] - all_dims = list(self.idata.posterior[main_var].dims) # type: ignore - ignored_dims = {"chain", "draw", "date"} - additional_dims = [d for d in all_dims if d not in ignored_dims] - - coords = { - key: value.to_numpy() - for key, value in self.idata.posterior[var].coords.items() - } - - # Apply user-specified filters (`dims`) - if dims: - self._validate_dims(dims=dims, all_dims=all_dims) - # Remove filtered dims from the combinations - additional_dims = [d for d in additional_dims if d not in dims] - else: - self._validate_dims({}, all_dims) - # additional_dims = [d for d in additional_dims if d not in dims] - - # Identify combos for remaining dims - if additional_dims: - additional_coords = [ - self.idata.posterior.coords[dim].values # type: ignore - for dim in additional_dims - ] - dim_combinations = list(itertools.product(*additional_coords)) - else: - dim_combinations = [()] - - # If dims contains lists, build all combinations for those as well - dims_keys, dims_combos = self._dim_list_handler(dims) - - # Prepare subplots: one for each combo of dims_lists and additional_dims - total_combos = list(itertools.product(dims_combos, dim_combinations)) - fig, axes = self._init_subplots(len(total_combos), ncols=1) - - for row_idx, (dims_combo, addl_combo) in enumerate(total_combos): - ax = axes[row_idx][0] - # Build indexers for dims and additional_dims - indexers = ( - dict(zip(additional_dims, addl_combo, strict=False)) - if additional_dims - else {} - ) - if dims: - # For dims with lists, use the current value from dims_combo - for i, k in enumerate(dims_keys): - indexers[k] = dims_combo[i] - # For dims with single values, use as is - for k, v in (dims or {}).items(): - if k not in dims_keys: - indexers[k] = v - - # Plot posterior median and HDI for each var - for v in var: - data = self.idata.posterior[v] - missing_coords = { - key: value for key, value in coords.items() if key not in data.dims - } - data = data.expand_dims(**missing_coords) - data = data.sel(**indexers) # apply slice - data = self._reduce_and_stack( - data, dims_to_ignore={"date", "chain", "draw", "sample"} - ) - ax = self._add_median_and_hdi(ax, data, v, hdi_prob=hdi_prob) - - # Title includes both fixed and combo dims - title_dims = ( - list(dims.keys()) + additional_dims if dims else additional_dims - ) - title_combo = tuple(indexers[k] for k in title_dims) - - title = self._build_subplot_title( - dims=title_dims, combo=title_combo, fallback_title="Time Series" - ) - ax.set_title(title) - ax.set_xlabel("Date") - ax.set_ylabel("Posterior Value") - ax.legend(loc="best") - - return fig, axes - - def temp_saturation_scatterplot( + def saturation_scatterplot( self, original_scale: bool = False, dims: dict[str, str | int | list] | None = None, backend: str | None = None, - return_as_pc: bool = False, - **kwargs, - ) -> tuple[Figure, list[Axes] | None] | PlotCollection: + ) -> PlotCollection: """Plot the saturation curves for each channel. Creates a grid of subplots for each combination of channel and non-(date/channel) dimensions. @@ -867,18 +590,10 @@ def temp_saturation_scatterplot( Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". If None, uses global config via mmm_config["plot.backend"]. Default (via config) is "matplotlib". - return_as_pc: bool, optional - If True, returns PlotCollection object. - If False, returns tuple (figure, axes) for backward compatibility. Returns ------- - PlotCollection or tuple - If return_as_pc=True, returns PlotCollection object. - If return_as_pc=False, returns (figure, axes) where: - - figure: backend-specific figure object (matplotlib.figure.Figure, - plotly.graph_objs.Figure, or bokeh.plotting.Figure) - - axes: list of matplotlib Axes if backend="matplotlib", else None + PlotCollection """ # Resolve backend backend = self._resolve_backend(backend) @@ -926,7 +641,6 @@ def temp_saturation_scatterplot( cols=additional_dims, rows=["channel"], aes={"color": ["channel"]}, - figure_kwargs={"figsize": (14, 2 * 4)}, backend=backend, ) pc.map( @@ -944,161 +658,7 @@ def temp_saturation_scatterplot( ignore_aes={"color"}, ) - return MMMPlotSuite._return_pc_or_fig_axes(pc, return_as_pc, backend=backend) - - def saturation_scatterplot( - self, - original_scale: bool = False, - dims: dict[str, str | int | list] | None = None, - **kwargs, - ) -> tuple[Figure, NDArray[Axes]]: - """Plot the saturation curves for each channel. - - Creates a grid of subplots for each combination of channel and non-(date/channel) dimensions. - Optionally, subset by dims (single values or lists). - Each channel will have a consistent color across all subplots. - """ - if not hasattr(self.idata, "constant_data"): - raise ValueError( - "No 'constant_data' found in 'self.idata'. " - "Please ensure 'self.idata' contains the constant_data group." - ) - - # Identify additional dimensions beyond 'date' and 'channel' - cdims = self.idata.constant_data.channel_data.dims - additional_dims = [dim for dim in cdims if dim not in ("date", "channel")] - - # Validate dims and remove filtered dims from additional_dims - if dims: - self._validate_dims(dims, list(self.idata.constant_data.channel_data.dims)) - additional_dims = [d for d in additional_dims if d not in dims] - else: - self._validate_dims({}, list(self.idata.constant_data.channel_data.dims)) - - # Build all combinations for dims with lists - dims_keys, dims_combos = self._dim_list_handler(dims) - - # Build all combinations for remaining dims - if additional_dims: - additional_coords = [ - self.idata.constant_data.coords[d].values for d in additional_dims - ] - additional_combinations = list(itertools.product(*additional_coords)) - else: - additional_combinations = [()] - - channels = self.idata.constant_data.coords["channel"].values - n_channels = len(channels) - n_addl = len(additional_combinations) - n_dims = len(dims_combos) - - # For most use cases, n_dims will be 1, so grid is channels x additional_combinations - # If dims_combos > 1, treat as extra axis (rare, but possible) - nrows = n_channels - ncols = n_addl * n_dims - total_combos = list( - itertools.product(channels, dims_combos, additional_combinations) - ) - n_subplots = len(total_combos) - - # Assign a color to each channel - channel_colors = {ch: f"C{i}" for i, ch in enumerate(channels)} - - # Prepare subplots as a grid - fig, axes = plt.subplots( - nrows=nrows, - ncols=ncols, - figsize=( - kwargs.get("width_per_col", 8) * ncols, - kwargs.get("height_per_row", 4) * nrows, - ), - squeeze=False, - ) - - channel_contribution = ( - "channel_contribution_original_scale" - if original_scale - else "channel_contribution" - ) - - if original_scale and not hasattr(self.idata.posterior, channel_contribution): - raise ValueError( - f"""No posterior.{channel_contribution} data found in 'self.idata'. \n - Add a original scale deterministic:\n - mmm.add_original_scale_contribution_variable(\n - var=[\n - \"channel_contribution\",\n - ...\n - ]\n - )\n - """ - ) - - for _idx, (channel, dims_combo, addl_combo) in enumerate(total_combos): - # Compute subplot position - row = list(channels).index(channel) - # If dims_combos > 1, treat as extra axis (columns: addl * dims) - if n_dims > 1: - col = list(additional_combinations).index(addl_combo) * n_dims + list( - dims_combos - ).index(dims_combo) - else: - col = list(additional_combinations).index(addl_combo) - ax = axes[row][col] - - # Build indexers for dims and additional_dims - indexers = ( - dict(zip(additional_dims, addl_combo, strict=False)) - if additional_dims - else {} - ) - if dims: - for i, k in enumerate(dims_keys): - indexers[k] = dims_combo[i] - for k, v in (dims or {}).items(): - if k not in dims_keys: - indexers[k] = v - indexers["channel"] = channel - - # Select X data (constant_data) - x_data = self.idata.constant_data.channel_data.sel(**indexers) - # Select Y data (posterior contributions) and scale if needed - y_data = self.idata.posterior[channel_contribution].sel(**indexers) - y_data = y_data.mean(dim=[d for d in y_data.dims if d in ("chain", "draw")]) - x_data = x_data.broadcast_like(y_data) - y_data = y_data.broadcast_like(x_data) - ax.scatter( - x_data.values.flatten(), - y_data.values.flatten(), - alpha=0.8, - color=channel_colors[channel], - label=str(channel), - ) - # Build subplot title - title_dims = ( - ["channel"] + (list(dims.keys()) if dims else []) + additional_dims - ) - title_combo = ( - channel, - *[indexers[k] for k in title_dims if k != "channel"], - ) - title = self._build_subplot_title( - dims=title_dims, - combo=title_combo, - fallback_title="Channel Saturation Curve", - ) - ax.set_title(title) - ax.set_xlabel("Channel Data (X)") - ax.set_ylabel("Channel Contributions (Y)") - ax.legend(loc="best") - - # Hide any unused axes (if grid is larger than needed) - for i in range(nrows): - for j in range(ncols): - if i * ncols + j >= n_subplots: - axes[i][j].set_visible(False) - - return fig, axes + return pc def saturation_curves( self, @@ -1107,12 +667,9 @@ def saturation_curves( n_samples: int = 10, hdi_probs: float | list[float] | None = None, random_seed: np.random.Generator | None = None, - colors: Iterable[str] | None = None, - subplot_kwargs: dict | None = None, - rc_params: dict | None = None, dims: dict[str, str | int | list] | None = None, - **plot_kwargs, - ) -> tuple[plt.Figure, np.ndarray]: + backend: str | None = None, + ) -> PlotCollection: """ Overlay saturation‑curve scatter‑plots with posterior‑predictive sample curves and HDI bands. @@ -1131,32 +688,26 @@ def saturation_curves( If None, uses ArviZ's default (0.94). random_seed : np.random.Generator, optional RNG for reproducible sampling. If None, uses `np.random.default_rng()`. - colors : iterable of str, optional - Colors for the sample & HDI plots. - subplot_kwargs : dict, optional - Passed to `plt.subplots` (e.g. `{"figsize": (10,8)}`). - Merged with the function's own default sizing. - rc_params : dict, optional - Temporary `matplotlib.rcParams` for this plot. - Example keys: `"xtick.labelsize"`, `"ytick.labelsize"`, - `"axes.labelsize"`, `"axes.titlesize"`. dims : dict[str, str | int | list], optional Dimension filters to apply. Example: {"country": ["US", "UK"], "region": "X"}. If provided, only the selected slice(s) will be plotted. - **plot_kwargs - Any other kwargs forwarded to `plot_curve` - (for instance `same_axes=True`, `legend=True`, etc.). + backend: str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config via mmm_config["plot.backend"]. + Default (via config) is "matplotlib". Returns ------- - fig : plt.Figure - Matplotlib figure with your grid. - axes : np.ndarray of plt.Axes - Array of shape `(n_channels, n_geo)`. - + PlotCollection + + Example use: + >>> curve = model.saturation.sample_curve( + >>> model.idata.posterior[["saturation_beta", "saturation_lam"]], max_value=2 + >>> ) + >>> pc = model.plot.saturation_curves(curve, original_scale=True, n_samples=10, + >>> hdi_probs=[0.9, 0.7], random_seed=rng) + >>> pc.show() """ - from pymc_marketing.plot import plot_hdi, plot_samples - if not hasattr(self.idata, "constant_data"): raise ValueError( "No 'constant_data' found in 'self.idata'. " @@ -1181,9 +732,11 @@ def saturation_curves( " )\n" """ ) - curve_data = ( - curve * self.idata.constant_data.target_scale if original_scale else curve - ) + if original_scale: + curve_data = curve * self.idata.constant_data.target_scale + curve_data["x"] = curve_data["x"] * self.idata.constant_data.channel_scale + else: + curve_data = curve curve_data = curve_data.rename("saturation_curve") # — 1. figure out grid shape based on scatter data dimensions / identify dims and combos @@ -1196,167 +749,54 @@ def saturation_curves( additional_dims = [d for d in additional_dims if d not in dims] else: self._validate_dims({}, all_dims) - # Build all combinations for dims with lists - dims_keys, dims_combos = self._dim_list_handler(dims) - # Build all combinations for remaining dims - if additional_dims: - additional_coords = [ - self.idata.constant_data.coords[d].values for d in additional_dims - ] - additional_combinations = list(itertools.product(*additional_coords)) - else: - additional_combinations = [()] - channels = self.idata.constant_data.coords["channel"].values - n_channels = len(channels) - n_addl = len(additional_combinations) - n_dims = len(dims_combos) - nrows = n_channels - ncols = n_addl * n_dims - total_combos = list( - itertools.product(channels, dims_combos, additional_combinations) + + # create the saturation scatterplot + pc = self.saturation_scatterplot( + original_scale=original_scale, dims=dims, backend=backend ) - n_subplots = len(total_combos) - - # — 2. merge subplot_kwargs — - user_subplot = subplot_kwargs or {} - - # Handle user-specified ncols/nrows - if "ncols" in user_subplot: - # User specified ncols, calculate nrows - ncols = user_subplot["ncols"] - nrows = int(np.ceil(n_subplots / ncols)) - user_subplot.pop("ncols") # Remove to avoid conflict - elif "nrows" in user_subplot: - # User specified nrows, calculate ncols - nrows = user_subplot["nrows"] - ncols = int(np.ceil(n_subplots / nrows)) - user_subplot.pop("nrows") # Remove to avoid conflict - default_subplot = {"figsize": (ncols * 4, nrows * 3)} - subkw = {**default_subplot, **user_subplot} - # — 3. create subplots ourselves — - rc_params = rc_params or {} - with plt.rc_context(rc_params): - fig, axes = plt.subplots(nrows=nrows, ncols=ncols, **subkw) - # ensure a 2D array - if nrows == 1 and ncols == 1: - axes = np.array([[axes]]) - elif nrows == 1: - axes = axes.reshape(1, -1) - elif ncols == 1: - axes = axes.reshape(-1, 1) - # Flatten axes for easier iteration - axes_flat = axes.flatten() - if colors is None: - colors = [f"C{i}" for i in range(n_channels)] - elif not isinstance(colors, list): - colors = list(colors) - subplot_idx = 0 - for _idx, (ch, dims_combo, addl_combo) in enumerate(total_combos): - if subplot_idx >= len(axes_flat): - break - ax = axes_flat[subplot_idx] - subplot_idx += 1 - # Build indexers for dims and additional_dims - indexers = ( - dict(zip(additional_dims, addl_combo, strict=False)) - if additional_dims - else {} - ) - if dims: - for i, k in enumerate(dims_keys): - indexers[k] = dims_combo[i] - for k, v in (dims or {}).items(): - if k not in dims_keys: - indexers[k] = v - indexers["channel"] = ch - # Select and broadcast curve data for this channel - curve_idx = { - dim: val for dim, val in indexers.items() if dim in curve_data.dims - } - subplot_curve = curve_data.sel(**curve_idx) - if original_scale: - valid_idx = { - k: v - for k, v in indexers.items() - if k in self.idata.constant_data.channel_scale.dims - } - channel_scale = self.idata.constant_data.channel_scale.sel(**valid_idx) - x_original = subplot_curve.coords["x"] * channel_scale - subplot_curve = subplot_curve.assign_coords(x=x_original) - if n_samples > 0: - plot_samples( - subplot_curve, - non_grid_names="x", - n=n_samples, - rng=random_seed, - axes=np.array([[ax]]), - colors=[colors[list(channels).index(ch)]], - same_axes=False, - legend=False, - **plot_kwargs, - ) - if hdi_probs is not None: - # Robustly handle hdi_probs as float, list, tuple, or np.ndarray - if isinstance(hdi_probs, (float, int)): - hdi_probs_iter = [hdi_probs] - elif isinstance(hdi_probs, (list, tuple, np.ndarray)): - hdi_probs_iter = hdi_probs - else: - raise TypeError( - "hdi_probs must be a float, list, tuple, or np.ndarray" - ) - for hdi_prob in hdi_probs_iter: - plot_hdi( - subplot_curve, - non_grid_names="x", - hdi_prob=hdi_prob, - axes=np.array([[ax]]), - colors=[colors[list(channels).index(ch)]], - same_axes=False, - legend=False, - **plot_kwargs, - ) - x_data = self.idata.constant_data.channel_data.sel(**indexers) - y = ( - self.idata.posterior[contrib_var] - .sel(**indexers) - .mean( - dim=[ - d - for d in self.idata.posterior[contrib_var].dims - if d in ("chain", "draw") - ] + + # add the hdi bands + if hdi_probs is not None: + # Robustly handle hdi_probs as float, list, tuple, or np.ndarray + if isinstance(hdi_probs, (float, int)): + hdi_probs_iter = [hdi_probs] + elif isinstance(hdi_probs, (list, tuple, np.ndarray)): + hdi_probs_iter = hdi_probs + else: + raise TypeError("hdi_probs must be a float, list, tuple, or np.ndarray") + for hdi_prob in hdi_probs_iter: + hdi = curve_data.azstats.hdi(hdi_prob) + pc.map( + azp.visuals.fill_between_y, + x=curve_data["x"], + y_bottom=hdi.sel(ci_bound="lower"), + y_top=hdi.sel(ci_bound="upper"), + alpha=0.2, ) + + if n_samples > 0: + ## sample the curves + rng = np.random.default_rng(random_seed) + + # Stack the two dimensions + stacked = curve_data.stack(sample=("chain", "draw")) + + # Sample from the stacked dimension + idx = rng.choice(stacked.sizes["sample"], size=n_samples, replace=False) + + # Select and unstack + sampled_curves = stacked.isel(sample=idx) + + # plot the sampled curves + pc.map( + azp.visuals.multiple_lines, x_dim="x", data=sampled_curves, alpha=0.2 ) - x_data, y = x_data.broadcast_like(y), y.broadcast_like(x_data) - ax.scatter( - x_data.values.flatten(), - y.values.flatten(), - alpha=0.8, - color=colors[list(channels).index(ch)], - ) - title_dims = ( - ["channel"] + (list(dims.keys()) if dims else []) + additional_dims - ) - title_combo = ( - ch, - *[indexers[k] for k in title_dims if k != "channel"], - ) - title = self._build_subplot_title( - dims=title_dims, - combo=title_combo, - fallback_title="Channel Saturation Curves", - ) - ax.set_title(title) - ax.set_xlabel("Channel Data (X)") - ax.set_ylabel("Channel Contribution (Y)") - for ax_idx in range(subplot_idx, len(axes_flat)): - axes_flat[ax_idx].set_visible(False) - return fig, axes + + return pc def saturation_curves_scatter( self, original_scale: bool = False, **kwargs - ) -> tuple[Figure, NDArray[Axes]]: + ) -> PlotCollection: """ Plot scatter plots of channel contributions vs. channel data. @@ -1376,10 +816,7 @@ def saturation_curves_scatter( Returns ------- - fig : plt.Figure - The matplotlib figure. - axes : np.ndarray - Array of matplotlib axes. + PlotCollection """ import warnings @@ -1393,47 +830,36 @@ def saturation_curves_scatter( # are not used by saturation_scatterplot, so we don't pass them return self.saturation_scatterplot(original_scale=original_scale, **kwargs) - def budget_allocation( + def budget_allocation_roas( self, samples: xr.Dataset, - scale_factor: float | None = None, - figsize: tuple[float, float] = (12, 6), - ax: plt.Axes | None = None, - original_scale: bool = True, dims: dict[str, str | int | list] | None = None, - ) -> tuple[Figure, plt.Axes] | tuple[Figure, np.ndarray]: - """Plot the budget allocation and channel contributions. - - Creates a bar chart comparing allocated spend and channel contributions - for each channel. If additional dimensions besides 'channel' are present, - creates a subplot for each combination of these dimensions. + dims_to_group_by: list[str] | str | None = None, + backend: str | None = None, + ) -> PlotCollection: + """Plot the ROI distribution of a given a response distribution and a budget allocation. Parameters ---------- samples : xr.Dataset The dataset containing the channel contributions and allocation values. Expected to have 'channel_contribution' and 'allocation' variables. - scale_factor : float, optional - Scale factor to convert to original scale, if original_scale=True. - If None and original_scale=True, assumes scale_factor=1. - figsize : tuple[float, float], optional - The size of the figure to be created. Default is (12, 6). - ax : plt.Axes, optional - The axis to plot on. If None, a new figure and axis will be created. - Only used when no extra dimensions are present. - original_scale : bool, optional - A boolean flag to determine if the values should be plotted in their - original scale. Default is True. dims : dict[str, str | int | list], optional Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. If provided, only the selected slice(s) will be plotted. + dims_to_group_by : list[str] | str | None, optional + Dimension(s) to group by for plotting purposes. + When a dimension is specified, all the ROAs distributions for each coordinate of that dimension will be + plotted together in a single plot. This is useful for comparing the ROAs distributions. + If None, will not group by any dimensions (i.e. each distribution will be plotted separately). + If a single string, will group by that dimension. + If a list of strings, will group by each of those dimensions. + backend : str | None, optional + Backend to use for plotting. If None, will use the global backend configuration. Returns ------- - fig : matplotlib.figure.Figure - The Figure object containing the plot. - axes : matplotlib.axes.Axes or numpy.ndarray of matplotlib.axes.Axes - The Axes object with the plot, or array of Axes for multiple subplots. + PlotCollection """ # Get the channels from samples if "channel" not in samples.dims: @@ -1442,11 +868,9 @@ def budget_allocation( ) # Check for required variables in samples - if not any( - "channel_contribution" in var_name for var_name in samples.data_vars - ): + if "channel_contribution_original_scale" not in samples.data_vars: raise ValueError( - "Expected a variable containing 'channel_contribution' in samples, but none found." + "Expected a variable containing 'channel_contribution_original_scale' in samples, but none found." ) if "allocation" not in samples: raise ValueError( @@ -1454,11 +878,7 @@ def budget_allocation( ) # Find the variable containing 'channel_contribution' in its name - channel_contrib_var = next( - var_name - for var_name in samples.data_vars - if "channel_contribution" in var_name - ) + channel_contrib_var = "channel_contribution_original_scale" all_dims = list(samples.dims) # Validate dims @@ -1467,184 +887,54 @@ def budget_allocation( else: self._validate_dims({}, all_dims) - # Handle list-valued dims: build all combinations - dims_keys, dims_combos = self._dim_list_handler(dims) - - # After filtering with dims, only use extra dims not in dims and not ignored for subplotting - ignored_dims = {"channel", "date", "sample", "chain", "draw"} - channel_contribution_dims = list(samples[channel_contrib_var].dims) - extra_dims = [ - d - for d in channel_contribution_dims - if d not in ignored_dims and d not in (dims or {}) - ] - - # Identify combos for remaining dims - if extra_dims: - extra_coords = [samples.coords[dim].values for dim in extra_dims] - extra_combos = list(itertools.product(*extra_coords)) - else: - extra_combos = [()] - - # Prepare subplots: one for each combo of dims_lists and extra_dims - total_combos = list(itertools.product(dims_combos, extra_combos)) - n_subplots = len(total_combos) - if n_subplots == 1 and ax is not None: - axes = np.array([[ax]]) - fig = ax.get_figure() - else: - fig, axes = self._init_subplots( - n_subplots=n_subplots, - ncols=1, - width_per_col=figsize[0], - height_per_row=figsize[1], - ) - - for row_idx, (dims_combo, extra_combo) in enumerate(total_combos): - ax_ = axes[row_idx][0] - # Build indexers for dims and extra_dims - indexers = ( - dict(zip(extra_dims, extra_combo, strict=False)) if extra_dims else {} - ) - if dims: - # For dims with lists, use the current value from dims_combo - for i, k in enumerate(dims_keys): - indexers[k] = dims_combo[i] - # For dims with single values, use as is - for k, v in (dims or {}).items(): - if k not in dims_keys: - indexers[k] = v - - # Select channel contributions for this subplot - channel_contrib_data = samples[channel_contrib_var].sel(**indexers) - allocation_data = samples.allocation - # Only select dims that exist in allocation - allocation_indexers = { - k: v for k, v in indexers.items() if k in allocation_data.dims - } - allocation_data = allocation_data.sel(**allocation_indexers) - - # Average over all dims except channel (and those used for this subplot) - used_dims = set(indexers.keys()) | {"channel"} - reduction_dims = [ - dim for dim in channel_contrib_data.dims if dim not in used_dims - ] - channel_contribution = channel_contrib_data.mean( - dim=reduction_dims - ).to_numpy() - if channel_contribution.ndim > 1: - channel_contribution = channel_contribution.flatten() - if original_scale and scale_factor is not None: - channel_contribution *= scale_factor - - allocation_used_dims = set(allocation_indexers.keys()) | {"channel"} - allocation_reduction_dims = [ - dim for dim in allocation_data.dims if dim not in allocation_used_dims - ] - if allocation_reduction_dims: - allocated_spend = allocation_data.mean( - dim=allocation_reduction_dims - ).to_numpy() - else: - allocated_spend = allocation_data.to_numpy() - if allocated_spend.ndim > 1: - allocated_spend = allocated_spend.flatten() - - self._plot_budget_allocation_bars( - ax_, - samples.coords["channel"].values, - allocated_spend, - channel_contribution, - ) - - # Build subplot title - title_dims = (list(dims.keys()) if dims else []) + extra_dims - title_combo = tuple(indexers[k] for k in title_dims) - title = self._build_subplot_title( - dims=title_dims, - combo=title_combo, - fallback_title="Budget Allocation", - ) - ax_.set_title(title) - - fig.tight_layout() - return fig, axes if n_subplots > 1 else (fig, axes[0][0]) - - def _plot_budget_allocation_bars( - self, - ax: plt.Axes, - channels: NDArray, - allocated_spend: NDArray, - channel_contribution: NDArray, - ) -> None: - """Plot budget allocation bars on a given axis. + channel_contribution = samples[channel_contrib_var].sum(dim="date") + channel_contribution.name = "channel_contribution" + + from arviz_base import convert_to_datatree + + roa_da = channel_contribution / samples.allocation + roa_dt = convert_to_datatree(roa_da) + if isinstance(dims_to_group_by, str): + dims_to_group_by = [dims_to_group_by] + if dims_to_group_by: + grouped = {"all": roa_dt.copy()} + for dim in dims_to_group_by: + new_grouped = {} + for curr_k, curr_group in grouped.items(): + curr_coords = curr_group.posterior.coords[dim].values + new_grouped.update( + { + f"{curr_k}, {dim}: {key}": curr_group.sel({dim: key}) + for key in curr_coords + } + ) + grouped = new_grouped - Parameters - ---------- - ax : plt.Axes - The axis to plot on. - channels : NDArray - Array of channel names. - allocated_spend : NDArray - Array of allocated spend values. - channel_contribution : NDArray - Array of channel contribution values. - """ - bar_width = 0.35 - opacity = 0.7 - index = range(len(channels)) - - # Plot allocated spend - bars1 = ax.bar( - index, - allocated_spend, - bar_width, - color="C0", - alpha=opacity, - label="Allocated Spend", - ) + grouped_roa_dt = {} + for k, v in grouped.items(): + grouped_roa_dt[k[5:]] = v + else: + grouped_roa_dt = roa_dt - # Create twin axis for contributions - ax2 = ax.twinx() - - # Plot contributions - bars2 = ax2.bar( - [i + bar_width for i in index], - channel_contribution, - bar_width, - color="C1", - alpha=opacity, - label="Channel Contribution", + pc = azp.plot_dist( + grouped_roa_dt, + kind="kde", + sample_dims=["sample"], + backend=backend, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), ) - # Labels and formatting - ax.set_xlabel("Channels") - ax.set_ylabel("Allocated Spend", color="C0", labelpad=10) - ax2.set_ylabel("Channel Contributions", color="C1", labelpad=10) - - # Set x-ticks in the middle of the bars - ax.set_xticks([i + bar_width / 2 for i in index]) - ax.set_xticklabels(channels) - ax.tick_params(axis="x", rotation=90) + if dims_to_group_by: + pc.add_legend(dim="model", title="") - # Turn off grid and add legend - ax.grid(False) - ax2.grid(False) - - bars = [bars1, bars2] - labels = ["Allocated Spend", "Channel Contributions"] - ax.legend(bars, labels, loc="best") + return pc def allocated_contribution_by_channel_over_time( self, samples: xr.Dataset, - scale_factor: float | None = None, - lower_quantile: float = 0.025, - upper_quantile: float = 0.975, - original_scale: bool = True, - figsize: tuple[float, float] = (10, 6), - ax: plt.Axes | None = None, - ) -> tuple[Figure, plt.Axes | NDArray[Axes]]: + hdi_prob: float = 0.85, + backend: str | None = None, + ) -> PlotCollection: """Plot the allocated contribution by channel with uncertainty intervals. This function visualizes the mean allocated contributions by channel along with @@ -1658,27 +948,14 @@ def allocated_contribution_by_channel_over_time( The dataset containing the samples of channel contributions. Expected to have 'channel_contribution' variable with dimensions 'channel', 'date', and 'sample'. - scale_factor : float, optional - Scale factor to convert to original scale, if original_scale=True. - If None and original_scale=True, assumes scale_factor=1. - lower_quantile : float, optional - The lower quantile for the uncertainty interval. Default is 0.025. - upper_quantile : float, optional - The upper quantile for the uncertainty interval. Default is 0.975. - original_scale : bool, optional - If True, the contributions are plotted on the original scale. Default is True. - figsize : tuple[float, float], optional - The size of the figure to be created. Default is (10, 6). - ax : plt.Axes, optional - The axis to plot on. If None, a new figure and axis will be created. - Only used when no extra dimensions are present. + hdi_prob : float, optional + The probability mass of the highest density interval to be displayed. Default is 0.85. + backend : str | None, optional + Backend to use for plotting. If None, will use the global backend configuration. Returns ------- - fig : matplotlib.figure.Figure - The Figure object containing the plot. - axes : matplotlib.axes.Axes or numpy.ndarray of matplotlib.axes.Axes - The Axes object with the plot, or array of Axes for multiple subplots. + PlotCollection """ # Check for expected dimensions and variables if "channel" not in samples.dims: @@ -1713,193 +990,70 @@ def allocated_contribution_by_channel_over_time( ignored_dims = {"channel", "date", "sample"} extra_dims = [dim for dim in all_dims if dim not in ignored_dims] - # If no extra dimensions or using provided axis, create a single plot - if not extra_dims or ax is not None: - if ax is None: - fig, ax = plt.subplots(figsize=figsize) - else: - fig = ax.get_figure() - - channel_contribution = samples[channel_contrib_var] - - # Apply scale factor if in original scale - if original_scale and scale_factor is not None: - channel_contribution = channel_contribution * scale_factor - - # Plot mean values by channel - channel_contribution.mean(dim="sample").plot(hue="channel", ax=ax) - - # Add uncertainty intervals for each channel - for channel in samples.coords["channel"].values: - ax.fill_between( - x=channel_contribution.date.values, - y1=channel_contribution.sel(channel=channel).quantile( - lower_quantile, dim="sample" - ), - y2=channel_contribution.sel(channel=channel).quantile( - upper_quantile, dim="sample" - ), - alpha=0.1, - ) + pc = azp.PlotCollection.wrap( + samples[channel_contrib_var].to_dataset(), + cols=extra_dims, + aes={"color": ["channel"]}, + col_wrap=1, + figure_kwargs={ + "sharex": True, + }, + backend=backend, + ) - ax.set_xlabel("Date") - ax.set_ylabel("Channel Contribution") - ax.set_title("Allocated Contribution by Channel Over Time") - - fig.tight_layout() - return fig, ax - - # For multiple dimensions, create a grid of subplots - # Determine layout based on number of extra dimensions - if len(extra_dims) == 1: - # One extra dimension: use for rows - dim_values = [samples.coords[extra_dims[0]].values] - nrows = len(dim_values[0]) - ncols = 1 - subplot_dims = [extra_dims[0], None] - elif len(extra_dims) == 2: - # Two extra dimensions: one for rows, one for columns - dim_values = [ - samples.coords[extra_dims[0]].values, - samples.coords[extra_dims[1]].values, - ] - nrows = len(dim_values[0]) - ncols = len(dim_values[1]) - subplot_dims = extra_dims - else: - # Three or more: use first two for rows/columns, average over the rest - dim_values = [ - samples.coords[extra_dims[0]].values, - samples.coords[extra_dims[1]].values, - ] - nrows = len(dim_values[0]) - ncols = len(dim_values[1]) - subplot_dims = [extra_dims[0], extra_dims[1]] - - # Calculate figure size based on number of subplots - subplot_figsize = (figsize[0] * max(1, ncols), figsize[1] * max(1, nrows)) - fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=subplot_figsize) - - # Make axes indexable even for 1x1 grid - if nrows == 1 and ncols == 1: - axes = np.array([[axes]]) - elif nrows == 1: - axes = axes.reshape(1, -1) - elif ncols == 1: - axes = axes.reshape(-1, 1) - - # Create a subplot for each combination of dimension values - for i, row_val in enumerate(dim_values[0]): - for j, col_val in enumerate( - dim_values[1] if len(dim_values) > 1 else [None] - ): - ax = axes[i, j] - - # Select data for this subplot - selection = {subplot_dims[0]: row_val} - if col_val is not None: - selection[subplot_dims[1]] = col_val - - # Select channel contributions for this subplot - subset = samples[channel_contrib_var].sel(**selection) - - # Apply scale factor if needed - if original_scale and scale_factor is not None: - subset = subset * scale_factor - - # Plot mean values by channel for this subset - subset.mean(dim="sample").plot(hue="channel", ax=ax) - - # Add uncertainty intervals for each channel - for channel in samples.coords["channel"].values: - channel_data = subset.sel(channel=channel) - ax.fill_between( - x=channel_data.date.values, - y1=channel_data.quantile(lower_quantile, dim="sample"), - y2=channel_data.quantile(upper_quantile, dim="sample"), - alpha=0.1, - ) + # plot hdi + hdi = samples[channel_contrib_var].azstats.hdi(hdi_prob, dim="sample") + pc.map( + azp.visuals.fill_between_y, + x=samples[channel_contrib_var]["date"], + y_bottom=hdi.sel(ci_bound="lower"), + y_top=hdi.sel(ci_bound="upper"), + alpha=0.2, + ) - # Add subplot title based on dimension values - title_parts = [] - if subplot_dims[0] is not None: - title_parts.append(f"{subplot_dims[0]}={row_val}") - if subplot_dims[1] is not None: - title_parts.append(f"{subplot_dims[1]}={col_val}") + # plot mean contribution line + pc.map( + azp.visuals.line_xy, + x=samples[channel_contrib_var]["date"], + y=samples[channel_contrib_var].mean(dim="sample"), + ) - base_title = "Allocated Contribution by Channel Over Time" - if title_parts: - ax.set_title(f"{base_title} - {', '.join(title_parts)}") - else: - ax.set_title(base_title) + pc.map(azp.visuals.labelled_x, text="Date", ignore_aes={"color"}) + pc.map( + azp.visuals.labelled_y, text="Channel Contribution", ignore_aes={"color"} + ) + pc.map( + azp.visuals.labelled_title, + subset_info=True, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), + ignore_aes={"color"}, + ) - ax.set_xlabel("Date") - ax.set_ylabel("Channel Contribution") + pc.add_legend(dim="channel") + return pc - fig.tight_layout() - return fig, axes - - def sensitivity_analysis( + def _sensitivity_analysis_plot( self, hdi_prob: float = 0.94, - ax: plt.Axes | None = None, aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - subplot_kwargs: dict[str, Any] | None = None, - *, - plot_kwargs: dict[str, Any] | None = None, - ylabel: str = "Effect", - xlabel: str = "Sweep", - title: str | None = None, - add_figure_title: bool = False, - subplot_title_fallback: str = "Sensitivity Analysis", - ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: - """Plot sensitivity analysis results. + backend: str | None = None, + ) -> PlotCollection: + """Plot helper for sensitivity analysis results. Parameters ---------- hdi_prob : float, default 0.94 HDI probability mass. - ax : plt.Axes, optional - The axis to plot on. aggregation : dict, optional Aggregation to apply to the data. E.g., {"sum": ("channel",)} to sum over the channel dimension. + backend : str | None, optional + Backend to use for plotting. If None, will use the global backend configuration. - Other Parameters - ---------------- - plot_kwargs : dict, optional - Keyword arguments forwarded to the underlying line plot. Defaults include - ``{"color": "C0"}``. - ylabel : str, optional - Y-axis label. Defaults to "Effect". - xlabel : str, optional - X-axis label. Defaults to "Sweep". - title : str, optional - Figure-level title to add when ``add_figure_title=True``. - add_figure_title : bool, optional - Whether to add a figure-level title. Defaults to ``False``. - subplot_title_fallback : str, optional - Fallback title used for subplot titles when no plotting dims exist. Defaults - to "Sensitivity Analysis". - - Examples - -------- - Basic run using stored results in `idata`: - - .. code-block:: python - - # Assuming you already ran a sweep and stored results - # under idata.sensitivity_analysis via SensitivityAnalysis.run_sweep(..., extend_idata=True) - ax = mmm.plot.sensitivity_analysis(hdi_prob=0.9) - - With aggregation over dimensions (e.g., sum over channels): - - .. code-block:: python + Returns + ------- + PlotCollection - ax = mmm.plot.sensitivity_analysis( - hdi_prob=0.9, - aggregation={"sum": ("channel",)}, - ) """ if not hasattr(self.idata, "sensitivity_analysis"): raise ValueError( @@ -1930,165 +1084,99 @@ def sensitivity_analysis( x = x.mean(dim=dims_list) else: x = x.median(dim=dims_list) - # Determine plotting dimensions (excluding sample & sweep) - plot_dims = [d for d in x.dims if d not in {"sample", "sweep"}] - if plot_dims: - dim_combinations = list( - itertools.product(*[x.coords[d].values for d in plot_dims]) - ) - else: - dim_combinations = [()] - n_panels = len(dim_combinations) + # Determine plotting dimensions (excluding sample & sweep) + plot_dims = set(x.dims) - {"sample", "sweep"} - # Handle axis/grid creation - subplot_kwargs = {**(subplot_kwargs or {})} - nrows_user = subplot_kwargs.pop("nrows", None) - ncols_user = subplot_kwargs.pop("ncols", None) - if nrows_user is not None and ncols_user is not None: - raise ValueError( - "Specify only one of 'nrows' or 'ncols' in subplot_kwargs." - ) + pc = azp.PlotCollection.wrap( + x.to_dataset(), + cols=plot_dims, + col_wrap=2, + figure_kwargs={ + "sharex": True, + }, + backend=backend, + ) - if n_panels > 1: - if ax is not None: - raise ValueError( - "Multiple sensitivity panels detected; please omit 'ax' and use 'subplot_kwargs' instead." - ) - if ncols_user is not None: - ncols = ncols_user - nrows = int(np.ceil(n_panels / ncols)) - elif nrows_user is not None: - nrows = nrows_user - ncols = int(np.ceil(n_panels / nrows)) - else: - ncols = max(1, int(np.ceil(np.sqrt(n_panels)))) - nrows = int(np.ceil(n_panels / ncols)) - subplot_kwargs.setdefault("figsize", (ncols * 4.0, nrows * 3.0)) - fig, axes_grid = plt.subplots( - nrows=nrows, - ncols=ncols, - **subplot_kwargs, - ) - if isinstance(axes_grid, plt.Axes): - axes_grid = np.array([[axes_grid]]) - elif axes_grid.ndim == 1: - axes_grid = axes_grid.reshape(1, -1) - axes_array = axes_grid - else: - if ax is not None: - axes_array = np.array([[ax]]) - fig = ax.figure - else: - if ncols_user is not None or nrows_user is not None: - subplot_kwargs.setdefault("figsize", (4.0, 3.0)) - fig, single_ax = plt.subplots( - nrows=1, - ncols=1, - **subplot_kwargs, - ) - else: - fig, single_ax = plt.subplots() - axes_array = np.array([[single_ax]]) - - # Merge plotting kwargs with defaults - _plot_kwargs = {"color": "C0"} - if plot_kwargs: - _plot_kwargs.update(plot_kwargs) - _line_color = _plot_kwargs.get("color", "C0") - - axes_flat = axes_array.flatten() - for idx, combo in enumerate(dim_combinations): - current_ax = axes_flat[idx] - indexers = dict(zip(plot_dims, combo, strict=False)) if plot_dims else {} - subset = x.sel(**indexers) if indexers else x - subset = subset.squeeze(drop=True) - subset = subset.astype(float) - - if "sweep" in subset.dims: - sweep_dim = "sweep" - else: - cand = [d for d in subset.dims if d != "sample"] - if not cand: - raise ValueError( - "Expected 'sweep' (or a non-sample) dimension in sensitivity results." - ) - sweep_dim = cand[0] + # plot hdi + hdi = x.azstats.hdi(hdi_prob, dim="sample") + pc.map( + azp.visuals.fill_between_y, + x=x["sweep"], + y_bottom=hdi.sel(ci_bound="lower"), + y_top=hdi.sel(ci_bound="upper"), + alpha=0.4, + color="C0", + ) + # plot aggregated line + pc.map( + azp.visuals.line_xy, + x=x["sweep"], + y=x.mean(dim="sample"), + color="C0", + ) + # add labels + pc.map(azp.visuals.labelled_x, text="Sweep") + pc.map( + azp.visuals.labelled_title, + subset_info=True, + labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(), + ) + return pc - sweep = ( - np.asarray(subset.coords[sweep_dim].values) - if sweep_dim in subset.coords - else np.arange(subset.sizes[sweep_dim]) - ) + def sensitivity_analysis( + self, + hdi_prob: float = 0.94, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + backend: str | None = None, + ) -> PlotCollection: + """Plot sensitivity analysis results. - mean = subset.mean("sample") if "sample" in subset.dims else subset - reduce_dims = [d for d in mean.dims if d != sweep_dim] - if reduce_dims: - mean = mean.sum(dim=reduce_dims) + Parameters + ---------- + hdi_prob : float, default 0.94 + HDI probability mass. + aggregation : dict, optional + Aggregation to apply to the data. + E.g., {"sum": ("channel",)} to sum over the channel dimension. + backend : str | None, optional + Backend to use for plotting. If None, will use the global backend configuration. - if "sample" in subset.dims: - hdi = az.hdi(subset, hdi_prob=hdi_prob, input_core_dims=[["sample"]]) - if isinstance(hdi, xr.Dataset): - hdi = hdi[next(iter(hdi.data_vars))] - else: - hdi = xr.concat([mean, mean], dim="hdi").assign_coords( - hdi=np.array([0, 1]) - ) + Returns + ------- + PlotCollection - reduce_hdi = [d for d in hdi.dims if d not in (sweep_dim, "hdi")] - if reduce_hdi: - hdi = hdi.sum(dim=reduce_hdi) - if set(hdi.dims) == {sweep_dim, "hdi"} and list(hdi.dims) != [ - sweep_dim, - "hdi", - ]: - hdi = hdi.transpose(sweep_dim, "hdi") # type: ignore - - current_ax.plot(sweep, np.asarray(mean.values, dtype=float), **_plot_kwargs) - az.plot_hdi( - x=sweep, - hdi_data=np.asarray(hdi.values, dtype=float), - hdi_prob=hdi_prob, - color=_line_color, - ax=current_ax, - ) + Examples + -------- + Basic run using stored results in `idata`: - title = self._build_subplot_title( - dims=plot_dims, - combo=combo, - fallback_title=subplot_title_fallback, - ) - current_ax.set_title(title) - current_ax.set_xlabel(xlabel) - current_ax.set_ylabel(ylabel) + .. code-block:: python - # Hide any unused axes (happens if grid > panels) - for ax_extra in axes_flat[n_panels:]: - ax_extra.set_visible(False) + # Assuming you already ran a sweep and stored results + # under idata.sensitivity_analysis via SensitivityAnalysis.run_sweep(..., extend_idata=True) + mmm.plot.sensitivity_analysis(hdi_prob=0.9) - # Optional figure-level title: only for multi-panel layouts, default color (black) - if add_figure_title and title is not None and n_panels > 1: - fig.suptitle(title) + With aggregation over dimensions (e.g., sum over channels): - if n_panels == 1: - return axes_array[0, 0] + .. code-block:: python - fig.tight_layout() - return fig, axes_array + mmm.plot.sensitivity_analysis( + hdi_prob=0.9, + aggregation={"sum": ("channel",)}, + ) + """ + pc = self._sensitivity_analysis_plot( + hdi_prob=hdi_prob, aggregation=aggregation, backend=backend + ) + pc.map(azp.visuals.labelled_y, text="Contribution") + return pc def uplift_curve( self, hdi_prob: float = 0.94, - ax: plt.Axes | None = None, aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - subplot_kwargs: dict[str, Any] | None = None, - *, - plot_kwargs: dict[str, Any] | None = None, - ylabel: str = "Uplift", - xlabel: str = "Sweep", - title: str | None = "Uplift curve", - add_figure_title: bool = True, - ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: + backend: str | None = None, + ) -> PlotCollection: """ Plot precomputed uplift curves stored under `idata.sensitivity_analysis['uplift_curve']`. @@ -2096,24 +1184,11 @@ def uplift_curve( ---------- hdi_prob : float, default 0.94 HDI probability mass. - ax : plt.Axes, optional - The axis to plot on. aggregation : dict, optional Aggregation to apply to the data. E.g., {"sum": ("channel",)} to sum over the channel dimension. - subplot_kwargs : dict, optional - Additional subplot configuration forwarded to :meth:`sensitivity_analysis`. - plot_kwargs : dict, optional - Keyword arguments forwarded to the underlying line plot. If not provided, defaults - are used by :meth:`sensitivity_analysis` (e.g., color "C0"). - ylabel : str, optional - Y-axis label. Defaults to "Uplift". - xlabel : str, optional - X-axis label. Defaults to "Sweep". - title : str, optional - Figure-level title to add when ``add_figure_title=True``. Defaults to "Uplift curve". - add_figure_title : bool, optional - Whether to add a figure-level title. Defaults to ``True``. + backend : str | None, optional + Backend to use for plotting. If None, will use the global backend configuration. Examples -------- @@ -2134,7 +1209,7 @@ def uplift_curve( uplift = sa.compute_uplift_curve_respect_to_base( results, ref=1.0, extend_idata=True ) - _ = mmm.plot.uplift_curve(hdi_prob=0.9) + mmm.plot.uplift_curve(hdi_prob=0.9) """ if not hasattr(self.idata, "sensitivity_analysis"): raise ValueError( @@ -2163,34 +1238,22 @@ def uplift_curve( original_group = self.idata.sensitivity_analysis # type: ignore try: self.idata.sensitivity_analysis = tmp_idata # type: ignore - return self.sensitivity_analysis( + pc = self._sensitivity_analysis_plot( hdi_prob=hdi_prob, - ax=ax, aggregation=aggregation, - subplot_kwargs=subplot_kwargs, - subplot_title_fallback="Uplift curve", - plot_kwargs=plot_kwargs, - ylabel=ylabel, - xlabel=xlabel, - title=title, - add_figure_title=add_figure_title, + backend=backend, ) + pc.map(azp.visuals.labelled_y, text="Uplift (%)") + return pc finally: self.idata.sensitivity_analysis = original_group # type: ignore def marginal_curve( self, hdi_prob: float = 0.94, - ax: plt.Axes | None = None, aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - subplot_kwargs: dict[str, Any] | None = None, - *, - plot_kwargs: dict[str, Any] | None = None, - ylabel: str = "Marginal effect", - xlabel: str = "Sweep", - title: str | None = "Marginal effects", - add_figure_title: bool = True, - ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: + backend: str | None = None, + ) -> PlotCollection: """ Plot precomputed marginal effects stored under `idata.sensitivity_analysis['marginal_effects']`. @@ -2198,23 +1261,15 @@ def marginal_curve( ---------- hdi_prob : float, default 0.94 HDI probability mass. - ax : plt.Axes, optional - The axis to plot on. aggregation : dict, optional Aggregation to apply to the data. E.g., {"sum": ("channel",)} to sum over the channel dimension. - subplot_kwargs : dict, optional - Additional subplot configuration forwarded to :meth:`sensitivity_analysis`. - plot_kwargs : dict, optional - Keyword arguments forwarded to the underlying line plot. Defaults to ``{"color": "C1"}``. - ylabel : str, optional - Y-axis label. Defaults to "Marginal effect". - xlabel : str, optional - X-axis label. Defaults to "Sweep". - title : str, optional - Figure-level title to add when ``add_figure_title=True``. Defaults to "Marginal effects". - add_figure_title : bool, optional - Whether to add a figure-level title. Defaults to ``True``. + backend : str | None, optional + Backend to use for plotting. If None, will use the global backend configuration. + + Returns + ------- + PlotCollection Examples -------- @@ -2233,7 +1288,7 @@ def marginal_curve( sweep_type="multiplicative", ) me = sa.compute_marginal_effects(results, extend_idata=True) - _ = mmm.plot.marginal_curve(hdi_prob=0.9) + mmm.plot.marginal_curve(hdi_prob=0.9) """ if not hasattr(self.idata, "sensitivity_analysis"): raise ValueError( @@ -2261,22 +1316,12 @@ def marginal_curve( original = self.idata.sensitivity_analysis # type: ignore try: self.idata.sensitivity_analysis = tmp # type: ignore - # Reuse core plotting; percentage=False by definition - # Merge defaults for plot_kwargs if not provided - _plot_kwargs = {"color": "C1"} - if plot_kwargs: - _plot_kwargs.update(plot_kwargs) - return self.sensitivity_analysis( + pc = self._sensitivity_analysis_plot( hdi_prob=hdi_prob, - ax=ax, aggregation=aggregation, - subplot_kwargs=subplot_kwargs, - subplot_title_fallback="Marginal effects", - plot_kwargs=_plot_kwargs, - ylabel=ylabel, - xlabel=xlabel, - title=title, - add_figure_title=add_figure_title, + backend=backend, ) + pc.map(azp.visuals.labelled_y, text="Marginal Effect") + return pc finally: self.idata.sensitivity_analysis = original # type: ignore From d6331a03727aa9c78ad16690aca25ce9cb869129 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Wed, 19 Nov 2025 08:20:03 -0500 Subject: [PATCH 06/29] clean up for plot.py --- .gitignore | 3 + CLAUDE.md | 272 +++ pymc_marketing/mmm/plot.py | 144 +- .../mmmplotsuite-backend-migration-tdd.md | 1795 ----------------- ...otsuite-backend-migration-comprehensive.md | 554 ----- 5 files changed, 319 insertions(+), 2449 deletions(-) create mode 100644 CLAUDE.md delete mode 100644 thoughts/shared/plans/mmmplotsuite-backend-migration-tdd.md delete mode 100644 thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md diff --git a/.gitignore b/.gitignore index ac4503fde..54cc61220 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,6 @@ dmypy.json # Gallery images docs/source/gallery/images/ docs/gettext/ + +# ignore Claude +.claude/* diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..b044d95ea --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,272 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +PyMC-Marketing is a Bayesian marketing analytics library built on PyMC, providing three main modeling capabilities: + +- **Marketing Mix Modeling (MMM)**: Measure marketing channel effectiveness with adstock, saturation, and budget optimization +- **Customer Lifetime Value (CLV)**: Predict customer value using probabilistic models (BG/NBD, Pareto/NBD, Gamma-Gamma, etc.) +- **Customer Choice Analysis**: Understand product selection with Multivariate Interrupted Time Series (MVITS) and discrete choice models + +## Development Commands + +### Environment Setup +```bash +# Create and activate conda environment (recommended) +conda env create -f environment.yml +conda activate pymc-marketing-dev + +# Install package in editable mode +make init +``` + +### Testing and Quality +To use pytest you first need to activate the enviroment: +```bash +source ~/miniconda3/etc/profile.d/conda.sh && conda activate pymc-marketing-dev +``` + +Running tests: +```bash +# first need to activate the enviorment: + +# Run all tests with coverage +make test + +# Run specific test file (you first need to activate the conda env: conda activate pymc-marketing-dev) +pytest tests/path/to/test_file.py + +# Run specific test function (you first need to activate the conda env: conda activate pymc-marketing-dev) +pytest tests/path/to/test_file.py::test_function_name + +# Check linting (ruff + mypy) +make check_lint + +# Auto-fix linting issues +make lint + +# Check code formatting +make check_format + +# Auto-format code +make format +``` + +### Documentation +```bash +# Build HTML documentation +make html + +# Clean docs and rebuild from scratch +make cleandocs && make html + +# Run notebooks to verify examples +make run_notebooks # All notebooks +make run_notebooks_mmm # MMM notebooks only +make run_notebooks_other # Non-MMM notebooks +``` + +### Other Utilities +```bash +# Generate UML diagrams for architecture +make uml + +# Start MLflow tracking server +make mlflow_server +``` + +## High-Level Architecture + +### Core Base Classes + +**ModelBuilder** ([pymc_marketing/model_builder.py](pymc_marketing/model_builder.py)) +- Abstract base class for all PyMC-Marketing models +- Defines the model lifecycle: `build_model()` → `fit()` → `predict()` +- Provides save/load functionality via NetCDF and InferenceData +- Manages `model_config` (priors) and `sampler_config` (MCMC settings) + +**RegressionModelBuilder** (extends ModelBuilder) +- Adds scikit-learn-like API: `fit(X, y)`, `predict(X)` +- Base class for MMM and some customer choice models +- Handles prior/posterior predictive sampling + +**CLVModel** ([pymc_marketing/clv/models/basic.py](pymc_marketing/clv/models/basic.py)) +- Base class for CLV models (BetaGeo, ParetoNBD, GammaGamma, etc.) +- Takes data in constructor, not fit method: `model = BetaGeoModel(data=df)` +- Supports multiple inference methods: `method="mcmc"` (default), `"map"`, `"advi"`, etc. + +### Module 1: MMM Architecture + +**Class Hierarchy:** +``` +RegressionModelBuilder + └── MMMModelBuilder (mmm/base.py) + ├── BaseMMM/MMM (mmm/mmm.py) - Single market + └── MMM (mmm/multidimensional.py) - Panel/hierarchical data +``` + +**Component-Based Design:** + +MMM uses composable transformation components: + +1. **Adstock Transformations** ([pymc_marketing/mmm/components/adstock.py](pymc_marketing/mmm/components/adstock.py)) + - Model carryover effects of advertising + - Built-in: GeometricAdstock, DelayedAdstock, WeibullCDFAdstock, WeibullPDFAdstock + - All extend `AdstockTransformation` base class + +2. **Saturation Transformations** ([pymc_marketing/mmm/components/saturation.py](pymc_marketing/mmm/components/saturation.py)) + - Model diminishing returns + - Built-in: LogisticSaturation, HillSaturation, MichaelisMentenSaturation, TanhSaturation + - All extend `SaturationTransformation` base class + +3. **Transformation Protocol** ([pymc_marketing/mmm/components/base.py](pymc_marketing/mmm/components/base.py)) + - Base class defining transformation interface + - Requires: `function()`, `prefix`, `default_priors` + - Custom transformations should extend this + +**Validation and Preprocessing System:** + +MMM models use a decorator-based system: +- Methods tagged with `_tags = {"validation_X": True}` run during `fit(X, y)` +- Methods tagged with `_tags = {"preprocessing_y": True}` transform data before modeling +- Built-in validators in [pymc_marketing/mmm/validating.py](pymc_marketing/mmm/validating.py) +- Built-in preprocessors in [pymc_marketing/mmm/preprocessing.py](pymc_marketing/mmm/preprocessing.py) + +**Key MMM Features:** +- Time-varying parameters via HSGP (Hilbert Space Gaussian Process) +- Lift test calibration for experiments +- Budget optimization ([pymc_marketing/mmm/budget_optimizer.py](pymc_marketing/mmm/budget_optimizer.py)) +- Causal DAG support ([pymc_marketing/mmm/causal.py](pymc_marketing/mmm/causal.py)) +- Additive effects system ([pymc_marketing/mmm/additive_effect.py](pymc_marketing/mmm/additive_effect.py)) for custom components + +**Multidimensional MMM vs Base MMM:** +- Base MMM ([pymc_marketing/mmm/mmm.py](pymc_marketing/mmm/mmm.py)): Single market, simpler API +- Multidimensional MMM ([pymc_marketing/mmm/multidimensional.py](pymc_marketing/mmm/multidimensional.py)): Panel data, per-channel transformations via `MediaConfigList`, more flexible + +### Module 2: CLV Architecture + +**Available Models:** +- BetaGeoModel: Beta-Geometric/NBD for continuous non-contractual settings +- ParetoNBDModel: Pareto/NBD alternative formulation +- GammaGammaModel: Monetary value prediction +- ShiftedBetaGeoModel, ModifiedBetaGeoModel: Variants +- BetaGeoBetaBinomModel: Discrete time variant + +**CLV Pattern:** +```python +# Data passed to constructor, not fit() +model = clv.BetaGeoModel(data=df) + +# Fit with various inference methods +model.fit(method="mcmc") # or "map", "advi", "fullrank_advi" + +# Predict for known customers +model.expected_purchases(customer_id, t) +model.probability_alive(customer_id) +``` + +**Custom Distributions:** +CLV models use custom distributions in [pymc_marketing/clv/distributions.py](pymc_marketing/clv/distributions.py) + +### Module 3: Customer Choice + +- **MVITS** ([pymc_marketing/customer_choice/mv_its.py](pymc_marketing/customer_choice/mv_its.py)): Multivariate Interrupted Time Series for product launch incrementality +- **Discrete Choice Models**: Logit models in [pymc_marketing/customer_choice/](pymc_marketing/customer_choice/) + +### Cross-Cutting Systems + +**Prior Configuration System** ([pymc_marketing/prior.py](pymc_marketing/prior.py), now in pymc_extras) +- Declarative prior specification outside PyMC context +- Example: `Prior("Normal", mu=0, sigma=1)` +- Supports hierarchical priors, non-centered parameterization, transformations +- Used in all `model_config` dictionaries + +**Model Configuration** ([pymc_marketing/model_config.py](pymc_marketing/model_config.py)) +- `parse_model_config()` converts dicts to Prior objects +- Handles nested priors for hierarchical models +- Supports HSGP kwargs for Gaussian processes + +**Save/Load Infrastructure** +- Models save to NetCDF via ArviZ InferenceData +- `model.save("filename.nc")` serializes model + data + config +- `Model.load("filename.nc")` reconstructs from file +- Training data stored in `idata.fit_data` group + +**MLflow Integration** ([pymc_marketing/mlflow.py](pymc_marketing/mlflow.py)) +- `autolog()` patches PyMC and PyMC-Marketing functions +- Automatically logs: model structure, diagnostics (r_hat, ESS, divergences), MMM/CLV configs +- Start server with: `make mlflow_server` + +## Code Style and Testing + +**Linting:** +- Uses Ruff for linting and formatting +- Uses mypy for type checking +- Config in [pyproject.toml](pyproject.toml) under `[tool.ruff]` and `[tool.mypy]` +- Docstrings follow NumPy style guide + +**Testing:** +- pytest with coverage reporting +- Config in [pyproject.toml](pyproject.toml) under `[tool.pytest.ini_options]` +- Test files mirror package structure in [tests/](tests/) + +**Pre-commit Hooks:** +```bash +pre-commit install # Set up hooks +pre-commit run --all-files # Run manually +``` + +## Important Patterns and Conventions + +### Adding a New MMM Transformation + +1. Extend `AdstockTransformation` or `SaturationTransformation` from [pymc_marketing/mmm/components/base.py](pymc_marketing/mmm/components/base.py) +2. Implement: `function()`, `prefix` property, `default_priors` property +3. Add to [pymc_marketing/mmm/components/adstock.py](pymc_marketing/mmm/components/adstock.py) or [saturation.py](pymc_marketing/mmm/components/saturation.py) +4. Export in [pymc_marketing/mmm/__init__.py](pymc_marketing/mmm/__init__.py) + +### Adding a New CLV Model + +1. Extend `CLVModel` from [pymc_marketing/clv/models/basic.py](pymc_marketing/clv/models/basic.py) +2. Implement: `build_model()`, prediction methods (e.g., `expected_purchases()`) +3. Define required data columns in `__init__` +4. Add tests in [tests/clv/models/](tests/clv/models/) + +### Adding a New Additive Effect (MMM) + +1. Implement `MuEffect` protocol from [pymc_marketing/mmm/additive_effect.py](pymc_marketing/mmm/additive_effect.py) +2. Required methods: `create_data()`, `create_effect()`, `set_data()` +3. See FourierEffect, LinearTrendEffect as examples + +### Model Lifecycle + +All models follow this pattern: +1. **Configuration**: Store data and config in `__init__` +2. **Build**: `build_model()` creates PyMC model, attaches to `self.model` +3. **Fit**: `fit()` calls `pm.sample()` or alternative inference +4. **Store**: Results stored in `self.idata` (ArviZ InferenceData) +5. **Predict**: `sample_posterior_predictive()` with new data + +## Documentation and Examples + +**Notebooks:** +- MMM examples: [docs/source/notebooks/mmm/](docs/source/notebooks/mmm/) +- CLV examples: [docs/source/notebooks/clv/](docs/source/notebooks/clv/) +- Customer choice: [docs/source/notebooks/customer_choice/](docs/source/notebooks/customer_choice/) + +**Gallery Generation:** +- [scripts/generate_gallery.py](scripts/generate_gallery.py) creates notebook gallery for docs +- Run with `make html` + +**UML Diagrams:** +- Architecture diagrams in [docs/source/uml/](docs/source/uml/) +- Generate with `make uml` +- See [CONTRIBUTING.md](CONTRIBUTING.md) for package/class diagrams + +## Community and Support + +- [GitHub Issues](https://github.com/pymc-labs/pymc-marketing/issues) for bugs/features +- [PyMC Discourse](https://discourse.pymc.io/) for general discussion +- [PyMC-Marketing Discussions](https://github.com/pymc-labs/pymc-marketing/discussions) for Q&A diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 13f91e252..95a229604 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -46,7 +46,7 @@ mmm.sample_posterior_predictive(X) # Posterior predictive time series - _ = mmm.plot.posterior_predictive(var=["y"], hdi_prob=0.9) + _ = mmm.plot.posterior_predictive(var="y", hdi_prob=0.9) # Posterior contributions over time (e.g., channel_contribution) _ = mmm.plot.contributions_over_time(var=["channel_contribution"], hdi_prob=0.9) @@ -88,7 +88,7 @@ idata.extend(pm.sample_posterior_predictive(idata, random_seed=1)) plot = MMMPlotSuite(idata) - _ = plot.posterior_predictive(var=["y"], hdi_prob=0.9) + _ = plot.posterior_predictive(var="y", hdi_prob=0.9) Custom contributions_over_time -------- @@ -173,14 +173,10 @@ import arviz as az import arviz_plots as azp -import matplotlib.pyplot as plt import numpy as np import xarray as xr from arviz_base.labels import DimCoordLabeller, NoVarLabeller, mix_labellers from arviz_plots import PlotCollection -from matplotlib.axes import Axes -from matplotlib.figure import Figure -from numpy.typing import NDArray __all__ = ["MMMPlotSuite"] @@ -201,53 +197,6 @@ def __init__( ): self.idata = idata - def _init_subplots( - self, - n_subplots: int, - ncols: int = 1, - width_per_col: float = 10.0, - height_per_row: float = 4.0, - ) -> tuple[Figure, NDArray[Axes]]: - """Initialize a grid of subplots. - - Parameters - ---------- - n_subplots : int - Number of rows (if ncols=1) or total subplots. - ncols : int - Number of columns in the subplot grid. - width_per_col : float - Width (in inches) for each column of subplots. - height_per_row : float - Height (in inches) for each row of subplots. - - Returns - ------- - fig : matplotlib.figure.Figure - The created Figure object. - axes : np.ndarray of matplotlib.axes.Axes - 2D array of axes of shape (n_subplots, ncols). - """ - fig, axes = plt.subplots( - nrows=n_subplots, - ncols=ncols, - figsize=(width_per_col * ncols, height_per_row * n_subplots), - squeeze=False, - ) - return fig, axes - - def _build_subplot_title( - self, - dims: list[str], - combo: tuple, - fallback_title: str = "Time Series", - ) -> str: - """Build a subplot title string from dimension names and their values.""" - if dims: - title_parts = [f"{d}={v}" for d, v in zip(dims, combo, strict=False)] - return ", ".join(title_parts) - return fallback_title - def _get_additional_dim_combinations( self, data: xr.Dataset, @@ -270,23 +219,6 @@ def _get_additional_dim_combinations( return additional_dims, dim_combinations - def _reduce_and_stack( - self, data: xr.DataArray, dims_to_ignore: set[str] | None = None - ) -> xr.DataArray: - """Sum over leftover dims and stack chain+draw into sample if present.""" - if dims_to_ignore is None: - dims_to_ignore = {"date", "chain", "draw", "sample"} - - leftover_dims = [d for d in data.dims if d not in dims_to_ignore] - if leftover_dims: - data = data.sum(dim=leftover_dims) - - # Combine chain+draw into 'sample' if both exist - if "chain" in data.dims and "draw" in data.dims: - data = data.stack(sample=("chain", "draw")) - - return data - def _get_posterior_predictive_data( self, idata: xr.Dataset | None, @@ -307,25 +239,6 @@ def _get_posterior_predictive_data( ) return self.idata.posterior_predictive # type: ignore - def _add_median_and_hdi( - self, ax: Axes, data: xr.DataArray, var: str, hdi_prob: float = 0.85 - ) -> Axes: - """Add median and HDI to the given axis.""" - median = data.median(dim="sample") if "sample" in data.dims else data.median() - hdi = az.hdi( - data, - hdi_prob=hdi_prob, - input_core_dims=[["sample"]] if "sample" in data.dims else None, - ) - - if "date" not in data.dims: - raise ValueError(f"Expected 'date' dimension in {var}, but none found.") - dates = data.coords["date"].values - # Add median and HDI to the plot - ax.plot(dates, median, label=var, alpha=0.9) - ax.fill_between(dates, hdi[var][..., 0], hdi[var][..., 1], alpha=0.2) - return ax - def _validate_dims( self, dims: dict[str, str | int | list], @@ -384,7 +297,7 @@ def _resolve_backend(self, backend: str | None) -> str: def posterior_predictive( self, - var: list[str] | None = None, + var: str | None = None, idata: xr.Dataset | None = None, hdi_prob: float = 0.85, backend: str | None = None, @@ -394,8 +307,8 @@ def posterior_predictive( Parameters ---------- - var : list of str, optional - List of variable names to plot. If None, uses "y". + var : str, optional + Variable name to plot. If None, uses "y". idata : xr.Dataset, optional Dataset containing posterior predictive samples. If None, uses self.idata.posterior_predictive. @@ -420,10 +333,10 @@ def posterior_predictive( # 1. Retrieve or validate posterior_predictive data pp_data = self._get_posterior_predictive_data(idata) - # 2. Determine variables to plot + # 2. Determine variable to plot if var is None: - var = ["y"] - main_var = var[0] + var = "y" + main_var = var # 3. Identify additional dims & get all combos ignored_dims = {"chain", "draw", "date", "sample"} @@ -522,6 +435,16 @@ def contributions_over_time( main_var = var[0] ignored_dims = {"chain", "draw", "date"} da = self.idata.posterior[var] + + # Apply dims filtering if provided + if dims: + self._validate_dims(dims, list(da[main_var].dims)) + for dim_name, dim_value in dims.items(): + if isinstance(dim_value, (list, tuple, np.ndarray)): + da = da.sel({dim_name: dim_value}) + else: + da = da.sel({dim_name: dim_value}) + additional_dims, _ = self._get_additional_dim_combinations( data=da, variable=main_var, ignored_dims=ignored_dims ) @@ -634,10 +557,21 @@ def saturation_scatterplot( """ ) + # Apply dims filtering to channel_data and channel_contribution + channel_data = self.idata.constant_data.channel_data + channel_contrib = self.idata.posterior[channel_contribution] + + if dims: + for dim_name, dim_value in dims.items(): + if isinstance(dim_value, (list, tuple, np.ndarray)): + channel_data = channel_data.sel({dim_name: dim_value}) + channel_contrib = channel_contrib.sel({dim_name: dim_value}) + else: + channel_data = channel_data.sel({dim_name: dim_value}) + channel_contrib = channel_contrib.sel({dim_name: dim_value}) + pc = azp.PlotCollection.grid( - self.idata.posterior[channel_contribution] - .mean(dim=["chain", "draw"]) - .to_dataset(), + channel_contrib.mean(dim=["chain", "draw"]).to_dataset(), cols=additional_dims, rows=["channel"], aes={"color": ["channel"]}, @@ -645,7 +579,7 @@ def saturation_scatterplot( ) pc.map( azp.visuals.scatter_xy, - x=self.idata.constant_data.channel_data, + x=channel_data, ) pc.map(azp.visuals.labelled_x, text="Channel Data", ignore_aes={"color"}) pc.map( @@ -732,6 +666,12 @@ def saturation_curves( " )\n" """ ) + # Validate curve dimensions + if "x" not in curve.dims: + raise ValueError("curve must have an 'x' dimension") + if "channel" not in curve.dims: + raise ValueError("curve must have a 'channel' dimension") + if original_scale: curve_data = curve * self.idata.constant_data.target_scale curve_data["x"] = curve_data["x"] * self.idata.constant_data.channel_scale @@ -911,8 +851,12 @@ def budget_allocation_roas( grouped = new_grouped grouped_roa_dt = {} + prefix = "all, " for k, v in grouped.items(): - grouped_roa_dt[k[5:]] = v + if k.startswith(prefix): + grouped_roa_dt[k[len(prefix) :]] = v + else: + grouped_roa_dt[k] = v else: grouped_roa_dt = roa_dt diff --git a/thoughts/shared/plans/mmmplotsuite-backend-migration-tdd.md b/thoughts/shared/plans/mmmplotsuite-backend-migration-tdd.md deleted file mode 100644 index b07a2ad9d..000000000 --- a/thoughts/shared/plans/mmmplotsuite-backend-migration-tdd.md +++ /dev/null @@ -1,1795 +0,0 @@ -# MMMPlotSuite Backend Migration - TDD Implementation Plan - -## Overview - -This plan implements backend-agnostic plotting for the MMMPlotSuite class using ArviZ's PlotCollection API, enabling support for matplotlib, plotly, and bokeh backends while maintaining full backward compatibility. We follow Test-Driven Development: write comprehensive tests first, verify they fail properly, then implement features by making those tests pass. - -## Current State Analysis - -### Existing Implementation -- **Location**: [pymc_marketing/mmm/plot.py:187-1924](pymc_marketing/mmm/plot.py#L187) -- **Class**: `MMMPlotSuite` with 10 public plotting methods -- **Current approach**: All methods directly use matplotlib APIs and return `(Figure, NDArray[Axes])` -- **Dependencies**: matplotlib, arviz (for HDI computation only) - -### Current Testing Landscape -- **Test framework**: pytest with parametrized tests -- **Test file**: [tests/mmm/test_plot.py](tests/mmm/test_plot.py) - 1053 lines, comprehensive fixture-based testing -- **Mock data patterns**: xarray-based InferenceData fixtures with realistic structure -- **Test conventions**: - - Module-scoped fixtures for expensive setup - - Type assertions only (no visual output validation) - - `plt.close()` after each test - - Parametrized tests for multiple configurations - -### Key Discoveries -1. **No PlotCollection usage**: ArviZ PlotCollection is not used anywhere in production code -2. **Testing patterns exist**: Parametrized tests, deprecation warnings, backward compatibility tests all have examples -3. **Mock data is realistic**: Fixtures create proper InferenceData structure with posterior, constant_data groups -4. **Helper functions available**: `_init_subplots()`, `_add_median_and_hdi()` need backend abstraction - -## Desired End State - -After implementation, the MMMPlotSuite should: - -1. ✅ Support matplotlib, plotly, and bokeh backends via ArviZ PlotCollection -2. ✅ Maintain 100% backward compatibility (existing code works unchanged) -3. ✅ Support global backend configuration via `mmm_config["plot.backend"]` -4. ✅ Support per-function backend parameter that overrides global config -5. ✅ Return PlotCollection when `return_as_pc=True`, tuple when `False` (default) -6. ✅ Handle matplotlib-specific features (twinx) with clear fallback warnings -7. ✅ Deprecate `rc_params` in favor of `backend_config` with warnings -8. ✅ Pass comprehensive test suite across all three backends - -## What We're NOT Testing/Implementing - -- Performance comparisons between backends (explicitly out of scope) -- Component plot methods outside MMMPlotSuite (requirement #9) -- Saving plots to files (not in current test suite) -- Interactive features specific to plotly/bokeh (basic rendering only) -- New plotting methods (only migrating existing 10 methods) - -## TDD Approach - -### Test Design Philosophy -1. **Depth over breadth**: Thoroughly test first 2-3 methods before moving to others -2. **Verify visual output**: Use PlotCollection's backend-specific output validation, not just type checking -3. **Fail diagnostically**: Tests should fail with clear messages pointing to missing functionality -4. **Test data isolation**: Use module-scoped fixtures, mock InferenceData structures - -### Implementation Priority -**Phase 1**: Infrastructure + `posterior_predictive()` (simplest method) -**Phase 2**: `contributions_over_time()` (similar to Phase 1) -**Phase 3**: `saturation_curves()` (rc_params deprecation, external functions) -**Phase 4**: `budget_allocation()` (twinx fallback behavior) - ---- - -## Phase 1: Test Design & Implementation - -### Overview -Write comprehensive, informative tests that define the feature completely. These tests should fail in expected, diagnostic ways. We focus deeply on infrastructure and the simplest method (`posterior_predictive()`) first. - -### Test Categories - -#### 1. Infrastructure Tests (Global Configuration & Return Types) -**Test File**: `tests/mmm/test_plot_backends.py` (NEW) -**Purpose**: Validate backend configuration system and return type switching - -**Test Cases to Write:** - -##### Test: `test_mmm_config_exists` -**Purpose**: Verify the global configuration object is accessible -**Test Data**: None needed -**Expected Behavior**: Can import and access `mmm_config` from `pymc_marketing.mmm` - -```python -def test_mmm_config_exists(): - """ - Test that the global mmm_config object exists and is accessible. - - This test verifies: - - mmm_config can be imported from pymc_marketing.mmm - - It has a "plot.backend" key - - Default backend is "matplotlib" - """ - from pymc_marketing.mmm import mmm_config - - assert "plot.backend" in mmm_config, \ - "mmm_config should have 'plot.backend' key" - assert mmm_config["plot.backend"] == "matplotlib", \ - f"Default backend should be 'matplotlib', got {mmm_config['plot.backend']}" -``` - -**Expected Failure Mode**: -- Error type: `ImportError` or `AttributeError` -- Expected message: `cannot import name 'mmm_config' from 'pymc_marketing.mmm'` - -##### Test: `test_mmm_config_backend_setting` -**Purpose**: Verify global backend can be changed and persists -**Test Data**: None needed -**Expected Behavior**: Setting backend value works and can be read back - -```python -def test_mmm_config_backend_setting(): - """ - Test that mmm_config backend can be set and retrieved. - - This test verifies: - - Backend can be changed from default - - New value persists - - Can be reset to default - """ - from pymc_marketing.mmm import mmm_config - - # Store original - original = mmm_config["plot.backend"] - - try: - # Change backend - mmm_config["plot.backend"] = "plotly" - assert mmm_config["plot.backend"] == "plotly", \ - "Backend should change to 'plotly'" - - # Reset - mmm_config.reset() - assert mmm_config["plot.backend"] == "matplotlib", \ - "reset() should restore default 'matplotlib' backend" - finally: - # Cleanup - mmm_config["plot.backend"] = original -``` - -**Expected Failure Mode**: -- Error type: `AttributeError` on `mmm_config.reset()` -- Expected message: `'dict' object has no attribute 'reset'` (if mmm_config is plain dict) - -##### Test: `test_mmm_config_invalid_backend_warning` -**Purpose**: Verify setting invalid backend emits a warning or raises error -**Test Data**: Invalid backend name "invalid_backend" -**Expected Behavior**: Validation prevents or warns about invalid backend - -```python -def test_mmm_config_invalid_backend_warning(): - """ - Test that setting an invalid backend name is handled gracefully. - - This test verifies: - - Invalid backend names are detected - - Either raises ValueError or emits UserWarning - - Helpful error message provided - """ - from pymc_marketing.mmm import mmm_config - import warnings - - original = mmm_config["plot.backend"] - - try: - # Attempt to set invalid backend - should either raise or warn - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - mmm_config["plot.backend"] = "invalid_backend" - - # If no exception, should have warning - assert len(w) > 0, \ - "Should emit warning for invalid backend" - assert "invalid" in str(w[0].message).lower(), \ - f"Warning should mention 'invalid', got: {w[0].message}" - except ValueError as e: - # Acceptable alternative: raise ValueError - assert "backend" in str(e).lower(), \ - f"Error should mention 'backend', got: {e}" - finally: - mmm_config["plot.backend"] = original -``` - -**Expected Failure Mode**: -- Error type: `AssertionError` -- Expected message: "Should emit warning for invalid backend" (no validation present) - -#### 2. Backend Parameter Tests (posterior_predictive) -**Test File**: `tests/mmm/test_plot_backends.py` -**Purpose**: Validate `backend` parameter is accepted and overrides global config - -**Test Cases to Write:** - -##### Test: `test_posterior_predictive_accepts_backend_parameter` -**Purpose**: Verify method accepts new `backend` parameter without error -**Test Data**: `mock_suite` fixture with posterior_predictive data -**Expected Behavior**: Method accepts backend="matplotlib" without TypeError - -```python -def test_posterior_predictive_accepts_backend_parameter(mock_suite_with_pp): - """ - Test that posterior_predictive() accepts backend parameter. - - This test verifies: - - backend parameter is accepted - - No TypeError is raised - - Method completes successfully - """ - # Should not raise TypeError - result = mock_suite_with_pp.posterior_predictive(backend="matplotlib") - - assert result is not None, \ - "posterior_predictive should return a result" -``` - -**Expected Failure Mode**: -- Error type: `TypeError` -- Expected message: `posterior_predictive() got an unexpected keyword argument 'backend'` - -##### Test: `test_posterior_predictive_accepts_return_as_pc_parameter` -**Purpose**: Verify method accepts new `return_as_pc` parameter without error -**Test Data**: `mock_suite_with_pp` fixture -**Expected Behavior**: Method accepts return_as_pc=False without TypeError - -```python -def test_posterior_predictive_accepts_return_as_pc_parameter(mock_suite_with_pp): - """ - Test that posterior_predictive() accepts return_as_pc parameter. - - This test verifies: - - return_as_pc parameter is accepted - - No TypeError is raised - """ - # Should not raise TypeError - result = mock_suite_with_pp.posterior_predictive(return_as_pc=False) - - assert result is not None, \ - "posterior_predictive should return a result" -``` - -**Expected Failure Mode**: -- Error type: `TypeError` -- Expected message: `posterior_predictive() got an unexpected keyword argument 'return_as_pc'` - -##### Test: `test_posterior_predictive_backend_overrides_global` -**Purpose**: Verify function parameter overrides global config -**Test Data**: `mock_suite_with_pp` fixture -**Expected Behavior**: backend="plotly" overrides global matplotlib setting - -```python -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_posterior_predictive_backend_overrides_global(mock_suite_with_pp, backend): - """ - Test that backend parameter overrides global mmm_config setting. - - This test verifies: - - Global config set to one backend - - Function called with different backend - - Function uses parameter, not global config - """ - from pymc_marketing.mmm import mmm_config - - original = mmm_config["plot.backend"] - - try: - # Set global to matplotlib - mmm_config["plot.backend"] = "matplotlib" - - # Call with different backend, request PlotCollection to check - pc = mock_suite_with_pp.posterior_predictive( - backend=backend, - return_as_pc=True - ) - - assert hasattr(pc, 'backend'), \ - "PlotCollection should have backend attribute" - assert pc.backend == backend, \ - f"PlotCollection backend should be '{backend}', got '{pc.backend}'" - finally: - mmm_config["plot.backend"] = original -``` - -**Expected Failure Mode**: -- Error type: `AttributeError` or `AssertionError` -- Expected message: `'tuple' object has no attribute 'backend'` (returns tuple instead of PlotCollection) - -#### 3. Return Type Tests (Backward Compatibility) -**Test File**: `tests/mmm/test_plot_backends.py` -**Purpose**: Verify return types match expectations based on `return_as_pc` parameter - -**Test Cases to Write:** - -##### Test: `test_posterior_predictive_returns_tuple_by_default` -**Purpose**: Verify backward compatibility - default returns tuple -**Test Data**: `mock_suite_with_pp` fixture -**Expected Behavior**: Returns `(Figure, List[Axes])` tuple by default - -```python -def test_posterior_predictive_returns_tuple_by_default(mock_suite_with_pp): - """ - Test that posterior_predictive() returns tuple by default (backward compat). - - This test verifies: - - Default behavior (no return_as_pc parameter) returns tuple - - Tuple has two elements: (figure, axes) - - axes is a list of matplotlib Axes objects (1D list, not 2D array) - """ - result = mock_suite_with_pp.posterior_predictive() - - assert isinstance(result, tuple), \ - f"Default return should be tuple, got {type(result)}" - assert len(result) == 2, \ - f"Tuple should have 2 elements (fig, axes), got {len(result)}" - - fig, axes = result - - # For matplotlib backend (default), should be Figure and list - from matplotlib.figure import Figure - from matplotlib.axes import Axes - assert isinstance(fig, Figure), \ - f"First element should be Figure, got {type(fig)}" - assert isinstance(axes, list), \ - f"Second element should be list, got {type(axes)}" - assert all(isinstance(ax, Axes) for ax in axes), \ - "All list elements should be matplotlib Axes instances" -``` - -**Expected Failure Mode**: -- Error type: `AssertionError` or `AttributeError` -- Expected message: `Default return should be tuple, got ` (if returning PC) - -##### Test: `test_posterior_predictive_returns_plotcollection_when_requested` -**Purpose**: Verify new behavior - returns PlotCollection when return_as_pc=True -**Test Data**: `mock_suite_with_pp` fixture -**Expected Behavior**: Returns PlotCollection object when return_as_pc=True - -```python -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_posterior_predictive_returns_plotcollection_when_requested( - mock_suite_with_pp, backend -): - """ - Test that posterior_predictive() returns PlotCollection when return_as_pc=True. - - This test verifies: - - return_as_pc=True returns PlotCollection object - - PlotCollection has correct backend attribute - """ - from arviz_plots import PlotCollection - - result = mock_suite_with_pp.posterior_predictive( - backend=backend, - return_as_pc=True - ) - - assert isinstance(result, PlotCollection), \ - f"Should return PlotCollection, got {type(result)}" - assert hasattr(result, 'backend'), \ - "PlotCollection should have backend attribute" - assert result.backend == backend, \ - f"Backend should be '{backend}', got '{result.backend}'" -``` - -**Expected Failure Mode**: -- Error type: `AssertionError` -- Expected message: `Should return PlotCollection, got ` (still returns tuple) - -##### Test: `test_posterior_predictive_tuple_has_correct_axes_for_matplotlib` -**Purpose**: Verify matplotlib backend returns list of Axes in tuple -**Test Data**: `mock_suite_with_pp` fixture -**Expected Behavior**: Tuple's second element is list of Axes objects - -```python -def test_posterior_predictive_tuple_has_correct_axes_for_matplotlib(mock_suite_with_pp): - """ - Test that matplotlib backend returns proper axes list in tuple. - - This test verifies: - - When return_as_pc=False and backend="matplotlib" - - Second tuple element is list of matplotlib Axes - - All elements in list are Axes instances - """ - from matplotlib.axes import Axes - - fig, axes = mock_suite_with_pp.posterior_predictive( - backend="matplotlib", - return_as_pc=False - ) - - assert isinstance(axes, list), \ - f"Axes should be list for matplotlib, got {type(axes)}" - assert all(isinstance(ax, Axes) for ax in axes), \ - "All list elements should be matplotlib Axes instances" -``` - -**Expected Failure Mode**: -- Error type: `AssertionError` -- Expected message: `Axes should be list for matplotlib, got ` (if not extracting axes) - -##### Test: `test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib` -**Purpose**: Verify non-matplotlib backends return None for axes in tuple -**Test Data**: `mock_suite_with_pp` fixture -**Expected Behavior**: Tuple's second element is None for plotly/bokeh - -```python -@pytest.mark.parametrize("backend", ["plotly", "bokeh"]) -def test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib( - mock_suite_with_pp, backend -): - """ - Test that non-matplotlib backends return None for axes in tuple. - - This test verifies: - - When return_as_pc=False and backend in ["plotly", "bokeh"] - - Second tuple element is None (no axes concept) - - First element is backend-specific figure object - """ - fig, axes = mock_suite_with_pp.posterior_predictive( - backend=backend, - return_as_pc=False - ) - - assert axes is None, \ - f"Axes should be None for {backend} backend, got {type(axes)}" - assert fig is not None, \ - f"Figure should exist for {backend} backend" -``` - -**Expected Failure Mode**: -- Error type: `AssertionError` -- Expected message: `Axes should be None for plotly backend, got ` (always matplotlib) - -#### 4. Visual Output Validation Tests -**Test File**: `tests/mmm/test_plot_backends.py` -**Purpose**: Verify that plots actually render and contain expected elements - -**Test Cases to Write:** - -##### Test: `test_posterior_predictive_plotcollection_has_viz_attribute` -**Purpose**: Verify PlotCollection has visualization data we can inspect -**Test Data**: `mock_suite_with_pp` fixture -**Expected Behavior**: PlotCollection has `viz` attribute with figure data - -```python -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_posterior_predictive_plotcollection_has_viz_attribute( - mock_suite_with_pp, backend -): - """ - Test that PlotCollection has viz attribute with figure data. - - This test verifies: - - PlotCollection has viz attribute - - viz has figure attribute - - Figure can be extracted - """ - from arviz_plots import PlotCollection - - pc = mock_suite_with_pp.posterior_predictive( - backend=backend, - return_as_pc=True - ) - - assert hasattr(pc, 'viz'), \ - "PlotCollection should have 'viz' attribute" - assert hasattr(pc.viz, 'figure'), \ - "PlotCollection.viz should have 'figure' attribute" - - # Should be able to extract figure - fig = pc.viz.figure.data.item() - assert fig is not None, \ - "Should be able to extract figure from PlotCollection" -``` - -**Expected Failure Mode**: -- Error type: `AttributeError` -- Expected message: `PlotCollection should have 'viz' attribute` (if PC not properly constructed) - -##### Test: `test_posterior_predictive_matplotlib_has_lines` -**Purpose**: Verify matplotlib output contains actual plot elements -**Test Data**: `mock_suite_with_pp` fixture with known variables -**Expected Behavior**: Axes contain Line2D objects (the actual plotted data) - -```python -def test_posterior_predictive_matplotlib_has_lines(mock_suite_with_pp): - """ - Test that matplotlib output contains actual plotted lines. - - This test verifies: - - Axes contain Line2D objects (plotted data) - - Number of lines matches expected variables - - Visual output actually created, not just empty axes - """ - from matplotlib.lines import Line2D - - fig, axes = mock_suite_with_pp.posterior_predictive( - backend="matplotlib", - return_as_pc=False - ) - - # Get first axis (should have plots) - ax = axes.flat[0] - - # Should have lines (median plots) - lines = [child for child in ax.get_children() if isinstance(child, Line2D)] - assert len(lines) > 0, \ - f"Axes should contain Line2D objects (plots), found {len(lines)}" -``` - -**Expected Failure Mode**: -- Error type: `AssertionError` -- Expected message: `Axes should contain Line2D objects (plots), found 0` (empty plot) - -##### Test: `test_posterior_predictive_plotly_has_traces` -**Purpose**: Verify plotly output contains traces (plotly's plot elements) -**Test Data**: `mock_suite_with_pp` fixture -**Expected Behavior**: Plotly figure has traces in data attribute - -```python -def test_posterior_predictive_plotly_has_traces(mock_suite_with_pp): - """ - Test that plotly output contains actual traces. - - This test verifies: - - Plotly figure has 'data' attribute with traces - - Number of traces > 0 (something was plotted) - - Visual output actually created - """ - fig, _ = mock_suite_with_pp.posterior_predictive( - backend="plotly", - return_as_pc=False - ) - - # Plotly figures have .data attribute with traces - assert hasattr(fig, 'data'), \ - "Plotly figure should have 'data' attribute" - assert len(fig.data) > 0, \ - f"Plotly figure should have traces, found {len(fig.data)}" -``` - -**Expected Failure Mode**: -- Error type: `AttributeError` or `AssertionError` -- Expected message: `Plotly figure should have 'data' attribute` (matplotlib Figure returned instead) - -##### Test: `test_posterior_predictive_bokeh_has_renderers` -**Purpose**: Verify bokeh output contains renderers (bokeh's plot elements) -**Test Data**: `mock_suite_with_pp` fixture -**Expected Behavior**: Bokeh figure has renderers (glyphs) - -```python -def test_posterior_predictive_bokeh_has_renderers(mock_suite_with_pp): - """ - Test that bokeh output contains actual renderers (plot elements). - - This test verifies: - - Bokeh figure has renderers - - Number of renderers > 0 (something was plotted) - - Visual output actually created - """ - fig, _ = mock_suite_with_pp.posterior_predictive( - backend="bokeh", - return_as_pc=False - ) - - # Bokeh figures have .renderers attribute - assert hasattr(fig, 'renderers'), \ - "Bokeh figure should have 'renderers' attribute" - assert len(fig.renderers) > 0, \ - f"Bokeh figure should have renderers, found {len(fig.renderers)}" -``` - -**Expected Failure Mode**: -- Error type: `AttributeError` or `AssertionError` -- Expected message: `Bokeh figure should have 'renderers' attribute` (matplotlib Figure returned instead) - -#### 5. Fixture Setup -**Test File**: `tests/mmm/test_plot_backends.py` -**Purpose**: Create reusable fixtures for backend testing - -```python -""" -Backend-agnostic plotting tests for MMMPlotSuite. - -This test file validates the migration to ArviZ PlotCollection API for -multi-backend support (matplotlib, plotly, bokeh). - -NOTE: Once this migration is complete and stable, evaluate whether -tests/mmm/test_plot.py can be consolidated into this file to avoid duplication. -""" - -import numpy as np -import pytest -import xarray as xr -import arviz as az -import pandas as pd -from matplotlib.figure import Figure -from matplotlib.axes import Axes - -from pymc_marketing.mmm.plot import MMMPlotSuite - - -@pytest.fixture(scope="module") -def mock_idata_for_pp(): - """ - Create mock InferenceData with posterior_predictive for testing. - - Structure mirrors real MMM output with: - - posterior_predictive group with y variable - - proper dimensions: chain, draw, date - - realistic date range - """ - seed = sum(map(ord, "Backend test posterior_predictive")) - rng = np.random.default_rng(seed) - - dates = pd.date_range("2025-01-01", periods=52, freq="W-MON") - - # Create posterior_predictive data - posterior_predictive = xr.Dataset({ - "y": xr.DataArray( - rng.normal(loc=100, scale=10, size=(4, 100, 52)), - dims=("chain", "draw", "date"), - coords={ - "chain": np.arange(4), - "draw": np.arange(100), - "date": dates, - }, - ) - }) - - # Also create a minimal posterior (required for some internal logic) - posterior = xr.Dataset({ - "intercept": xr.DataArray( - rng.normal(size=(4, 100)), - dims=("chain", "draw"), - coords={ - "chain": np.arange(4), - "draw": np.arange(100), - }, - ) - }) - - return az.InferenceData( - posterior=posterior, - posterior_predictive=posterior_predictive - ) - - -@pytest.fixture(scope="module") -def mock_suite_with_pp(mock_idata_for_pp): - """ - Fixture providing MMMPlotSuite with posterior_predictive data. - - Used for testing posterior_predictive() method across backends. - """ - return MMMPlotSuite(idata=mock_idata_for_pp) - - -@pytest.fixture(scope="function") -def reset_mmm_config(): - """ - Fixture to reset mmm_config after each test. - - Ensures test isolation - one test's backend changes don't affect others. - """ - from pymc_marketing.mmm import mmm_config - - original = mmm_config["plot.backend"] - yield - mmm_config["plot.backend"] = original -``` - -### Implementation Steps - -1. **Create test file**: `tests/mmm/test_plot_backends.py` -2. **Add note to existing test file**: Edit `tests/mmm/test_plot.py` line 1 to add: - ```python - # NOTE: This file may be consolidated with test_plot_backends.py in the future - # once the backend migration is complete and stable. - ``` - -3. **Implement fixtures** (see Fixture Setup section above) - -4. **Implement all test cases** in the order listed: - - Infrastructure tests (global config) - - Backend parameter tests - - Return type tests - - Visual output validation tests - -5. **Run tests to verify failures**: `pytest tests/mmm/test_plot_backends.py -v` - -### Success Criteria - -#### Automated Verification: -- [x] Test file created: `tests/mmm/test_plot_backends.py` -- [x] All tests discovered: `pytest tests/mmm/test_plot_backends.py --collect-only` -- [x] Tests fail (not pass): `pytest tests/mmm/test_plot_backends.py --tb=short` -- [x] No import/syntax errors: `pytest tests/mmm/test_plot_backends.py --tb=line` -- [x] Linting passes: `make lint` -- [x] Test code follows conventions: Style matches `test_plot.py` patterns - -#### Manual Verification: -- [ ] Each test has clear docstring explaining what it validates -- [ ] Test names clearly describe what they test (e.g., `test_X_does_Y`) -- [ ] Assertion messages are diagnostic and helpful -- [ ] Fixtures are well-documented with realistic data -- [ ] Test file header includes note about consolidation - ---- - -## Phase 2: Test Failure Verification - -### Overview -Run the tests and verify they fail in the expected, diagnostic ways. This ensures our tests are actually testing something and will catch regressions. - -### Verification Steps - -1. **Run the test suite**: - ```bash - pytest tests/mmm/test_plot_backends.py -v - ``` - -2. **Verify all tests are discovered**: - ```bash - pytest tests/mmm/test_plot_backends.py --collect-only - ``` - Expected: All tests listed, no collection errors - -3. **Check failure modes**: - ```bash - pytest tests/mmm/test_plot_backends.py -v --tb=short - ``` - Review each failure to ensure it matches expected failure mode - -### Expected Failures - -**Infrastructure Tests:** -- `test_mmm_config_exists`: `ImportError: cannot import name 'mmm_config'` -- `test_mmm_config_backend_setting`: `ImportError: cannot import name 'mmm_config'` -- `test_mmm_config_invalid_backend_warning`: `ImportError: cannot import name 'mmm_config'` - -**Backend Parameter Tests:** -- `test_posterior_predictive_accepts_backend_parameter`: `TypeError: posterior_predictive() got an unexpected keyword argument 'backend'` -- `test_posterior_predictive_accepts_return_as_pc_parameter`: `TypeError: posterior_predictive() got an unexpected keyword argument 'return_as_pc'` -- `test_posterior_predictive_backend_overrides_global`: `ImportError: cannot import name 'mmm_config'` or `TypeError` (backend param) - -**Return Type Tests:** -- `test_posterior_predictive_returns_tuple_by_default`: Should PASS (existing behavior works) -- `test_posterior_predictive_returns_plotcollection_when_requested`: `TypeError: unexpected keyword argument 'return_as_pc'` -- `test_posterior_predictive_tuple_has_correct_axes_for_matplotlib`: Should PASS (existing behavior) -- `test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib`: `TypeError: unexpected keyword argument 'backend'` - -**Visual Output Tests:** -- `test_posterior_predictive_plotcollection_has_viz_attribute`: `TypeError: unexpected keyword argument 'return_as_pc'` -- `test_posterior_predictive_matplotlib_has_lines`: Should PASS (existing behavior works) -- `test_posterior_predictive_plotly_has_traces`: `TypeError: unexpected keyword argument 'backend'` -- `test_posterior_predictive_bokeh_has_renderers`: `TypeError: unexpected keyword argument 'backend'` - -### Success Criteria - -#### Automated Verification: -- [ ] All tests run (no collection errors): `pytest tests/mmm/test_plot_backends.py --collect-only` -- [ ] Expected number of failures: Count matches test cases written -- [ ] No unexpected errors: No `ImportError` on test fixtures, no syntax errors -- [ ] Existing tests still pass: `pytest tests/mmm/test_plot.py -k test_posterior_predictive` - -#### Manual Verification: -- [ ] Each test fails with expected error type (TypeError, ImportError, AssertionError as listed) -- [ ] Failure messages clearly indicate what's missing -- [ ] Failure messages would help during implementation (diagnostic) -- [ ] Stack traces point to relevant code locations (test assertions, not fixture setup) -- [ ] No cryptic or misleading error messages - -### Adjustment Phase - -If tests don't fail properly: - -**Problem**: Tests pass unexpectedly -- **Fix**: Review test assertions - they may be too lenient -- **Action**: Add stricter type checks, verify specific attributes - -**Problem**: Tests error instead of fail (e.g., ImportError on fixtures) -- **Fix**: Check fixture dependencies, ensure mock data doesn't rely on new code -- **Action**: Simplify fixtures to not use non-existent features - -**Problem**: Confusing error messages -- **Fix**: Improve assertion messages with context -- **Action**: Add `assert x, f"Expected Y, got {x}"` style messages - -**Problem**: Tests fail in wrong order (dependency issues) -- **Fix**: Ensure test isolation - no shared state between tests -- **Action**: Use `reset_mmm_config` fixture, don't modify shared fixtures - -**Checklist for Adjustment:** -- [ ] All infrastructure tests fail with ImportError or AttributeError -- [ ] All backend parameter tests fail with TypeError (unexpected keyword) -- [ ] Return type tests for new behavior fail with TypeError -- [ ] Return type tests for existing behavior PASS -- [ ] Visual output tests fail with TypeError (unexpected keyword) - ---- - -## Phase 3: Feature Implementation (Red → Green) - -### Overview -Implement the feature by making tests pass, one at a time. Work like debugging - let test failures guide what needs to be implemented next. - -### Implementation Strategy - -**Order of Implementation:** -1. Global config infrastructure (`mmm_config`) -2. Add `backend` and `return_as_pc` parameters to `posterior_predictive()` -3. Implement PlotCollection integration -4. Implement figure/axes extraction for tuple return -5. Verify visual output across backends - -### Implementation 1: Create Global Configuration - -**Target Tests**: -- `test_mmm_config_exists` -- `test_mmm_config_backend_setting` -- `test_mmm_config_invalid_backend_warning` - -**Current Failure**: `ImportError: cannot import name 'mmm_config' from 'pymc_marketing.mmm'` - -**Changes Required:** - -**File**: `pymc_marketing/mmm/config.py` (NEW) -**Purpose**: Global configuration management for MMM plotting - -```python -"""Configuration management for MMM plotting.""" - -VALID_BACKENDS = {"matplotlib", "plotly", "bokeh"} - - -class MMMConfig(dict): - """ - Configuration dictionary for MMM plotting settings. - - Provides backend configuration with validation and reset functionality. - Modeled after ArviZ's rcParams pattern. - - Examples - -------- - >>> from pymc_marketing.mmm import mmm_config - >>> mmm_config["plot.backend"] = "plotly" - >>> mmm_config["plot.backend"] - 'plotly' - >>> mmm_config.reset() - >>> mmm_config["plot.backend"] - 'matplotlib' - """ - - _defaults = { - "plot.backend": "matplotlib", - "plot.show_warnings": True, - } - - def __init__(self): - super().__init__(self._defaults) - - def __setitem__(self, key, value): - """Set config value with validation for backend.""" - if key == "plot.backend": - if value not in VALID_BACKENDS: - import warnings - warnings.warn( - f"Invalid backend '{value}'. Valid backends are: {VALID_BACKENDS}. " - f"Setting anyway, but plotting may fail.", - UserWarning, - stacklevel=2 - ) - super().__setitem__(key, value) - - def reset(self): - """Reset all configuration to default values.""" - self.clear() - self.update(self._defaults) - - -# Global config instance -mmm_config = MMMConfig() -``` - -**File**: `pymc_marketing/mmm/__init__.py` -**Changes**: Add mmm_config export - -```python -# Existing imports... - -from pymc_marketing.mmm.config import mmm_config - -__all__ = [ - # ... existing exports ... - "mmm_config", -] -``` - -**Debugging Approach:** -1. Create `config.py` with MMMConfig class -2. Run: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_exists -v` -3. If fails, check import path and __all__ export -4. Run: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_backend_setting -v` -5. If fails, check reset() implementation -6. Run: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_invalid_backend_warning -v` -7. If fails, verify warning is emitted in __setitem__ - -**Success Criteria:** - -##### Automated Verification: -- [x] Test passes: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_exists -v` -- [x] Test passes: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_backend_setting -v` -- [x] Test passes: `pytest tests/mmm/test_plot_backends.py::test_mmm_config_invalid_backend_warning -v` -- [x] Can import: `python -c "from pymc_marketing.mmm import mmm_config; print(mmm_config['plot.backend'])"` -- [x] Linting passes: `make lint` -- [x] Type checking passes: `mypy pymc_marketing/mmm/config.py` (no new errors) - -##### Manual Verification: -- [ ] Code is clean and well-documented -- [ ] Follows project conventions (NumPy docstrings) -- [ ] No performance issues (dict operations are O(1)) -- [ ] Warning messages are clear and actionable - -### Implementation 2: Add Parameters to posterior_predictive() - -**Target Tests**: -- `test_posterior_predictive_accepts_backend_parameter` -- `test_posterior_predictive_accepts_return_as_pc_parameter` - -**Current Failure**: `TypeError: posterior_predictive() got an unexpected keyword argument 'backend'` - -**Changes Required:** - -**File**: `pymc_marketing/mmm/plot.py` -**Method**: `posterior_predictive()` (line 375) -**Changes**: Add backend and return_as_pc parameters - -```python -def posterior_predictive( - self, - var: list[str] | None = None, - idata: xr.Dataset | None = None, - hdi_prob: float = 0.85, - backend: str | None = None, - return_as_pc: bool = False, -) -> tuple[Figure, list[Axes] | None] | "PlotCollection": - """ - Plot posterior predictive distributions over time. - - Parameters - ---------- - var : list of str, optional - List of variable names to plot. If None, uses "y". - idata : xr.Dataset, optional - Dataset containing posterior predictive samples. - If None, uses self.idata.posterior_predictive. - hdi_prob : float, default 0.85 - Probability mass for HDI interval. - backend : str, optional - Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". - If None, uses global config via mmm_config["plot.backend"]. - Default (via config) is "matplotlib". - return_as_pc : bool, default False - If True, returns PlotCollection object. - If False, returns tuple (figure, axes) for backward compatibility. - - Returns - ------- - PlotCollection or tuple - If return_as_pc=True, returns PlotCollection object. - If return_as_pc=False, returns (figure, axes) where: - - figure: backend-specific figure object (matplotlib.figure.Figure, - plotly.graph_objs.Figure, or bokeh.plotting.Figure) - - axes: list of matplotlib Axes if backend="matplotlib", else None - - Notes - ----- - When backend is not "matplotlib" and return_as_pc=False, the axes - element of the returned tuple will be None, as plotly and bokeh - do not have an equivalent axes list concept. - - Examples - -------- - >>> # Backward compatible usage (matplotlib) - >>> fig, axes = model.plot.posterior_predictive() - - >>> # Multi-backend with PlotCollection - >>> pc = model.plot.posterior_predictive(backend="plotly", return_as_pc=True) - >>> pc.show() - """ - from pymc_marketing.mmm.config import mmm_config - - # Resolve backend (parameter overrides global config) - backend = backend or mmm_config["plot.backend"] - - # Temporary: Keep existing matplotlib implementation - # This makes tests pass (accepts parameters) but doesn't use them yet - # We'll implement PlotCollection integration in next step - - # [Existing implementation continues unchanged for now...] - # Just pass through to existing code -``` - -**Debugging Approach:** -1. Add parameters to signature with defaults -2. Add backend resolution logic (import mmm_config, use parameter or config) -3. Run: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_accepts_backend_parameter -v` -4. Should PASS (accepts parameter even if not used yet) -5. Run: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_accepts_return_as_pc_parameter -v` -6. Should PASS (accepts parameter even if not used yet) -7. Update docstring with new parameters -8. Update type hints in return annotation - -**Success Criteria:** - -##### Automated Verification: -- [x] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_accepts_backend_parameter -v` -- [x] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_accepts_return_as_pc_parameter -v` -- [x] Existing tests still pass: `pytest tests/mmm/test_plot.py::test_posterior_predictive -v` -- [x] Linting passes: `make lint` -- [x] Type checking passes: `mypy pymc_marketing/mmm/plot.py` - -##### Manual Verification: -- [x] Docstring updated with new parameters (NumPy style) -- [x] Default values maintain backward compatibility -- [x] Parameter order is logical (existing params first, new params last) -- [x] Type hints are accurate (use string quotes for forward ref to PlotCollection) - -### Implementation 3: Integrate PlotCollection and Return Type Logic - -**Target Tests**: -- `test_posterior_predictive_returns_tuple_by_default` -- `test_posterior_predictive_returns_plotcollection_when_requested` -- `test_posterior_predictive_backend_overrides_global` -- `test_posterior_predictive_tuple_has_correct_axes_for_matplotlib` -- `test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib` - -**Current Failure**: Tests pass/fail mix - need to integrate PlotCollection - -**Changes Required:** - -This is the most complex implementation step. We need to: -1. Create PlotCollection-based plotting logic -2. Implement figure/axes extraction for tuple return -3. Handle backend-specific differences - -**File**: `pymc_marketing/mmm/plot.py` -**Method**: `posterior_predictive()` (line 375) -**Changes**: Complete PlotCollection integration - -```python -def posterior_predictive( - self, - var: list[str] | None = None, - idata: xr.Dataset | None = None, - hdi_prob: float = 0.85, - backend: str | None = None, - return_as_pc: bool = False, -) -> tuple[Figure, list[Axes] | None] | "PlotCollection": - """[Docstring from previous step]""" - from pymc_marketing.mmm.config import mmm_config - from arviz_plots import PlotCollection, visuals - - # Resolve backend (parameter overrides global config) - backend = backend or mmm_config["plot.backend"] - - # Get data - var = var or ["y"] - pp_data = self._get_posterior_predictive_data(idata=idata) - - # Get dimension combinations for subplots - ignored_dims = {"chain", "draw", "date", "sample"} - available_dims = [d for d in pp_data[var[0]].dims if d not in ignored_dims] - additional_dims = [d for d in available_dims if d not in var] - dim_combinations = self._get_additional_dim_combinations( - pp_data[var[0]], additional_dims - ) - - n_subplots = len(dim_combinations) - - # Create PlotCollection with grid layout - # We'll build a dataset for PlotCollection - plot_data = {} - for v in var: - data = pp_data[v] - # Stack chain and draw into sample dimension - if "chain" in data.dims and "draw" in data.dims: - data = data.stack(sample=("chain", "draw")) - plot_data[v] = data - - plot_dataset = xr.Dataset(plot_data) - - # Create figure with appropriate layout - # PlotCollection.grid creates a grid of subplots - pc = PlotCollection.grid( - plot_dataset, - backend=backend, - plots_per_row=1, # One column layout like original - figsize=(10, 4 * n_subplots), - ) - - # For each subplot, add line plot and HDI - for row_idx, combo in enumerate(dim_combinations): - indexers = dict(zip(additional_dims, combo, strict=False)) if additional_dims else {} - - # Select subplot - if n_subplots > 1: - # Multi-panel: select by row index - pc_subplot = pc.sel(row=row_idx) - else: - # Single panel: use full pc - pc_subplot = pc - - for v in var: - data = plot_data[v].sel(**indexers) if indexers else plot_data[v] - - # Compute median and HDI - median = data.median(dim="sample") - hdi = az.hdi(data, hdi_prob=hdi_prob, input_core_dims=[["sample"]]) - - # Add median line - pc_subplot.map( - visuals.line, - data=median.rename("median"), - color=f"C{var.index(v)}", - label=v, - ) - - # Add HDI band - pc_subplot.map( - visuals.fill_between, - data1=hdi[v].sel(hdi="lower"), - data2=hdi[v].sel(hdi="higher"), - color=f"C{var.index(v)}", - alpha=0.2, - ) - - # Add labels - title = self._build_subplot_title(additional_dims, combo, "Posterior Predictive") - pc_subplot.map(visuals.labelled, title=title, xlabel="Date", ylabel="Posterior Predictive") - pc_subplot.map(visuals.legend) - - # Return based on return_as_pc flag - if return_as_pc: - return pc - else: - # Extract figure from PlotCollection - fig = pc.viz.figure.data.item() - - # Extract axes (only for matplotlib) - if backend == "matplotlib": - axes = list(fig.get_axes()) # Return as simple list - else: - axes = None - - return fig, axes -``` - -**Note**: The above is pseudocode showing the structure. Actual implementation will need to: -- Check PlotCollection API documentation for exact method signatures -- Handle dimension combinations correctly -- Ensure HDI computation works with PlotCollection -- Test iteratively with debugger - -**Debugging Approach:** -1. Start with simplest case: single variable, no extra dimensions -2. Run: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_returns_plotcollection_when_requested[matplotlib] -v` -3. Debug PlotCollection creation - check what data format it expects -4. Debug median/HDI computation - verify dimensions match -5. Debug PlotCollection.map() calls - check visual function signatures -6. Once matplotlib works, test plotly: `pytest ... [plotly]` -7. Debug backend-specific issues (figure extraction, etc.) -8. Test tuple return: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_returns_tuple_by_default -v` -9. Debug figure/axes extraction logic - -**Alternative Simpler Approach** (if PlotCollection API is challenging): - -Keep matplotlib implementation, add a wrapper that converts to/from PlotCollection: - -```python -def posterior_predictive(self, ..., backend=None, return_as_pc=False): - """[docstring]""" - from pymc_marketing.mmm.config import mmm_config - backend = backend or mmm_config["plot.backend"] - - # For now, always use matplotlib internally - # This lets us make progress while learning PlotCollection API - fig_mpl, axes_mpl = self._posterior_predictive_matplotlib( - var=var, idata=idata, hdi_prob=hdi_prob - ) - - if backend != "matplotlib": - # Convert matplotlib to other backend via PlotCollection - # This is a valid incremental approach - import warnings - warnings.warn( - f"Backend '{backend}' requested but full support not yet implemented. " - f"Using matplotlib with conversion.", - UserWarning - ) - # Conversion logic here... - - if return_as_pc: - # Wrap matplotlib figure in PlotCollection - pc = PlotCollection.wrap(fig_mpl, backend=backend) - return pc - else: - if backend == "matplotlib": - return fig_mpl, axes_mpl - else: - # Convert figure to target backend - fig_converted = convert_figure(fig_mpl, backend) - return fig_converted, None -``` - -This incremental approach lets tests pass while we refine the implementation. - -**Success Criteria:** - -##### Automated Verification: -- [ ] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_returns_plotcollection_when_requested -v` -- [ ] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_returns_tuple_by_default -v` -- [ ] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_backend_overrides_global -v` -- [ ] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_tuple_has_correct_axes_for_matplotlib -v` -- [ ] Test passes: `pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib -v` -- [ ] All existing tests pass: `pytest tests/mmm/test_plot.py::test_posterior_predictive -v` -- [ ] Linting passes: `make lint` - -##### Manual Verification: -- [ ] PlotCollection objects are created correctly -- [ ] Figure extraction works for all backends -- [ ] Axes extraction works for matplotlib, returns None for others -- [ ] Visual output looks reasonable (manually inspect one plot per backend) -- [ ] No performance regressions (test with `time pytest ...`) - -### Implementation 4: Visual Output Validation - -**Target Tests**: -- `test_posterior_predictive_plotcollection_has_viz_attribute` -- `test_posterior_predictive_matplotlib_has_lines` -- `test_posterior_predictive_plotly_has_traces` -- `test_posterior_predictive_bokeh_has_renderers` - -**Current State**: May already pass if Implementation 3 is complete, or may need refinement - -**Debugging Approach:** -1. Run: `pytest tests/mmm/test_plot_backends.py -k "visual_output" -v` -2. If `test_plotcollection_has_viz_attribute` fails: - - Check PlotCollection structure - - Verify viz.figure.data.item() works -3. If `test_matplotlib_has_lines` fails: - - Check that median lines are actually plotted - - Verify Line2D objects exist in axes -4. If `test_plotly_has_traces` fails: - - Check plotly figure structure - - Verify conversion from matplotlib worked - - Check fig.data contains traces -5. If `test_bokeh_has_renderers` fails: - - Check bokeh figure structure - - Verify renderers exist - -**Possible Issues and Fixes:** -- **Empty plots**: Check that visuals.line() is actually called -- **Wrong backend**: Verify backend parameter is passed through correctly -- **Extraction fails**: Check PlotCollection API version, may need updates - -**Success Criteria:** - -##### Automated Verification: -- [ ] All visual output tests pass: `pytest tests/mmm/test_plot_backends.py -k "visual" -v` -- [ ] No warnings about empty plots -- [ ] All backends produce non-empty output - -##### Manual Verification: -- [ ] Matplotlib plots look correct (run test, inspect saved figure manually) -- [ ] Plotly plots render correctly (check fig.show() if interactive) -- [ ] Bokeh plots render correctly (check bokeh output) -- [ ] HDI bands are visible and correct - -### Complete Feature Implementation - -Once all tests pass: - -**Final Integration Check:** -```bash -# Run all backend tests -pytest tests/mmm/test_plot_backends.py -v - -# Run all existing tests to ensure no regressions -pytest tests/mmm/test_plot.py -v - -# Run full test suite -pytest tests/mmm/ -v - -# Check coverage for new code -pytest tests/mmm/test_plot_backends.py --cov=pymc_marketing.mmm.plot --cov-report=term-missing -``` - -**Success Criteria:** - -##### Automated Verification: -- [ ] All new tests pass: `pytest tests/mmm/test_plot_backends.py -v` -- [ ] No regressions: `pytest tests/mmm/test_plot.py::test_posterior_predictive -v` -- [ ] All MMM tests pass: `pytest tests/mmm/ -v` -- [ ] Code coverage: New code is >90% covered -- [ ] Linting passes: `make lint` -- [ ] Type checking passes: `make typecheck` - -##### Manual Verification: -- [ ] Can import and use mmm_config: `from pymc_marketing.mmm import mmm_config` -- [ ] Backward compatible: Old code works unchanged -- [ ] New API works: Can switch backends and get PlotCollection -- [ ] Visual output: Plots look correct in all three backends -- [ ] Documentation: Docstrings are complete and accurate - ---- - -## Phase 4: Refactoring & Cleanup - -### Overview -Now that tests are green, refactor to improve code quality while keeping tests passing. Tests protect us during refactoring. - -### Refactoring Targets - -#### 1. Code Duplication in Test File -**Problem**: Test cases may have repeated setup code -**Solution**: Extract common patterns to helper functions - -```python -# tests/mmm/test_plot_backends.py - -def assert_valid_plotcollection(pc, expected_backend): - """ - Helper to validate PlotCollection structure. - - Reduces duplication across tests. - """ - from arviz_plots import PlotCollection - - assert isinstance(pc, PlotCollection), \ - f"Should return PlotCollection, got {type(pc)}" - assert hasattr(pc, 'backend'), \ - "PlotCollection should have backend attribute" - assert pc.backend == expected_backend, \ - f"Backend should be '{expected_backend}', got '{pc.backend}'" - - -def assert_valid_backend_figure(fig, backend): - """ - Helper to validate backend-specific figure types. - """ - if backend == "matplotlib": - from matplotlib.figure import Figure - assert isinstance(fig, Figure) - elif backend == "plotly": - assert hasattr(fig, 'data'), "Plotly figure should have 'data'" - elif backend == "bokeh": - assert hasattr(fig, 'renderers'), "Bokeh figure should have 'renderers'" -``` - -#### 2. Backend Resolution Logic -**Problem**: Backend resolution logic may be duplicated in every method -**Solution**: Extract to a helper method in MMMPlotSuite - -```python -# pymc_marketing/mmm/plot.py - -class MMMPlotSuite: - """[existing docstring]""" - - def _resolve_backend(self, backend: str | None) -> str: - """ - Resolve backend parameter to actual backend string. - - Parameters - ---------- - backend : str or None - Backend parameter from method call. - - Returns - ------- - str - Resolved backend name (parameter overrides global config). - - Examples - -------- - >>> suite._resolve_backend(None) # uses global config - 'matplotlib' - >>> suite._resolve_backend("plotly") # uses parameter - 'plotly' - """ - from pymc_marketing.mmm.config import mmm_config - return backend or mmm_config["plot.backend"] -``` - -#### 3. Figure/Axes Extraction Logic -**Problem**: Tuple return logic may be complex and repeated -**Solution**: Extract to helper method - -```python -# pymc_marketing/mmm/plot.py - -class MMMPlotSuite: - """[existing docstring]""" - - def _extract_figure_and_axes( - self, - pc: "PlotCollection", - backend: str - ) -> tuple: - """ - Extract figure and axes from PlotCollection for tuple return. - - Parameters - ---------- - pc : PlotCollection - PlotCollection object to extract from. - backend : str - Backend name ("matplotlib", "plotly", or "bokeh"). - - Returns - ------- - tuple - (figure, axes) where figure is backend-specific Figure object - and axes is list of Axes for matplotlib, None for other backends. - - Notes - ----- - This method enables backward compatibility by extracting matplotlib-style - return values from PlotCollection objects. - """ - # Extract figure - fig = pc.viz.figure.data.item() - - # Extract axes (only for matplotlib) - if backend == "matplotlib": - axes = list(fig.get_axes()) - else: - axes = None - - return fig, axes -``` - -#### 4. Simplify PlotCollection Creation -**Problem**: PlotCollection creation logic may be verbose -**Solution**: Extract data preparation to helper method - -```python -# pymc_marketing/mmm/plot.py - -class MMMPlotSuite: - """[existing docstring]""" - - def _prepare_data_for_plotcollection( - self, - data: xr.DataArray, - stack_dims: tuple[str, ...] = ("chain", "draw") - ) -> xr.DataArray: - """ - Prepare xarray data for PlotCollection plotting. - - Parameters - ---------- - data : xr.DataArray - Input data with MCMC dimensions. - stack_dims : tuple of str, default ("chain", "draw") - Dimensions to stack into 'sample' dimension. - - Returns - ------- - xr.DataArray - Data with chain and draw stacked into sample dimension. - """ - if all(d in data.dims for d in stack_dims): - data = data.stack(sample=stack_dims) - return data -``` - -#### 5. Test Code Quality -**Problem**: Some tests may have complex setup -**Solution**: Use parametrize more effectively - -```python -# Example refactoring of tests - -# Before: Multiple similar test functions -def test_backend_matplotlib_works(): - pc = suite.posterior_predictive(backend="matplotlib", return_as_pc=True) - assert pc.backend == "matplotlib" - -def test_backend_plotly_works(): - pc = suite.posterior_predictive(backend="plotly", return_as_pc=True) - assert pc.backend == "plotly" - -def test_backend_bokeh_works(): - pc = suite.posterior_predictive(backend="bokeh", return_as_pc=True) - assert pc.backend == "bokeh" - -# After: Single parametrized test -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_backend_parameter_works(suite, backend): - """Test that all backends work correctly.""" - pc = suite.posterior_predictive(backend=backend, return_as_pc=True) - assert pc.backend == backend -``` - -### Refactoring Steps - -1. **Ensure all tests pass before starting**: - ```bash - pytest tests/mmm/test_plot_backends.py -v - ``` - -2. **For each refactoring**: - - Make the change (extract helper, rename variable, etc.) - - Run tests immediately: `pytest tests/mmm/test_plot_backends.py -v` - - If tests pass, commit the change (or move to next refactoring) - - If tests fail, revert and reconsider - -3. **Focus areas**: - - Extract helper methods (backend resolution, figure extraction) - - Improve naming (clear variable names, descriptive method names) - - Add code comments where logic is complex - - Simplify conditional logic - - Remove any dead code or unused imports - -4. **Test code refactoring**: - - Extract test helpers (assertion helpers) - - Use parametrize more effectively - - Improve test names for clarity - - Add docstrings to complex test fixtures - -### Success Criteria - -#### Automated Verification: -- [ ] All tests still pass: `pytest tests/mmm/test_plot_backends.py -v` -- [ ] No regressions: `pytest tests/mmm/test_plot.py -v` -- [ ] Code coverage maintained: `pytest --cov=pymc_marketing.mmm.plot --cov-report=term-missing` -- [ ] Linting passes: `make lint` -- [ ] Type checking passes: `mypy pymc_marketing/mmm/plot.py` -- [ ] No performance regressions: Compare test run time before/after - -#### Manual Verification: -- [ ] Code is more readable after refactoring -- [ ] No unnecessary complexity added -- [ ] Function/variable names are clear and descriptive -- [ ] Comments explain "why" not "what" -- [ ] Helper methods have clear single responsibilities -- [ ] Test code is DRY (Don't Repeat Yourself) -- [ ] Code follows project idioms (check CLAUDE.md patterns) - ---- - -## Phase 5: Expand to contributions_over_time() - -### Overview -Apply the same TDD process to the second method, `contributions_over_time()`. This method is similar to `posterior_predictive()`, so the pattern is established. - -### Test Design for contributions_over_time() - -**New Test Cases** (add to `tests/mmm/test_plot_backends.py`): - -```python -@pytest.fixture(scope="module") -def mock_suite_with_contributions(mock_idata): - """ - Fixture providing MMMPlotSuite with contribution data. - - Reuses mock_idata which already has intercept and linear_trend. - """ - return MMMPlotSuite(idata=mock_idata) - - -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_contributions_over_time_backend_parameter(mock_suite_with_contributions, backend): - """Test contributions_over_time accepts backend parameter and uses it.""" - pc = mock_suite_with_contributions.contributions_over_time( - var=["intercept"], - backend=backend, - return_as_pc=True - ) - assert_valid_plotcollection(pc, backend) - - -def test_contributions_over_time_returns_tuple_by_default(mock_suite_with_contributions): - """Test backward compatibility - returns tuple by default.""" - result = mock_suite_with_contributions.contributions_over_time( - var=["intercept"] - ) - assert isinstance(result, tuple) - assert len(result) == 2 - - -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_contributions_over_time_plotcollection(mock_suite_with_contributions, backend): - """Test return_as_pc=True returns PlotCollection.""" - pc = mock_suite_with_contributions.contributions_over_time( - var=["intercept"], - backend=backend, - return_as_pc=True - ) - assert_valid_plotcollection(pc, backend) - - -# Add visual output tests similar to posterior_predictive -``` - -### Implementation Steps - -1. **Write tests first**: Add all test cases to `test_plot_backends.py` -2. **Run tests to verify failures**: `pytest tests/mmm/test_plot_backends.py -k contributions_over_time -v` -3. **Add parameters to method signature**: Same as posterior_predictive -4. **Implement PlotCollection integration**: Reuse patterns from posterior_predictive -5. **Extract figure/axes for tuple return**: Use helper methods from refactoring -6. **Verify all tests pass**: `pytest tests/mmm/test_plot_backends.py -k contributions_over_time -v` - -### Success Criteria - -- [ ] All contributions_over_time tests pass -- [ ] Existing contributions_over_time tests still pass -- [ ] Code follows same pattern as posterior_predictive -- [ ] Refactored helpers are reused (no duplication) - ---- - -## Testing Strategy Summary - -### Test Coverage Goals -- [x] Normal operation paths: All public methods work with default parameters -- [x] Backend switching: All three backends (matplotlib, plotly, bokeh) work -- [x] Return type switching: Both tuple and PlotCollection returns work -- [x] Backward compatibility: Existing code works unchanged -- [x] Configuration: Global and per-function backend configuration -- [x] Edge cases: Invalid backends warn, missing data errors clearly -- [x] Visual output: Plots contain expected elements (lines, traces, renderers) - -### Test Organization -- **Test files**: - - `tests/mmm/test_plot_backends.py` (NEW) - Backend migration tests - - `tests/mmm/test_plot.py` (EXISTING) - Original tests, marked for future consolidation -- **Fixtures**: - - Module-scoped for expensive InferenceData creation - - Function-scoped for config cleanup - - Located at top of test file -- **Test utilities**: - - Helper assertions (assert_valid_plotcollection, assert_valid_backend_figure) - - Located in test file (not separate module yet) -- **Test data**: - - xarray-based InferenceData fixtures - - Realistic structure matching MMM output - -### Running Tests - -```bash -# Run all backend tests -pytest tests/mmm/test_plot_backends.py -v - -# Run specific test -pytest tests/mmm/test_plot_backends.py::test_posterior_predictive_returns_plotcollection_when_requested -v - -# Run with coverage -pytest tests/mmm/test_plot_backends.py --cov=pymc_marketing.mmm.plot --cov-report=term-missing - -# Run with failure details -pytest tests/mmm/test_plot_backends.py -vv --tb=short - -# Run only matplotlib backend tests (faster) -pytest tests/mmm/test_plot_backends.py -k "matplotlib" -v - -# Run all backends in parallel (if pytest-xdist installed) -pytest tests/mmm/test_plot_backends.py -n auto -``` - -## Performance Considerations - -Performance testing is explicitly out of scope (requirement #2), but we should avoid obvious regressions: - -- **Keep existing matplotlib path fast**: Don't add unnecessary overhead for default usage -- **Lazy imports**: Import PlotCollection only when needed -- **Reuse computations**: Don't recompute HDI if already computed -- **Fixture scope**: Use module-scoped fixtures to avoid repeated setup - -## Migration Notes - -### For Users - -**Backward Compatibility**: -- All existing code continues to work without changes -- Default behavior unchanged (matplotlib, tuple return) -- No breaking changes to public API - -**New Features**: -```python -# Global backend configuration -from pymc_marketing.mmm import mmm_config -mmm_config["plot.backend"] = "plotly" - -# All plots now use plotly -model.plot.posterior_predictive() - -# Override for specific plot -model.plot.contributions_over_time(backend="matplotlib") - -# Get PlotCollection for advanced customization -pc = model.plot.saturation_curves(curve=curve_data, return_as_pc=True) -pc.map(custom_visual_function) -pc.show() -``` - -### For Developers - -**Adding New Plotting Methods**: -1. Add `backend` and `return_as_pc` parameters -2. Use `self._resolve_backend(backend)` to get backend -3. Create PlotCollection with appropriate backend -4. Use `self._extract_figure_and_axes(pc, backend)` for tuple return -5. Write tests in `test_plot_backends.py` before implementing - -**Testing Checklist**: -- [ ] Test accepts backend parameter -- [ ] Test accepts return_as_pc parameter -- [ ] Test returns tuple by default (backward compat) -- [ ] Test returns PlotCollection when requested -- [ ] Test all three backends (parametrize) -- [ ] Test visual output (has lines/traces/renderers) - -## Dependencies - -### New Dependencies -- **arviz-plots**: Required for PlotCollection API - - Add to `pyproject.toml`: `arviz-plots>=0.7.0` - - Add to `environment.yml`: `- arviz-plots>=0.7.0` - -### Existing Dependencies (no changes) -- matplotlib: Already required -- arviz: Already required -- xarray: Already required -- numpy: Already required - -## References - -- Original research: [thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md](thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md) -- MMMPlotSuite implementation: [pymc_marketing/mmm/plot.py:187-1924](pymc_marketing/mmm/plot.py#L187-1924) -- Existing tests: [tests/mmm/test_plot.py](tests/mmm/test_plot.py) -- ArviZ PlotCollection docs: https://arviz-plots.readthedocs.io/ -- Test patterns reference: This plan's Phase 1 test examples - -## Open Questions - -1. **PlotCollection API Learning Curve**: ✅ ADDRESSED - Use incremental approach, start with matplotlib wrapper if needed -2. **Visual Output Validation**: ✅ ADDRESSED - Test for presence of elements (lines/traces/renderers), not pixel-perfect matching -3. **Performance Impact**: ✅ OUT OF SCOPE - User confirmed not a concern for this migration -4. **Deprecation Timeline**: When should we deprecate tuple return in favor of PlotCollection? - - Recommendation: Keep both indefinitely, default to tuple for backward compat -5. **Test File Consolidation**: When to merge `test_plot.py` into `test_plot_backends.py`? - - Recommendation: After all methods migrated and stable (next version) - -## Next Steps After Phase 5 - -Once `posterior_predictive()` and `contributions_over_time()` are fully implemented and tested: - -1. **Expand to saturation methods**: - - `saturation_scatterplot()` - Similar pattern, adds scatter plots - - `saturation_curves()` - Adds `rc_params` deprecation, `backend_config` parameter - -2. **Implement twinx fallback**: - - `budget_allocation()` - Special case with fallback warning - -3. **Expand to sensitivity methods**: - - `sensitivity_analysis()`, `uplift_curve()`, `marginal_curve()` - Wrappers - -4. **Full test suite validation**: - - Run all MMM tests: `pytest tests/mmm/ -v` - - Check coverage: `pytest tests/mmm/ --cov=pymc_marketing.mmm --cov-report=html` - - Performance baseline: `pytest tests/mmm/ --durations=20` - -5. **Documentation**: - - Update user guide with backend examples - - Add migration guide for users - - Update docstrings with examples - - Create notebook showing multi-backend usage - -## Summary of Key Decisions - -1. ✅ **Backward Compatibility**: Maintained via `return_as_pc=False` default -2. ✅ **Global Configuration**: ArviZ-style `mmm_config` dictionary -3. ✅ **Test Organization**: New file `test_plot_backends.py`, mark old file for future consolidation -4. ✅ **Mock Data**: Use existing patterns from `test_plot.py`, realistic xarray structures -5. ✅ **Test Depth**: Prioritize depth (thorough testing of first 2 methods) over breadth -6. ✅ **Visual Validation**: Test for presence of plot elements, not pixel-perfect matching -7. ✅ **Default Backend**: Keep "matplotlib" as default for full backward compatibility -8. ✅ **Helper Extraction**: Refactor common patterns to methods during cleanup phase -9. ✅ **Incremental Implementation**: OK to start with matplotlib-only, add backends incrementally diff --git a/thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md b/thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md deleted file mode 100644 index bb4e06ede..000000000 --- a/thoughts/shared/research/2025-11-12-mmmplotsuite-backend-migration-comprehensive.md +++ /dev/null @@ -1,554 +0,0 @@ ---- -date: 2025-11-11T21:13:39-05:00 -researcher: Claude -git_commit: e78e3afb259a33f0d2b09d0d6c7e409fe4ddc90d -branch: main -repository: pymc-marketing -topic: "MMMPlotSuite Migration to ArviZ PlotCollection with Backward Compatibility - Comprehensive Research" -tags: [research, codebase, plotting, visualization, arviz, matplotlib, plotly, bokeh, backend-agnostic, mmm, backward-compatibility] -status: complete -last_updated: 2025-11-11 -last_updated_by: Claude ---- - -# Research: MMMPlotSuite Migration to ArviZ PlotCollection with Backward Compatibility - -**Date**: 2025-11-11T21:13:39-05:00 -**Researcher**: Claude -**Git Commit**: e78e3afb259a33f0d2b09d0d6c7e409fe4ddc90d -**Branch**: main -**Repository**: pymc-marketing - -## Research Question - -The user wants to rewrite the MMMPlotSuite class in [plot.py](pymc_marketing/mmm/plot.py) to support additional backends beyond matplotlib. Specifically, they want to rewrite the functions in that class to use ArviZ's PlotCollection API instead of matplotlib directly, making the methods return PlotCollection objects instead of matplotlib Figure and Axes objects. - -### Updated Requirements (Corrected from previous research) - -1. **Global Backend Configuration**: Support the ability to set the backend once, in a "global manner", and then all plots will use that backend. This is on top of the option of setting the backend to individual functions using a backend argument, which will override the global setting. - -2. **Backward Compatibility**: The changes to MMMPlotSuite plotting functions should be backward compatible. The output to the function will be based on the backend argument, and there would also be an argument that would control whether to return a PlotCollection object instead. - -3. **Backend-Specific Code**: Identify all the matplotlib-specific functions that do not have a direct equivalent in other backends and come up with specific code to handle them for all backends. - -4. **RC Params Handling**: For the method `saturation_curves`, it uses `plt.rc_context(rc_params)` so we will need to change it. Use `backend="matplotlib", backend_config=None` as arguments. We are going to keep that `rc_params` parameter for backward compatibility, but emit a warning when using it. - -5. **Twin Axes Fallback**: A function that uses matplotlib `twinx` cannot be currently written using arviz-plots. So if a different backend is chosen it needs to emit a warning and fallback to matplotlib. - -6. **ArviZ-style rcParams**: Use the recommended "ArviZ-style rcParams with fallback" for Global Backend Configuration Implementation. - -7. **Performance**: Performance is not a concern for this migration. - -8. **Testing**: Testing of all functions should be across matplotlib, plotly and bokeh backends. - -9. **Component Plot Methods**: Do not migrate component plot methods outside MMMPlotSuite. If MMMPlotSuite uses a plotting function that is defined in a different file we would need to create a new function instead. - -## Summary - -The current MMMPlotSuite implementation is tightly coupled to matplotlib, with all 10 public plotting methods returning `tuple[Figure, NDArray[Axes]]` or similar matplotlib objects. The class uses matplotlib-specific APIs throughout (`plt.subplots`, `ax.plot`, `ax.fill_between`, `ax.twinx`, etc.). - -ArviZ's PlotCollection API (from arviz-plots) provides a backend-agnostic alternative that supports matplotlib, bokeh, plotly, and none backends. The codebase already uses ArviZ extensively for HDI computation (`az.hdi()`) and some plotting (`az.plot_hdi()`), but does not use PlotCollection anywhere in production code. - -The migration must maintain backward compatibility, support global backend configuration with per-function overrides, handle matplotlib-specific features gracefully (particularly `ax.twinx()` which requires backend-specific implementations or fallback), and be tested across matplotlib, plotly, and bokeh backends. - -## Detailed Findings - -### Current MMMPlotSuite Architecture - -#### Class Overview - -**Location**: [pymc_marketing/mmm/plot.py:187-1924](pymc_marketing/mmm/plot.py#L187) - -The MMMPlotSuite class is a standalone visualization class for MMM models: -- Initialized with `xr.Dataset` or `az.InferenceData` -- 10 public plotting methods (including 1 deprecated) -- Multiple helper methods for subplot creation and data manipulation -- All methods return matplotlib objects - -#### Method Signatures and Return Types - -| Method | Current Return Type | Lines | Usage | -|--------|-------------------|-------|-------| -| `posterior_predictive()` | `tuple[Figure, NDArray[Axes]]` | 375-463 | Plot posterior predictive time series | -| `contributions_over_time()` | `tuple[Figure, NDArray[Axes]]` | 465-588 | Plot contribution time series with HDI | -| `saturation_scatterplot()` | `tuple[Figure, NDArray[Axes]]` | 590-742 | Scatter plots of channel saturation | -| `saturation_curves()` | `tuple[plt.Figure, np.ndarray]` | 744-996 | Overlay scatter data with posterior curves | -| `saturation_curves_scatter()` | `tuple[Figure, NDArray[Axes]]` | 998-1035 | **Deprecated** - use `saturation_scatterplot()` | -| `budget_allocation()` | `tuple[Figure, plt.Axes] \| tuple[Figure, np.ndarray]` | 1037-1212 | Bar chart with dual y-axes | -| `allocated_contribution_by_channel_over_time()` | `tuple[Figure, plt.Axes \| NDArray[Axes]]` | 1279-1481 | Line plots with uncertainty bands | -| `sensitivity_analysis()` | `tuple[Figure, NDArray[Axes]] \| plt.Axes` | 1483-1718 | Plot sensitivity sweep results | -| `uplift_curve()` | `tuple[Figure, NDArray[Axes]] \| plt.Axes` | 1720-1820 | Wrapper around sensitivity_analysis for uplift | -| `marginal_curve()` | `tuple[Figure, NDArray[Axes]] \| plt.Axes` | 1822-1923 | Wrapper around sensitivity_analysis for marginal effects | - -#### Matplotlib-Specific APIs Used - -**Core matplotlib functions used across all methods:** -- `plt.subplots()` - Creating figure and axes grid -- `ax.plot()` - Line plots for medians -- `ax.fill_between()` - HDI/uncertainty bands -- `ax.scatter()` - Scatter plots for data points -- `ax.bar()` - Bar charts (budget allocation) -- `ax.twinx()` - Dual y-axes (budget allocation) - **CRITICAL FEATURE** -- `ax.set_title()`, `ax.set_xlabel()`, `ax.set_ylabel()` - Labeling -- `ax.legend()` - Legends -- `ax.set_visible()` - Hide unused axes -- `fig.tight_layout()` - Layout adjustment -- `fig.suptitle()` - Figure titles -- `plt.rc_context()` - Temporary matplotlib settings - -### Matplotlib-Specific Features Analysis - -#### Critical Feature: ax.twinx() - Dual Y-Axes - -**Location**: [pymc_marketing/mmm/plot.py:1249](pymc_marketing/mmm/plot.py#L1249) - -**Method**: `_plot_budget_allocation_bars()` → `budget_allocation()` - -**What it does**: Creates a secondary y-axis with independent scale, used to compare allocated spend vs. channel contribution on the same plot with different y-scales. - -**Implementation details**: -```python -# Line 1239-1246: Primary bars on primary axis -bars1 = ax.bar(index, allocated_spend, bar_width, color="C0", alpha=opacity, label="Allocated Spend") - -# Line 1249: Create twin axis -ax2 = ax.twinx() - -# Line 1252-1259: Secondary bars on secondary axis -bars2 = ax2.bar([i + bar_width for i in index], channel_contribution, bar_width, - color="C1", alpha=opacity, label="Channel Contribution") -``` - -**Backends without native PlotCollection support**: -- **Bokeh**: No direct twin axes support in PlotCollection -- **Plotly**: Has secondary y-axes but requires different approach - -**User Requirement**: "A function that uses matplotlib `twinx` cannot be currently written using arviz. So if a different backend is chosen it needs to emit a warning and fallback to matplotlib." - -**Recommended Strategy**: Detect when non-matplotlib backend is requested, emit warning, and force fallback to matplotlib backend. - -#### Medium Impact Feature: plt.rc_context() - -**Location**: [pymc_marketing/mmm/plot.py:878-880](pymc_marketing/mmm/plot.py#L878-880) - -**Method**: `saturation_curves()` - -**Code**: -```python -rc_params = rc_params or {} -with plt.rc_context(rc_params): - fig, axes = plt.subplots(nrows=nrows, ncols=ncols, **subkw) -``` - -**User Requirement**: "For the method `saturation_curves`, it uses `plt.rc_context(rc_params)` so we will need to change it. I want to use `backend="matplotlib", backend_config=None` as arguments. We are going to keep that `rc_params` parameter for backward compatibility, but emit a warning when using it." - -**Recommended Strategy**: -- Add `backend_config` parameter -- Keep `rc_params` parameter with deprecation warning -- Only apply config when backend is matplotlib -- Warn if `backend_config` provided for non-matplotlib backends - -### ArviZ Usage in the Codebase - -#### Current ArviZ Integration - -**ArviZ functions currently used:** -1. `az.hdi()` - HDI computation (used extensively) -2. `az.plot_hdi()` - HDI plotting (used in 2 methods) -3. `az.summary()` - Summary statistics -4. `az.InferenceData` - Primary data container (88+ references) - -**Key Finding: No PlotCollection usage** -- Zero instances of `PlotCollection` in production code -- No imports from `arviz_plots` in production code -- All plotting is matplotlib-specific - -#### External Plotting Functions Used by MMMPlotSuite - -**From [pymc_marketing/plot.py](pymc_marketing/plot.py):** - -Imported at line 799 in `saturation_curves()` method: -```python -from pymc_marketing.plot import plot_hdi, plot_samples -``` - -These functions use matplotlib directly and would need PlotCollection versions created per requirement #9. - -### Recommendations - -#### 1. Global Backend Configuration Implementation - -**Pattern**: ArviZ-style rcParams (per user requirement #6) - -**Implementation**: - -```python -# pymc_marketing/mmm/config.py (new file) -class MMMConfig(dict): - """Configuration dictionary for MMM plotting.""" - - _defaults = { - "plot.backend": "matplotlib", - "plot.show_warnings": True, - } - - def __init__(self): - super().__init__(self._defaults) - - def reset(self): - """Reset to defaults.""" - self.clear() - self.update(self._defaults) - -# Global config instance -mmm_config = MMMConfig() -``` - -**User API**: -```python -import pymc_marketing as pmm - -# Set global backend -pmm.mmm.mmm_config["plot.backend"] = "plotly" - -# All subsequent plots use plotly -model.plot.posterior_predictive() - -# Override for specific plot -model.plot.saturation_curves(backend="matplotlib") - -# Reset to defaults -pmm.mmm.mmm_config.reset() -``` - -#### 2. Backward Compatibility Strategy - -**Method Signature Pattern**: -```python -def posterior_predictive( - self, - var: list[str] | None = None, - idata: xr.Dataset | None = None, - hdi_prob: float = 0.85, - backend: str | None = None, - return_as_pc: bool = False, -) -> tuple[Figure, NDArray[Axes]] | PlotCollection: - """ - Parameters - ---------- - backend : str, optional - Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". - If None, uses global config (default: "matplotlib"). - return_as_pc : bool, default False - If True, returns PlotCollection object. If False, returns a tuple of - (figure, axes) where figure is the backend-specific figure object and - axes is an array of axes for matplotlib or None for other backends. - - Returns - ------- - PlotCollection or tuple - If return_as_pc=True, returns PlotCollection object. - If return_as_pc=False, returns (figure, axes) where: - - figure: backend-specific figure object (plt.Figure, plotly.graph_objs.Figure, etc.) - - axes: np.ndarray of matplotlib Axes if backend="matplotlib", else None - """ - # Resolve backend - backend = backend or mmm_config["plot.backend"] - - # Create PlotCollection - pc = PlotCollection.grid(data, backend=backend, ...) - pc.map(plotting_function, ...) - - # Return based on return_as_pc flag - if return_as_pc: - return pc - else: - # Extract figure from PlotCollection - fig = pc.viz.figure.data.item() - - # Only matplotlib has axes - if backend == "matplotlib": - axes = fig.get_axes() - else: - axes = None - - return fig, axes -``` - -#### 3. Twin Axes Fallback Strategy - -**Implementation for `budget_allocation()`**: - -```python -def budget_allocation( - self, - samples: xr.Dataset, - backend: str | None = None, - return_as_pc: bool = False, - **kwargs -) -> tuple[Figure, plt.Axes] | tuple[Figure, np.ndarray] | PlotCollection: - """ - Notes - ----- - This method uses dual y-axes (matplotlib's twinx), which is not supported - by PlotCollection. If a non-matplotlib backend is requested, a warning - will be issued and the method will fallback to matplotlib. - """ - # Resolve backend - backend = backend or mmm_config["plot.backend"] - - # Check for twinx compatibility (per user requirement #5) - if backend != "matplotlib": - import warnings - warnings.warn( - f"budget_allocation() uses dual y-axes (ax.twinx()) which is not " - f"supported by PlotCollection with backend='{backend}'. " - f"Falling back to matplotlib.", - UserWarning - ) - backend = "matplotlib" - - # Proceed with implementation - # ... -``` - -#### 4. RC Params Handling for saturation_curves() - -**Implementation** (per user requirement #4): - -```python -def saturation_curves( - self, - curve: xr.DataArray, - rc_params: dict | None = None, # DEPRECATED - backend: str | None = None, - backend_config: dict | None = None, - return_as_pc: bool = False, - **kwargs, -) -> tuple[plt.Figure, np.ndarray] | PlotCollection: - """ - Parameters - ---------- - rc_params : dict, optional - **DEPRECATED**: Use `backend_config` instead. - Temporary `matplotlib.rcParams` for this plot (matplotlib backend only). - A DeprecationWarning will be issued when using this parameter. - backend : str, optional - Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". - If None, uses global config (default: "matplotlib"). - backend_config : dict, optional - Backend-specific configuration dictionary: - - matplotlib: rcParams dict (same as deprecated rc_params) - - plotly: layout configuration dict - - bokeh: theme configuration dict - """ - # Resolve backend - backend = backend or mmm_config["plot.backend"] - - # Handle deprecated rc_params (per user requirement #4) - if rc_params is not None: - import warnings - warnings.warn( - "The 'rc_params' parameter is deprecated and will be removed in a " - "future version. Use 'backend_config' instead.", - DeprecationWarning, - stacklevel=2 - ) - if backend_config is None: - backend_config = rc_params - - # Apply backend-specific config if matplotlib - if backend == "matplotlib" and backend_config: - with plt.rc_context(backend_config): - # ... create PlotCollection ... - else: - if backend_config and backend != "matplotlib": - import warnings - warnings.warn( - f"backend_config only supported for matplotlib backend, " - f"ignoring for backend='{backend}'", - UserWarning - ) - # ... create PlotCollection without rc_context ... -``` - -#### 5. Helper Function Migration Strategy - -**Problem**: Helper functions in [pymc_marketing/plot.py](pymc_marketing/plot.py) use matplotlib directly. - -**Approach**: Create new PlotCollection-compatible versions (per user requirement #9) - -```python -# New backend-agnostic versions -def plot_hdi_pc(data, *, backend=None, plot_collection=None, **pc_kwargs): - """Plot HDI using PlotCollection (backend-agnostic).""" - backend = backend or mmm_config["plot.backend"] - - if plot_collection is None: - pc = PlotCollection.grid(data, backend=backend, **pc_kwargs) - else: - pc = plot_collection - - pc.map(_plot_hdi_visual, data=data) - return pc - -# Keep existing matplotlib-specific version for backward compatibility -def plot_hdi(da, ax=None, **kwargs): - """Plot HDI using matplotlib (legacy).""" - # ... existing matplotlib implementation -``` - -#### 6. Testing Strategy - -**Requirement**: Test across matplotlib, plotly and bokeh backends (user requirement #8). - -**Implementation**: -```python -# tests/mmm/test_plotting_backends.py -import pytest - -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -class TestMMMPlotSuiteBackends: - """Test all MMMPlotSuite methods across backends.""" - - def test_posterior_predictive(self, mmm_model, backend): - pc = mmm_model.plot.posterior_predictive( - backend=backend, - return_as_pc=True - ) - assert isinstance(pc, PlotCollection) - assert pc.backend == backend - - def test_backward_compatibility_matplotlib(self, mmm_model): - """Test backward compatibility with matplotlib.""" - fig, axes = mmm_model.plot.posterior_predictive( - backend="matplotlib", - return_as_pc=False - ) - assert isinstance(fig, plt.Figure) - assert isinstance(axes, np.ndarray) - - def test_twinx_fallback(self, mmm_model): - """Test that budget_allocation falls back to matplotlib for non-matplotlib backends.""" - with pytest.warns(UserWarning, match="Falling back to matplotlib"): - result = mmm_model.plot.budget_allocation( - samples=..., - backend="plotly", - return_as_pc=False - ) - # Should return matplotlib objects despite requesting plotly - assert isinstance(result[0], plt.Figure) - - def test_rc_params_deprecation(self, mmm_model): - """Test that rc_params parameter issues deprecation warning.""" - with pytest.warns(DeprecationWarning, match="rc_params.*deprecated"): - mmm_model.plot.saturation_curves( - curve=..., - rc_params={"xtick.labelsize": 12}, - backend="matplotlib" - ) - - def test_global_backend_config(self, mmm_model): - """Test global backend configuration.""" - import pymc_marketing as pmm - original_backend = pmm.mmm.mmm_config["plot.backend"] - try: - pmm.mmm.mmm_config["plot.backend"] = "plotly" - pc = mmm_model.plot.posterior_predictive(return_as_pc=True) - assert pc.backend == "plotly" - finally: - pmm.mmm.mmm_config["plot.backend"] = original_backend -``` - -## Migration Implementation Checklist - -### Phase 1: Infrastructure Setup - -- [ ] Add `arviz-plots` as a required dependency in `pyproject.toml` -- [ ] Create `pymc_marketing/mmm/config.py` with `MMMConfig` class and `mmm_config` instance -- [ ] Export `mmm_config` from `pymc_marketing/mmm/__init__.py` -- [ ] Create backend-agnostic plotting function templates - -### Phase 2: Helper Functions - -- [ ] Create `plot_hdi_pc()` PlotCollection version in `pymc_marketing/plot.py` -- [ ] Create `plot_samples_pc()` PlotCollection version in `pymc_marketing/plot.py` -- [ ] Implement backend detection logic in visual functions -- [ ] Keep existing `plot_hdi()` and `plot_samples()` for backward compatibility - -### Phase 3: MMMPlotSuite Methods (Priority Order) - -**High Priority (Simple methods)**: -1. [ ] `posterior_predictive()` - Add backend/return_as_pc parameters -2. [ ] `contributions_over_time()` - Add backend/return_as_pc parameters -3. [ ] `saturation_scatterplot()` - Add backend/return_as_pc parameters -4. [ ] `sensitivity_analysis()` - Add backend/return_as_pc parameters -5. [ ] `uplift_curve()` - Inherits from sensitivity_analysis -6. [ ] `marginal_curve()` - Inherits from sensitivity_analysis - -**Medium Priority (Uses external functions)**: -7. [ ] `saturation_curves()` - Add backend/backend_config/return_as_pc, deprecate rc_params -8. [ ] `allocated_contribution_by_channel_over_time()` - Add backend/return_as_pc - -**Low Priority (Requires twinx fallback)**: -9. [ ] `budget_allocation()` - Add backend/return_as_pc with twinx fallback logic - -### Phase 4: Testing - -- [ ] Create `tests/mmm/test_plotting_backends.py` -- [ ] Parametrized tests across matplotlib/plotly/bokeh -- [ ] Backward compatibility tests -- [ ] Global config tests -- [ ] Fallback behavior tests (twinx) -- [ ] Deprecation warning tests (rc_params) -- [ ] Return type validation tests - -### Phase 5: Documentation - -- [ ] Update all docstrings with new parameters -- [ ] Add migration guide for users -- [ ] Add examples showing new API -- [ ] Document backend limitations -- [ ] Update notebooks to show multi-backend usage - -## Code References - -### MMMPlotSuite Implementation -- Class definition: [pymc_marketing/mmm/plot.py:187](pymc_marketing/mmm/plot.py#L187) -- Twin axes usage: [pymc_marketing/mmm/plot.py:1249](pymc_marketing/mmm/plot.py#L1249) -- RC context usage: [pymc_marketing/mmm/plot.py:878-880](pymc_marketing/mmm/plot.py#L878-880) -- Helper methods: [pymc_marketing/mmm/plot.py:200-370](pymc_marketing/mmm/plot.py#L200-370) -- Main plotting methods: [pymc_marketing/mmm/plot.py:375-1923](pymc_marketing/mmm/plot.py#L375-1923) - -### Helper Functions to Migrate -- plot_hdi: [pymc_marketing/plot.py:434](pymc_marketing/plot.py#L434) -- plot_samples: [pymc_marketing/plot.py:503](pymc_marketing/plot.py#L503) - -### Dependencies -- Package configuration: `/Users/imrisofer/projects/pymc-marketing/pyproject.toml` -- Conda environment: `/Users/imrisofer/projects/pymc-marketing/environment.yml` - -## Open Questions - -1. **PlotCollection Figure/Axes Extraction**: ✅ RESOLVED - Use `pc.viz.figure.data.item()` to extract backend-specific figure object, then `fig.get_axes()` for matplotlib axes (returns None for non-matplotlib backends). - -2. **Backend-Specific Styling**: Should we implement backend-specific styling translation for common use cases, or just warn users that `backend_config` only works for matplotlib? - -3. **Helper Function Strategy**: Should we deprecate old `plot_hdi()` and `plot_samples()` or keep them indefinitely? - -4. **Component Plotting**: While we're not migrating component plot methods per requirement #9, should we at least add documentation noting that they remain matplotlib-only? - -5. **Version Number**: Should this migration be part of a major version bump (e.g., 0.x → 1.0 or 1.x → 2.0)? - -## Summary of Key Decisions - -1. **Backward Compatibility**: Maintained via `return_as_pc=False` default parameter - - When `return_as_pc=False`, functions return `(figure, axes)` tuple - - `figure` is extracted via `pc.viz.figure.data.item()` - - `axes` is extracted via `fig.get_axes()` for matplotlib, `None` for other backends -2. **Global Configuration**: ArviZ-style `mmm_config` dictionary -3. **Twin Axes**: Fallback to matplotlib with warning for `budget_allocation()` -4. **RC Params**: Deprecate `rc_params`, add `backend_config` parameter -5. **Testing**: Parametrized tests across matplotlib, plotly, bokeh -6. **Helper Functions**: Create new PlotCollection versions, keep existing for compatibility -7. **Default Backend**: Keep "matplotlib" as default for full backward compatibility From f5574d3704d6b16524fe35fdf8e34ecaf449b6c4 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Wed, 19 Nov 2025 14:46:03 -0500 Subject: [PATCH 07/29] doing research --- pymc_marketing/mmm/plot.py | 2 +- ...mmplotsuite-migration-complete-analysis.md | 1569 +++++++++++++++++ 2 files changed, 1570 insertions(+), 1 deletion(-) create mode 100644 thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 95a229604..8d7ac108d 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -741,7 +741,7 @@ def saturation_curves_scatter( Plot scatter plots of channel contributions vs. channel data. .. deprecated:: 0.1.0 - Will be removed in version 0.2.0. Use :meth:`saturation_scatterplot` instead. + Will be removed in version 0.20.0. Use :meth:`saturation_scatterplot` instead. Parameters ---------- diff --git a/thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md b/thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md new file mode 100644 index 000000000..3451ac6cd --- /dev/null +++ b/thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md @@ -0,0 +1,1569 @@ +--- +date: 2025-11-19T14:04:21+0000 +researcher: Claude +git_commit: d6331a03727aa9c78ad16690aca25ce9cb869129 +branch: feature/mmmplotsuite-arviz +repository: pymc-labs/pymc-marketing +topic: "MMMPlotSuite Migration - Complete Implementation Analysis and Requirements" +tags: [research, codebase, mmm, plotting, migration, backward-compatibility, testing, arviz-plots] +status: complete +last_updated: 2025-11-19 +last_updated_by: Claude +--- + +# Research: MMMPlotSuite Migration - Complete Implementation Analysis and Requirements + +**Date**: 2025-11-19T14:04:21+0000 +**Researcher**: Claude +**Git Commit**: d6331a03727aa9c78ad16690aca25ce9cb869129 +**Branch**: feature/mmmplotsuite-arviz +**Repository**: pymc-labs/pymc-marketing + +## Research Question + +The user is migrating MMMPlotSuite from matplotlib-based plotting to arviz_plots with multi-backend support. The legacy implementation is currently in `mmm/old_plot.py` and should be renamed to `mmm/legacy_plot.py`. To complete this migration, they need to: + +1. Rename `old_plot.py` to `legacy_plot.py` and `OldMMMPlotSuite` to `LegacyMMMPlotSuite` +2. Support global backend configuration with per-function override capability +3. Implement backward compatibility with a flag to control legacy vs new behavior (default: legacy) +4. Add deprecation warning pointing to v0.20.0 removal +5. Review the new code implementation for quality issues +6. Create comprehensive tests for matplotlib, bokeh, and plotly backends + +## Summary + +Based on comprehensive codebase analysis, the migration is **75% complete** with critical gaps identified: + +**✅ Already Implemented:** +- Backend configuration system with `mmm_config["plot.backend"]` supporting matplotlib/plotly/bokeh +- Complete new arviz_plots-based implementation returning `PlotCollection` objects +- Legacy matplotlib-based implementation preserved in `old_plot.py` (to be renamed `legacy_plot.py`) +- Per-method backend override via `backend` parameter on all plot methods + +**❌ Missing Critical Components:** +- Rename `old_plot.py` to `legacy_plot.py` and `OldMMMPlotSuite` to `LegacyMMMPlotSuite` +- Backward compatibility flag (`use_v2`) to toggle between legacy/new suite +- Deprecation warning system for users +- Comprehensive backend testing for the new suite +- Compatibility test suite +- Documentation of breaking changes + +**⚠️ Code Review Issues Found:** +- Return type documentation incomplete +- Breaking parameter type changes across all methods (intentional, no backward compatibility needed) +- Lost customization parameters (colors, subplot_kwargs, rc_params) - handled by arviz_plots + +## Detailed Findings + +### 1. Current Architecture + +#### 1.1 Class Definitions and Locations + +**New Implementation:** +- **File**: [pymc_marketing/mmm/plot.py:187-1272](pymc_marketing/mmm/plot.py#L187-L1272) +- **Class**: `MMMPlotSuite` +- **Export**: `__all__ = ["MMMPlotSuite"]` at line 181 +- **Technology**: arviz_plots library +- **Return Type**: `PlotCollection` (unified across all backends) + +**Legacy Implementation:** +- **File**: [pymc_marketing/mmm/old_plot.py:191-1936](pymc_marketing/mmm/old_plot.py#L191-L1936) (to be renamed to `legacy_plot.py`) +- **Class**: `OldMMMPlotSuite` (to be renamed to `LegacyMMMPlotSuite`) +- **Export**: Not exported in any `__all__` +- **Technology**: matplotlib only +- **Return Type**: `tuple[Figure, NDArray[Axes]]` or `tuple[Figure, plt.Axes]` + +**Integration Point:** +- **File**: [pymc_marketing/mmm/multidimensional.py:602-607](pymc_marketing/mmm/multidimensional.py#L602-L607) +- **Property**: `MMM.plot` returns `MMMPlotSuite(idata=self.idata)` +- **Issue**: Hardcoded to only return new suite, no version control + +#### 1.2 Method Comparison Matrix + +| Method | New Suite | Legacy Suite | API Compatible | Breaking Changes | +|--------|-----------|--------------|----------------|------------------| +| `__init__` | ✅ | ✅ | ✅ | None | +| `posterior_predictive()` | ✅ | ✅ | ❌ | `var: str` vs `list[str]`, return type | +| `contributions_over_time()` | ✅ | ✅ | ⚠️ | Return type only | +| `saturation_scatterplot()` | ✅ | ✅ | ⚠️ | Lost `**kwargs`, return type | +| `saturation_curves()` | ✅ | ✅ | ❌ | Lost colors, subplot_kwargs, rc_params | +| `saturation_curves_scatter()` | ✅ | ✅ | ✅ | DEPRECATED in both (delegates) | +| `budget_allocation()` | ❌ | ✅ | ❌ | **REMOVED** - no replacement | +| `budget_allocation_roas()` | ✅ | ❌ | N/A | New method, different purpose | +| `allocated_contribution_by_channel_over_time()` | ✅ | ✅ | ❌ | Lost scale_factor, quantiles, figsize, ax | +| `sensitivity_analysis()` | ✅ | ✅ | ❌ | Lost ax, subplot_kwargs, plot_kwargs | +| `uplift_curve()` | ✅ | ✅ | ❌ | Lost ax, subplot_kwargs, plot_kwargs | +| `marginal_curve()` | ✅ | ✅ | ❌ | Lost ax, subplot_kwargs, plot_kwargs | + +**Helper Methods:** +- New Suite: `_get_additional_dim_combinations()`, `_get_posterior_predictive_data()`, `_validate_dims()`, `_dim_list_handler()`, `_resolve_backend()`, `_sensitivity_analysis_plot()` +- Legacy Suite: `_init_subplots()`, `_build_subplot_title()`, `_reduce_and_stack()`, `_add_median_and_hdi()`, `_plot_budget_allocation_bars()` + shared helpers + +### 2. Backend Configuration System ✅ **COMPLETE** + +#### 2.1 Implementation + +**File**: [pymc_marketing/mmm/config.py:21-66](pymc_marketing/mmm/config.py#L21-L66) + +```python +VALID_BACKENDS = {"matplotlib", "plotly", "bokeh"} + +class MMMConfig(dict): + """Configuration dictionary for MMM plotting settings.""" + + _defaults = { + "plot.backend": "matplotlib", + "plot.show_warnings": True, + } + + def __setitem__(self, key, value): + """Set config value with validation for backend.""" + if key == "plot.backend": + if value not in VALID_BACKENDS: + warnings.warn( + f"Invalid backend '{value}'. Valid backends are: {VALID_BACKENDS}. " + f"Setting anyway, but plotting may fail.", + UserWarning, + stacklevel=2, + ) + super().__setitem__(key, value) + +# Global config instance +mmm_config = MMMConfig() +``` + +#### 2.2 Backend Resolution + +**File**: [pymc_marketing/mmm/plot.py:288-292](pymc_marketing/mmm/plot.py#L288-L292) + +```python +def _resolve_backend(self, backend: str | None) -> str: + """Resolve backend parameter to actual backend string.""" + from pymc_marketing.mmm.config import mmm_config + return backend or mmm_config["plot.backend"] +``` + +#### 2.3 Usage Pattern + +```python +from pymc_marketing.mmm import mmm_config + +# Set global backend +mmm_config["plot.backend"] = "plotly" + +# All plots use plotly +mmm.plot.posterior_predictive() + +# Override for specific plot +mmm.plot.posterior_predictive(backend="matplotlib") +``` + +**Status**: ✅ No action needed - fully functional + +### 3. Backward Compatibility ❌ **MISSING - CRITICAL** + +#### 3.1 Current Gap + +The `.plot` property currently only returns the new suite: + +```python +# Current implementation in multidimensional.py:602-607 +@property +def plot(self) -> MMMPlotSuite: + """Use the MMMPlotSuite to plot the results.""" + self._validate_model_was_built() + self._validate_idata_exists() + return MMMPlotSuite(idata=self.idata) +``` + +#### 3.2 Required Implementation + +**Step 1: Add flag to config.py** + +```python +# File: pymc_marketing/mmm/config.py +_defaults = { + "plot.backend": "matplotlib", + "plot.show_warnings": True, + "plot.use_v2": False, # ← ADD THIS LINE +} +``` + +**Step 2: Implement version switching in multidimensional.py** + +```python +# File: pymc_marketing/mmm/multidimensional.py:602-607 +@property +def plot(self) -> MMMPlotSuite | LegacyMMMPlotSuite: + """Use the MMMPlotSuite to plot the results.""" + from pymc_marketing.mmm.config import mmm_config + from pymc_marketing.mmm.plot import MMMPlotSuite + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + import warnings + + self._validate_model_was_built() + self._validate_idata_exists() + + # Check version flag + if mmm_config.get("plot.use_v2", False): + return MMMPlotSuite(idata=self.idata) + else: + # Show deprecation warning for legacy suite + if mmm_config.get("plot.show_warnings", True): + warnings.warn( + "The current MMMPlotSuite will be deprecated in v0.20.0. " + "The new version uses arviz_plots and supports multiple backends (matplotlib, plotly, bokeh). " + "To use the new version: " + ">>> from pymc_marketing.mmm.config import mmm_config\n" + ">>> mmm_config['plot.use_v2'] = True\n" + "To suppress this warning: mmm_config['plot.show_warnings'] = False\n" + "See migration guide: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html", + FutureWarning, + stacklevel=2, + ) + return LegacyMMMPlotSuite(idata=self.idata) +``` + +#### 3.3 Design Rationale + +**Why `FutureWarning` instead of `DeprecationWarning`?** +- `DeprecationWarning` is for library developers (hidden by default in Python) +- `FutureWarning` is for end users (always shown) +- Our users are data scientists/analysts, not library developers +- Pattern found in [pymc_marketing/mlflow.py:180-185](pymc_marketing/mlflow.py#L180-L185) + +**Why config flag instead of function parameter?** +- Consistent with existing backend configuration pattern +- Allows global setting affecting all plot calls +- Can be overridden per-session +- Pattern found throughout codebase (e.g., `plot.backend`) + +**Why default to `False` (legacy suite)?** +- Non-breaking change in initial release +- Gives users time to migrate (1-2 releases) +- Prevents surprise breakage for existing code + +### 4. Deprecation Patterns Research + +Found **10 distinct patterns** used across the codebase: + +#### Pattern 1: Parameter Name Deprecation with Helper +**Location**: [pymc_marketing/model_builder.py:60-77](pymc_marketing/model_builder.py#L60-L77) +**Test**: [tests/test_model_builder.py:530-554](tests/test_model_builder.py#L530-L554) + +```python +def _handle_deprecate_pred_argument(value, name: str, kwargs: dict): + name_pred = f"{name}_pred" + + if name_pred in kwargs and value is not None: + raise ValueError(f"Both {name} and {name_pred} cannot be provided.") + + if name_pred in kwargs: + warnings.warn( + f"{name_pred} is deprecated, use {name} instead", + DeprecationWarning, + stacklevel=2, + ) + return kwargs.pop(name_pred) + + return value +``` + +#### Pattern 2: Method Deprecation with Delegation +**Location**: [pymc_marketing/mmm/plot.py:737-771](pymc_marketing/mmm/plot.py#L737-L771) +**Test**: [tests/mmm/test_plot.py:722-731](tests/mmm/test_plot.py#L722-L731) + +```python +def saturation_curves_scatter(self, original_scale: bool = False, **kwargs) -> PlotCollection: + """ + .. deprecated:: 0.1.0 + Will be removed in version 0.20.0. Use :meth:`saturation_scatterplot` instead. + """ + import warnings + warnings.warn( + "saturation_curves_scatter is deprecated and will be removed in version 0.2.0. " + "Use saturation_scatterplot instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.saturation_scatterplot(original_scale=original_scale, **kwargs) +``` + +#### Pattern 3: Config Key Renaming +**Location**: [pymc_marketing/clv/models/basic.py:49-59](pymc_marketing/clv/models/basic.py#L49-L59) + +```python +deprecated_keys = [key for key in model_config if key.endswith("_prior")] +for key in deprecated_keys: + new_key = key.replace("_prior", "") + warnings.warn( + f"The key '{key}' in model_config is deprecated. Use '{new_key}' instead.", + DeprecationWarning, + stacklevel=2, + ) + model_config[new_key] = model_config.pop(key) +``` + +#### Pattern 4: Module-Level Deprecation +**Location**: [pymc_marketing/deserialize.py:14-40](pymc_marketing/deserialize.py#L14-L40) + +```python +warnings.warn( + "The pymc_marketing.deserialize module is deprecated. " + "Please use pymc_extras.deserialize instead.", + DeprecationWarning, + stacklevel=2, +) +``` + +**Key Testing Pattern**: All deprecation warnings tested with `pytest.warns()`: + +```python +def test_deprecation(): + with pytest.warns(DeprecationWarning, match=r"is deprecated"): + result = deprecated_function() + + # Verify functionality still works + assert isinstance(result, ExpectedType) +``` + +### 5. Code Review: Issues Found in New Implementation + +#### Issue 1: Return Type Documentation ⚠️ **MINOR** + +**Problem**: Method docstrings don't clearly state `PlotCollection` return type vs old `(Figure, Axes)` tuple. + +**Location**: All methods in [plot.py:298-1272](pymc_marketing/mmm/plot.py#L298-L1272) + +**Example** - `posterior_predictive()` docstring: +```python +def posterior_predictive(...) -> PlotCollection: + """ + Plot posterior predictive distributions over time. + + Returns + ------- + PlotCollection # ← States type but doesn't explain what it is + """ +``` + +**Fix**: Add explanatory text: +```python + Returns + ------- + PlotCollection + arviz_plots PlotCollection object containing the plot. + Use .show() to display or .save("filename") to save. + Unlike the old implementation which returned (Figure, Axes), + this provides a unified interface across matplotlib, plotly, and bokeh backends. +``` + +#### Issue 2: Breaking Parameter Type Changes ✅ **INTENTIONAL - NO ACTION NEEDED** + +**Status**: Many parameters have changed across all methods. Since this is a comprehensive migration to a new architecture (arviz_plots), these breaking changes are expected and documented. + +**Examples of parameter changes**: + +```python +# LEGACY (old_plot.py:387 - to be renamed to legacy_plot.py) +def posterior_predictive( + self, + var: list[str] | None = None, # ← Accepts list + ... +) -> tuple[Figure, NDArray[Axes]]: + +# NEW (plot.py:300) +def posterior_predictive( + self, + var: str | None = None, # ← Only accepts string + ... +) -> PlotCollection: +``` + +**Rationale for no backward compatibility**: +- The entire API is changing (return types, parameters, behavior) +- Users switch to new suite explicitly via `mmm_config["plot.use_v2"] = True` +- Legacy suite remains available for those who need legacy parameter behavior +- Attempting to handle all parameter changes would add significant complexity for minimal benefit +- Migration guide will document all parameter changes with examples + +**Action**: Document parameter changes in migration guide, let users adapt code when they opt into v2. + +#### Issue 3: Missing Method ⚠️ **MAJOR** + +**Problem**: `budget_allocation()` completely removed with no replacement. + +**Legacy Method**: [old_plot.py:1049-1224](pymc_marketing/mmm/old_plot.py#L1049-L1224) (to be renamed to legacy_plot.py) +- Creates bar chart comparing allocated spend vs channel contributions +- Dual y-axis visualization + +**New Method**: `budget_allocation_roas()` at [plot.py:773-874](pymc_marketing/mmm/plot.py#L773-L874) +- Completely different purpose (ROI distributions) +- Different parameters and output + +**Impact**: Code using `mmm.plot.budget_allocation()` will fail with `AttributeError`. + +**Recommendation**: Add stub method that raises helpful error: + +```python +def budget_allocation(self, *args, **kwargs): + """ + .. deprecated:: 0.18.0 + Removed in version 2.0. See budget_allocation_roas() for ROI distributions. + + Raises + ------ + NotImplementedError + This method was removed in MMMPlotSuite v2. + For ROI distributions, use budget_allocation_roas(). + To use the old budget_allocation(), set mmm_config['plot.use_v2'] = False. + """ + raise NotImplementedError( + "budget_allocation() was removed in MMMPlotSuite v2. " + "The new version uses arviz_plots which doesn't support this chart type. " + "Options:\n" + " 1. For ROI distributions: use budget_allocation_roas()\n" + " 2. To use old method: set mmm_config['plot.use_v2'] = False\n" + " 3. Implement custom bar chart using samples data" + ) +``` + +#### Issue 4: Lost Customization Parameters ⚠️ **MODERATE** + +**Problem**: Many customization parameters removed in new implementation. + +**Examples**: + +`saturation_curves()`: +```python +# LEGACY - rich customization +def saturation_curves( + self, + curve, + colors: Iterable[str] | None = None, + subplot_kwargs: dict | None = None, + rc_params: dict | None = None, + **plot_kwargs +) + +# NEW - simplified +def saturation_curves( + self, + curve, + backend: str | None = None, + # Lost: colors, subplot_kwargs, rc_params, plot_kwargs +) +``` + +`sensitivity_analysis()`: +```python +# LEGACY +def sensitivity_analysis( + self, + ax: plt.Axes | None = None, + subplot_kwargs: dict[str, Any] | None = None, + plot_kwargs: dict[str, Any] | None = None, + ylabel: str = "Effect", + xlabel: str = "Sweep", + title: str | None = None, + ... +) + +# NEW +def sensitivity_analysis( + self, + backend: str | None = None, + # Lost: ax, subplot_kwargs, plot_kwargs, ylabel, xlabel, title +) +``` + +**Rationale**: arviz_plots handles layout automatically, reducing need for manual control. + +**Mitigation**: Document how to customize `PlotCollection` objects after creation: +```python +pc = mmm.plot.saturation_curves(curve) + +# For matplotlib backend +if pc.backend == "matplotlib": + for ax in pc.axes.flat: + ax.set_title("Custom Title") + ax.set_xlabel("Custom X Label") +``` + +#### Issue 5: Backend Parameter Coverage ✅ **GOOD** + +**Status**: All public methods have `backend` parameter: +- `posterior_predictive()` ✅ +- `contributions_over_time()` ✅ +- `saturation_scatterplot()` ✅ +- `saturation_curves()` ✅ +- `budget_allocation_roas()` ✅ +- `allocated_contribution_by_channel_over_time()` ✅ +- `sensitivity_analysis()` ✅ +- `uplift_curve()` ✅ +- `marginal_curve()` ✅ + +**Pattern**: Consistent across all methods, properly resolves via `_resolve_backend()`. + +### 6. Testing Infrastructure ⚠️ **MAJOR GAPS** + +#### 6.1 Current Test Coverage + +**Test Files Found:** +1. [tests/mmm/test_plot.py](tests/mmm/test_plot.py) - 800+ lines + - Contains ~28 test functions + - Good fixture patterns + - **Tests for LegacyMMMPlotSuite only** (currently using `old_plot.py`) + - **NEW suite (plot.py) has NO test coverage** + - **Needs new tests for the new MMMPlotSuite with all backends** + +2. [tests/mmm/test_plot_backends.py](tests/mmm/test_plot_backends.py) - 255 lines + - **EXPERIMENTAL FILE - SHOULD BE REMOVED** + - Contains ~14 test functions + - Only tests `posterior_predictive()` with multiple backends + - Functionality should be merged into test_plot.py with parametrization + +3. [tests/mmm/test_plotting.py](tests/mmm/test_plotting.py) - Legacy tests + - Tests for old `BaseMMM` and `MMM` plotting + - Not for MMMPlotSuite + +**Test Coverage Analysis:** + +| Method | Legacy Suite Tests | New Suite Tests | All Backends | Compatibility Tests | +|--------|-------------------|----------------|--------------|---------------------| +| `posterior_predictive()` | ✅ (matplotlib only) | ⚠️ (test_plot_backends.py only) | ❌ | ❌ | +| `contributions_over_time()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | +| `saturation_scatterplot()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | +| `saturation_curves()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | +| `budget_allocation()` | ✅ (matplotlib only) | N/A (removed) | ❌ | ❌ | +| `budget_allocation_roas()` | N/A (doesn't exist) | ❌ | ❌ | ❌ | +| `allocated_contribution_by_channel_over_time()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | +| `sensitivity_analysis()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | +| `uplift_curve()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | +| `marginal_curve()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | +| Config flag switching | ❌ | ❌ | ❌ | ❌ | +| Deprecation warnings | ❌ | ❌ | ❌ | ❌ | + +**Coverage**: +- Legacy suite: ~80% (8 methods tested, matplotlib only) +- **New suite: ~1% (only 1 method partially tested in experimental file)** +- Compatibility tests: 0% + +**Critical Gap**: The new MMMPlotSuite (plot.py) has essentially NO test coverage! + +**Testing Strategy**: +- **Create new comprehensive tests for the new MMMPlotSuite** +- Parametrize all new tests to run against all backends (matplotlib, plotly, bokeh) +- Keep existing test_plot.py tests for legacy suite (will be removed in v0.20.0) +- Create separate compatibility test suite + +#### 6.2 Available Test Fixtures + +**From test_plot.py (for LegacyMMMPlotSuite):** +```python +@pytest.fixture(scope="module") +def mock_idata() -> az.InferenceData: + """Basic mock InferenceData with posterior.""" + # Line 201 + +@pytest.fixture(scope="module") +def mock_idata_with_constant_data() -> az.InferenceData: + """Mock InferenceData with constant_data for saturation plots.""" + # Line 315 + +@pytest.fixture(scope="module") +def mock_suite(mock_idata) -> LegacyMMMPlotSuite: + """LegacyMMMPlotSuite instance with basic mock data.""" + # Line 290 - currently creates from old_plot + +@pytest.fixture(scope="module") +def mock_suite_with_constant_data(mock_idata_with_constant_data) -> LegacyMMMPlotSuite: + """LegacyMMMPlotSuite with constant data for saturation plots.""" + # Line 382 - currently creates from old_plot + +@pytest.fixture +def mock_saturation_curve(mock_idata_with_constant_data) -> xr.DataArray: + """Mock saturation curve DataArray.""" + # Line 388 +``` + +**Pattern**: All fixtures use deterministic seeds for reproducibility. + +**Note**: These fixtures will need to be adapted/duplicated for testing the new MMMPlotSuite. + +#### 6.3 Required Test Implementation + +**Strategy**: Create NEW comprehensive tests for the new MMMPlotSuite with multi-backend support + +**Step 1: Keep existing test_plot.py for legacy suite** +- Rename test file to make it clear it's for legacy: `test_plot.py` → `test_legacy_plot.py` +- Update imports to use `legacy_plot.LegacyMMMPlotSuite` +- These tests will be removed in v0.20.0 along with the legacy suite + +**Step 2: Create new test_plot.py for new MMMPlotSuite** + +Create comprehensive tests with backend parametrization: + +```python +"""Tests for new MMMPlotSuite with multi-backend support.""" + +import pytest +from arviz_plots import PlotCollection +from pymc_marketing.mmm import mmm_config +from pymc_marketing.mmm.plot import MMMPlotSuite + +@pytest.fixture(scope="module") +def new_mock_suite(mock_idata) -> MMMPlotSuite: + """New MMMPlotSuite instance with basic mock data.""" + return MMMPlotSuite(idata=mock_idata) + +@pytest.fixture(scope="module") +def new_mock_suite_with_constant_data(mock_idata_with_constant_data) -> MMMPlotSuite: + """New MMMPlotSuite with constant data for saturation plots.""" + return MMMPlotSuite(idata=mock_idata_with_constant_data) + +# Parametrize all tests across all backends +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_posterior_predictive(new_mock_suite, backend): + """Test posterior_predictive works with all backends.""" + pc = new_mock_suite.posterior_predictive(backend=backend) + assert isinstance(pc, PlotCollection) + assert pc.backend == backend + +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_contributions_over_time(new_mock_suite, backend): + """Test contributions_over_time works with all backends.""" + pc = new_mock_suite.contributions_over_time( + var=["intercept"], + backend=backend + ) + assert isinstance(pc, PlotCollection) + assert pc.backend == backend + +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_saturation_scatterplot(new_mock_suite_with_constant_data, backend): + """Test saturation_scatterplot works with all backends.""" + pc = new_mock_suite_with_constant_data.saturation_scatterplot(backend=backend) + assert isinstance(pc, PlotCollection) + assert pc.backend == backend + +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_saturation_curves( + new_mock_suite_with_constant_data, mock_saturation_curve, backend +): + """Test saturation_curves works with all backends.""" + pc = new_mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, + backend=backend + ) + assert isinstance(pc, PlotCollection) + assert pc.backend == backend + +@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) +def test_budget_allocation_roas(new_mock_suite, backend): + """Test budget_allocation_roas works with all backends.""" + # Note: This is a NEW method that doesn't exist in legacy suite + pc = new_mock_suite.budget_allocation_roas(backend=backend) + assert isinstance(pc, PlotCollection) + assert pc.backend == backend + +# ... Create tests for all 9 methods with 3 backends = 27 core tests ... +``` + +**Step 3: Remove experimental test_plot_backends.py** +```bash +rm tests/mmm/test_plot_backends.py +``` + +**Step 4: Add backend-specific tests** + +```python +def test_backend_overrides_global_config(mock_suite): + """Test that method backend parameter overrides global config.""" + original = mmm_config.get("plot.backend", "matplotlib") + try: + mmm_config["plot.backend"] = "matplotlib" + + # Override with plotly + pc = mock_suite.contributions_over_time( + var=["intercept"], + backend="plotly" + ) + assert pc.backend == "plotly" + + # Default should still be matplotlib + pc2 = mock_suite.contributions_over_time(var=["intercept"]) + assert pc2.backend == "matplotlib" + finally: + mmm_config["plot.backend"] = original + +def test_invalid_backend_warning(mock_suite): + """Test that invalid backend shows warning but attempts plot.""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + mmm_config["plot.backend"] = "invalid_backend" + + assert len(w) == 1 + assert "Invalid backend" in str(w[0].message) +``` + +**Result**: +- New suite: ~9 methods × 3 backends = ~27 core test cases (plus backend-specific tests) +- Legacy suite: ~28 existing test functions (matplotlib only, will be removed in v0.20.0) + +**File 2: Create tests/mmm/test_plot_compatibility.py** + +New file for backward compatibility: + +```python +"""Tests for MMMPlotSuite backward compatibility and version switching.""" + +import pytest +import warnings +import numpy as np +from matplotlib.figure import Figure +from matplotlib.axes import Axes +from arviz_plots import PlotCollection + +from pymc_marketing.mmm import mmm_config +from pymc_marketing.mmm.plot import MMMPlotSuite +from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + + +class TestVersionSwitching: + """Test mmm_config['plot.use_v2'] flag controls suite version.""" + + def test_use_v2_false_returns_legacy_suite(self, mock_mmm): + """Test that use_v2=False returns LegacyMMMPlotSuite.""" + original = mmm_config.get("plot.use_v2", False) + try: + mmm_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning, match="deprecated in v0.20.0"): + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, LegacyMMMPlotSuite) + assert not isinstance(plot_suite, MMMPlotSuite) + finally: + mmm_config["plot.use_v2"] = original + + def test_use_v2_true_returns_new_suite(self, mock_mmm): + """Test that use_v2=True returns MMMPlotSuite.""" + original = mmm_config.get("plot.use_v2", False) + try: + mmm_config["plot.use_v2"] = True + + # Should not warn + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, MMMPlotSuite) + finally: + mmm_config["plot.use_v2"] = original + + def test_default_is_legacy_suite(self, mock_mmm): + """Test that default behavior uses legacy suite (backward compatible).""" + # Reset to defaults + mmm_config.reset() + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, LegacyMMMPlotSuite) + + +class TestDeprecationWarnings: + """Test deprecation warning system.""" + + def test_deprecation_warning_shown_by_default(self, mock_mmm): + """Test that deprecation warning is shown when using legacy suite.""" + mmm_config["plot.use_v2"] = False + mmm_config["plot.show_warnings"] = True + + with pytest.warns(FutureWarning, match=r"deprecated in v0\.20\.0"): + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, LegacyMMMPlotSuite) + + def test_deprecation_warning_suppressible(self, mock_mmm): + """Test that deprecation warning can be suppressed.""" + original_use_v2 = mmm_config.get("plot.use_v2", False) + original_warnings = mmm_config.get("plot.show_warnings", True) + + try: + mmm_config["plot.use_v2"] = False + mmm_config["plot.show_warnings"] = False + + # Should not warn + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, LegacyMMMPlotSuite) + finally: + mmm_config["plot.use_v2"] = original_use_v2 + mmm_config["plot.show_warnings"] = original_warnings + + def test_warning_message_includes_migration_info(self, mock_mmm): + """Test that warning provides clear migration instructions.""" + mmm_config["plot.use_v2"] = False + mmm_config["plot.show_warnings"] = True + + with pytest.warns(FutureWarning) as warning_list: + plot_suite = mock_mmm.plot + + warning_msg = str(warning_list[0].message) + assert "v0.20.0" in warning_msg + assert "mmm_config['plot.use_v2'] = True" in warning_msg + assert "migration guide" in warning_msg.lower() or "documentation" in warning_msg.lower() + + def test_no_warning_when_using_new_suite(self, mock_mmm): + """Test that no warning shown when using new suite.""" + mmm_config["plot.use_v2"] = True + + with warnings.catch_warnings(): + warnings.simplefilter("error") + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, MMMPlotSuite) + + +class TestReturnTypeCompatibility: + """Test that both suites return expected types.""" + + def test_legacy_suite_returns_tuple(self, mock_mmm_fitted): + """Test legacy suite returns (Figure, Axes) tuple.""" + mmm_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + result = plot_suite.posterior_predictive() + + assert isinstance(result, tuple) + assert len(result) == 2 + assert isinstance(result[0], Figure) + # result[1] can be Axes or ndarray of Axes + if isinstance(result[1], np.ndarray): + assert all(isinstance(ax, Axes) for ax in result[1].flat) + else: + assert isinstance(result[1], Axes) + + def test_new_suite_returns_plot_collection(self, mock_mmm_fitted): + """Test new suite returns PlotCollection.""" + mmm_config["plot.use_v2"] = True + + plot_suite = mock_mmm_fitted.plot + result = plot_suite.posterior_predictive() + + assert isinstance(result, PlotCollection) + assert hasattr(result, 'backend') + assert hasattr(result, 'show') + + def test_both_suites_produce_valid_plots(self, mock_mmm_fitted): + """Test that both suites can successfully create plots.""" + # Legacy suite + mmm_config["plot.use_v2"] = False + with pytest.warns(FutureWarning): + legacy_result = mock_mmm_fitted.plot.contributions_over_time( + var=["intercept"] + ) + assert legacy_result is not None + + # New suite + mmm_config["plot.use_v2"] = True + new_result = mock_mmm_fitted.plot.contributions_over_time( + var=["intercept"] + ) + assert new_result is not None + + +class TestMissingMethods: + """Test handling of methods that exist in one suite but not the other.""" + + def test_budget_allocation_exists_in_legacy_suite(self, mock_mmm_fitted, mock_allocation_samples): + """Test that budget_allocation() works in legacy suite.""" + mmm_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + + # Should work (not raise AttributeError) + result = plot_suite.budget_allocation(samples=mock_allocation_samples) + assert isinstance(result, tuple) + + def test_budget_allocation_raises_in_new_suite(self, mock_mmm_fitted): + """Test that budget_allocation() raises helpful error in new suite.""" + mmm_config["plot.use_v2"] = True + plot_suite = mock_mmm_fitted.plot + + with pytest.raises(NotImplementedError, match="removed in MMMPlotSuite v2"): + plot_suite.budget_allocation(samples=None) + + def test_budget_allocation_roas_exists_in_new_suite( + self, mock_mmm_fitted, mock_allocation_samples + ): + """Test that budget_allocation_roas() works in new suite.""" + mmm_config["plot.use_v2"] = True + plot_suite = mock_mmm_fitted.plot + + result = plot_suite.budget_allocation_roas(samples=mock_allocation_samples) + assert isinstance(result, PlotCollection) + + def test_budget_allocation_roas_missing_in_legacy_suite(self, mock_mmm_fitted): + """Test that budget_allocation_roas() doesn't exist in legacy suite.""" + mmm_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + + with pytest.raises(AttributeError): + plot_suite.budget_allocation_roas(samples=None) + + +class TestParameterCompatibility: + """Test parameter compatibility between suites.""" + + def test_var_parameter_list_in_legacy_suite(self, mock_mmm_fitted): + """Test that legacy suite accepts var as list.""" + mmm_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + + # Should accept list + result = plot_suite.posterior_predictive(var=["y", "target"]) + assert isinstance(result, tuple) + + def test_var_parameter_list_warning_in_new_suite(self, mock_mmm_fitted): + """Test that new suite warns when given list for var.""" + mmm_config["plot.use_v2"] = True + plot_suite = mock_mmm_fitted.plot + + with pytest.warns(UserWarning, match="only supports single variable"): + result = plot_suite.posterior_predictive(var=["y"]) + + assert isinstance(result, PlotCollection) +``` + +**File 3: Additional fixtures in tests/conftest.py or tests/mmm/conftest.py** + +```python +@pytest.fixture +def mock_mmm(mock_idata): + """Mock MMM instance with idata.""" + from pymc_marketing.mmm.multidimensional import MMM + + mmm = Mock(spec=MMM) + mmm.idata = mock_idata + mmm._validate_model_was_built = Mock() + mmm._validate_idata_exists = Mock() + + # Make .plot property work + type(mmm).plot = MMM.plot + + return mmm + +@pytest.fixture +def mock_allocation_samples(): + """Mock samples dataset for budget allocation tests.""" + import xarray as xr + import numpy as np + + rng = np.random.default_rng(42) + + return xr.Dataset({ + "channel_contribution_original_scale": xr.DataArray( + rng.normal(size=(4, 100, 52, 3)), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + "channel": ["TV", "Radio", "Digital"], + }, + ), + "allocation": xr.DataArray( + rng.uniform(100, 1000, size=(3,)), + dims=("channel",), + coords={"channel": ["TV", "Radio", "Digital"]}, + ), + }) +``` + +#### 6.4 Test Execution Checklist + +**Backend Testing:** +- [ ] Remove experimental test_plot_backends.py file +- [ ] Parametrize all ~28 tests in test_plot.py with backend parameter +- [ ] All ~84 parametrized tests pass (28 tests × 3 backends) +- [ ] Backend override test works correctly +- [ ] Invalid backend warning test passes + +**Compatibility Testing:** +- [ ] Create new test_plot_compatibility.py file +- [ ] All 15+ compatibility tests pass +- [ ] Config flag switching works +- [ ] Deprecation warnings show correctly +- [ ] Warnings are suppressible +- [ ] Both suites produce valid output +- [ ] Missing method raises helpful errors + +### 7. Import/Export Architecture + +#### 7.1 Current Import Chain + +``` +User Code + ↓ +from pymc_marketing.mmm.multidimensional import MMM + ↓ +MMM.plot property (multidimensional.py:602-607) + ↓ +Imports: from pymc_marketing.mmm.plot import MMMPlotSuite + ↓ +Returns: MMMPlotSuite(idata=self.idata) +``` + +#### 7.2 Required Imports for Compatibility + +**In multidimensional.py:** +```python +# Current (line 194) +from pymc_marketing.mmm.plot import MMMPlotSuite + +# Need to add in .plot property +from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite # Import locally in property +from pymc_marketing.mmm.config import mmm_config # Import locally in property +``` + +**NOT in mmm/__init__.py:** +- MMMPlotSuite is **not** exported in [pymc_marketing/mmm/__init__.py](pymc_marketing/mmm/__init__.py#L69-L119) +- Users access it via `mmm.plot.method()`, not by importing directly +- This is good - no need to modify `__all__` + +#### 7.3 User Usage Pattern + +```python +# Users do this: +from pymc_marketing.mmm.multidimensional import MMM + +mmm = MMM(...) +mmm.fit(...) + +# Access via property - this is where version switching happens +mmm.plot.posterior_predictive() # ← Property returns either old or new suite +``` + +### 8. Migration Timeline + +#### Phase 1: v0.18.0 (Current/Next Release) +**Goal**: Introduce new suite with safe fallback + +- ✅ Backend configuration (done) +- ✅ New suite implementation (done) +- ❌ Add `use_v2` flag to config (TODO) +- ❌ Implement version switching in `.plot` property (TODO) +- ❌ Add deprecation warning (TODO) +- ❌ Complete test coverage (TODO) +- ❌ Write migration guide documentation (TODO) + +**User Experience**: +- Default behavior: legacy suite with warning +- Opt-in to new: `mmm_config["plot.use_v2"] = True` +- Clear migration path provided + +#### Phase 2: v0.19.0 +**Goal**: Encourage migration to new suite + +- Change default: `"plot.use_v2": True` +- Keep legacy suite available via `use_v2=False` +- Strengthen warning when using legacy suite +- Monitor for issues + +**User Experience**: +- Default behavior: new suite +- Opt-out to legacy: `mmm_config["plot.use_v2"] = False` +- Legacy suite shows stronger deprecation warning + +#### Phase 3: v0.20.0 +**Goal**: Complete migration + +- Remove `LegacyMMMPlotSuite` class +- Remove `legacy_plot.py` file +- Remove `use_v2` flag +- Update all documentation +- Only new suite available + +**User Experience**: +- Only new suite available +- Legacy code must update to new API + +### 9. Breaking Changes Summary + +#### 9.1 Return Type + +**Legacy**: `tuple[Figure, NDArray[Axes]]` or `tuple[Figure, plt.Axes]` +```python +fig, axes = mmm.plot.posterior_predictive() +axes[0].set_title("Custom") +fig.savefig("plot.png") +``` + +**New**: `PlotCollection` +```python +pc = mmm.plot.posterior_predictive() +pc.show() # Display +pc.save("plot.png") # Save +``` + +#### 9.2 Parameter Changes + +| Method | Parameter | Legacy | New | Fix | +|--------|-----------|--------|-----|-----| +| `posterior_predictive()` | `var` | `list[str]` | `str` | Call multiple times or use list with warning | +| `saturation_scatterplot()` | `**kwargs` | Accepted | Removed | Customize PlotCollection after | +| `saturation_curves()` | `colors` | Supported | Removed | Use PlotCollection API | +| `saturation_curves()` | `subplot_kwargs` | Supported | Removed | Use PlotCollection API | +| `saturation_curves()` | `rc_params` | Supported | Removed | Set before calling | +| All methods | `ax` | Supported | Removed | Use PlotCollection | +| All methods | `figsize` | Supported | Removed | Use PlotCollection | +| All methods | `backend` | N/A | Added | Override global config | + +#### 9.3 Method Changes + +| Method | Status | Replacement | Notes | +|--------|--------|-------------|-------| +| `budget_allocation()` | **REMOVED** | None exact | Use legacy suite or custom plot | +| `budget_allocation_roas()` | **NEW** | N/A | Different purpose (ROI dist) | + +### 10. Documentation Requirements + +#### 10.1 Migration Guide (docs/source/guides/mmm_plotting_migration.rst) + +Must include: + +1. **Overview** + - Why the change (arviz_plots benefits) + - Timeline (v0.18.0 intro, v0.19.0 default, v0.20.0 removal) + - How to opt-in/opt-out + +2. **Quick Start** + ```python + # Use new suite + from pymc_marketing.mmm import mmm_config + mmm_config["plot.use_v2"] = True + + # Set backend + mmm_config["plot.backend"] = "plotly" + ``` + +3. **Return Type Migration** + - Side-by-side examples + - How to work with PlotCollection + +4. **Method-by-Method Guide** + - API changes table + - Code examples for each method + - Common issues and solutions + +5. **Missing Features** + - `budget_allocation()` alternatives + - Lost customization parameters + - Workarounds + +6. **Backend Selection** + - Pros/cons of each backend + - When to use which + - Examples + +#### 10.2 Docstring Updates + +All methods in new suite need: +```python +def method_name(...) -> PlotCollection: + """ + Description. + + .. versionadded:: 0.18.0 + New arviz_plots-based implementation supporting multiple backends. + + Parameters + ---------- + backend : str, optional + Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". + If None, uses global config via mmm_config["plot.backend"]. + Default is "matplotlib". + + Returns + ------- + PlotCollection + arviz_plots PlotCollection object containing the plot. + Use .show() to display or .save("filename") to save. + Supports matplotlib, plotly, and bokeh backends. + + Unlike v1 which returned (Figure, Axes), this provides + a unified interface across all backends. + + Examples + -------- + Basic usage: + + >>> pc = mmm.plot.method_name() + >>> pc.show() + + Save to file: + + >>> pc.save("output.png") + + Use different backend: + + >>> pc = mmm.plot.method_name(backend="plotly") + >>> pc.show() + """ +``` + +## Code References + +### Core Implementation Files +- [pymc_marketing/mmm/plot.py](pymc_marketing/mmm/plot.py) - New MMMPlotSuite (1272 lines) +- [pymc_marketing/mmm/old_plot.py](pymc_marketing/mmm/old_plot.py) - Legacy implementation (1936 lines) - **TO BE RENAMED to legacy_plot.py** +- [pymc_marketing/mmm/config.py:21-66](pymc_marketing/mmm/config.py#L21-L66) - Backend configuration +- [pymc_marketing/mmm/multidimensional.py:602-607](pymc_marketing/mmm/multidimensional.py#L602-L607) - Integration point (.plot property) + +### Test Files +- [tests/mmm/test_plot.py](tests/mmm/test_plot.py) - Main plot tests (800+ lines) +- [tests/mmm/test_plot_backends.py](tests/mmm/test_plot_backends.py) - Backend tests (255 lines, incomplete) +- [tests/mmm/test_plotting.py](tests/mmm/test_plotting.py) - Legacy plotting tests + +### Deprecation Patterns +- [pymc_marketing/model_builder.py:60-77](pymc_marketing/model_builder.py#L60-L77) - Parameter deprecation helper +- [pymc_marketing/mmm/plot.py:737-771](pymc_marketing/mmm/plot.py#L737-L771) - Method deprecation example +- [pymc_marketing/clv/models/basic.py:49-59](pymc_marketing/clv/models/basic.py#L49-L59) - Config key deprecation +- [tests/test_model_builder.py:530-554](tests/test_model_builder.py#L530-L554) - Deprecation test pattern + +## Architecture Insights + +1. **Config-Based Feature Flags**: The codebase uses dict-based configuration (`mmm_config`) for runtime behavior control, similar to matplotlib's `rcParams` or arviz's config system. + +2. **Property-Based API**: Plot methods are accessed via `.plot` property that creates instances on-demand, enabling clean version switching at the access point. + +3. **Backend Abstraction**: The new implementation achieves backend independence through arviz_plots' `PlotCollection`, which handles backend-specific rendering internally. + +4. **Test Fixture Patterns**: All test fixtures use deterministic random seeds and module scope for performance, following pytest best practices. + +5. **Deprecation Philosophy**: The codebase uses `DeprecationWarning` for library developers and `FutureWarning` for end users, with clear migration paths in all warnings. + +6. **Incremental Migration**: Multiple patterns show support for gradual API transitions over several releases before removing old code. + +## Recommendations + +### Priority 0: File Renaming (Must Complete First) + +0. **Rename files and classes** ✅ 30 minutes + - Rename `pymc_marketing/mmm/old_plot.py` to `legacy_plot.py` + - Rename class `OldMMMPlotSuite` to `LegacyMMMPlotSuite` throughout the file + - Update any imports in existing code/tests + - **This must be done BEFORE implementing other changes** + +### Priority 1: Critical (Must Complete for PR) + +1. **Add backward compatibility flag** ✅ 2 hours + - Modify `config.py` to add `"plot.use_v2": False` + - Implement version switching in `multidimensional.py:602-607` + - Import from `legacy_plot` module + - Add deprecation warning with migration guide link + - Test manual switching works + +2. **Create comprehensive backend testing for new suite** ✅ 6 hours + - Rename existing test_plot.py to test_legacy_plot.py + - Update imports in legacy test file to use legacy_plot module + - CREATE NEW test_plot.py for the new MMMPlotSuite + - Write ~9 methods × 3 backends = ~27 parametrized tests + - Remove experimental test_plot_backends.py file + - Add backend override and invalid backend tests + - Verify all new tests pass + +3. **Create compatibility test suite** ✅ 3 hours + - Create `test_plot_compatibility.py` + - Test version switching (5 tests) + - Test deprecation warnings (4 tests) + - Test return types (3 tests) + - Test missing methods (4 tests) + - Test parameter compatibility (2 tests) + +### Priority 2: Important (Before Merge) + +4. **Update documentation** ⏱️ 4 hours + - Update method docstrings with PlotCollection info + - Add version directives (.. versionadded::) + - Document backend parameter + - Add usage examples + +5. **Write migration guide** ⏱️ 6 hours + - Create `docs/source/guides/mmm_plotting_migration.rst` + - Document all breaking changes (including parameter type changes) + - Provide side-by-side examples + - List missing features and workarounds + - Explain that parameter changes require code adaptation when switching to v2 + +### Priority 3: Nice to Have (Can Defer) + +8. **Add usage examples to docstrings** ⏱️ 2 hours + - Add Examples section to all methods + - Show basic usage, saving, backend switching + +9. **Create visual test notebook** ⏱️ 3 hours + - Notebook comparing old vs new outputs + - Demonstrates all backends + - Helps verify visual equivalence + +10. **Performance testing** ⏱️ 2 hours + - Compare old vs new rendering times + - Test with large datasets + - Document any performance changes + +## Open Questions + +### Q1: When should default switch from old to new? + +**Options**: +- A. v0.18.0 - Aggressive, breaks existing code immediately +- B. v0.19.0 - Conservative, gives 1 release for users to adapt +- C. v0.20.0 - Very conservative, 2 releases to adapt + +**Recommendation**: Option B (v0.19.0) +- v0.18.0: Introduce with legacy default + warning +- v0.19.0: Switch to new default, keep legacy available +- v0.20.0: Remove legacy completely + +### Q2: Should LegacyMMMPlotSuite be importable directly? + +**Current**: Only via `.plot` property +**Alternative**: Export in `mmm/__init__.py` + +**Recommendation**: Keep internal-only +- Encourages proper migration +- Reduces maintenance burden +- Users can still access via `use_v2=False` + +### Q3: How to handle `budget_allocation()` removal? + +**Options**: +- A. Keep in legacy suite, remove from new (current approach) +- B. Add adapter in new suite that approximates behavior +- C. Port to new suite with PlotCollection return type + +**Recommendation**: Option A with stub that raises +- Clear error message guides users +- Avoids maintaining duplicate functionality +- Allows temporary use of legacy suite + +### Q4: Should warnings be shown every time or once per session? + +**Current Pattern**: Every call +**Alternative**: Once per session using warning filters + +**Recommendation**: Every call (current) +- More visible, harder to ignore +- Consistent with other deprecation warnings +- Users can suppress globally if desired + +### Q5: What about projects pinned to specific versions? + +**Scenario**: User pins to v0.18.0, doesn't update + +**Solution**: +- `use_v2=False` default in v0.18.0 ensures no breakage +- Warning provides clear timeline +- Projects can update at their own pace +- No forced migration until they upgrade to v0.20.0+ + +## Implementation Checklist + +### Phase 1: Code Changes +- [ ] **Rename `old_plot.py` to `legacy_plot.py` and `OldMMMPlotSuite` to `LegacyMMMPlotSuite`** +- [ ] Add `"plot.use_v2": False` to config.py defaults +- [ ] Modify multidimensional.py `.plot` property with version switching +- [ ] Add FutureWarning for legacy suite usage +- [ ] Update all docstrings to document PlotCollection return type + +### Phase 2: Testing +- [ ] **Rename `tests/mmm/test_plot.py` to `test_legacy_plot.py` (tests for legacy suite)** +- [ ] **Update imports in renamed test file to use `legacy_plot.LegacyMMMPlotSuite`** +- [ ] **Create NEW `tests/mmm/test_plot.py` for new MMMPlotSuite** +- [ ] **Write ~9 methods × 3 backends = ~27 parametrized tests for new suite** +- [ ] Remove experimental `tests/mmm/test_plot_backends.py` file +- [ ] Add backend override and invalid backend tests +- [ ] Create `tests/mmm/test_plot_compatibility.py` (15+ tests) +- [ ] Add mock_mmm fixture +- [ ] Add mock_allocation_samples fixture +- [ ] Verify all ~27 new suite backend tests pass +- [ ] Verify all 15 compatibility tests pass +- [ ] Test warning suppression works +- [ ] Test both suites produce valid output + +### Phase 3: Documentation +- [ ] Create migration guide (docs/source/guides/mmm_plotting_migration.rst) +- [ ] Document breaking changes table +- [ ] Provide code examples for migration +- [ ] Update API reference +- [ ] Add versionadded directives +- [ ] Document backend selection +- [ ] List missing features and workarounds + +### Phase 4: Review +- [ ] Code review for new implementation +- [ ] Test coverage review (aim for >95%) +- [ ] Documentation review +- [ ] Migration guide validation with sample code +- [ ] Timeline communication (v0.18.0 → v0.20.0) + +## Related Research + +- [CLAUDE.md](../../CLAUDE.md) - Project development guidelines +- [CONTRIBUTING.md](../../CONTRIBUTING.md) - Code style and testing requirements +- [pyproject.toml](../../pyproject.toml) - Test configuration and linting rules + +## Appendix: Complete Implementation Template + +### A. Config File Modification + +**File**: `pymc_marketing/mmm/config.py` + +```python +_defaults = { + "plot.backend": "matplotlib", + "plot.show_warnings": True, + "plot.use_v2": False, # ← ADD THIS LINE +} +``` + +### B. Property Modification + +**File**: `pymc_marketing/mmm/multidimensional.py` + +```python +@property +def plot(self) -> MMMPlotSuite | LegacyMMMPlotSuite: + """Use the MMMPlotSuite to plot the results. + + The plot suite version is controlled by mmm_config["plot.use_v2"]: + - False (default): Uses legacy matplotlib-based suite (will be deprecated) + - True: Uses new arviz_plots-based suite with multi-backend support + + .. versionchanged:: 0.18.0 + Added version control via mmm_config["plot.use_v2"]. + The legacy suite will be removed in v0.20.0. + + Examples + -------- + Use new plot suite: + + >>> from pymc_marketing.mmm import mmm_config + >>> mmm_config["plot.use_v2"] = True + >>> pc = mmm.plot.posterior_predictive() + >>> pc.show() + + Returns + ------- + MMMPlotSuite or LegacyMMMPlotSuite + Plot suite instance for creating MMM visualizations. + """ + from pymc_marketing.mmm.config import mmm_config + from pymc_marketing.mmm.plot import MMMPlotSuite + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + import warnings + + self._validate_model_was_built() + self._validate_idata_exists() + + # Check version flag + if mmm_config.get("plot.use_v2", False): + return MMMPlotSuite(idata=self.idata) + else: + # Show deprecation warning for legacy suite + if mmm_config.get("plot.show_warnings", True): + warnings.warn( + "The current MMMPlotSuite will be deprecated in v0.20.0. " + "The new version uses arviz_plots and supports multiple backends (matplotlib, plotly, bokeh). " + "To use the new version: mmm_config['plot.use_v2'] = True\n" + "To suppress this warning: mmm_config['plot.show_warnings'] = False\n" + "See migration guide: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html", + FutureWarning, + stacklevel=2, + ) + return LegacyMMMPlotSuite(idata=self.idata) +``` + +### C. Missing Method Stub + +**File**: `pymc_marketing/mmm/plot.py` (add to MMMPlotSuite class) + +```python +def budget_allocation(self, *args, **kwargs): + """ + Create bar chart comparing allocated spend and channel contributions. + + .. deprecated:: 0.18.0 + This method was removed in MMMPlotSuite v2. The arviz_plots library + used in v2 doesn't support this specific chart type. See alternatives below. + + Raises + ------ + NotImplementedError + This method is not available in MMMPlotSuite v2. + + Notes + ----- + Alternatives: + + 1. **For ROI distributions**: Use :meth:`budget_allocation_roas` + (different purpose but related to budget allocation) + + 2. **To use the old method**: Switch to legacy suite: + + >>> from pymc_marketing.mmm import mmm_config + >>> mmm_config["plot.use_v2"] = False + >>> mmm.plot.budget_allocation(samples) + + 3. **Custom implementation**: Create bar chart using samples data: + + >>> import matplotlib.pyplot as plt + >>> channel_contrib = samples["channel_contribution"].mean(...) + >>> allocated_spend = samples["allocation"] + >>> # Create custom bar chart with matplotlib + + See Also + -------- + budget_allocation_roas : Plot ROI distributions by channel + + Examples + -------- + Use legacy suite temporarily: + + >>> from pymc_marketing.mmm import mmm_config + >>> original = mmm_config.get("plot.use_v2") + >>> try: + ... mmm_config["plot.use_v2"] = False + ... fig, ax = mmm.plot.budget_allocation(samples) + ... fig.savefig("budget.png") + ... finally: + ... mmm_config["plot.use_v2"] = original + """ + raise NotImplementedError( + "budget_allocation() was removed in MMMPlotSuite v2.\n\n" + "The new arviz_plots-based implementation doesn't support this chart type.\n\n" + "Alternatives:\n" + " 1. For ROI distributions: use budget_allocation_roas()\n" + " 2. To use old method: set mmm_config['plot.use_v2'] = False\n" + " 3. Implement custom bar chart using the samples data\n\n" + "See documentation: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html#budget-allocation" + ) +``` + +--- + +**End of Research Document** From e1403627c3563e86d29eb8bc425256a6b346a5de Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Thu, 20 Nov 2025 11:40:19 -0500 Subject: [PATCH 08/29] change to research --- ...mmplotsuite-migration-complete-analysis.md | 717 ++++++++++++++++-- 1 file changed, 639 insertions(+), 78 deletions(-) diff --git a/thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md b/thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md index 3451ac6cd..d631d7038 100644 --- a/thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md +++ b/thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md @@ -42,6 +42,8 @@ Based on comprehensive codebase analysis, the migration is **75% complete** with **❌ Missing Critical Components:** - Rename `old_plot.py` to `legacy_plot.py` and `OldMMMPlotSuite` to `LegacyMMMPlotSuite` +- **Data Parameter Standardization**: All plotting methods should accept data as input parameters (some with fallback to `self.idata`, some without). Currently inconsistent across methods. +- **`_sensitivity_analysis_plot()` refactoring**: Must accept `data` as REQUIRED parameter (no fallback), and all callers (`sensitivity_analysis()`, `uplift_curve()`, `marginal_curve()`) must be updated to pass data explicitly. - Backward compatibility flag (`use_v2`) to toggle between legacy/new suite - Deprecation warning system for users - Comprehensive backend testing for the new suite @@ -52,6 +54,7 @@ Based on comprehensive codebase analysis, the migration is **75% complete** with - Return type documentation incomplete - Breaking parameter type changes across all methods (intentional, no backward compatibility needed) - Lost customization parameters (colors, subplot_kwargs, rc_params) - handled by arviz_plots +- **Deprecated method carried forward**: `saturation_curves_scatter()` is implemented in v2 but should be removed (already deprecated in v0.1.0) ## Detailed Findings @@ -87,7 +90,7 @@ Based on comprehensive codebase analysis, the migration is **75% complete** with | `contributions_over_time()` | ✅ | ✅ | ⚠️ | Return type only | | `saturation_scatterplot()` | ✅ | ✅ | ⚠️ | Lost `**kwargs`, return type | | `saturation_curves()` | ✅ | ✅ | ❌ | Lost colors, subplot_kwargs, rc_params | -| `saturation_curves_scatter()` | ✅ | ✅ | ✅ | DEPRECATED in both (delegates) | +| `saturation_curves_scatter()` | ⚠️ | ✅ | ⚠️ | **SHOULD BE REMOVED** - Currently in v2 but deprecated, delegates to saturation_scatterplot | | `budget_allocation()` | ❌ | ✅ | ❌ | **REMOVED** - no replacement | | `budget_allocation_roas()` | ✅ | ❌ | N/A | New method, different purpose | | `allocated_contribution_by_channel_over_time()` | ✅ | ✅ | ❌ | Lost scale_factor, quantiles, figsize, ax | @@ -428,69 +431,7 @@ def budget_allocation(self, *args, **kwargs): ) ``` -#### Issue 4: Lost Customization Parameters ⚠️ **MODERATE** - -**Problem**: Many customization parameters removed in new implementation. - -**Examples**: - -`saturation_curves()`: -```python -# LEGACY - rich customization -def saturation_curves( - self, - curve, - colors: Iterable[str] | None = None, - subplot_kwargs: dict | None = None, - rc_params: dict | None = None, - **plot_kwargs -) - -# NEW - simplified -def saturation_curves( - self, - curve, - backend: str | None = None, - # Lost: colors, subplot_kwargs, rc_params, plot_kwargs -) -``` - -`sensitivity_analysis()`: -```python -# LEGACY -def sensitivity_analysis( - self, - ax: plt.Axes | None = None, - subplot_kwargs: dict[str, Any] | None = None, - plot_kwargs: dict[str, Any] | None = None, - ylabel: str = "Effect", - xlabel: str = "Sweep", - title: str | None = None, - ... -) - -# NEW -def sensitivity_analysis( - self, - backend: str | None = None, - # Lost: ax, subplot_kwargs, plot_kwargs, ylabel, xlabel, title -) -``` - -**Rationale**: arviz_plots handles layout automatically, reducing need for manual control. - -**Mitigation**: Document how to customize `PlotCollection` objects after creation: -```python -pc = mmm.plot.saturation_curves(curve) - -# For matplotlib backend -if pc.backend == "matplotlib": - for ax in pc.axes.flat: - ax.set_title("Custom Title") - ax.set_xlabel("Custom X Label") -``` - -#### Issue 5: Backend Parameter Coverage ✅ **GOOD** +#### Issue 4: Backend Parameter Coverage ✅ **GOOD** **Status**: All public methods have `backend` parameter: - `posterior_predictive()` ✅ @@ -502,9 +443,59 @@ if pc.backend == "matplotlib": - `sensitivity_analysis()` ✅ - `uplift_curve()` ✅ - `marginal_curve()` ✅ +- ~~`saturation_curves_scatter()`~~ - **TO BE REMOVED** (deprecated method, see Issue 5) **Pattern**: Consistent across all methods, properly resolves via `_resolve_backend()`. +#### Issue 5: Deprecated Method Should Be Removed ⚠️ **MINOR BUT IMPORTANT** + +**Problem**: `saturation_curves_scatter()` is currently implemented in MMMPlotSuite v2 but is deprecated and just delegates to `saturation_scatterplot()`. + +**Current implementation** (lines 737-771 in [plot.py](pymc_marketing/mmm/plot.py#L737-L771)): +```python +def saturation_curves_scatter(self, original_scale: bool = False, **kwargs) -> PlotCollection: + """ + .. deprecated:: 0.1.0 + Will be removed in version 0.20.0. Use :meth:`saturation_scatterplot` instead. + """ + import warnings + warnings.warn( + "saturation_curves_scatter is deprecated and will be removed in version 0.2.0. " + "Use saturation_scatterplot instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.saturation_scatterplot(original_scale=original_scale, **kwargs) +``` + +**Rationale for removal**: +- Since MMMPlotSuite v2 is a completely new implementation, we should NOT carry forward deprecated methods +- The legacy suite (LegacyMMMPlotSuite) already has this method for users who need it +- Users who opt into v2 (`mmm_config["plot.use_v2"] = True`) should use the new, correct method name +- Keeping deprecated methods in v2 defeats the purpose of a clean migration +- The method was deprecated in v0.1.0, giving users ample time to migrate + +**Recommendation**: **REMOVE** `saturation_curves_scatter()` from MMMPlotSuite (plot.py) entirely. + +**Implementation**: +1. Delete the method from [pymc_marketing/mmm/plot.py:737-771](pymc_marketing/mmm/plot.py#L737-L771) +2. Keep it in LegacyMMMPlotSuite (legacy_plot.py) for backward compatibility +3. Document the removal in migration guide + +**Alternative** (if keeping for one more release): +Add a note in the deprecation warning that it won't be available in v2 by default: +```python +warnings.warn( + "saturation_curves_scatter is deprecated and will be removed in version 0.20.0. " + "Use saturation_scatterplot instead. " + "Note: This method is not available when using mmm_config['plot.use_v2'] = True.", + DeprecationWarning, + stacklevel=2, +) +``` + +**Preferred approach**: Clean removal from v2, keep only in legacy suite. + ### 6. Testing Infrastructure ⚠️ **MAJOR GAPS** #### 6.1 Current Test Coverage @@ -667,7 +658,7 @@ def test_budget_allocation_roas(new_mock_suite, backend): assert isinstance(pc, PlotCollection) assert pc.backend == backend -# ... Create tests for all 9 methods with 3 backends = 27 core tests ... +# ... Create tests for all 8 methods with 3 backends = 24 core tests ... ``` **Step 3: Remove experimental test_plot_backends.py** @@ -710,7 +701,7 @@ def test_invalid_backend_warning(mock_suite): ``` **Result**: -- New suite: ~9 methods × 3 backends = ~27 core test cases (plus backend-specific tests) +- New suite: ~8 methods × 3 backends = ~24 core test cases (plus backend-specific tests) - note: saturation_curves_scatter removed - Legacy suite: ~28 existing test functions (matplotlib only, will be removed in v0.20.0) **File 2: Create tests/mmm/test_plot_compatibility.py** @@ -997,8 +988,9 @@ def mock_allocation_samples(): **Backend Testing:** - [ ] Remove experimental test_plot_backends.py file -- [ ] Parametrize all ~28 tests in test_plot.py with backend parameter -- [ ] All ~84 parametrized tests pass (28 tests × 3 backends) +- [ ] Remove deprecated `saturation_curves_scatter()` from MMMPlotSuite +- [ ] Parametrize all tests in new test_plot.py with backend parameter (~8 methods) +- [ ] All ~24 parametrized tests pass (8 methods × 3 backends) - [ ] Backend override test works correctly - [ ] Invalid backend warning test passes @@ -1136,6 +1128,7 @@ pc.save("plot.png") # Save | Method | Status | Replacement | Notes | |--------|--------|-------------|-------| +| `saturation_curves_scatter()` | **REMOVED in v2** | `saturation_scatterplot()` | Deprecated in v0.1.0, not carried forward to v2 | | `budget_allocation()` | **REMOVED** | None exact | Use legacy suite or custom plot | | `budget_allocation_roas()` | **NEW** | N/A | Different purpose (ROI dist) | @@ -1258,9 +1251,549 @@ def method_name(...) -> PlotCollection: 6. **Incremental Migration**: Multiple patterns show support for gradual API transitions over several releases before removing old code. +## Data Parameter Standardization ⚠️ **CRITICAL - MUST IMPLEMENT** + +### Summary + +**Goal**: All plotting methods should accept data as input parameters for consistency, testability, and flexibility. + +**Status**: Currently **inconsistent** - some methods accept data, others hard-code `self.idata` access. + +**Impact**: Must be fixed BEFORE writing tests, as tests need to be written against the correct API. + +**Time Estimate**: 4 hours + +**Key Changes**: +- 7 methods need updates +- `_sensitivity_analysis_plot()` must accept `data` as REQUIRED parameter (no fallback) +- All other methods can have fallback to `self.idata` +- Removes monkey-patching in `uplift_curve()` and `marginal_curve()` + +### Current State Analysis + +The new MMMPlotSuite methods currently have **inconsistent data parameter patterns**: + +**✅ Methods that already accept data as input:** +- `posterior_predictive(idata: xr.Dataset | None)` - With fallback to `self.idata.posterior_predictive` +- `budget_allocation_roas(samples: xr.Dataset)` - No fallback +- `allocated_contribution_by_channel_over_time(samples: xr.Dataset)` - No fallback + +**❌ Methods that need data parameters added:** +- `contributions_over_time()` - Currently uses `self.idata.posterior` directly +- `saturation_scatterplot()` - Currently uses `self.idata.constant_data` and `self.idata.posterior` +- `saturation_curves()` - Accepts `curve` but still uses `self.idata.constant_data` and `self.idata.posterior` for scatter +- `_sensitivity_analysis_plot()` - Currently uses `self.idata.sensitivity_analysis` (**must accept data, NO fallback**) +- `sensitivity_analysis()` - Needs to accept and pass data to `_sensitivity_analysis_plot()` +- `uplift_curve()` - Needs to accept and pass data to `_sensitivity_analysis_plot()` +- `marginal_curve()` - Needs to accept and pass data to `_sensitivity_analysis_plot()` + +### Required API Changes + +#### 1. contributions_over_time() - Add data parameter with fallback + +**Current signature** (line 387): +```python +def contributions_over_time( + self, + var: list[str], + hdi_prob: float = 0.85, + dims: dict[str, str | int | list] | None = None, + backend: str | None = None, +) -> PlotCollection: +``` + +**New signature**: +```python +def contributions_over_time( + self, + var: list[str], + data: xr.Dataset | None = None, # ← ADD THIS + hdi_prob: float = 0.85, + dims: dict[str, str | int | list] | None = None, + backend: str | None = None, +) -> PlotCollection: + """Plot the time-series contributions for each variable in `var`. + + Parameters + ---------- + var : list of str + A list of variable names to plot from the posterior. + data : xr.Dataset, optional + Dataset containing posterior data. If None, uses self.idata.posterior. + ... + """ +``` + +**Implementation changes** (lines 426-437): +```python +# OLD: +if not hasattr(self.idata, "posterior"): + raise ValueError(...) +da = self.idata.posterior[var] + +# NEW: +if data is None: + if not hasattr(self.idata, "posterior"): + raise ValueError( + "No posterior data found in 'self.idata' and no 'data' argument provided. " + "Please ensure 'self.idata' contains a 'posterior' group or provide 'data'." + ) + data = self.idata.posterior +da = data[var] +``` + +#### 2. saturation_scatterplot() - Add data parameters with fallback + +**Current signature** (line 493): +```python +def saturation_scatterplot( + self, + original_scale: bool = False, + dims: dict[str, str | int | list] | None = None, + backend: str | None = None, +) -> PlotCollection: +``` + +**New signature**: +```python +def saturation_scatterplot( + self, + original_scale: bool = False, + constant_data: xr.Dataset | None = None, # ← ADD THIS + posterior_data: xr.Dataset | None = None, # ← ADD THIS + dims: dict[str, str | int | list] | None = None, + backend: str | None = None, +) -> PlotCollection: + """Plot the saturation curves for each channel. + + Parameters + ---------- + original_scale : bool, optional + Whether to plot the original scale contributions. Default is False. + constant_data : xr.Dataset, optional + Dataset containing constant_data group with 'channel_data' variable. + If None, uses self.idata.constant_data. + posterior_data : xr.Dataset, optional + Dataset containing posterior group with channel contribution variables. + If None, uses self.idata.posterior. + ... + """ +``` + +**Implementation changes** (lines 524-562): +```python +# OLD: +if not hasattr(self.idata, "constant_data"): + raise ValueError(...) +cdims = self.idata.constant_data.channel_data.dims +channel_data = self.idata.constant_data.channel_data +channel_contrib = self.idata.posterior[channel_contribution] + +# NEW: +if constant_data is None: + if not hasattr(self.idata, "constant_data"): + raise ValueError( + "No 'constant_data' found in 'self.idata' and no 'constant_data' argument provided. " + "Please ensure 'self.idata' contains the constant_data group or provide 'constant_data'." + ) + constant_data = self.idata.constant_data + +if posterior_data is None: + if not hasattr(self.idata, "posterior"): + raise ValueError( + "No 'posterior' found in 'self.idata' and no 'posterior_data' argument provided. " + "Please ensure 'self.idata' contains the posterior group or provide 'posterior_data'." + ) + posterior_data = self.idata.posterior + +cdims = constant_data.channel_data.dims +channel_data = constant_data.channel_data +channel_contrib = posterior_data[channel_contribution] +``` + +#### 3. saturation_curves() - Update to use data parameters from saturation_scatterplot + +**Current signature** (line 597): +```python +def saturation_curves( + self, + curve: xr.DataArray, + original_scale: bool = False, + n_samples: int = 10, + hdi_probs: float | list[float] | None = None, + random_seed: np.random.Generator | None = None, + dims: dict[str, str | int | list] | None = None, + backend: str | None = None, +) -> PlotCollection: +``` + +**New signature**: +```python +def saturation_curves( + self, + curve: xr.DataArray, + original_scale: bool = False, + constant_data: xr.Dataset | None = None, # ← ADD THIS + posterior_data: xr.Dataset | None = None, # ← ADD THIS + n_samples: int = 10, + hdi_probs: float | list[float] | None = None, + random_seed: np.random.Generator | None = None, + dims: dict[str, str | int | list] | None = None, + backend: str | None = None, +) -> PlotCollection: + """Overlay saturation‑curve scatter‑plots with posterior‑predictive sample curves. + + Parameters + ---------- + curve : xr.DataArray + Posterior‑predictive curves (e.g. dims `("chain","draw","x","channel","geo")`). + original_scale : bool, default=False + Plot `channel_contribution_original_scale` if True, else `channel_contribution`. + constant_data : xr.Dataset, optional + Dataset containing constant_data group. If None, uses self.idata.constant_data. + posterior_data : xr.Dataset, optional + Dataset containing posterior group. If None, uses self.idata.posterior. + ... + """ +``` + +**Implementation changes** (lines 645-696): +```python +# OLD: +if not hasattr(self.idata, "constant_data"): + raise ValueError(...) +if original_scale: + curve_data = curve * self.idata.constant_data.target_scale + curve_data["x"] = curve_data["x"] * self.idata.constant_data.channel_scale +cdims = self.idata.constant_data.channel_data.dims +pc = self.saturation_scatterplot(original_scale=original_scale, dims=dims, backend=backend) + +# NEW: +if constant_data is None: + if not hasattr(self.idata, "constant_data"): + raise ValueError( + "No 'constant_data' found in 'self.idata' and no 'constant_data' argument provided." + ) + constant_data = self.idata.constant_data + +if posterior_data is None: + if not hasattr(self.idata, "posterior"): + raise ValueError( + "No 'posterior' found in 'self.idata' and no 'posterior_data' argument provided." + ) + posterior_data = self.idata.posterior + +if original_scale: + curve_data = curve * constant_data.target_scale + curve_data["x"] = curve_data["x"] * constant_data.channel_scale +cdims = constant_data.channel_data.dims +pc = self.saturation_scatterplot( + original_scale=original_scale, + constant_data=constant_data, + posterior_data=posterior_data, + dims=dims, + backend=backend +) +``` + +#### 4. _sensitivity_analysis_plot() - Accept data parameter WITHOUT fallback ⚠️ **CRITICAL** + +**Current signature** (line 979): +```python +def _sensitivity_analysis_plot( + self, + hdi_prob: float = 0.94, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + backend: str | None = None, +) -> PlotCollection: +``` + +**New signature**: +```python +def _sensitivity_analysis_plot( + self, + data: xr.DataArray | xr.Dataset, # ← ADD THIS (REQUIRED, NO DEFAULT) + hdi_prob: float = 0.94, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + backend: str | None = None, +) -> PlotCollection: + """Plot helper for sensitivity analysis results. + + Parameters + ---------- + data : xr.DataArray or xr.Dataset + Sensitivity analysis data to plot. Must have 'sample' and 'sweep' dimensions. + If Dataset, should contain 'x' variable. NO fallback to self.idata. + ... + """ +``` + +**Implementation changes** (lines 1002-1007): +```python +# OLD: +if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError("No sensitivity analysis results found. Run run_sweep() first.") +sa = self.idata.sensitivity_analysis +x = sa["x"] if isinstance(sa, xr.Dataset) else sa + +# NEW: +# Validate input data +if data is None: + raise ValueError( + "data parameter is required for _sensitivity_analysis_plot. " + "This is a helper method that should receive data explicitly." + ) + +# Handle Dataset or DataArray +x = data["x"] if isinstance(data, xr.Dataset) else data +``` + +**Rationale for NO fallback:** +- This is a **private helper method** (prefixed with `_`) +- It should be a pure plotting function that operates on provided data +- The public methods (`sensitivity_analysis()`, `uplift_curve()`, `marginal_curve()`) handle data retrieval from `self.idata` +- This separation of concerns makes the code more testable and maintainable + +#### 5. sensitivity_analysis() - Update to pass data + +**Current implementation** (lines 1071-1116): +```python +def sensitivity_analysis( + self, + hdi_prob: float = 0.94, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + backend: str | None = None, +) -> PlotCollection: + pc = self._sensitivity_analysis_plot( + hdi_prob=hdi_prob, aggregation=aggregation, backend=backend + ) + pc.map(azp.visuals.labelled_y, text="Contribution") + return pc +``` + +**New implementation**: +```python +def sensitivity_analysis( + self, + data: xr.DataArray | xr.Dataset | None = None, # ← ADD THIS + hdi_prob: float = 0.94, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + backend: str | None = None, +) -> PlotCollection: + """Plot sensitivity analysis results. + + Parameters + ---------- + data : xr.DataArray or xr.Dataset, optional + Sensitivity analysis data to plot. If None, uses self.idata.sensitivity_analysis. + ... + """ + # Retrieve data if not provided + if data is None: + if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError( + "No sensitivity analysis results found in 'self.idata' and no 'data' argument provided. " + "Run 'mmm.sensitivity.run_sweep()' first or provide 'data'." + ) + data = self.idata.sensitivity_analysis # type: ignore + + pc = self._sensitivity_analysis_plot( + data=data, # ← PASS DATA + hdi_prob=hdi_prob, + aggregation=aggregation, + backend=backend, + ) + pc.map(azp.visuals.labelled_y, text="Contribution") + return pc +``` + +#### 6. uplift_curve() - Update to pass data + +**Current implementation** (lines 1158-1193): +```python +def uplift_curve( + self, + hdi_prob: float = 0.94, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + backend: str | None = None, +) -> PlotCollection: + if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError(...) + + sa_group = self.idata.sensitivity_analysis + if isinstance(sa_group, xr.Dataset): + if "uplift_curve" not in sa_group: + raise ValueError(...) + data_var = sa_group["uplift_curve"] + else: + raise ValueError(...) + + # Monkey-patch approach with temporary swap + tmp_idata = xr.Dataset({"x": data_var}) + original_group = self.idata.sensitivity_analysis + try: + self.idata.sensitivity_analysis = tmp_idata + pc = self._sensitivity_analysis_plot(...) + ... + finally: + self.idata.sensitivity_analysis = original_group +``` + +**New implementation**: +```python +def uplift_curve( + self, + data: xr.DataArray | xr.Dataset | None = None, # ← ADD THIS + hdi_prob: float = 0.94, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + backend: str | None = None, +) -> PlotCollection: + """Plot precomputed uplift curves. + + Parameters + ---------- + data : xr.DataArray or xr.Dataset, optional + Uplift curve data to plot. If Dataset, should contain 'uplift_curve' variable. + If None, uses self.idata.sensitivity_analysis['uplift_curve']. + ... + """ + # Retrieve data if not provided + if data is None: + if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError( + "No sensitivity analysis results found in 'self.idata' and no 'data' argument provided. " + "Run 'mmm.sensitivity.run_sweep()' first or provide 'data'." + ) + + sa_group = self.idata.sensitivity_analysis # type: ignore + if isinstance(sa_group, xr.Dataset): + if "uplift_curve" not in sa_group: + raise ValueError( + "Expected 'uplift_curve' in idata.sensitivity_analysis. " + "Use SensitivityAnalysis.compute_uplift_curve_respect_to_base(..., extend_idata=True)." + ) + data = sa_group["uplift_curve"] + else: + raise ValueError( + "sensitivity_analysis does not contain 'uplift_curve'. Did you persist it to idata?" + ) + + # Handle Dataset input + if isinstance(data, xr.Dataset): + if "uplift_curve" in data: + data = data["uplift_curve"] + elif "x" in data: + data = data["x"] + else: + raise ValueError("Dataset must contain 'uplift_curve' or 'x' variable.") + + # Call helper with data (no more monkey-patching!) + pc = self._sensitivity_analysis_plot( + data=data, # ← PASS DATA DIRECTLY + hdi_prob=hdi_prob, + aggregation=aggregation, + backend=backend, + ) + pc.map(azp.visuals.labelled_y, text="Uplift (%)") + return pc +``` + +#### 7. marginal_curve() - Update to pass data + +**Current implementation** (lines 1237-1271): +```python +def marginal_curve( + self, + hdi_prob: float = 0.94, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + backend: str | None = None, +) -> PlotCollection: + if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError(...) + + sa_group = self.idata.sensitivity_analysis + # Similar monkey-patching as uplift_curve +``` + +**New implementation**: +```python +def marginal_curve( + self, + data: xr.DataArray | xr.Dataset | None = None, # ← ADD THIS + hdi_prob: float = 0.94, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + backend: str | None = None, +) -> PlotCollection: + """Plot precomputed marginal effects. + + Parameters + ---------- + data : xr.DataArray or xr.Dataset, optional + Marginal effects data to plot. If Dataset, should contain 'marginal_effects' variable. + If None, uses self.idata.sensitivity_analysis['marginal_effects']. + ... + """ + # Retrieve data if not provided + if data is None: + if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError( + "No sensitivity analysis results found in 'self.idata' and no 'data' argument provided. " + "Run 'mmm.sensitivity.run_sweep()' first or provide 'data'." + ) + + sa_group = self.idata.sensitivity_analysis # type: ignore + if isinstance(sa_group, xr.Dataset): + if "marginal_effects" not in sa_group: + raise ValueError( + "Expected 'marginal_effects' in idata.sensitivity_analysis. " + "Use SensitivityAnalysis.compute_marginal_effects(..., extend_idata=True)." + ) + data = sa_group["marginal_effects"] + else: + raise ValueError( + "sensitivity_analysis does not contain 'marginal_effects'. Did you persist it to idata?" + ) + + # Handle Dataset input + if isinstance(data, xr.Dataset): + if "marginal_effects" in data: + data = data["marginal_effects"] + elif "x" in data: + data = data["x"] + else: + raise ValueError("Dataset must contain 'marginal_effects' or 'x' variable.") + + # Call helper with data (no more monkey-patching!) + pc = self._sensitivity_analysis_plot( + data=data, # ← PASS DATA DIRECTLY + hdi_prob=hdi_prob, + aggregation=aggregation, + backend=backend, + ) + pc.map(azp.visuals.labelled_y, text="Marginal Effect") + return pc +``` + +### Benefits of This Standardization + +1. **Consistency**: All methods follow the same pattern for data handling +2. **Flexibility**: Users can pass external data or use data from self.idata +3. **Testability**: Methods can be tested with mock data without needing full MMM setup +4. **Separation of Concerns**: `_sensitivity_analysis_plot()` is a pure plotting function +5. **No More Monkey-Patching**: The uplift_curve() and marginal_curve() methods no longer need to temporarily swap self.idata +6. **Better Error Messages**: Clear messages when data is missing + +### Implementation Priority + +This is **CRITICAL** and should be completed as **Priority 0** (along with file renaming) because: +- It's a fundamental API design issue +- It affects multiple methods +- It's easier to fix before the migration is complete +- Tests need to be written against the correct API + ## Recommendations -### Priority 0: File Renaming (Must Complete First) +### Priority 0: File Renaming and Data Parameter Standardization (Must Complete First) 0. **Rename files and classes** ✅ 30 minutes - Rename `pymc_marketing/mmm/old_plot.py` to `legacy_plot.py` @@ -1268,25 +1801,40 @@ def method_name(...) -> PlotCollection: - Update any imports in existing code/tests - **This must be done BEFORE implementing other changes** +0b. **Data Parameter Standardization** ✅ 4 hours + - Update `contributions_over_time()` to accept `data` parameter with fallback + - Update `saturation_scatterplot()` to accept `constant_data` and `posterior_data` parameters with fallback + - Update `saturation_curves()` to accept and pass `constant_data` and `posterior_data` parameters + - Update `_sensitivity_analysis_plot()` to accept `data` parameter WITHOUT fallback (REQUIRED parameter) + - Update `sensitivity_analysis()` to accept and pass `data` parameter with fallback + - Update `uplift_curve()` to accept and pass `data` parameter with fallback (removes monkey-patching) + - Update `marginal_curve()` to accept and pass `data` parameter with fallback (removes monkey-patching) + - **This is critical for API consistency and must be done BEFORE writing tests** + ### Priority 1: Critical (Must Complete for PR) -1. **Add backward compatibility flag** ✅ 2 hours +1. **Remove deprecated method from new suite** ✅ 15 minutes + - Delete `saturation_curves_scatter()` from [pymc_marketing/mmm/plot.py:737-771](pymc_marketing/mmm/plot.py#L737-L771) + - Keep it in LegacyMMMPlotSuite (will be in legacy_plot.py after renaming) + - Document in migration guide that deprecated methods are not carried forward to v2 + +2. **Add backward compatibility flag** ✅ 2 hours - Modify `config.py` to add `"plot.use_v2": False` - Implement version switching in `multidimensional.py:602-607` - Import from `legacy_plot` module - Add deprecation warning with migration guide link - Test manual switching works -2. **Create comprehensive backend testing for new suite** ✅ 6 hours +3. **Create comprehensive backend testing for new suite** ✅ 6 hours - Rename existing test_plot.py to test_legacy_plot.py - Update imports in legacy test file to use legacy_plot module - CREATE NEW test_plot.py for the new MMMPlotSuite - - Write ~9 methods × 3 backends = ~27 parametrized tests + - Write ~8 methods × 3 backends = ~24 parametrized tests (note: saturation_curves_scatter removed) - Remove experimental test_plot_backends.py file - Add backend override and invalid backend tests - Verify all new tests pass -3. **Create compatibility test suite** ✅ 3 hours +4. **Create compatibility test suite** ✅ 3 hours - Create `test_plot_compatibility.py` - Test version switching (5 tests) - Test deprecation warnings (4 tests) @@ -1296,13 +1844,13 @@ def method_name(...) -> PlotCollection: ### Priority 2: Important (Before Merge) -4. **Update documentation** ⏱️ 4 hours +5. **Update documentation** ⏱️ 4 hours - Update method docstrings with PlotCollection info - Add version directives (.. versionadded::) - Document backend parameter - Add usage examples -5. **Write migration guide** ⏱️ 6 hours +6. **Write migration guide** ⏱️ 6 hours - Create `docs/source/guides/mmm_plotting_migration.rst` - Document all breaking changes (including parameter type changes) - Provide side-by-side examples @@ -1385,22 +1933,35 @@ def method_name(...) -> PlotCollection: ### Phase 1: Code Changes - [ ] **Rename `old_plot.py` to `legacy_plot.py` and `OldMMMPlotSuite` to `LegacyMMMPlotSuite`** +- [ ] **Remove deprecated method from new suite:** + - [ ] Delete `saturation_curves_scatter()` from pymc_marketing/mmm/plot.py (lines 737-771) + - [ ] Keep it in LegacyMMMPlotSuite (legacy_plot.py) for backward compatibility + - [ ] Add note in migration guide about deprecated methods not carried forward to v2 +- [ ] **Data Parameter Standardization (CRITICAL - do before tests):** + - [ ] Update `contributions_over_time()` - add `data` parameter with fallback + - [ ] Update `saturation_scatterplot()` - add `constant_data` and `posterior_data` parameters with fallback + - [ ] Update `saturation_curves()` - add and pass `constant_data` and `posterior_data` parameters + - [ ] Update `_sensitivity_analysis_plot()` - add `data` parameter WITHOUT fallback (REQUIRED) + - [ ] Update `sensitivity_analysis()` - add and pass `data` parameter with fallback + - [ ] Update `uplift_curve()` - add and pass `data` parameter with fallback + - [ ] Update `marginal_curve()` - add and pass `data` parameter with fallback - [ ] Add `"plot.use_v2": False` to config.py defaults - [ ] Modify multidimensional.py `.plot` property with version switching - [ ] Add FutureWarning for legacy suite usage -- [ ] Update all docstrings to document PlotCollection return type +- [ ] Update all docstrings to document PlotCollection return type and new data parameters ### Phase 2: Testing - [ ] **Rename `tests/mmm/test_plot.py` to `test_legacy_plot.py` (tests for legacy suite)** - [ ] **Update imports in renamed test file to use `legacy_plot.LegacyMMMPlotSuite`** - [ ] **Create NEW `tests/mmm/test_plot.py` for new MMMPlotSuite** -- [ ] **Write ~9 methods × 3 backends = ~27 parametrized tests for new suite** +- [ ] **Write ~8 methods × 3 backends = ~24 parametrized tests for new suite** (note: saturation_curves_scatter removed) - [ ] Remove experimental `tests/mmm/test_plot_backends.py` file +- [ ] Remove deprecated `saturation_curves_scatter()` from pymc_marketing/mmm/plot.py - [ ] Add backend override and invalid backend tests - [ ] Create `tests/mmm/test_plot_compatibility.py` (15+ tests) - [ ] Add mock_mmm fixture - [ ] Add mock_allocation_samples fixture -- [ ] Verify all ~27 new suite backend tests pass +- [ ] Verify all ~24 new suite backend tests pass - [ ] Verify all 15 compatibility tests pass - [ ] Test warning suppression works - [ ] Test both suites produce valid output From ef0227621278ccace12eecc98bc3d7d29cac3db9 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Thu, 20 Nov 2025 14:00:25 -0500 Subject: [PATCH 09/29] finished milestone 1.4 --- pymc_marketing/mmm/legacy_plot.py | 1937 ++++++++++++++++++++++ pymc_marketing/mmm/plot.py | 349 +++- tests/mmm/conftest.py | 187 +++ tests/mmm/test_legacy_plot_imports.py | 41 + tests/mmm/test_legacy_plot_regression.py | 56 + tests/mmm/test_plot_data_parameters.py | 120 ++ 6 files changed, 2598 insertions(+), 92 deletions(-) create mode 100644 pymc_marketing/mmm/legacy_plot.py create mode 100644 tests/mmm/conftest.py create mode 100644 tests/mmm/test_legacy_plot_imports.py create mode 100644 tests/mmm/test_legacy_plot_regression.py create mode 100644 tests/mmm/test_plot_data_parameters.py diff --git a/pymc_marketing/mmm/legacy_plot.py b/pymc_marketing/mmm/legacy_plot.py new file mode 100644 index 000000000..590d71a03 --- /dev/null +++ b/pymc_marketing/mmm/legacy_plot.py @@ -0,0 +1,1937 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MMM related plotting class. + +Examples +-------- +Quickstart with MMM: + +.. code-block:: python + + from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation + from pymc_marketing.mmm.multidimensional import MMM + import pandas as pd + + # Minimal dataset + X = pd.DataFrame( + { + "date": pd.date_range("2025-01-01", periods=12, freq="W-MON"), + "C1": [100, 120, 90, 110, 105, 115, 98, 102, 108, 111, 97, 109], + "C2": [80, 70, 95, 85, 90, 88, 92, 94, 91, 89, 93, 87], + } + ) + y = pd.Series( + [230, 260, 220, 240, 245, 255, 235, 238, 242, 246, 233, 249], name="y" + ) + + mmm = MMM( + date_column="date", + channel_columns=["C1", "C2"], + target_column="y", + adstock=GeometricAdstock(l_max=10), + saturation=LogisticSaturation(), + ) + mmm.fit(X, y) + mmm.sample_posterior_predictive(X) + + # Posterior predictive time series + _ = mmm.plot.posterior_predictive(var=["y"], hdi_prob=0.9) + + # Posterior contributions over time (e.g., channel_contribution) + _ = mmm.plot.contributions_over_time(var=["channel_contribution"], hdi_prob=0.9) + + # Channel saturation scatter plot (scaled space by default) + _ = mmm.plot.saturation_scatterplot(original_scale=False) + +Wrap a custom PyMC model +-------- + +Requirements + +- posterior_predictive plots: an `az.InferenceData` with a `posterior_predictive` group + containing the variable(s) you want to plot with a `date` coordinate. +- contributions_over_time plots: a `posterior` group with time‑series variables (with `date`). +- saturation plots: a `constant_data` dataset with variables: + - `channel_data`: dims include `("date", "channel", ...)` + - `channel_scale`: dims include `("channel", ...)` + - `target_scale`: scalar or broadcastable to the curve dims + and a `posterior` variable named `channel_contribution` (or + `channel_contribution_original_scale` if plotting `original_scale=True`). + +.. code-block:: python + + import numpy as np + import pandas as pd + import pymc as pm + from pymc_marketing.mmm.plot import MMMPlotSuite + + dates = pd.date_range("2025-01-01", periods=30, freq="D") + y_obs = np.random.normal(size=len(dates)) + + with pm.Model(coords={"date": dates}): + sigma = pm.HalfNormal("sigma", 1.0) + pm.Normal("y", 0.0, sigma, observed=y_obs, dims="date") + + idata = pm.sample_prior_predictive(random_seed=1) + idata.extend(pm.sample(draws=200, chains=2, tune=200, random_seed=1)) + idata.extend(pm.sample_posterior_predictive(idata, random_seed=1)) + + plot = MMMPlotSuite(idata) + _ = plot.posterior_predictive(var=["y"], hdi_prob=0.9) + +Custom contributions_over_time +-------- + +.. code-block:: python + + import numpy as np + import pandas as pd + import pymc as pm + from pymc_marketing.mmm.plot import MMMPlotSuite + + dates = pd.date_range("2025-01-01", periods=30, freq="D") + x = np.linspace(0, 2 * np.pi, len(dates)) + series = np.sin(x) + + with pm.Model(coords={"date": dates}): + pm.Deterministic("component", series, dims="date") + idata = pm.sample_prior_predictive(random_seed=2) + idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=2)) + + plot = MMMPlotSuite(idata) + _ = plot.contributions_over_time(var=["component"], hdi_prob=0.9) + +Saturation plots with a custom model +-------- + +.. code-block:: python + + import numpy as np + import pandas as pd + import xarray as xr + import pymc as pm + from pymc_marketing.mmm.plot import MMMPlotSuite + + dates = pd.date_range("2025-01-01", periods=20, freq="W-MON") + channels = ["C1", "C2"] + + # Create constant_data required for saturation plots + channel_data = xr.DataArray( + np.random.rand(len(dates), len(channels)), + dims=("date", "channel"), + coords={"date": dates, "channel": channels}, + name="channel_data", + ) + channel_scale = xr.DataArray( + np.ones(len(channels)), + dims=("channel",), + coords={"channel": channels}, + name="channel_scale", + ) + target_scale = xr.DataArray(1.0, name="target_scale") + + # Build a toy model that yields a matching posterior var + with pm.Model(coords={"date": dates, "channel": channels}): + # A fake contribution over time per channel (dims must include date & channel) + contrib = pm.Normal("channel_contribution", 0.0, 1.0, dims=("date", "channel")) + + idata = pm.sample_prior_predictive(random_seed=3) + idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=3)) + + # Attach constant_data to idata + idata.constant_data = xr.Dataset( + { + "channel_data": channel_data, + "channel_scale": channel_scale, + "target_scale": target_scale, + } + ) + + plot = MMMPlotSuite(idata) + _ = plot.saturation_scatterplot(original_scale=False) + +Notes +----- +- `MMM` exposes this suite via the `mmm.plot` property, which internally passes the model's + `idata` into `MMMPlotSuite`. +- Any PyMC model can use `MMMPlotSuite` directly if its `InferenceData` contains the needed + groups/variables described above. +""" + +import itertools +from collections.abc import Iterable +from typing import Any + +import arviz as az +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from numpy.typing import NDArray + +WIDTH_PER_COL: float = 10.0 +HEIGHT_PER_ROW: float = 4.0 + + +class LegacyMMMPlotSuite: + """Legacy matplotlib-based MMM plotting suite. + + .. deprecated:: 0.18.0 + This class will be removed in v0.20.0. Use MMMPlotSuite with + mmm_config["plot.use_v2"] = True for the new arviz_plots-based suite. + + This class is maintained for backward compatibility but will be removed + in a future release. The new MMMPlotSuite supports multiple backends + (matplotlib, plotly, bokeh) and returns PlotCollection objects. + + Provides methods for visualizing the posterior predictive distribution, + contributions over time, and saturation curves for a Media Mix Model. + """ + + def __init__( + self, + idata: xr.Dataset | az.InferenceData, + ): + self.idata = idata + + def _init_subplots( + self, + n_subplots: int, + ncols: int = 1, + width_per_col: float = 10.0, + height_per_row: float = 4.0, + ) -> tuple[Figure, NDArray[Axes]]: + """Initialize a grid of subplots. + + Parameters + ---------- + n_subplots : int + Number of rows (if ncols=1) or total subplots. + ncols : int + Number of columns in the subplot grid. + width_per_col : float + Width (in inches) for each column of subplots. + height_per_row : float + Height (in inches) for each row of subplots. + + Returns + ------- + fig : matplotlib.figure.Figure + The created Figure object. + axes : np.ndarray of matplotlib.axes.Axes + 2D array of axes of shape (n_subplots, ncols). + """ + fig, axes = plt.subplots( + nrows=n_subplots, + ncols=ncols, + figsize=(width_per_col * ncols, height_per_row * n_subplots), + squeeze=False, + ) + return fig, axes + + def _build_subplot_title( + self, + dims: list[str], + combo: tuple, + fallback_title: str = "Time Series", + ) -> str: + """Build a subplot title string from dimension names and their values.""" + if dims: + title_parts = [f"{d}={v}" for d, v in zip(dims, combo, strict=False)] + return ", ".join(title_parts) + return fallback_title + + def _get_additional_dim_combinations( + self, + data: xr.Dataset, + variable: str, + ignored_dims: set[str], + ) -> tuple[list[str], list[tuple]]: + """Identify dimensions to plot over and get their coordinate combinations.""" + if variable not in data: + raise ValueError(f"Variable '{variable}' not found in the dataset.") + + all_dims = list(data[variable].dims) + additional_dims = [d for d in all_dims if d not in ignored_dims] + + if additional_dims: + additional_coords = [data.coords[d].values for d in additional_dims] + dim_combinations = list(itertools.product(*additional_coords)) + else: + # If no extra dims, just treat as a single combination + dim_combinations = [()] + + return additional_dims, dim_combinations + + def _reduce_and_stack( + self, data: xr.DataArray, dims_to_ignore: set[str] | None = None + ) -> xr.DataArray: + """Sum over leftover dims and stack chain+draw into sample if present.""" + if dims_to_ignore is None: + dims_to_ignore = {"date", "chain", "draw", "sample"} + + leftover_dims = [d for d in data.dims if d not in dims_to_ignore] + if leftover_dims: + data = data.sum(dim=leftover_dims) + + # Combine chain+draw into 'sample' if both exist + if "chain" in data.dims and "draw" in data.dims: + data = data.stack(sample=("chain", "draw")) + + return data + + def _get_posterior_predictive_data( + self, + idata: xr.Dataset | None, + ) -> xr.Dataset: + """Retrieve the posterior_predictive group from either provided or self.idata.""" + if idata is not None: + return idata + + # Otherwise, check if self.idata has posterior_predictive + if ( + not hasattr(self.idata, "posterior_predictive") # type: ignore + or self.idata.posterior_predictive is None # type: ignore + ): + raise ValueError( + "No posterior_predictive data found in 'self.idata'. " + "Please run 'MMM.sample_posterior_predictive()' or provide " + "an external 'idata' argument." + ) + return self.idata.posterior_predictive # type: ignore + + def _add_median_and_hdi( + self, ax: Axes, data: xr.DataArray, var: str, hdi_prob: float = 0.85 + ) -> Axes: + """Add median and HDI to the given axis.""" + median = data.median(dim="sample") if "sample" in data.dims else data.median() + hdi = az.hdi( + data, + hdi_prob=hdi_prob, + input_core_dims=[["sample"]] if "sample" in data.dims else None, + ) + + if "date" not in data.dims: + raise ValueError(f"Expected 'date' dimension in {var}, but none found.") + dates = data.coords["date"].values + # Add median and HDI to the plot + ax.plot(dates, median, label=var, alpha=0.9) + ax.fill_between(dates, hdi[var][..., 0], hdi[var][..., 1], alpha=0.2) + return ax + + def _validate_dims( + self, + dims: dict[str, str | int | list], + all_dims: list[str], + ) -> None: + """Validate that provided dims exist in the model's dimensions and values.""" + if dims: + for key, val in dims.items(): + if key not in all_dims: + raise ValueError( + f"Dimension '{key}' not found in idata dimensions." + ) + valid_values = self.idata.posterior.coords[key].values + if isinstance(val, (list, tuple, np.ndarray)): + for v in val: + if v not in valid_values: + raise ValueError( + f"Value '{v}' not found in dimension '{key}'." + ) + else: + if val not in valid_values: + raise ValueError( + f"Value '{val}' not found in dimension '{key}'." + ) + + def _dim_list_handler( + self, dims: dict[str, str | int | list] | None + ) -> tuple[list[str], list[tuple]]: + """Extract keys, values, and all combinations for list-valued dims.""" + dims_lists = { + k: v + for k, v in (dims or {}).items() + if isinstance(v, (list, tuple, np.ndarray)) + } + if dims_lists: + dims_keys = list(dims_lists.keys()) + dims_values = [ + v if isinstance(v, (list, tuple, np.ndarray)) else [v] + for v in dims_lists.values() + ] + dims_combos = list(itertools.product(*dims_values)) + else: + dims_keys = [] + dims_combos = [()] + return dims_keys, dims_combos + + def _resolve_backend(self, backend: str | None) -> str: + """Resolve backend parameter to actual backend string.""" + from pymc_marketing.mmm.config import mmm_config + + return backend or mmm_config["plot.backend"] + + # ------------------------------------------------------------------------ + # Main Plotting Methods + # ------------------------------------------------------------------------ + + def posterior_predictive( + self, + var: list[str] | None = None, + idata: xr.Dataset | None = None, + hdi_prob: float = 0.85, + ) -> tuple[Figure, NDArray[Axes]]: + """Plot time series from the posterior predictive distribution. + + By default, if both `var` and `idata` are not provided, uses + `self.idata.posterior_predictive` and defaults the variable to `["y"]`. + + Parameters + ---------- + var : list of str, optional + A list of variable names to plot. Default is ["y"] if not provided. + idata : xarray.Dataset, optional + The posterior predictive dataset to plot. If not provided, tries to + use `self.idata.posterior_predictive`. + hdi_prob: float, optional + The probability mass of the highest density interval to be displayed. Default is 0.85. + + Returns + ------- + fig : matplotlib.figure.Figure + The Figure object containing the subplots. + axes : np.ndarray of matplotlib.axes.Axes + Array of Axes objects corresponding to each subplot row. + + Raises + ------ + ValueError + If no `idata` is provided and `self.idata.posterior_predictive` does + not exist, instructing the user to run `MMM.sample_posterior_predictive()`. + If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value. + """ + if not 0 < hdi_prob < 1: + raise ValueError("HDI probability must be between 0 and 1.") + # 1. Retrieve or validate posterior_predictive data + pp_data = self._get_posterior_predictive_data(idata) + + # 2. Determine variables to plot + if var is None: + var = ["y"] + main_var = var[0] + + # 3. Identify additional dims & get all combos + ignored_dims = {"chain", "draw", "date", "sample"} + additional_dims, dim_combinations = self._get_additional_dim_combinations( + data=pp_data, variable=main_var, ignored_dims=ignored_dims + ) + + # 4. Prepare subplots + fig, axes = self._init_subplots(n_subplots=len(dim_combinations), ncols=1) + + # 5. Loop over dimension combinations + for row_idx, combo in enumerate(dim_combinations): + ax = axes[row_idx][0] + + # Build indexers + indexers = ( + dict(zip(additional_dims, combo, strict=False)) + if additional_dims + else {} + ) + + # 6. Plot each requested variable + for v in var: + if v not in pp_data: + raise ValueError( + f"Variable '{v}' not in the posterior_predictive dataset." + ) + + data = pp_data[v].sel(**indexers) + # Sum leftover dims, stack chain+draw if needed + data = self._reduce_and_stack(data, ignored_dims) + ax = self._add_median_and_hdi(ax, data, v, hdi_prob=hdi_prob) + + # 7. Subplot title & labels + title = self._build_subplot_title( + dims=additional_dims, + combo=combo, + fallback_title="Posterior Predictive Time Series", + ) + ax.set_title(title) + ax.set_xlabel("Date") + ax.set_ylabel("Posterior Predictive") + ax.legend(loc="best") + + return fig, axes + + def contributions_over_time( + self, + var: list[str], + hdi_prob: float = 0.85, + dims: dict[str, str | int | list] | None = None, + ) -> tuple[Figure, NDArray[Axes]]: + """Plot the time-series contributions for each variable in `var`. + + showing the median and the credible interval (default 85%). + Creates one subplot per combination of non-(chain/draw/date) dimensions + and places all variables on the same subplot. + + Parameters + ---------- + var : list of str + A list of variable names to plot from the posterior. + hdi_prob: float, optional + The probability mass of the highest density interval to be displayed. Default is 0.85. + dims : dict[str, str | int | list], optional + Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. + If provided, only the selected slice(s) will be plotted. + + Returns + ------- + fig : matplotlib.figure.Figure + The Figure object containing the subplots. + axes : np.ndarray of matplotlib.axes.Axes + Array of Axes objects corresponding to each subplot row. + + Raises + ------ + ValueError + If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value. + """ + if not 0 < hdi_prob < 1: + raise ValueError("HDI probability must be between 0 and 1.") + + if not hasattr(self.idata, "posterior"): + raise ValueError( + "No posterior data found in 'self.idata'. " + "Please ensure 'self.idata' contains a 'posterior' group." + ) + + main_var = var[0] + all_dims = list(self.idata.posterior[main_var].dims) # type: ignore + ignored_dims = {"chain", "draw", "date"} + additional_dims = [d for d in all_dims if d not in ignored_dims] + + coords = { + key: value.to_numpy() + for key, value in self.idata.posterior[var].coords.items() + } + + # Apply user-specified filters (`dims`) + if dims: + self._validate_dims(dims=dims, all_dims=all_dims) + # Remove filtered dims from the combinations + additional_dims = [d for d in additional_dims if d not in dims] + else: + self._validate_dims({}, all_dims) + # additional_dims = [d for d in additional_dims if d not in dims] + + # Identify combos for remaining dims + if additional_dims: + additional_coords = [ + self.idata.posterior.coords[dim].values # type: ignore + for dim in additional_dims + ] + dim_combinations = list(itertools.product(*additional_coords)) + else: + dim_combinations = [()] + + # If dims contains lists, build all combinations for those as well + dims_keys, dims_combos = self._dim_list_handler(dims) + + # Prepare subplots: one for each combo of dims_lists and additional_dims + total_combos = list(itertools.product(dims_combos, dim_combinations)) + fig, axes = self._init_subplots(len(total_combos), ncols=1) + + for row_idx, (dims_combo, addl_combo) in enumerate(total_combos): + ax = axes[row_idx][0] + # Build indexers for dims and additional_dims + indexers = ( + dict(zip(additional_dims, addl_combo, strict=False)) + if additional_dims + else {} + ) + if dims: + # For dims with lists, use the current value from dims_combo + for i, k in enumerate(dims_keys): + indexers[k] = dims_combo[i] + # For dims with single values, use as is + for k, v in (dims or {}).items(): + if k not in dims_keys: + indexers[k] = v + + # Plot posterior median and HDI for each var + for v in var: + data = self.idata.posterior[v] + missing_coords = { + key: value for key, value in coords.items() if key not in data.dims + } + data = data.expand_dims(**missing_coords) + data = data.sel(**indexers) # apply slice + data = self._reduce_and_stack( + data, dims_to_ignore={"date", "chain", "draw", "sample"} + ) + ax = self._add_median_and_hdi(ax, data, v, hdi_prob=hdi_prob) + + # Title includes both fixed and combo dims + title_dims = ( + list(dims.keys()) + additional_dims if dims else additional_dims + ) + title_combo = tuple(indexers[k] for k in title_dims) + + title = self._build_subplot_title( + dims=title_dims, combo=title_combo, fallback_title="Time Series" + ) + ax.set_title(title) + ax.set_xlabel("Date") + ax.set_ylabel("Posterior Value") + ax.legend(loc="best") + + return fig, axes + + def saturation_scatterplot( + self, + original_scale: bool = False, + dims: dict[str, str | int | list] | None = None, + **kwargs, + ) -> tuple[Figure, NDArray[Axes]]: + """Plot the saturation curves for each channel. + + Creates a grid of subplots for each combination of channel and non-(date/channel) dimensions. + Optionally, subset by dims (single values or lists). + Each channel will have a consistent color across all subplots. + """ + if not hasattr(self.idata, "constant_data"): + raise ValueError( + "No 'constant_data' found in 'self.idata'. " + "Please ensure 'self.idata' contains the constant_data group." + ) + + # Identify additional dimensions beyond 'date' and 'channel' + cdims = self.idata.constant_data.channel_data.dims + additional_dims = [dim for dim in cdims if dim not in ("date", "channel")] + + # Validate dims and remove filtered dims from additional_dims + if dims: + self._validate_dims(dims, list(self.idata.constant_data.channel_data.dims)) + additional_dims = [d for d in additional_dims if d not in dims] + else: + self._validate_dims({}, list(self.idata.constant_data.channel_data.dims)) + + # Build all combinations for dims with lists + dims_keys, dims_combos = self._dim_list_handler(dims) + + # Build all combinations for remaining dims + if additional_dims: + additional_coords = [ + self.idata.constant_data.coords[d].values for d in additional_dims + ] + additional_combinations = list(itertools.product(*additional_coords)) + else: + additional_combinations = [()] + + channels = self.idata.constant_data.coords["channel"].values + n_channels = len(channels) + n_addl = len(additional_combinations) + n_dims = len(dims_combos) + + # For most use cases, n_dims will be 1, so grid is channels x additional_combinations + # If dims_combos > 1, treat as extra axis (rare, but possible) + nrows = n_channels + ncols = n_addl * n_dims + total_combos = list( + itertools.product(channels, dims_combos, additional_combinations) + ) + n_subplots = len(total_combos) + + # Assign a color to each channel + channel_colors = {ch: f"C{i}" for i, ch in enumerate(channels)} + + # Prepare subplots as a grid + fig, axes = plt.subplots( + nrows=nrows, + ncols=ncols, + figsize=( + kwargs.get("width_per_col", 8) * ncols, + kwargs.get("height_per_row", 4) * nrows, + ), + squeeze=False, + ) + + channel_contribution = ( + "channel_contribution_original_scale" + if original_scale + else "channel_contribution" + ) + + if original_scale and not hasattr(self.idata.posterior, channel_contribution): + raise ValueError( + f"""No posterior.{channel_contribution} data found in 'self.idata'. \n + Add a original scale deterministic:\n + mmm.add_original_scale_contribution_variable(\n + var=[\n + \"channel_contribution\",\n + ...\n + ]\n + )\n + """ + ) + + for _idx, (channel, dims_combo, addl_combo) in enumerate(total_combos): + # Compute subplot position + row = list(channels).index(channel) + # If dims_combos > 1, treat as extra axis (columns: addl * dims) + if n_dims > 1: + col = list(additional_combinations).index(addl_combo) * n_dims + list( + dims_combos + ).index(dims_combo) + else: + col = list(additional_combinations).index(addl_combo) + ax = axes[row][col] + + # Build indexers for dims and additional_dims + indexers = ( + dict(zip(additional_dims, addl_combo, strict=False)) + if additional_dims + else {} + ) + if dims: + for i, k in enumerate(dims_keys): + indexers[k] = dims_combo[i] + for k, v in (dims or {}).items(): + if k not in dims_keys: + indexers[k] = v + indexers["channel"] = channel + + # Select X data (constant_data) + x_data = self.idata.constant_data.channel_data.sel(**indexers) + # Select Y data (posterior contributions) and scale if needed + y_data = self.idata.posterior[channel_contribution].sel(**indexers) + y_data = y_data.mean(dim=[d for d in y_data.dims if d in ("chain", "draw")]) + x_data = x_data.broadcast_like(y_data) + y_data = y_data.broadcast_like(x_data) + ax.scatter( + x_data.values.flatten(), + y_data.values.flatten(), + alpha=0.8, + color=channel_colors[channel], + label=str(channel), + ) + # Build subplot title + title_dims = ( + ["channel"] + (list(dims.keys()) if dims else []) + additional_dims + ) + title_combo = ( + channel, + *[indexers[k] for k in title_dims if k != "channel"], + ) + title = self._build_subplot_title( + dims=title_dims, + combo=title_combo, + fallback_title="Channel Saturation Curve", + ) + ax.set_title(title) + ax.set_xlabel("Channel Data (X)") + ax.set_ylabel("Channel Contributions (Y)") + ax.legend(loc="best") + + # Hide any unused axes (if grid is larger than needed) + for i in range(nrows): + for j in range(ncols): + if i * ncols + j >= n_subplots: + axes[i][j].set_visible(False) + + return fig, axes + + def saturation_curves( + self, + curve: xr.DataArray, + original_scale: bool = False, + n_samples: int = 10, + hdi_probs: float | list[float] | None = None, + random_seed: np.random.Generator | None = None, + colors: Iterable[str] | None = None, + subplot_kwargs: dict | None = None, + rc_params: dict | None = None, + dims: dict[str, str | int | list] | None = None, + **plot_kwargs, + ) -> tuple[plt.Figure, np.ndarray]: + """ + Overlay saturation‑curve scatter‑plots with posterior‑predictive sample curves and HDI bands. + + **allowing** you to customize figsize and font sizes. + + Parameters + ---------- + curve : xr.DataArray + Posterior‑predictive curves (e.g. dims `("chain","draw","x","channel","geo")`). + original_scale : bool, default=False + Plot `channel_contribution_original_scale` if True, else `channel_contribution`. + n_samples : int, default=10 + Number of sample‑curves per subplot. + hdi_probs : float or list of float, optional + Credible interval probabilities (e.g. 0.94 or [0.5, 0.94]). + If None, uses ArviZ's default (0.94). + random_seed : np.random.Generator, optional + RNG for reproducible sampling. If None, uses `np.random.default_rng()`. + colors : iterable of str, optional + Colors for the sample & HDI plots. + subplot_kwargs : dict, optional + Passed to `plt.subplots` (e.g. `{"figsize": (10,8)}`). + Merged with the function's own default sizing. + rc_params : dict, optional + Temporary `matplotlib.rcParams` for this plot. + Example keys: `"xtick.labelsize"`, `"ytick.labelsize"`, + `"axes.labelsize"`, `"axes.titlesize"`. + dims : dict[str, str | int | list], optional + Dimension filters to apply. Example: {"country": ["US", "UK"], "region": "X"}. + If provided, only the selected slice(s) will be plotted. + **plot_kwargs + Any other kwargs forwarded to `plot_curve` + (for instance `same_axes=True`, `legend=True`, etc.). + + Returns + ------- + fig : plt.Figure + Matplotlib figure with your grid. + axes : np.ndarray of plt.Axes + Array of shape `(n_channels, n_geo)`. + """ + from pymc_marketing.plot import plot_hdi, plot_samples + + if not hasattr(self.idata, "constant_data"): + raise ValueError( + "No 'constant_data' found in 'self.idata'. " + "Please ensure 'self.idata' contains the constant_data group." + ) + + contrib_var = ( + "channel_contribution_original_scale" + if original_scale + else "channel_contribution" + ) + + if original_scale and not hasattr(self.idata.posterior, contrib_var): + raise ValueError( + f"""No posterior.{contrib_var} data found in 'self.idata'.\n" + "Add a original scale deterministic:\n" + " mmm.add_original_scale_contribution_variable(\n" + " var=[\n" + " 'channel_contribution',\n" + " ...\n" + " ]\n" + " )\n" + """ + ) + curve_data = ( + curve * self.idata.constant_data.target_scale if original_scale else curve + ) + curve_data = curve_data.rename("saturation_curve") + + # — 1. figure out grid shape based on scatter data dimensions / identify dims and combos + cdims = self.idata.constant_data.channel_data.dims + all_dims = list(cdims) + additional_dims = [d for d in cdims if d not in ("date", "channel")] + # Validate dims and remove filtered dims from additional_dims + if dims: + self._validate_dims(dims, all_dims) + additional_dims = [d for d in additional_dims if d not in dims] + else: + self._validate_dims({}, all_dims) + # Build all combinations for dims with lists + dims_keys, dims_combos = self._dim_list_handler(dims) + # Build all combinations for remaining dims + if additional_dims: + additional_coords = [ + self.idata.constant_data.coords[d].values for d in additional_dims + ] + additional_combinations = list(itertools.product(*additional_coords)) + else: + additional_combinations = [()] + channels = self.idata.constant_data.coords["channel"].values + n_channels = len(channels) + n_addl = len(additional_combinations) + n_dims = len(dims_combos) + nrows = n_channels + ncols = n_addl * n_dims + total_combos = list( + itertools.product(channels, dims_combos, additional_combinations) + ) + n_subplots = len(total_combos) + + # — 2. merge subplot_kwargs — + user_subplot = subplot_kwargs or {} + + # Handle user-specified ncols/nrows + if "ncols" in user_subplot: + # User specified ncols, calculate nrows + ncols = user_subplot["ncols"] + nrows = int(np.ceil(n_subplots / ncols)) + user_subplot.pop("ncols") # Remove to avoid conflict + elif "nrows" in user_subplot: + # User specified nrows, calculate ncols + nrows = user_subplot["nrows"] + ncols = int(np.ceil(n_subplots / nrows)) + user_subplot.pop("nrows") # Remove to avoid conflict + default_subplot = {"figsize": (ncols * 4, nrows * 3)} + subkw = {**default_subplot, **user_subplot} + # — 3. create subplots ourselves — + rc_params = rc_params or {} + with plt.rc_context(rc_params): + fig, axes = plt.subplots(nrows=nrows, ncols=ncols, **subkw) + # ensure a 2D array + if nrows == 1 and ncols == 1: + axes = np.array([[axes]]) + elif nrows == 1: + axes = axes.reshape(1, -1) + elif ncols == 1: + axes = axes.reshape(-1, 1) + # Flatten axes for easier iteration + axes_flat = axes.flatten() + if colors is None: + colors = [f"C{i}" for i in range(n_channels)] + elif not isinstance(colors, list): + colors = list(colors) + subplot_idx = 0 + for _idx, (ch, dims_combo, addl_combo) in enumerate(total_combos): + if subplot_idx >= len(axes_flat): + break + ax = axes_flat[subplot_idx] + subplot_idx += 1 + # Build indexers for dims and additional_dims + indexers = ( + dict(zip(additional_dims, addl_combo, strict=False)) + if additional_dims + else {} + ) + if dims: + for i, k in enumerate(dims_keys): + indexers[k] = dims_combo[i] + for k, v in (dims or {}).items(): + if k not in dims_keys: + indexers[k] = v + indexers["channel"] = ch + # Select and broadcast curve data for this channel + curve_idx = { + dim: val for dim, val in indexers.items() if dim in curve_data.dims + } + subplot_curve = curve_data.sel(**curve_idx) + if original_scale: + valid_idx = { + k: v + for k, v in indexers.items() + if k in self.idata.constant_data.channel_scale.dims + } + channel_scale = self.idata.constant_data.channel_scale.sel(**valid_idx) + x_original = subplot_curve.coords["x"] * channel_scale + subplot_curve = subplot_curve.assign_coords(x=x_original) + if n_samples > 0: + plot_samples( + subplot_curve, + non_grid_names="x", + n=n_samples, + rng=random_seed, + axes=np.array([[ax]]), + colors=[colors[list(channels).index(ch)]], + same_axes=False, + legend=False, + **plot_kwargs, + ) + if hdi_probs is not None: + # Robustly handle hdi_probs as float, list, tuple, or np.ndarray + if isinstance(hdi_probs, (float, int)): + hdi_probs_iter = [hdi_probs] + elif isinstance(hdi_probs, (list, tuple, np.ndarray)): + hdi_probs_iter = hdi_probs + else: + raise TypeError( + "hdi_probs must be a float, list, tuple, or np.ndarray" + ) + for hdi_prob in hdi_probs_iter: + plot_hdi( + subplot_curve, + non_grid_names="x", + hdi_prob=hdi_prob, + axes=np.array([[ax]]), + colors=[colors[list(channels).index(ch)]], + same_axes=False, + legend=False, + **plot_kwargs, + ) + x_data = self.idata.constant_data.channel_data.sel(**indexers) + y = ( + self.idata.posterior[contrib_var] + .sel(**indexers) + .mean( + dim=[ + d + for d in self.idata.posterior[contrib_var].dims + if d in ("chain", "draw") + ] + ) + ) + x_data, y = x_data.broadcast_like(y), y.broadcast_like(x_data) + ax.scatter( + x_data.values.flatten(), + y.values.flatten(), + alpha=0.8, + color=colors[list(channels).index(ch)], + ) + title_dims = ( + ["channel"] + (list(dims.keys()) if dims else []) + additional_dims + ) + title_combo = ( + ch, + *[indexers[k] for k in title_dims if k != "channel"], + ) + title = self._build_subplot_title( + dims=title_dims, + combo=title_combo, + fallback_title="Channel Saturation Curves", + ) + ax.set_title(title) + ax.set_xlabel("Channel Data (X)") + ax.set_ylabel("Channel Contribution (Y)") + for ax_idx in range(subplot_idx, len(axes_flat)): + axes_flat[ax_idx].set_visible(False) + return fig, axes + + def saturation_curves_scatter( + self, original_scale: bool = False, **kwargs + ) -> tuple[Figure, NDArray[Axes]]: + """ + Plot scatter plots of channel contributions vs. channel data. + + .. deprecated:: 0.1.0 + Will be removed in version 0.2.0. Use :meth:`saturation_scatterplot` instead. + + Parameters + ---------- + channel_contribution : str, optional + Name of the channel contribution variable in the InferenceData. + additional_dims : list[str], optional + Additional dimensions to consider beyond 'channel'. + additional_combinations : list[tuple], optional + Specific combinations of additional dimensions to plot. + **kwargs + Additional keyword arguments passed to _init_subplots. + + Returns + ------- + fig : plt.Figure + The matplotlib figure. + axes : np.ndarray + Array of matplotlib axes. + """ + import warnings + + warnings.warn( + "saturation_curves_scatter is deprecated and will be removed in version 0.2.0. " + "Use saturation_scatterplot instead.", + DeprecationWarning, + stacklevel=2, + ) + # Note: channel_contribution, additional_dims, and additional_combinations + # are not used by saturation_scatterplot, so we don't pass them + return self.saturation_scatterplot(original_scale=original_scale, **kwargs) + + def budget_allocation( + self, + samples: xr.Dataset, + scale_factor: float | None = None, + figsize: tuple[float, float] = (12, 6), + ax: plt.Axes | None = None, + original_scale: bool = True, + dims: dict[str, str | int | list] | None = None, + ) -> tuple[Figure, plt.Axes] | tuple[Figure, np.ndarray]: + """Plot the budget allocation and channel contributions. + + Creates a bar chart comparing allocated spend and channel contributions + for each channel. If additional dimensions besides 'channel' are present, + creates a subplot for each combination of these dimensions. + + Parameters + ---------- + samples : xr.Dataset + The dataset containing the channel contributions and allocation values. + Expected to have 'channel_contribution' and 'allocation' variables. + scale_factor : float, optional + Scale factor to convert to original scale, if original_scale=True. + If None and original_scale=True, assumes scale_factor=1. + figsize : tuple[float, float], optional + The size of the figure to be created. Default is (12, 6). + ax : plt.Axes, optional + The axis to plot on. If None, a new figure and axis will be created. + Only used when no extra dimensions are present. + original_scale : bool, optional + A boolean flag to determine if the values should be plotted in their + original scale. Default is True. + dims : dict[str, str | int | list], optional + Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. + If provided, only the selected slice(s) will be plotted. + + Returns + ------- + fig : matplotlib.figure.Figure + The Figure object containing the plot. + axes : matplotlib.axes.Axes or numpy.ndarray of matplotlib.axes.Axes + The Axes object with the plot, or array of Axes for multiple subplots. + """ + # Get the channels from samples + if "channel" not in samples.dims: + raise ValueError( + "Expected 'channel' dimension in samples dataset, but none found." + ) + + # Check for required variables in samples + if not any( + "channel_contribution" in var_name for var_name in samples.data_vars + ): + raise ValueError( + "Expected a variable containing 'channel_contribution' in samples, but none found." + ) + if "allocation" not in samples: + raise ValueError( + "Expected 'allocation' variable in samples, but none found." + ) + + # Find the variable containing 'channel_contribution' in its name + channel_contrib_var = next( + var_name + for var_name in samples.data_vars + if "channel_contribution" in var_name + ) + + all_dims = list(samples.dims) + # Validate dims + if dims: + self._validate_dims(dims=dims, all_dims=all_dims) + else: + self._validate_dims({}, all_dims) + + # Handle list-valued dims: build all combinations + dims_keys, dims_combos = self._dim_list_handler(dims) + + # After filtering with dims, only use extra dims not in dims and not ignored for subplotting + ignored_dims = {"channel", "date", "sample", "chain", "draw"} + channel_contribution_dims = list(samples[channel_contrib_var].dims) + extra_dims = [ + d + for d in channel_contribution_dims + if d not in ignored_dims and d not in (dims or {}) + ] + + # Identify combos for remaining dims + if extra_dims: + extra_coords = [samples.coords[dim].values for dim in extra_dims] + extra_combos = list(itertools.product(*extra_coords)) + else: + extra_combos = [()] + + # Prepare subplots: one for each combo of dims_lists and extra_dims + total_combos = list(itertools.product(dims_combos, extra_combos)) + n_subplots = len(total_combos) + if n_subplots == 1 and ax is not None: + axes = np.array([[ax]]) + fig = ax.get_figure() + else: + fig, axes = self._init_subplots( + n_subplots=n_subplots, + ncols=1, + width_per_col=figsize[0], + height_per_row=figsize[1], + ) + + for row_idx, (dims_combo, extra_combo) in enumerate(total_combos): + ax_ = axes[row_idx][0] + # Build indexers for dims and extra_dims + indexers = ( + dict(zip(extra_dims, extra_combo, strict=False)) if extra_dims else {} + ) + if dims: + # For dims with lists, use the current value from dims_combo + for i, k in enumerate(dims_keys): + indexers[k] = dims_combo[i] + # For dims with single values, use as is + for k, v in (dims or {}).items(): + if k not in dims_keys: + indexers[k] = v + + # Select channel contributions for this subplot + channel_contrib_data = samples[channel_contrib_var].sel(**indexers) + allocation_data = samples.allocation + # Only select dims that exist in allocation + allocation_indexers = { + k: v for k, v in indexers.items() if k in allocation_data.dims + } + allocation_data = allocation_data.sel(**allocation_indexers) + + # Average over all dims except channel (and those used for this subplot) + used_dims = set(indexers.keys()) | {"channel"} + reduction_dims = [ + dim for dim in channel_contrib_data.dims if dim not in used_dims + ] + channel_contribution = channel_contrib_data.mean( + dim=reduction_dims + ).to_numpy() + if channel_contribution.ndim > 1: + channel_contribution = channel_contribution.flatten() + if original_scale and scale_factor is not None: + channel_contribution *= scale_factor + + allocation_used_dims = set(allocation_indexers.keys()) | {"channel"} + allocation_reduction_dims = [ + dim for dim in allocation_data.dims if dim not in allocation_used_dims + ] + if allocation_reduction_dims: + allocated_spend = allocation_data.mean( + dim=allocation_reduction_dims + ).to_numpy() + else: + allocated_spend = allocation_data.to_numpy() + if allocated_spend.ndim > 1: + allocated_spend = allocated_spend.flatten() + + self._plot_budget_allocation_bars( + ax_, + samples.coords["channel"].values, + allocated_spend, + channel_contribution, + ) + + # Build subplot title + title_dims = (list(dims.keys()) if dims else []) + extra_dims + title_combo = tuple(indexers[k] for k in title_dims) + title = self._build_subplot_title( + dims=title_dims, + combo=title_combo, + fallback_title="Budget Allocation", + ) + ax_.set_title(title) + + fig.tight_layout() + return fig, axes if n_subplots > 1 else (fig, axes[0][0]) + + def _plot_budget_allocation_bars( + self, + ax: plt.Axes, + channels: NDArray, + allocated_spend: NDArray, + channel_contribution: NDArray, + ) -> None: + """Plot budget allocation bars on a given axis. + + Parameters + ---------- + ax : plt.Axes + The axis to plot on. + channels : NDArray + Array of channel names. + allocated_spend : NDArray + Array of allocated spend values. + channel_contribution : NDArray + Array of channel contribution values. + """ + bar_width = 0.35 + opacity = 0.7 + index = range(len(channels)) + + # Plot allocated spend + bars1 = ax.bar( + index, + allocated_spend, + bar_width, + color="C0", + alpha=opacity, + label="Allocated Spend", + ) + + # Create twin axis for contributions + ax2 = ax.twinx() + + # Plot contributions + bars2 = ax2.bar( + [i + bar_width for i in index], + channel_contribution, + bar_width, + color="C1", + alpha=opacity, + label="Channel Contribution", + ) + + # Labels and formatting + ax.set_xlabel("Channels") + ax.set_ylabel("Allocated Spend", color="C0", labelpad=10) + ax2.set_ylabel("Channel Contributions", color="C1", labelpad=10) + + # Set x-ticks in the middle of the bars + ax.set_xticks([i + bar_width / 2 for i in index]) + ax.set_xticklabels(channels) + ax.tick_params(axis="x", rotation=90) + + # Turn off grid and add legend + ax.grid(False) + ax2.grid(False) + + bars = [bars1, bars2] + labels = ["Allocated Spend", "Channel Contributions"] + ax.legend(bars, labels, loc="best") + + def allocated_contribution_by_channel_over_time( + self, + samples: xr.Dataset, + scale_factor: float | None = None, + lower_quantile: float = 0.025, + upper_quantile: float = 0.975, + original_scale: bool = True, + figsize: tuple[float, float] = (10, 6), + ax: plt.Axes | None = None, + ) -> tuple[Figure, plt.Axes | NDArray[Axes]]: + """Plot the allocated contribution by channel with uncertainty intervals. + + This function visualizes the mean allocated contributions by channel along with + the uncertainty intervals defined by the lower and upper quantiles. + If additional dimensions besides 'channel', 'date', and 'sample' are present, + creates a subplot for each combination of these dimensions. + + Parameters + ---------- + samples : xr.Dataset + The dataset containing the samples of channel contributions. + Expected to have 'channel_contribution' variable with dimensions + 'channel', 'date', and 'sample'. + scale_factor : float, optional + Scale factor to convert to original scale, if original_scale=True. + If None and original_scale=True, assumes scale_factor=1. + lower_quantile : float, optional + The lower quantile for the uncertainty interval. Default is 0.025. + upper_quantile : float, optional + The upper quantile for the uncertainty interval. Default is 0.975. + original_scale : bool, optional + If True, the contributions are plotted on the original scale. Default is True. + figsize : tuple[float, float], optional + The size of the figure to be created. Default is (10, 6). + ax : plt.Axes, optional + The axis to plot on. If None, a new figure and axis will be created. + Only used when no extra dimensions are present. + + Returns + ------- + fig : matplotlib.figure.Figure + The Figure object containing the plot. + axes : matplotlib.axes.Axes or numpy.ndarray of matplotlib.axes.Axes + The Axes object with the plot, or array of Axes for multiple subplots. + """ + # Check for expected dimensions and variables + if "channel" not in samples.dims: + raise ValueError( + "Expected 'channel' dimension in samples dataset, but none found." + ) + if "date" not in samples.dims: + raise ValueError( + "Expected 'date' dimension in samples dataset, but none found." + ) + if "sample" not in samples.dims: + raise ValueError( + "Expected 'sample' dimension in samples dataset, but none found." + ) + # Check if any variable contains channel contributions + if not any( + "channel_contribution" in var_name for var_name in samples.data_vars + ): + raise ValueError( + "Expected a variable containing 'channel_contribution' in samples, but none found." + ) + + # Get channel contributions data + channel_contrib_var = next( + var_name + for var_name in samples.data_vars + if "channel_contribution" in var_name + ) + + # Identify extra dimensions beyond 'channel', 'date', and 'sample' + all_dims = list(samples[channel_contrib_var].dims) + ignored_dims = {"channel", "date", "sample"} + extra_dims = [dim for dim in all_dims if dim not in ignored_dims] + + # If no extra dimensions or using provided axis, create a single plot + if not extra_dims or ax is not None: + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + else: + fig = ax.get_figure() + + channel_contribution = samples[channel_contrib_var] + + # Apply scale factor if in original scale + if original_scale and scale_factor is not None: + channel_contribution = channel_contribution * scale_factor + + # Plot mean values by channel + channel_contribution.mean(dim="sample").plot(hue="channel", ax=ax) + + # Add uncertainty intervals for each channel + for channel in samples.coords["channel"].values: + ax.fill_between( + x=channel_contribution.date.values, + y1=channel_contribution.sel(channel=channel).quantile( + lower_quantile, dim="sample" + ), + y2=channel_contribution.sel(channel=channel).quantile( + upper_quantile, dim="sample" + ), + alpha=0.1, + ) + + ax.set_xlabel("Date") + ax.set_ylabel("Channel Contribution") + ax.set_title("Allocated Contribution by Channel Over Time") + + fig.tight_layout() + return fig, ax + + # For multiple dimensions, create a grid of subplots + # Determine layout based on number of extra dimensions + if len(extra_dims) == 1: + # One extra dimension: use for rows + dim_values = [samples.coords[extra_dims[0]].values] + nrows = len(dim_values[0]) + ncols = 1 + subplot_dims = [extra_dims[0], None] + elif len(extra_dims) == 2: + # Two extra dimensions: one for rows, one for columns + dim_values = [ + samples.coords[extra_dims[0]].values, + samples.coords[extra_dims[1]].values, + ] + nrows = len(dim_values[0]) + ncols = len(dim_values[1]) + subplot_dims = extra_dims + else: + # Three or more: use first two for rows/columns, average over the rest + dim_values = [ + samples.coords[extra_dims[0]].values, + samples.coords[extra_dims[1]].values, + ] + nrows = len(dim_values[0]) + ncols = len(dim_values[1]) + subplot_dims = [extra_dims[0], extra_dims[1]] + + # Calculate figure size based on number of subplots + subplot_figsize = (figsize[0] * max(1, ncols), figsize[1] * max(1, nrows)) + fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=subplot_figsize) + + # Make axes indexable even for 1x1 grid + if nrows == 1 and ncols == 1: + axes = np.array([[axes]]) + elif nrows == 1: + axes = axes.reshape(1, -1) + elif ncols == 1: + axes = axes.reshape(-1, 1) + + # Create a subplot for each combination of dimension values + for i, row_val in enumerate(dim_values[0]): + for j, col_val in enumerate( + dim_values[1] if len(dim_values) > 1 else [None] + ): + ax = axes[i, j] + + # Select data for this subplot + selection = {subplot_dims[0]: row_val} + if col_val is not None: + selection[subplot_dims[1]] = col_val + + # Select channel contributions for this subplot + subset = samples[channel_contrib_var].sel(**selection) + + # Apply scale factor if needed + if original_scale and scale_factor is not None: + subset = subset * scale_factor + + # Plot mean values by channel for this subset + subset.mean(dim="sample").plot(hue="channel", ax=ax) + + # Add uncertainty intervals for each channel + for channel in samples.coords["channel"].values: + channel_data = subset.sel(channel=channel) + ax.fill_between( + x=channel_data.date.values, + y1=channel_data.quantile(lower_quantile, dim="sample"), + y2=channel_data.quantile(upper_quantile, dim="sample"), + alpha=0.1, + ) + + # Add subplot title based on dimension values + title_parts = [] + if subplot_dims[0] is not None: + title_parts.append(f"{subplot_dims[0]}={row_val}") + if subplot_dims[1] is not None: + title_parts.append(f"{subplot_dims[1]}={col_val}") + + base_title = "Allocated Contribution by Channel Over Time" + if title_parts: + ax.set_title(f"{base_title} - {', '.join(title_parts)}") + else: + ax.set_title(base_title) + + ax.set_xlabel("Date") + ax.set_ylabel("Channel Contribution") + + fig.tight_layout() + return fig, axes + + def sensitivity_analysis( + self, + hdi_prob: float = 0.94, + ax: plt.Axes | None = None, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + subplot_kwargs: dict[str, Any] | None = None, + *, + plot_kwargs: dict[str, Any] | None = None, + ylabel: str = "Effect", + xlabel: str = "Sweep", + title: str | None = None, + add_figure_title: bool = False, + subplot_title_fallback: str = "Sensitivity Analysis", + ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: + """Plot sensitivity analysis results. + + Parameters + ---------- + hdi_prob : float, default 0.94 + HDI probability mass. + ax : plt.Axes, optional + The axis to plot on. + aggregation : dict, optional + Aggregation to apply to the data. + E.g., {"sum": ("channel",)} to sum over the channel dimension. + + Other Parameters + ---------------- + plot_kwargs : dict, optional + Keyword arguments forwarded to the underlying line plot. Defaults include + ``{"color": "C0"}``. + ylabel : str, optional + Y-axis label. Defaults to "Effect". + xlabel : str, optional + X-axis label. Defaults to "Sweep". + title : str, optional + Figure-level title to add when ``add_figure_title=True``. + add_figure_title : bool, optional + Whether to add a figure-level title. Defaults to ``False``. + subplot_title_fallback : str, optional + Fallback title used for subplot titles when no plotting dims exist. Defaults + to "Sensitivity Analysis". + + Examples + -------- + Basic run using stored results in `idata`: + + .. code-block:: python + + # Assuming you already ran a sweep and stored results + # under idata.sensitivity_analysis via SensitivityAnalysis.run_sweep(..., extend_idata=True) + ax = mmm.plot.sensitivity_analysis(hdi_prob=0.9) + + With aggregation over dimensions (e.g., sum over channels): + + .. code-block:: python + + ax = mmm.plot.sensitivity_analysis( + hdi_prob=0.9, + aggregation={"sum": ("channel",)}, + ) + """ + if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError( + "No sensitivity analysis results found. Run run_sweep() first." + ) + sa = self.idata.sensitivity_analysis # type: ignore + x = sa["x"] if isinstance(sa, xr.Dataset) else sa + # Coerce numeric dtype + try: + x = x.astype(float) + except Exception as err: + import warnings + + warnings.warn( + f"Failed to cast sensitivity analysis data to float: {err}", + RuntimeWarning, + stacklevel=2, + ) + # Apply aggregations + if aggregation: + for op, dims in aggregation.items(): + dims_list = [d for d in dims if d in x.dims] + if not dims_list: + continue + if op == "sum": + x = x.sum(dim=dims_list) + elif op == "mean": + x = x.mean(dim=dims_list) + else: + x = x.median(dim=dims_list) + # Determine plotting dimensions (excluding sample & sweep) + plot_dims = [d for d in x.dims if d not in {"sample", "sweep"}] + if plot_dims: + dim_combinations = list( + itertools.product(*[x.coords[d].values for d in plot_dims]) + ) + else: + dim_combinations = [()] + + n_panels = len(dim_combinations) + + # Handle axis/grid creation + subplot_kwargs = {**(subplot_kwargs or {})} + nrows_user = subplot_kwargs.pop("nrows", None) + ncols_user = subplot_kwargs.pop("ncols", None) + if nrows_user is not None and ncols_user is not None: + raise ValueError( + "Specify only one of 'nrows' or 'ncols' in subplot_kwargs." + ) + + if n_panels > 1: + if ax is not None: + raise ValueError( + "Multiple sensitivity panels detected; please omit 'ax' and use 'subplot_kwargs' instead." + ) + if ncols_user is not None: + ncols = ncols_user + nrows = int(np.ceil(n_panels / ncols)) + elif nrows_user is not None: + nrows = nrows_user + ncols = int(np.ceil(n_panels / nrows)) + else: + ncols = max(1, int(np.ceil(np.sqrt(n_panels)))) + nrows = int(np.ceil(n_panels / ncols)) + subplot_kwargs.setdefault("figsize", (ncols * 4.0, nrows * 3.0)) + fig, axes_grid = plt.subplots( + nrows=nrows, + ncols=ncols, + **subplot_kwargs, + ) + if isinstance(axes_grid, plt.Axes): + axes_grid = np.array([[axes_grid]]) + elif axes_grid.ndim == 1: + axes_grid = axes_grid.reshape(1, -1) + axes_array = axes_grid + else: + if ax is not None: + axes_array = np.array([[ax]]) + fig = ax.figure + else: + if ncols_user is not None or nrows_user is not None: + subplot_kwargs.setdefault("figsize", (4.0, 3.0)) + fig, single_ax = plt.subplots( + nrows=1, + ncols=1, + **subplot_kwargs, + ) + else: + fig, single_ax = plt.subplots() + axes_array = np.array([[single_ax]]) + + # Merge plotting kwargs with defaults + _plot_kwargs = {"color": "C0"} + if plot_kwargs: + _plot_kwargs.update(plot_kwargs) + _line_color = _plot_kwargs.get("color", "C0") + + axes_flat = axes_array.flatten() + for idx, combo in enumerate(dim_combinations): + current_ax = axes_flat[idx] + indexers = dict(zip(plot_dims, combo, strict=False)) if plot_dims else {} + subset = x.sel(**indexers) if indexers else x + subset = subset.squeeze(drop=True) + subset = subset.astype(float) + + if "sweep" in subset.dims: + sweep_dim = "sweep" + else: + cand = [d for d in subset.dims if d != "sample"] + if not cand: + raise ValueError( + "Expected 'sweep' (or a non-sample) dimension in sensitivity results." + ) + sweep_dim = cand[0] + + sweep = ( + np.asarray(subset.coords[sweep_dim].values) + if sweep_dim in subset.coords + else np.arange(subset.sizes[sweep_dim]) + ) + + mean = subset.mean("sample") if "sample" in subset.dims else subset + reduce_dims = [d for d in mean.dims if d != sweep_dim] + if reduce_dims: + mean = mean.sum(dim=reduce_dims) + + if "sample" in subset.dims: + hdi = az.hdi(subset, hdi_prob=hdi_prob, input_core_dims=[["sample"]]) + if isinstance(hdi, xr.Dataset): + hdi = hdi[next(iter(hdi.data_vars))] + else: + hdi = xr.concat([mean, mean], dim="hdi").assign_coords( + hdi=np.array([0, 1]) + ) + + reduce_hdi = [d for d in hdi.dims if d not in (sweep_dim, "hdi")] + if reduce_hdi: + hdi = hdi.sum(dim=reduce_hdi) + if set(hdi.dims) == {sweep_dim, "hdi"} and list(hdi.dims) != [ + sweep_dim, + "hdi", + ]: + hdi = hdi.transpose(sweep_dim, "hdi") # type: ignore + + current_ax.plot(sweep, np.asarray(mean.values, dtype=float), **_plot_kwargs) + az.plot_hdi( + x=sweep, + hdi_data=np.asarray(hdi.values, dtype=float), + hdi_prob=hdi_prob, + color=_line_color, + ax=current_ax, + ) + + title = self._build_subplot_title( + dims=plot_dims, + combo=combo, + fallback_title=subplot_title_fallback, + ) + current_ax.set_title(title) + current_ax.set_xlabel(xlabel) + current_ax.set_ylabel(ylabel) + + # Hide any unused axes (happens if grid > panels) + for ax_extra in axes_flat[n_panels:]: + ax_extra.set_visible(False) + + # Optional figure-level title: only for multi-panel layouts, default color (black) + if add_figure_title and title is not None and n_panels > 1: + fig.suptitle(title) + + if n_panels == 1: + return axes_array[0, 0] + + fig.tight_layout() + return fig, axes_array + + def uplift_curve( + self, + hdi_prob: float = 0.94, + ax: plt.Axes | None = None, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + subplot_kwargs: dict[str, Any] | None = None, + *, + plot_kwargs: dict[str, Any] | None = None, + ylabel: str = "Uplift", + xlabel: str = "Sweep", + title: str | None = "Uplift curve", + add_figure_title: bool = True, + ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: + """ + Plot precomputed uplift curves stored under `idata.sensitivity_analysis['uplift_curve']`. + + Parameters + ---------- + hdi_prob : float, default 0.94 + HDI probability mass. + ax : plt.Axes, optional + The axis to plot on. + aggregation : dict, optional + Aggregation to apply to the data. + E.g., {"sum": ("channel",)} to sum over the channel dimension. + subplot_kwargs : dict, optional + Additional subplot configuration forwarded to :meth:`sensitivity_analysis`. + plot_kwargs : dict, optional + Keyword arguments forwarded to the underlying line plot. If not provided, defaults + are used by :meth:`sensitivity_analysis` (e.g., color "C0"). + ylabel : str, optional + Y-axis label. Defaults to "Uplift". + xlabel : str, optional + X-axis label. Defaults to "Sweep". + title : str, optional + Figure-level title to add when ``add_figure_title=True``. Defaults to "Uplift curve". + add_figure_title : bool, optional + Whether to add a figure-level title. Defaults to ``True``. + + Examples + -------- + Persist uplift curve and plot: + + .. code-block:: python + + from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis + + sweeps = np.linspace(0.5, 1.5, 11) + sa = SensitivityAnalysis(mmm.model, mmm.idata) + results = sa.run_sweep( + var_input="channel_data", + sweep_values=sweeps, + var_names="channel_contribution", + sweep_type="multiplicative", + ) + uplift = sa.compute_uplift_curve_respect_to_base( + results, ref=1.0, extend_idata=True + ) + _ = mmm.plot.uplift_curve(hdi_prob=0.9) + """ + if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError( + "No sensitivity analysis results found in 'self.idata'. " + "Run 'mmm.sensitivity.run_sweep()' first." + ) + + sa_group = self.idata.sensitivity_analysis # type: ignore + if isinstance(sa_group, xr.Dataset): + if "uplift_curve" not in sa_group: + raise ValueError( + "Expected 'uplift_curve' in idata.sensitivity_analysis. " + "Use SensitivityAnalysis.compute_uplift_curve_respect_to_base(..., extend_idata=True)." + ) + data_var = sa_group["uplift_curve"] + else: + raise ValueError( + "sensitivity_analysis does not contain 'uplift_curve'. Did you persist it to idata?" + ) + + # Delegate to a thin wrapper by temporarily constructing a Dataset + tmp_idata = xr.Dataset({"x": data_var}) + # Monkey-patch minimal attributes needed + tmp_idata["x"].attrs.update(getattr(sa_group, "attrs", {})) # type: ignore + # Temporarily swap + original_group = self.idata.sensitivity_analysis # type: ignore + try: + self.idata.sensitivity_analysis = tmp_idata # type: ignore + return self.sensitivity_analysis( + hdi_prob=hdi_prob, + ax=ax, + aggregation=aggregation, + subplot_kwargs=subplot_kwargs, + subplot_title_fallback="Uplift curve", + plot_kwargs=plot_kwargs, + ylabel=ylabel, + xlabel=xlabel, + title=title, + add_figure_title=add_figure_title, + ) + finally: + self.idata.sensitivity_analysis = original_group # type: ignore + + def marginal_curve( + self, + hdi_prob: float = 0.94, + ax: plt.Axes | None = None, + aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, + subplot_kwargs: dict[str, Any] | None = None, + *, + plot_kwargs: dict[str, Any] | None = None, + ylabel: str = "Marginal effect", + xlabel: str = "Sweep", + title: str | None = "Marginal effects", + add_figure_title: bool = True, + ) -> tuple[Figure, NDArray[Axes]] | plt.Axes: + """ + Plot precomputed marginal effects stored under `idata.sensitivity_analysis['marginal_effects']`. + + Parameters + ---------- + hdi_prob : float, default 0.94 + HDI probability mass. + ax : plt.Axes, optional + The axis to plot on. + aggregation : dict, optional + Aggregation to apply to the data. + E.g., {"sum": ("channel",)} to sum over the channel dimension. + subplot_kwargs : dict, optional + Additional subplot configuration forwarded to :meth:`sensitivity_analysis`. + plot_kwargs : dict, optional + Keyword arguments forwarded to the underlying line plot. Defaults to ``{"color": "C1"}``. + ylabel : str, optional + Y-axis label. Defaults to "Marginal effect". + xlabel : str, optional + X-axis label. Defaults to "Sweep". + title : str, optional + Figure-level title to add when ``add_figure_title=True``. Defaults to "Marginal effects". + add_figure_title : bool, optional + Whether to add a figure-level title. Defaults to ``True``. + + Examples + -------- + Persist marginal effects and plot: + + .. code-block:: python + + from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis + + sweeps = np.linspace(0.5, 1.5, 11) + sa = SensitivityAnalysis(mmm.model, mmm.idata) + results = sa.run_sweep( + var_input="channel_data", + sweep_values=sweeps, + var_names="channel_contribution", + sweep_type="multiplicative", + ) + me = sa.compute_marginal_effects(results, extend_idata=True) + _ = mmm.plot.marginal_curve(hdi_prob=0.9) + """ + if not hasattr(self.idata, "sensitivity_analysis"): + raise ValueError( + "No sensitivity analysis results found in 'self.idata'. " + "Run 'mmm.sensitivity.run_sweep()' first." + ) + + sa_group = self.idata.sensitivity_analysis # type: ignore + if isinstance(sa_group, xr.Dataset): + if "marginal_effects" not in sa_group: + raise ValueError( + "Expected 'marginal_effects' in idata.sensitivity_analysis. " + "Use SensitivityAnalysis.compute_marginal_effects(..., extend_idata=True)." + ) + data_var = sa_group["marginal_effects"] + else: + raise ValueError( + "sensitivity_analysis does not contain 'marginal_effects'. Did you persist it to idata?" + ) + + # We want a different y-label and color + # Temporarily swap group to reuse plotting logic + tmp = xr.Dataset({"x": data_var}) + tmp["x"].attrs.update(getattr(sa_group, "attrs", {})) # type: ignore + original = self.idata.sensitivity_analysis # type: ignore + try: + self.idata.sensitivity_analysis = tmp # type: ignore + # Reuse core plotting; percentage=False by definition + # Merge defaults for plot_kwargs if not provided + _plot_kwargs = {"color": "C1"} + if plot_kwargs: + _plot_kwargs.update(plot_kwargs) + return self.sensitivity_analysis( + hdi_prob=hdi_prob, + ax=ax, + aggregation=aggregation, + subplot_kwargs=subplot_kwargs, + subplot_title_fallback="Marginal effects", + plot_kwargs=_plot_kwargs, + ylabel=ylabel, + xlabel=xlabel, + title=title, + add_figure_title=add_figure_title, + ) + finally: + self.idata.sensitivity_analysis = original # type: ignore diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 8d7ac108d..3c57112b7 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -291,6 +291,42 @@ def _resolve_backend(self, backend: str | None) -> str: return backend or mmm_config["plot.backend"] + def _get_data_or_fallback( + self, + data: xr.Dataset | None, + idata_attr: str, + data_name: str, + ) -> xr.Dataset: + """Get data from parameter or fall back to self.idata attribute. + + Parameters + ---------- + data : xr.Dataset or None + Data provided by user. + idata_attr : str + Attribute name on self.idata to use as fallback (e.g., "posterior"). + data_name : str + Human-readable name for error messages (e.g., "posterior data"). + + Returns + ------- + xr.Dataset + The data to use. + + Raises + ------ + ValueError + If data is None and self.idata doesn't have the required attribute. + """ + if data is None: + if not hasattr(self.idata, idata_attr): + raise ValueError( + f"No {data_name} found in 'self.idata' and no 'data' argument provided. " + f"Please ensure 'self.idata' contains a '{idata_attr}' group or provide 'data' explicitly." + ) + data = getattr(self.idata, idata_attr) + return data + # ------------------------------------------------------------------------ # Main Plotting Methods # ------------------------------------------------------------------------ @@ -387,6 +423,7 @@ def posterior_predictive( def contributions_over_time( self, var: list[str], + data: xr.Dataset | None = None, hdi_prob: float = 0.85, dims: dict[str, str | int | list] | None = None, backend: str | None = None, @@ -401,6 +438,14 @@ def contributions_over_time( ---------- var : list of str A list of variable names to plot from the posterior. + data : xr.Dataset, optional + Dataset containing posterior data. If None, uses self.idata.posterior. + + This parameter allows: + - Testing with mock data without modifying self.idata + - Plotting external results not stored in self.idata + - Comparing different posterior samples side-by-side + - Avoiding unintended side effects on self.idata hdi_prob: float, optional The probability mass of the highest density interval to be displayed. Default is 0.85. dims : dict[str, str | int | list], optional @@ -419,14 +464,21 @@ def contributions_over_time( ------ ValueError If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value. + ValueError + If no posterior data found in self.idata and no data argument provided. """ if not 0 < hdi_prob < 1: raise ValueError("HDI probability must be between 0 and 1.") - if not hasattr(self.idata, "posterior"): + # Get data with fallback to self.idata.posterior + data = self._get_data_or_fallback(data, "posterior", "posterior data") + + # Validate data has the required variables + missing_vars = [v for v in var if v not in data] + if missing_vars: raise ValueError( - "No posterior data found in 'self.idata'. " - "Please ensure 'self.idata' contains a 'posterior' group." + f"Variables {missing_vars} not found in data. " + f"Available variables: {list(data.data_vars)}" ) # Resolve backend @@ -434,7 +486,7 @@ def contributions_over_time( main_var = var[0] ignored_dims = {"chain", "draw", "date"} - da = self.idata.posterior[var] + da = data[var] # Apply dims filtering if provided if dims: @@ -493,6 +545,8 @@ def contributions_over_time( def saturation_scatterplot( self, original_scale: bool = False, + constant_data: xr.Dataset | None = None, + posterior_data: xr.Dataset | None = None, dims: dict[str, str | int | list] | None = None, backend: str | None = None, ) -> PlotCollection: @@ -506,6 +560,22 @@ def saturation_scatterplot( ---------- original_scale: bool, optional Whether to plot the original scale contributions. Default is False. + constant_data : xr.Dataset, optional + Dataset containing constant_data group with 'channel_data' variable. + If None, uses self.idata.constant_data. + + This parameter allows: + - Testing with mock constant data + - Plotting with alternative scaling factors + - Comparing different data scenarios + posterior_data : xr.Dataset, optional + Dataset containing posterior group with channel contribution variables. + If None, uses self.idata.posterior. + + This parameter allows: + - Testing with mock posterior samples + - Plotting external inference results + - Comparing different model fits dims: dict[str, str | int | list], optional Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. If provided, only the selected slice(s) will be plotted. @@ -517,26 +587,40 @@ def saturation_scatterplot( Returns ------- PlotCollection + + Raises + ------ + ValueError + If required data not found in self.idata and not provided explicitly. """ # Resolve backend backend = self._resolve_backend(backend) - if not hasattr(self.idata, "constant_data"): + # Get constant_data and posterior_data with fallback + constant_data = self._get_data_or_fallback( + constant_data, "constant_data", "constant data" + ) + posterior_data = self._get_data_or_fallback( + posterior_data, "posterior", "posterior data" + ) + + # Validate required variables exist + if "channel_data" not in constant_data: raise ValueError( - "No 'constant_data' found in 'self.idata'. " - "Please ensure 'self.idata' contains the constant_data group." + "'channel_data' variable not found in constant_data. " + f"Available variables: {list(constant_data.data_vars)}" ) # Identify additional dimensions beyond 'date' and 'channel' - cdims = self.idata.constant_data.channel_data.dims + cdims = constant_data.channel_data.dims additional_dims = [dim for dim in cdims if dim not in ("date", "channel")] # Validate dims and remove filtered dims from additional_dims if dims: - self._validate_dims(dims, list(self.idata.constant_data.channel_data.dims)) + self._validate_dims(dims, list(constant_data.channel_data.dims)) additional_dims = [d for d in additional_dims if d not in dims] else: - self._validate_dims({}, list(self.idata.constant_data.channel_data.dims)) + self._validate_dims({}, list(constant_data.channel_data.dims)) channel_contribution = ( "channel_contribution_original_scale" @@ -544,9 +628,9 @@ def saturation_scatterplot( else "channel_contribution" ) - if original_scale and not hasattr(self.idata.posterior, channel_contribution): + if channel_contribution not in posterior_data: raise ValueError( - f"""No posterior.{channel_contribution} data found in 'self.idata'. \n + f"""No posterior.{channel_contribution} data found in posterior_data. \n Add a original scale deterministic:\n mmm.add_original_scale_contribution_variable(\n var=[\n @@ -558,8 +642,8 @@ def saturation_scatterplot( ) # Apply dims filtering to channel_data and channel_contribution - channel_data = self.idata.constant_data.channel_data - channel_contrib = self.idata.posterior[channel_contribution] + channel_data = constant_data.channel_data + channel_contrib = posterior_data[channel_contribution] if dims: for dim_name, dim_value in dims.items(): @@ -598,6 +682,8 @@ def saturation_curves( self, curve: xr.DataArray, original_scale: bool = False, + constant_data: xr.Dataset | None = None, + posterior_data: xr.Dataset | None = None, n_samples: int = 10, hdi_probs: float | list[float] | None = None, random_seed: np.random.Generator | None = None, @@ -615,6 +701,14 @@ def saturation_curves( Posterior‑predictive curves (e.g. dims `("chain","draw","x","channel","geo")`). original_scale : bool, default=False Plot `channel_contribution_original_scale` if True, else `channel_contribution`. + constant_data : xr.Dataset, optional + Dataset containing constant_data group. If None, uses self.idata.constant_data. + + This parameter allows testing with mock data and plotting alternative scenarios. + posterior_data : xr.Dataset, optional + Dataset containing posterior group. If None, uses self.idata.posterior. + + This parameter allows testing with mock posterior samples and comparing model fits. n_samples : int, default=10 Number of sample‑curves per subplot. hdi_probs : float or list of float, optional @@ -642,11 +736,13 @@ def saturation_curves( >>> hdi_probs=[0.9, 0.7], random_seed=rng) >>> pc.show() """ - if not hasattr(self.idata, "constant_data"): - raise ValueError( - "No 'constant_data' found in 'self.idata'. " - "Please ensure 'self.idata' contains the constant_data group." - ) + # Get constant_data and posterior_data with fallback + constant_data = self._get_data_or_fallback( + constant_data, "constant_data", "constant data" + ) + posterior_data = self._get_data_or_fallback( + posterior_data, "posterior", "posterior data" + ) contrib_var = ( "channel_contribution_original_scale" @@ -654,9 +750,9 @@ def saturation_curves( else "channel_contribution" ) - if original_scale and not hasattr(self.idata.posterior, contrib_var): + if original_scale and contrib_var not in posterior_data: raise ValueError( - f"""No posterior.{contrib_var} data found in 'self.idata'.\n" + f"""No posterior.{contrib_var} data found in posterior_data.\n" "Add a original scale deterministic:\n" " mmm.add_original_scale_contribution_variable(\n" " var=[\n" @@ -673,14 +769,14 @@ def saturation_curves( raise ValueError("curve must have a 'channel' dimension") if original_scale: - curve_data = curve * self.idata.constant_data.target_scale - curve_data["x"] = curve_data["x"] * self.idata.constant_data.channel_scale + curve_data = curve * constant_data.target_scale + curve_data["x"] = curve_data["x"] * constant_data.channel_scale else: curve_data = curve curve_data = curve_data.rename("saturation_curve") # — 1. figure out grid shape based on scatter data dimensions / identify dims and combos - cdims = self.idata.constant_data.channel_data.dims + cdims = constant_data.channel_data.dims all_dims = list(cdims) additional_dims = [d for d in cdims if d not in ("date", "channel")] # Validate dims and remove filtered dims from additional_dims @@ -692,7 +788,11 @@ def saturation_curves( # create the saturation scatterplot pc = self.saturation_scatterplot( - original_scale=original_scale, dims=dims, backend=backend + original_scale=original_scale, + constant_data=constant_data, + posterior_data=posterior_data, + dims=dims, + backend=backend, ) # add the hdi bands @@ -978,14 +1078,23 @@ def allocated_contribution_by_channel_over_time( def _sensitivity_analysis_plot( self, + data: xr.DataArray | xr.Dataset, hdi_prob: float = 0.94, aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, backend: str | None = None, ) -> PlotCollection: """Plot helper for sensitivity analysis results. + This is a private helper method that operates on provided data. + Public methods (sensitivity_analysis, uplift_curve, marginal_curve) + handle data retrieval from self.idata. + Parameters ---------- + data : xr.DataArray or xr.Dataset + Sensitivity analysis data to plot. Must have 'sample' and 'sweep' dimensions. + If Dataset, should contain 'x' variable. This parameter is REQUIRED with + no fallback to self.idata to maintain separation of concerns. hdi_prob : float, default 0.94 HDI probability mass. aggregation : dict, optional @@ -999,12 +1108,15 @@ def _sensitivity_analysis_plot( PlotCollection """ - if not hasattr(self.idata, "sensitivity_analysis"): + # Handle Dataset or DataArray + x = data["x"] if isinstance(data, xr.Dataset) else data + + # Validate dimensions + required_dims = {"sample", "sweep"} + if not required_dims.issubset(set(x.dims)): raise ValueError( - "No sensitivity analysis results found. Run run_sweep() first." + f"Data must have dimensions {required_dims}, got {set(x.dims)}" ) - sa = self.idata.sensitivity_analysis # type: ignore - x = sa["x"] if isinstance(sa, xr.Dataset) else sa # Coerce numeric dtype try: x = x.astype(float) @@ -1070,6 +1182,7 @@ def _sensitivity_analysis_plot( def sensitivity_analysis( self, + data: xr.DataArray | xr.Dataset | None = None, hdi_prob: float = 0.94, aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, backend: str | None = None, @@ -1078,6 +1191,13 @@ def sensitivity_analysis( Parameters ---------- + data : xr.DataArray or xr.Dataset, optional + Sensitivity analysis data to plot. If None, uses self.idata.sensitivity_analysis. + + This parameter allows: + - Testing with mock sensitivity analysis results + - Plotting external sweep results + - Comparing different sensitivity analyses hdi_prob : float, default 0.94 HDI probability mass. aggregation : dict, optional @@ -1090,6 +1210,11 @@ def sensitivity_analysis( ------- PlotCollection + Raises + ------ + ValueError + If no sensitivity analysis data found and no data provided. + Examples -------- Basic run using stored results in `idata`: @@ -1109,14 +1234,20 @@ def sensitivity_analysis( aggregation={"sum": ("channel",)}, ) """ + # Retrieve data if not provided + data = self._get_data_or_fallback( + data, "sensitivity_analysis", "sensitivity analysis results" + ) + pc = self._sensitivity_analysis_plot( - hdi_prob=hdi_prob, aggregation=aggregation, backend=backend + data=data, hdi_prob=hdi_prob, aggregation=aggregation, backend=backend ) pc.map(azp.visuals.labelled_y, text="Contribution") return pc def uplift_curve( self, + data: xr.DataArray | xr.Dataset | None = None, hdi_prob: float = 0.94, aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, backend: str | None = None, @@ -1126,6 +1257,14 @@ def uplift_curve( Parameters ---------- + data : xr.DataArray or xr.Dataset, optional + Uplift curve data to plot. If Dataset, should contain 'uplift_curve' variable. + If None, uses self.idata.sensitivity_analysis['uplift_curve']. + + This parameter allows: + - Testing with mock uplift curve data + - Plotting externally computed uplift curves + - Comparing uplift curves from different models hdi_prob : float, default 0.94 HDI probability mass. aggregation : dict, optional @@ -1134,6 +1273,16 @@ def uplift_curve( backend : str | None, optional Backend to use for plotting. If None, will use the global backend configuration. + Returns + ------- + PlotCollection + arviz_plots PlotCollection object. + + Raises + ------ + ValueError + If no uplift curve data found and no data provided. + Examples -------- Persist uplift curve and plot: @@ -1155,45 +1304,45 @@ def uplift_curve( ) mmm.plot.uplift_curve(hdi_prob=0.9) """ - if not hasattr(self.idata, "sensitivity_analysis"): - raise ValueError( - "No sensitivity analysis results found in 'self.idata'. " - "Run 'mmm.sensitivity.run_sweep()' first." + # Retrieve data if not provided + if data is None: + sa_group = self._get_data_or_fallback( + None, "sensitivity_analysis", "sensitivity analysis results" ) - - sa_group = self.idata.sensitivity_analysis # type: ignore - if isinstance(sa_group, xr.Dataset): - if "uplift_curve" not in sa_group: + if isinstance(sa_group, xr.Dataset): + if "uplift_curve" not in sa_group: + raise ValueError( + "Expected 'uplift_curve' in idata.sensitivity_analysis. " + "Use SensitivityAnalysis.compute_uplift_curve_respect_to_base(..., extend_idata=True)." + ) + data = sa_group["uplift_curve"] + else: raise ValueError( - "Expected 'uplift_curve' in idata.sensitivity_analysis. " - "Use SensitivityAnalysis.compute_uplift_curve_respect_to_base(..., extend_idata=True)." + "sensitivity_analysis does not contain 'uplift_curve'. Did you persist it to idata?" ) - data_var = sa_group["uplift_curve"] - else: - raise ValueError( - "sensitivity_analysis does not contain 'uplift_curve'. Did you persist it to idata?" - ) - # Delegate to a thin wrapper by temporarily constructing a Dataset - tmp_idata = xr.Dataset({"x": data_var}) - # Monkey-patch minimal attributes needed - tmp_idata["x"].attrs.update(getattr(sa_group, "attrs", {})) # type: ignore - # Temporarily swap - original_group = self.idata.sensitivity_analysis # type: ignore - try: - self.idata.sensitivity_analysis = tmp_idata # type: ignore - pc = self._sensitivity_analysis_plot( - hdi_prob=hdi_prob, - aggregation=aggregation, - backend=backend, - ) - pc.map(azp.visuals.labelled_y, text="Uplift (%)") - return pc - finally: - self.idata.sensitivity_analysis = original_group # type: ignore + # Handle Dataset input + if isinstance(data, xr.Dataset): + if "uplift_curve" in data: + data = data["uplift_curve"] + elif "x" in data: + data = data["x"] + else: + raise ValueError("Dataset must contain 'uplift_curve' or 'x' variable.") + + # Call helper with data (no more monkey-patching!) + pc = self._sensitivity_analysis_plot( + data=data, + hdi_prob=hdi_prob, + aggregation=aggregation, + backend=backend, + ) + pc.map(azp.visuals.labelled_y, text="Uplift (%)") + return pc def marginal_curve( self, + data: xr.DataArray | xr.Dataset | None = None, hdi_prob: float = 0.94, aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, backend: str | None = None, @@ -1203,6 +1352,14 @@ def marginal_curve( Parameters ---------- + data : xr.DataArray or xr.Dataset, optional + Marginal effects data to plot. If Dataset, should contain 'marginal_effects' variable. + If None, uses self.idata.sensitivity_analysis['marginal_effects']. + + This parameter allows: + - Testing with mock marginal effects data + - Plotting externally computed marginal effects + - Comparing marginal effects from different models hdi_prob : float, default 0.94 HDI probability mass. aggregation : dict, optional @@ -1214,6 +1371,12 @@ def marginal_curve( Returns ------- PlotCollection + arviz_plots PlotCollection object. + + Raises + ------ + ValueError + If no marginal effects data found and no data provided. Examples -------- @@ -1234,38 +1397,40 @@ def marginal_curve( me = sa.compute_marginal_effects(results, extend_idata=True) mmm.plot.marginal_curve(hdi_prob=0.9) """ - if not hasattr(self.idata, "sensitivity_analysis"): - raise ValueError( - "No sensitivity analysis results found in 'self.idata'. " - "Run 'mmm.sensitivity.run_sweep()' first." + # Retrieve data if not provided + if data is None: + sa_group = self._get_data_or_fallback( + None, "sensitivity_analysis", "sensitivity analysis results" ) + if isinstance(sa_group, xr.Dataset): + if "marginal_effects" not in sa_group: + raise ValueError( + "Expected 'marginal_effects' in idata.sensitivity_analysis. " + "Use SensitivityAnalysis.compute_marginal_effects(..., extend_idata=True)." + ) + data = sa_group["marginal_effects"] + else: + raise ValueError( + "sensitivity_analysis does not contain 'marginal_effects'. Did you persist it to idata?" + ) - sa_group = self.idata.sensitivity_analysis # type: ignore - if isinstance(sa_group, xr.Dataset): - if "marginal_effects" not in sa_group: + # Handle Dataset input + if isinstance(data, xr.Dataset): + if "marginal_effects" in data: + data = data["marginal_effects"] + elif "x" in data: + data = data["x"] + else: raise ValueError( - "Expected 'marginal_effects' in idata.sensitivity_analysis. " - "Use SensitivityAnalysis.compute_marginal_effects(..., extend_idata=True)." + "Dataset must contain 'marginal_effects' or 'x' variable." ) - data_var = sa_group["marginal_effects"] - else: - raise ValueError( - "sensitivity_analysis does not contain 'marginal_effects'. Did you persist it to idata?" - ) - # We want a different y-label and color - # Temporarily swap group to reuse plotting logic - tmp = xr.Dataset({"x": data_var}) - tmp["x"].attrs.update(getattr(sa_group, "attrs", {})) # type: ignore - original = self.idata.sensitivity_analysis # type: ignore - try: - self.idata.sensitivity_analysis = tmp # type: ignore - pc = self._sensitivity_analysis_plot( - hdi_prob=hdi_prob, - aggregation=aggregation, - backend=backend, - ) - pc.map(azp.visuals.labelled_y, text="Marginal Effect") - return pc - finally: - self.idata.sensitivity_analysis = original # type: ignore + # Call helper with data (no more monkey-patching!) + pc = self._sensitivity_analysis_plot( + data=data, + hdi_prob=hdi_prob, + aggregation=aggregation, + backend=backend, + ) + pc.map(azp.visuals.labelled_y, text="Marginal Effect") + return pc diff --git a/tests/mmm/conftest.py b/tests/mmm/conftest.py new file mode 100644 index 000000000..133a04ab4 --- /dev/null +++ b/tests/mmm/conftest.py @@ -0,0 +1,187 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared fixtures for MMM plotting tests.""" + +import arviz as az +import numpy as np +import pandas as pd +import pytest +import xarray as xr + + +@pytest.fixture +def mock_posterior_data(): + """Mock posterior Dataset for testing data parameters.""" + rng = np.random.default_rng(42) + return xr.Dataset( + { + "intercept": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + }, + ) + } + ) + + +@pytest.fixture +def mock_constant_data(): + """Mock constant_data Dataset for saturation plots.""" + rng = np.random.default_rng(42) + n_dates = 52 + n_channels = 3 + + return xr.Dataset( + { + "channel_data": xr.DataArray( + rng.uniform(0, 100, size=(n_dates, n_channels)), + dims=("date", "channel"), + coords={ + "date": pd.date_range("2025-01-01", periods=n_dates, freq="W"), + "channel": ["TV", "Radio", "Digital"], + }, + ), + "channel_scale": xr.DataArray( + rng.uniform(0.5, 2.0, size=(n_channels,)), + dims=("channel",), + coords={"channel": ["TV", "Radio", "Digital"]}, + ), + "target_scale": xr.DataArray(1.0), + } + ) + + +@pytest.fixture +def mock_sensitivity_data(): + """Mock sensitivity analysis data.""" + rng = np.random.default_rng(42) + return xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + + +@pytest.fixture +def mock_idata_with_posterior(): + """Mock InferenceData with posterior data.""" + rng = np.random.default_rng(42) + posterior = xr.Dataset( + { + "intercept": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + }, + ) + } + ) + return az.InferenceData(posterior=posterior) + + +@pytest.fixture +def mock_idata_with_uplift_curve(): + """Mock InferenceData with uplift_curve in sensitivity_analysis.""" + rng = np.random.default_rng(42) + + posterior = xr.Dataset( + { + "intercept": xr.DataArray( + rng.normal(size=(4, 100)), + dims=("chain", "draw"), + ) + } + ) + + sensitivity_analysis = xr.Dataset( + { + "uplift_curve": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + ) + } + ) + + return az.InferenceData( + posterior=posterior, sensitivity_analysis=sensitivity_analysis + ) + + +@pytest.fixture +def mock_idata_with_sensitivity(): + """Mock InferenceData with sensitivity_analysis group.""" + rng = np.random.default_rng(42) + + posterior = xr.Dataset( + { + "intercept": xr.DataArray( + rng.normal(size=(4, 100)), + dims=("chain", "draw"), + ) + } + ) + + sensitivity_analysis = xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + + return az.InferenceData( + posterior=posterior, sensitivity_analysis=sensitivity_analysis + ) + + +@pytest.fixture +def mock_idata_for_legacy(): + """Mock InferenceData for legacy suite tests.""" + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + + posterior_predictive = xr.Dataset( + { + "y": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ) + } + ) + + return az.InferenceData(posterior_predictive=posterior_predictive) diff --git a/tests/mmm/test_legacy_plot_imports.py b/tests/mmm/test_legacy_plot_imports.py new file mode 100644 index 000000000..5aeefde76 --- /dev/null +++ b/tests/mmm/test_legacy_plot_imports.py @@ -0,0 +1,41 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for legacy plot module renaming.""" + +import pytest + + +def test_legacy_plot_module_exists(): + """Test that legacy_plot module exists and can be imported.""" + try: + from pymc_marketing.mmm import legacy_plot + + assert hasattr(legacy_plot, "LegacyMMMPlotSuite") + except ImportError as e: + pytest.fail(f"Failed to import legacy_plot: {e}") + + +def test_legacy_class_name(): + """Test that legacy suite class is named LegacyMMMPlotSuite.""" + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + + assert LegacyMMMPlotSuite.__name__ == "LegacyMMMPlotSuite" + + +def test_old_plot_module_not_exists(): + """Test that old_plot module has been removed.""" + with pytest.raises( + ImportError, match=r"(No module named.*old_plot|cannot import name 'old_plot')" + ): + from pymc_marketing.mmm import old_plot # noqa: F401 diff --git a/tests/mmm/test_legacy_plot_regression.py b/tests/mmm/test_legacy_plot_regression.py new file mode 100644 index 000000000..7c42baace --- /dev/null +++ b/tests/mmm/test_legacy_plot_regression.py @@ -0,0 +1,56 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Regression tests for legacy plot suite.""" + +import numpy as np +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + + +def test_legacy_suite_all_methods_exist(): + """Test all legacy suite methods still exist after rename.""" + expected_methods = [ + "posterior_predictive", + "contributions_over_time", + "saturation_scatterplot", + "saturation_curves", + "saturation_curves_scatter", # Deprecated but still in legacy + "budget_allocation", + "allocated_contribution_by_channel_over_time", + "sensitivity_analysis", + "uplift_curve", + "marginal_curve", + ] + + for method_name in expected_methods: + assert hasattr(LegacyMMMPlotSuite, method_name), ( + f"LegacyMMMPlotSuite missing method: {method_name}" + ) + + +def test_legacy_suite_returns_tuple(mock_idata_for_legacy): + """Test legacy suite returns tuple, not PlotCollection.""" + suite = LegacyMMMPlotSuite(idata=mock_idata_for_legacy) + result = suite.posterior_predictive() + + assert isinstance(result, tuple) + assert len(result) == 2 + assert isinstance(result[0], Figure) + # result[1] can be Axes or ndarray of Axes + if isinstance(result[1], np.ndarray): + assert all(isinstance(ax, Axes) for ax in result[1].flat) + else: + assert isinstance(result[1], Axes) diff --git a/tests/mmm/test_plot_data_parameters.py b/tests/mmm/test_plot_data_parameters.py new file mode 100644 index 000000000..7ea0c5f61 --- /dev/null +++ b/tests/mmm/test_plot_data_parameters.py @@ -0,0 +1,120 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for data parameter standardization across plotting methods.""" + +import arviz_plots +import pytest +import xarray as xr + +from pymc_marketing.mmm.plot import MMMPlotSuite + + +def test_contributions_over_time_accepts_data_parameter(mock_posterior_data): + """Test that contributions_over_time accepts data parameter.""" + # Create suite without idata + suite = MMMPlotSuite(idata=None) + + # Should work with explicit data parameter + pc = suite.contributions_over_time(var=["intercept"], data=mock_posterior_data) + + assert isinstance(pc, arviz_plots.PlotCollection) + + +def test_contributions_over_time_data_parameter_fallback(mock_idata_with_posterior): + """Test that contributions_over_time falls back to self.idata.posterior.""" + suite = MMMPlotSuite(idata=mock_idata_with_posterior) + + # Should work without data parameter (fallback) + pc = suite.contributions_over_time(var=["intercept"]) + + assert isinstance(pc, arviz_plots.PlotCollection) + + +def test_contributions_over_time_no_data_raises_clear_error(): + """Test clear error when no data available.""" + suite = MMMPlotSuite(idata=None) + + with pytest.raises( + ValueError, match=r"No posterior data found.*and no 'data' argument provided" + ): + suite.contributions_over_time(var=["intercept"]) + + +def test_saturation_scatterplot_accepts_data_parameters( + mock_constant_data, mock_posterior_data +): + """Test saturation_scatterplot accepts data parameters.""" + import numpy as np + + # Need to add channel_contribution to mock_posterior_data + # Replicate the data across the channel dimension (3 channels) + intercept_values = mock_posterior_data["intercept"].values + channel_contrib_values = np.repeat(intercept_values[:, :, :, np.newaxis], 3, axis=3) + + mock_posterior_data["channel_contribution"] = xr.DataArray( + channel_contrib_values, + dims=("chain", "draw", "date", "channel"), + coords={ + **{k: v for k, v in mock_posterior_data.coords.items()}, + "channel": ["TV", "Radio", "Digital"], + }, + ) + + suite = MMMPlotSuite(idata=None) + + pc = suite.saturation_scatterplot( + constant_data=mock_constant_data, posterior_data=mock_posterior_data + ) + + assert isinstance(pc, arviz_plots.PlotCollection) + + +def test_sensitivity_analysis_plot_requires_data_parameter(mock_sensitivity_data): + """Test _sensitivity_analysis_plot requires data parameter (no fallback).""" + suite = MMMPlotSuite(idata=None) + + # Should work with data parameter + pc = suite._sensitivity_analysis_plot(data=mock_sensitivity_data) + + assert isinstance(pc, arviz_plots.PlotCollection) + + +def test_sensitivity_analysis_plot_no_fallback_to_self_idata( + mock_idata_with_sensitivity, +): + """Test _sensitivity_analysis_plot doesn't use self.idata even if available.""" + suite = MMMPlotSuite(idata=mock_idata_with_sensitivity) + + # Should raise error even though self.idata has sensitivity_analysis + with pytest.raises(TypeError, match=r"missing.*required.*argument.*data"): + suite._sensitivity_analysis_plot() + + +def test_uplift_curve_passes_data_to_helper_no_monkey_patch( + mock_idata_with_uplift_curve, +): + """Test uplift_curve passes data directly, no monkey-patching.""" + suite = MMMPlotSuite(idata=mock_idata_with_uplift_curve) + + # Store original idata reference + original_idata = suite.idata + original_sa_group = original_idata.sensitivity_analysis + + # Call uplift_curve + pc = suite.uplift_curve() + + # Verify no monkey-patching occurred + assert suite.idata is original_idata + assert suite.idata.sensitivity_analysis is original_sa_group + assert isinstance(pc, arviz_plots.PlotCollection) From 2632822953b35a32a272bce565bb94f1a1839ef4 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Thu, 20 Nov 2025 15:28:03 -0500 Subject: [PATCH 10/29] finished 3.2 --- pymc_marketing/mmm/config.py | 1 + pymc_marketing/mmm/multidimensional.py | 59 +++- pymc_marketing/mmm/plot.py | 96 ++++-- tests/mmm/conftest.py | 120 +++++++ tests/mmm/test_plot.py | 403 ++++++++++++++++++++-- tests/mmm/test_plot_backends.py | 443 ------------------------- tests/mmm/test_plot_compatibility.py | 375 +++++++++++++++++++++ 7 files changed, 984 insertions(+), 513 deletions(-) delete mode 100644 tests/mmm/test_plot_backends.py create mode 100644 tests/mmm/test_plot_compatibility.py diff --git a/pymc_marketing/mmm/config.py b/pymc_marketing/mmm/config.py index 5b7051822..889a15377 100644 --- a/pymc_marketing/mmm/config.py +++ b/pymc_marketing/mmm/config.py @@ -39,6 +39,7 @@ class MMMConfig(dict): _defaults = { "plot.backend": "matplotlib", "plot.show_warnings": True, + "plot.use_v2": False, # Use new arviz_plots-based suite (False = legacy suite for backward compatibility) } def __init__(self): diff --git a/pymc_marketing/mmm/multidimensional.py b/pymc_marketing/mmm/multidimensional.py index 921d97fb7..67def0eb6 100644 --- a/pymc_marketing/mmm/multidimensional.py +++ b/pymc_marketing/mmm/multidimensional.py @@ -191,7 +191,6 @@ add_lift_measurements_to_likelihood_from_saturation, scale_lift_measurements, ) -from pymc_marketing.mmm.plot import MMMPlotSuite from pymc_marketing.mmm.scaling import Scaling, VariableScaling from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis from pymc_marketing.mmm.tvp import infer_time_index @@ -600,11 +599,63 @@ def attrs_to_init_kwargs(cls, attrs: dict[str, str]) -> dict[str, Any]: } @property - def plot(self) -> MMMPlotSuite: - """Use the MMMPlotSuite to plot the results.""" + def plot(self): + """Use the MMMPlotSuite to plot the results. + + The plot suite version is controlled by mmm_config["plot.use_v2"]: + - False (default): Uses legacy matplotlib-based suite (will be deprecated) + - True: Uses new arviz_plots-based suite with multi-backend support + + .. versionchanged:: 0.18.0 + Added version control via mmm_config["plot.use_v2"]. + The legacy suite will be removed in v0.20.0. + + Examples + -------- + Use new plot suite: + + >>> from pymc_marketing.mmm import mmm_config + >>> mmm_config["plot.use_v2"] = True + >>> pc = mmm.plot.posterior_predictive() + >>> pc.show() + + Use legacy plot suite: + + >>> mmm_config["plot.use_v2"] = False + >>> fig, ax = mmm.plot.posterior_predictive() + >>> fig.savefig("plot.png") + + Returns + ------- + MMMPlotSuite or LegacyMMMPlotSuite + Plot suite instance for creating MMM visualizations. + """ + import warnings + + from pymc_marketing.mmm.config import mmm_config + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + from pymc_marketing.mmm.plot import MMMPlotSuite + self._validate_model_was_built() self._validate_idata_exists() - return MMMPlotSuite(idata=self.idata) + + # Check version flag + if mmm_config.get("plot.use_v2", False): + return MMMPlotSuite(idata=self.idata) + else: + # Show deprecation warning for legacy suite + if mmm_config.get("plot.show_warnings", True): + warnings.warn( + "The current MMMPlotSuite will be deprecated in v0.20.0. " + "The new version uses arviz_plots and supports multiple backends " + "(matplotlib, plotly, bokeh). " + "To use the new version: mmm_config['plot.use_v2'] = True\n" + "To suppress this warning: mmm_config['plot.show_warnings'] = False\n" + "See migration guide: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html", + FutureWarning, + stacklevel=2, + ) + return LegacyMMMPlotSuite(idata=self.idata) @property def default_model_config(self) -> dict: diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 3c57112b7..c93abcf13 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -834,42 +834,6 @@ def saturation_curves( return pc - def saturation_curves_scatter( - self, original_scale: bool = False, **kwargs - ) -> PlotCollection: - """ - Plot scatter plots of channel contributions vs. channel data. - - .. deprecated:: 0.1.0 - Will be removed in version 0.20.0. Use :meth:`saturation_scatterplot` instead. - - Parameters - ---------- - channel_contribution : str, optional - Name of the channel contribution variable in the InferenceData. - additional_dims : list[str], optional - Additional dimensions to consider beyond 'channel'. - additional_combinations : list[tuple], optional - Specific combinations of additional dimensions to plot. - **kwargs - Additional keyword arguments passed to _init_subplots. - - Returns - ------- - PlotCollection - """ - import warnings - - warnings.warn( - "saturation_curves_scatter is deprecated and will be removed in version 0.2.0. " - "Use saturation_scatterplot instead.", - DeprecationWarning, - stacklevel=2, - ) - # Note: channel_contribution, additional_dims, and additional_combinations - # are not used by saturation_scatterplot, so we don't pass them - return self.saturation_scatterplot(original_scale=original_scale, **kwargs) - def budget_allocation_roas( self, samples: xr.Dataset, @@ -1434,3 +1398,63 @@ def marginal_curve( ) pc.map(azp.visuals.labelled_y, text="Marginal Effect") return pc + + def budget_allocation(self, *args, **kwargs): + """ + Create bar chart comparing allocated spend and channel contributions. + + .. deprecated:: 0.18.0 + This method was removed in MMMPlotSuite v2. The arviz_plots library + used in v2 doesn't support this specific chart type. See alternatives below. + + Raises + ------ + NotImplementedError + This method is not available in MMMPlotSuite v2. + + Notes + ----- + Alternatives: + + 1. **For ROI distributions**: Use :meth:`budget_allocation_roas` + (different purpose but related to budget allocation) + + 2. **To use the old method**: Switch to legacy suite: + + >>> from pymc_marketing.mmm import mmm_config + >>> mmm_config["plot.use_v2"] = False + >>> mmm.plot.budget_allocation(samples) + + 3. **Custom implementation**: Create bar chart using samples data: + + >>> import matplotlib.pyplot as plt + >>> channel_contrib = samples["channel_contribution"].mean(...) + >>> allocated_spend = samples["allocation"] + >>> # Create custom bar chart with matplotlib + + See Also + -------- + budget_allocation_roas : Plot ROI distributions by channel + + Examples + -------- + Use legacy suite temporarily: + + >>> from pymc_marketing.mmm import mmm_config + >>> original = mmm_config.get("plot.use_v2") + >>> try: + ... mmm_config["plot.use_v2"] = False + ... fig, ax = mmm.plot.budget_allocation(samples) + ... fig.savefig("budget.png") + ... finally: + ... mmm_config["plot.use_v2"] = original + """ + raise NotImplementedError( + "budget_allocation() was removed in MMMPlotSuite v2.\n\n" + "The new arviz_plots-based implementation doesn't support this chart type.\n\n" + "Alternatives:\n" + " 1. For ROI distributions: use budget_allocation_roas()\n" + " 2. To use old method: set mmm_config['plot.use_v2'] = False\n" + " 3. Implement custom bar chart using the samples data\n\n" + "See documentation: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html#budget-allocation" + ) diff --git a/tests/mmm/conftest.py b/tests/mmm/conftest.py index 133a04ab4..1c36e609d 100644 --- a/tests/mmm/conftest.py +++ b/tests/mmm/conftest.py @@ -185,3 +185,123 @@ def mock_idata_for_legacy(): ) return az.InferenceData(posterior_predictive=posterior_predictive) + + +@pytest.fixture +def mock_idata(): + """Mock InferenceData for compatibility testing.""" + rng = np.random.default_rng(42) + + posterior = xr.Dataset( + { + "intercept": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + }, + ), + "channel_contribution": xr.DataArray( + rng.normal(size=(4, 100, 52, 3)), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + "channel": ["TV", "Radio", "Digital"], + }, + ), + } + ) + + posterior_predictive = xr.Dataset( + { + "y": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + }, + ), + } + ) + + constant_data = xr.Dataset( + { + "channel_data": xr.DataArray( + rng.uniform(0, 100, size=(52, 3)), + dims=("date", "channel"), + coords={ + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + "channel": ["TV", "Radio", "Digital"], + }, + ), + "channel_scale": xr.DataArray( + rng.uniform(0.5, 2.0, size=(3,)), + dims=("channel",), + coords={"channel": ["TV", "Radio", "Digital"]}, + ), + "target_scale": xr.DataArray(1.0), + } + ) + + return az.InferenceData( + posterior=posterior, + posterior_predictive=posterior_predictive, + constant_data=constant_data, + ) + + +@pytest.fixture +def mock_mmm(mock_idata): + """Mock MMM instance with idata for compatibility testing.""" + from unittest.mock import Mock + + from pymc_marketing.mmm.multidimensional import MMM + + mmm = Mock(spec=MMM) + mmm.idata = mock_idata + mmm._validate_model_was_built = Mock() + mmm._validate_idata_exists = Mock() + + # Make .plot property work with actual implementation + type(mmm).plot = MMM.plot + + return mmm + + +@pytest.fixture +def mock_mmm_fitted(mock_mmm): + """Mock fitted MMM instance for compatibility testing.""" + # Same as mock_mmm, just clearer name for tests that need fitted model + return mock_mmm + + +@pytest.fixture +def mock_allocation_samples(): + """Mock samples dataset for budget allocation tests.""" + rng = np.random.default_rng(42) + + return xr.Dataset( + { + "channel_contribution_original_scale": xr.DataArray( + rng.normal(size=(4, 100, 52, 3)), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + "channel": ["TV", "Radio", "Digital"], + }, + ), + "allocation": xr.DataArray( + rng.uniform(100, 1000, size=(3,)), + dims=("channel",), + coords={"channel": ["TV", "Radio", "Digital"]}, + ), + } + ) diff --git a/tests/mmm/test_plot.py b/tests/mmm/test_plot.py index 88922348c..f162fde86 100644 --- a/tests/mmm/test_plot.py +++ b/tests/mmm/test_plot.py @@ -134,55 +134,48 @@ def test_contributions_over_time(fit_mmm_with_channel_original_scale): def test_contributions_over_time_with_dim(mock_suite: MMMPlotSuite): # Test with explicit dim argument - fig, ax = mock_suite.contributions_over_time( + from arviz_plots import PlotCollection + + pc = mock_suite.contributions_over_time( var=["intercept", "linear_trend"], dims={"country": "A"}, ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - assert all(isinstance(a, Axes) for a in ax.flat) - # Optionally, check axes shape if known - if hasattr(ax, "shape"): - # When filtering to a single country, shape[-1] should be 1 - assert ax.shape[-1] == 1 + assert isinstance(pc, PlotCollection) + # Verify plot was created successfully + assert hasattr(pc, "backend") + assert hasattr(pc, "show") def test_contributions_over_time_with_dims_list(mock_suite: MMMPlotSuite): """Test that passing a list to dims creates a subplot for each value.""" - fig, ax = mock_suite.contributions_over_time( + from arviz_plots import PlotCollection + + pc = mock_suite.contributions_over_time( var=["intercept"], dims={"country": ["A", "B"]}, ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - # Should create one subplot per value in the list (here: 2 countries) - assert ax.shape[0] == 2 - # Optionally, check subplot titles contain the correct country - for i, country in enumerate(["A", "B"]): - assert country in ax[i, 0].get_title() + assert isinstance(pc, PlotCollection) + assert hasattr(pc, "backend") + assert hasattr(pc, "show") def test_contributions_over_time_with_multiple_dims_lists(mock_suite: MMMPlotSuite): """Test that passing multiple lists to dims creates a subplot for each combination.""" + from arviz_plots import PlotCollection + # Add a fake 'region' dim to the mock posterior for this test if not present idata = mock_suite.idata if "region" not in idata.posterior["intercept"].dims: idata.posterior["intercept"] = idata.posterior["intercept"].expand_dims( region=["X", "Y"] ) - fig, ax = mock_suite.contributions_over_time( + pc = mock_suite.contributions_over_time( var=["intercept"], dims={"country": ["A", "B"], "region": ["X", "Y"]}, ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - # Should create one subplot per combination (2 countries x 2 regions = 4) - assert ax.shape[0] == 4 - combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] - for i, (country, region) in enumerate(combos): - title = ax[i, 0].get_title() - assert country in title - assert region in title + assert isinstance(pc, PlotCollection) + assert hasattr(pc, "backend") + assert hasattr(pc, "show") def test_posterior_predictive(fit_mmm_with_channel_original_scale, df): @@ -299,16 +292,18 @@ def mock_suite_with_sensitivity(mock_idata_with_sensitivity): def test_contributions_over_time_expand_dims(mock_suite: MMMPlotSuite): - fig, ax = mock_suite.contributions_over_time( + from arviz_plots import PlotCollection + + pc = mock_suite.contributions_over_time( var=[ "intercept", "linear_trend", ] ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - assert all(isinstance(a, Axes) for a in ax.flat) + assert isinstance(pc, PlotCollection) + assert hasattr(pc, "backend") + assert hasattr(pc, "show") @pytest.fixture(scope="module") @@ -1053,3 +1048,351 @@ def test__dim_list_handler_mixed(): keys, combos = suite._dim_list_handler({"country": ["A", "B"], "region": "X"}) assert keys == ["country"] assert set(combos) == {("A",), ("B",)} + + +# ============================================================================= +# Comprehensive Backend Tests (Milestone 3) +# ============================================================================= +# These tests verify that all plotting methods work correctly across all +# supported backends (matplotlib, plotly, bokeh). +# ============================================================================= + + +class TestPosteriorPredictiveBackends: + """Test posterior_predictive method across all backends.""" + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_posterior_predictive_all_backends(self, mock_suite, backend): + """Test posterior_predictive works with all backends.""" + from arviz_plots import PlotCollection + + # Create idata with posterior_predictive + idata = mock_suite.idata.copy() + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + idata.posterior_predictive = xr.Dataset( + { + "y": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ) + } + ) + suite = MMMPlotSuite(idata=idata) + + pc = suite.posterior_predictive(backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" + ) + + +class TestContributionsOverTimeBackends: + """Test contributions_over_time method across all backends.""" + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_contributions_over_time_all_backends(self, mock_suite, backend): + """Test contributions_over_time works with all backends.""" + from arviz_plots import PlotCollection + + pc = mock_suite.contributions_over_time(var=["intercept"], backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" + ) + + +class TestSaturationPlotBackends: + """Test saturation plot methods across all backends.""" + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_saturation_scatterplot_all_backends( + self, mock_suite_with_constant_data, backend + ): + """Test saturation_scatterplot works with all backends.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_scatterplot(backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" + ) + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_saturation_curves_all_backends( + self, mock_suite_with_constant_data, mock_saturation_curve, backend + ): + """Test saturation_curves works with all backends.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, backend=backend, n_samples=3 + ) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" + ) + + +class TestBudgetAllocationBackends: + """Test budget allocation methods across all backends.""" + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_budget_allocation_roas_all_backends(self, mock_suite, backend): + """Test budget_allocation_roas works with all backends.""" + from arviz_plots import PlotCollection + + # Create proper allocation samples with required variables and dimensions + rng = np.random.default_rng(42) + channels = ["TV", "Radio", "Digital"] + dates = pd.date_range("2025-01-01", periods=52, freq="W") + samples = xr.Dataset( + { + "channel_contribution_original_scale": xr.DataArray( + rng.normal(loc=1000, scale=100, size=(100, 52, 3)), + dims=("sample", "date", "channel"), + coords={ + "sample": np.arange(100), + "date": dates, + "channel": channels, + }, + ), + "allocation": xr.DataArray( + rng.uniform(100, 1000, size=(3,)), + dims=("channel",), + coords={"channel": channels}, + ), + } + ) + + pc = mock_suite.budget_allocation_roas(samples=samples, backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" + ) + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_allocated_contribution_by_channel_over_time_all_backends( + self, mock_suite, backend + ): + """Test allocated_contribution_by_channel_over_time works with all backends.""" + from arviz_plots import PlotCollection + + # Create proper samples with 'sample', 'date', and 'channel' dimensions + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + channels = ["TV", "Radio", "Digital"] + samples = xr.Dataset( + { + "channel_contribution": xr.DataArray( + rng.normal(size=(100, 52, 3)), + dims=("sample", "date", "channel"), + coords={ + "sample": np.arange(100), + "date": dates, + "channel": channels, + }, + ) + } + ) + + pc = mock_suite.allocated_contribution_by_channel_over_time( + samples=samples, backend=backend + ) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" + ) + + +class TestSensitivityAnalysisBackends: + """Test sensitivity analysis methods across all backends.""" + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_sensitivity_analysis_all_backends( + self, mock_suite_with_sensitivity, backend + ): + """Test sensitivity_analysis works with all backends.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_sensitivity.sensitivity_analysis(backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" + ) + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_uplift_curve_all_backends(self, mock_suite_with_sensitivity, backend): + """Test uplift_curve works with all backends.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_sensitivity.uplift_curve(backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" + ) + + @pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) + def test_marginal_curve_all_backends(self, mock_suite_with_sensitivity, backend): + """Test marginal_curve works with all backends.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_sensitivity.marginal_curve(backend=backend) + + assert isinstance(pc, PlotCollection), ( + f"Expected PlotCollection for backend {backend}, got {type(pc)}" + ) + + +class TestBackendBehavior: + """Test backend configuration and override behavior.""" + + def test_backend_overrides_global_config(self, mock_suite): + """Test that method backend parameter overrides global config.""" + from arviz_plots import PlotCollection + + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.backend", "matplotlib") + + try: + # Set global to matplotlib + mmm_config["plot.backend"] = "matplotlib" + + # Override with plotly + pc_plotly = mock_suite.contributions_over_time( + var=["intercept"], backend="plotly" + ) + assert isinstance(pc_plotly, PlotCollection) + + # Default should still be matplotlib + pc_default = mock_suite.contributions_over_time(var=["intercept"]) + assert isinstance(pc_default, PlotCollection) + + finally: + mmm_config["plot.backend"] = original + + @pytest.mark.parametrize("config_backend", ["matplotlib", "plotly", "bokeh"]) + def test_backend_parameter_none_uses_config(self, mock_suite, config_backend): + """Test that backend=None uses global config.""" + from arviz_plots import PlotCollection + + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.backend", "matplotlib") + + try: + mmm_config["plot.backend"] = config_backend + + pc = mock_suite.contributions_over_time( + var=["intercept"], + backend=None, # Explicitly None + ) + + assert isinstance(pc, PlotCollection) + # PlotCollection should be created with config_backend + + finally: + mmm_config["plot.backend"] = original + + def test_invalid_backend_warning(self, mock_suite): + """Test that invalid backend either shows warning or raises error.""" + # Invalid backend should either warn or raise an error + # arviz_plots may accept the invalid backend string without warning + # This test just verifies the code doesn't crash in unexpected ways + try: + mock_suite.contributions_over_time( + var=["intercept"], backend="invalid_backend" + ) + # If it succeeds, the backend was accepted (implementation dependent) + except (ValueError, TypeError, NotImplementedError, ModuleNotFoundError): + # These are expected exceptions for invalid backends + pass + + +class TestDataParameters: + """Test explicit data parameter functionality.""" + + def test_contributions_over_time_with_explicit_data(self, mock_posterior_data): + """Test contributions_over_time accepts explicit data parameter.""" + from arviz_plots import PlotCollection + + # Create suite without idata + suite = MMMPlotSuite(idata=None) + + # Should work with explicit data parameter + pc = suite.contributions_over_time(var=["intercept"], data=mock_posterior_data) + + assert isinstance(pc, PlotCollection) + + def test_saturation_scatterplot_with_explicit_data( + self, mock_constant_data, mock_posterior_data + ): + """Test saturation_scatterplot accepts explicit data parameters.""" + from arviz_plots import PlotCollection + + suite = MMMPlotSuite(idata=None) + + # Create proper posterior data with channel_contribution + rng = np.random.default_rng(42) + posterior_data = xr.Dataset( + { + "channel_contribution": xr.DataArray( + rng.normal(size=(4, 100, 52, 3)), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": pd.date_range("2025-01-01", periods=52, freq="W"), + "channel": ["TV", "Radio", "Digital"], + }, + ) + } + ) + + pc = suite.saturation_scatterplot( + constant_data=mock_constant_data, posterior_data=posterior_data + ) + + assert isinstance(pc, PlotCollection) + + +class TestIntegration: + """Test complete workflows and method interactions.""" + + def test_multiple_plots_same_suite_instance(self, mock_suite_with_constant_data): + """Test that same suite instance can create multiple plots.""" + from arviz_plots import PlotCollection + + suite = mock_suite_with_constant_data + + # Create multiple different plots + # Use channel_contribution which exists in the fixture + pc1 = suite.contributions_over_time(var=["channel_contribution"]) + pc2 = suite.saturation_scatterplot() + + assert isinstance(pc1, PlotCollection) + assert isinstance(pc2, PlotCollection) + + # All should be independent PlotCollection objects + assert pc1 is not pc2 + + def test_backend_switching_same_method(self, mock_suite): + """Test that backends can be switched for same method.""" + from arviz_plots import PlotCollection + + suite = mock_suite + + # Create same plot with different backends + pc_mpl = suite.contributions_over_time(var=["intercept"], backend="matplotlib") + pc_plotly = suite.contributions_over_time(var=["intercept"], backend="plotly") + pc_bokeh = suite.contributions_over_time(var=["intercept"], backend="bokeh") + + assert isinstance(pc_mpl, PlotCollection) + assert isinstance(pc_plotly, PlotCollection) + assert isinstance(pc_bokeh, PlotCollection) diff --git a/tests/mmm/test_plot_backends.py b/tests/mmm/test_plot_backends.py deleted file mode 100644 index bce16e6de..000000000 --- a/tests/mmm/test_plot_backends.py +++ /dev/null @@ -1,443 +0,0 @@ -# Copyright 2022 - 2025 The PyMC Labs Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Backend-agnostic plotting tests for MMMPlotSuite. - -This test file validates the migration to ArviZ PlotCollection API for -multi-backend support (matplotlib, plotly, bokeh). - -NOTE: Once this migration is complete and stable, evaluate whether -tests/mmm/test_plot.py can be consolidated into this file to avoid duplication. -""" - -import arviz as az -import numpy as np -import pandas as pd -import pytest -import xarray as xr -from matplotlib.axes import Axes -from matplotlib.figure import Figure - -from pymc_marketing.mmm.plot import MMMPlotSuite - - -@pytest.fixture(scope="module") -def mock_idata_for_pp(): - """ - Create mock InferenceData with posterior_predictive for testing. - - Structure mirrors real MMM output with: - - posterior_predictive group with y variable - - proper dimensions: chain, draw, date - - realistic date range - """ - seed = sum(map(ord, "Backend test posterior_predictive")) - rng = np.random.default_rng(seed) - - dates = pd.date_range("2025-01-01", periods=52, freq="W-MON") - - # Create posterior_predictive data - posterior_predictive = xr.Dataset( - { - "y": xr.DataArray( - rng.normal(loc=100, scale=10, size=(4, 100, 52)), - dims=("chain", "draw", "date"), - coords={ - "chain": np.arange(4), - "draw": np.arange(100), - "date": dates, - }, - ) - } - ) - - # Also create a minimal posterior (required for some internal logic) - posterior = xr.Dataset( - { - "intercept": xr.DataArray( - rng.normal(size=(4, 100)), - dims=("chain", "draw"), - coords={ - "chain": np.arange(4), - "draw": np.arange(100), - }, - ) - } - ) - - return az.InferenceData( - posterior=posterior, posterior_predictive=posterior_predictive - ) - - -@pytest.fixture(scope="module") -def mock_suite_with_pp(mock_idata_for_pp): - """ - Fixture providing MMMPlotSuite with posterior_predictive data. - - Used for testing posterior_predictive() method across backends. - """ - return MMMPlotSuite(idata=mock_idata_for_pp) - - -@pytest.fixture(scope="function") -def reset_mmm_config(): - """ - Fixture to reset mmm_config after each test. - - Ensures test isolation - one test's backend changes don't affect others. - """ - from pymc_marketing.mmm import mmm_config - - original = mmm_config["plot.backend"] - yield - mmm_config["plot.backend"] = original - - -# ============================================================================= -# Infrastructure Tests (Global Configuration & Return Types) -# ============================================================================= - - -def test_mmm_config_exists(): - """ - Test that the global mmm_config object exists and is accessible. - - This test verifies: - - mmm_config can be imported from pymc_marketing.mmm - - It has a "plot.backend" key - - Default backend is "matplotlib" - """ - from pymc_marketing.mmm import mmm_config - - assert "plot.backend" in mmm_config, "mmm_config should have 'plot.backend' key" - assert mmm_config["plot.backend"] == "matplotlib", ( - f"Default backend should be 'matplotlib', got {mmm_config['plot.backend']}" - ) - - -def test_mmm_config_backend_setting(): - """ - Test that mmm_config backend can be set and retrieved. - - This test verifies: - - Backend can be changed from default - - New value persists - - Can be reset to default - """ - from pymc_marketing.mmm import mmm_config - - # Store original - original = mmm_config["plot.backend"] - - try: - # Change backend - mmm_config["plot.backend"] = "plotly" - assert mmm_config["plot.backend"] == "plotly", ( - "Backend should change to 'plotly'" - ) - - # Reset - mmm_config.reset() - assert mmm_config["plot.backend"] == "matplotlib", ( - "reset() should restore default 'matplotlib' backend" - ) - finally: - # Cleanup - mmm_config["plot.backend"] = original - - -def test_mmm_config_invalid_backend_warning(): - """ - Test that setting an invalid backend name is handled gracefully. - - This test verifies: - - Invalid backend names are detected - - Either raises ValueError or emits UserWarning - - Helpful error message provided - """ - import warnings - - from pymc_marketing.mmm import mmm_config - - original = mmm_config["plot.backend"] - - try: - # Attempt to set invalid backend - should either raise or warn - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - mmm_config["plot.backend"] = "invalid_backend" - - # If no exception, should have warning - assert len(w) > 0, "Should emit warning for invalid backend" - assert "invalid" in str(w[0].message).lower(), ( - f"Warning should mention 'invalid', got: {w[0].message}" - ) - except ValueError as e: - # Acceptable alternative: raise ValueError - assert "backend" in str(e).lower(), f"Error should mention 'backend', got: {e}" - finally: - mmm_config["plot.backend"] = original - - -# ============================================================================= -# Backend Parameter Tests (posterior_predictive) -# ============================================================================= - - -def test_posterior_predictive_accepts_backend_parameter(mock_suite_with_pp): - """ - Test that posterior_predictive() accepts backend parameter. - - This test verifies: - - backend parameter is accepted - - No TypeError is raised - - Method completes successfully - """ - # Should not raise TypeError - result = mock_suite_with_pp.posterior_predictive(backend="matplotlib") - - assert result is not None, "posterior_predictive should return a result" - - -def test_posterior_predictive_accepts_return_as_pc_parameter(mock_suite_with_pp): - """ - Test that posterior_predictive() accepts return_as_pc parameter. - - This test verifies: - - return_as_pc parameter is accepted - - No TypeError is raised - """ - # Should not raise TypeError - result = mock_suite_with_pp.posterior_predictive(return_as_pc=False) - - assert result is not None, "posterior_predictive should return a result" - - -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_posterior_predictive_backend_overrides_global(mock_suite_with_pp, backend): - """ - Test that backend parameter overrides global mmm_config setting. - - This test verifies: - - Global config set to one backend - - Function called with different backend - - Function uses parameter, not global config - """ - from pymc_marketing.mmm import mmm_config - - original = mmm_config["plot.backend"] - - try: - # Set global to matplotlib - mmm_config["plot.backend"] = "matplotlib" - - # Call with different backend, request PlotCollection to check - pc = mock_suite_with_pp.posterior_predictive(backend=backend, return_as_pc=True) - - assert hasattr(pc, "backend"), "PlotCollection should have backend attribute" - assert pc.backend == backend, ( - f"PlotCollection backend should be '{backend}', got '{pc.backend}'" - ) - finally: - mmm_config["plot.backend"] = original - - -# ============================================================================= -# Return Type Tests (Backward Compatibility) -# ============================================================================= - - -def test_posterior_predictive_returns_tuple_by_default(mock_suite_with_pp): - """ - Test that posterior_predictive() returns tuple by default (backward compat). - - This test verifies: - - Default behavior (no return_as_pc parameter) returns tuple - - Tuple has two elements: (figure, axes) - - axes is a list of matplotlib Axes objects (1D list, not 2D array) - """ - result = mock_suite_with_pp.posterior_predictive() - - assert isinstance(result, tuple), ( - f"Default return should be tuple, got {type(result)}" - ) - assert len(result) == 2, ( - f"Tuple should have 2 elements (fig, axes), got {len(result)}" - ) - - fig, axes = result - - # For matplotlib backend (default), should be Figure and array - assert isinstance(fig, Figure), f"First element should be Figure, got {type(fig)}" - # Note: Current implementation returns NDArray[Axes], need to adapt test - assert axes is not None, "Second element should not be None for matplotlib backend" - - -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_posterior_predictive_returns_plotcollection_when_requested( - mock_suite_with_pp, backend -): - """ - Test that posterior_predictive() returns PlotCollection when return_as_pc=True. - - This test verifies: - - return_as_pc=True returns PlotCollection object - - PlotCollection has correct backend attribute - """ - from arviz_plots import PlotCollection - - result = mock_suite_with_pp.posterior_predictive(backend=backend, return_as_pc=True) - - assert isinstance(result, PlotCollection), ( - f"Should return PlotCollection, got {type(result)}" - ) - assert hasattr(result, "backend"), "PlotCollection should have backend attribute" - assert result.backend == backend, ( - f"Backend should be '{backend}', got '{result.backend}'" - ) - - -def test_posterior_predictive_tuple_has_correct_axes_for_matplotlib(mock_suite_with_pp): - """ - Test that matplotlib backend returns proper axes list in tuple. - - This test verifies: - - When return_as_pc=False and backend="matplotlib" - - Second tuple element is list/array of matplotlib Axes - - All elements in list are Axes instances - """ - _fig, axes = mock_suite_with_pp.posterior_predictive( - backend="matplotlib", return_as_pc=False - ) - - assert axes is not None, "Axes should not be None for matplotlib backend" - # Handle both list and NDArray cases - axes_flat = axes if isinstance(axes, list) else axes.flat - assert all(isinstance(ax, Axes) for ax in axes_flat), ( - "All elements should be matplotlib Axes instances" - ) - - -@pytest.mark.parametrize("backend", ["plotly", "bokeh"]) -def test_posterior_predictive_tuple_has_none_axes_for_nonmatplotlib( - mock_suite_with_pp, backend -): - """ - Test that non-matplotlib backends return None for axes in tuple. - - This test verifies: - - When return_as_pc=False and backend in ["plotly", "bokeh"] - - Second tuple element is None (no axes concept) - - First element is backend-specific figure object - """ - fig, axes = mock_suite_with_pp.posterior_predictive( - backend=backend, return_as_pc=False - ) - - assert axes is None, f"Axes should be None for {backend} backend, got {type(axes)}" - assert fig is not None, f"Figure should exist for {backend} backend" - - -# ============================================================================= -# Visual Output Validation Tests -# ============================================================================= - - -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_posterior_predictive_plotcollection_has_viz_attribute( - mock_suite_with_pp, backend -): - """ - Test that PlotCollection has viz attribute with figure data. - - This test verifies: - - PlotCollection has viz attribute - - viz has figure attribute - - Figure can be extracted - """ - - pc = mock_suite_with_pp.posterior_predictive(backend=backend, return_as_pc=True) - - assert hasattr(pc, "viz"), "PlotCollection should have 'viz' attribute" - assert hasattr(pc.viz, "figure"), ( - "PlotCollection.viz should have 'figure' attribute" - ) - - # Should be able to extract figure - fig = pc.viz.figure.data.item() - assert fig is not None, "Should be able to extract figure from PlotCollection" - - -def test_posterior_predictive_matplotlib_has_lines(mock_suite_with_pp): - """ - Test that matplotlib output contains actual plotted lines. - - This test verifies: - - Axes contain Line2D objects (plotted data) - - Number of lines matches expected variables - - Visual output actually created, not just empty axes - """ - from matplotlib.lines import Line2D - - _fig, axes = mock_suite_with_pp.posterior_predictive( - backend="matplotlib", return_as_pc=False - ) - - # Get first axis (should have plots) - ax = axes.flat[0] - - # Should have lines (median plots) - lines = [child for child in ax.get_children() if isinstance(child, Line2D)] - assert len(lines) > 0, ( - f"Axes should contain Line2D objects (plots), found {len(lines)}" - ) - - -def test_posterior_predictive_plotly_has_traces(mock_suite_with_pp): - """ - Test that plotly output contains actual traces. - - This test verifies: - - Plotly figure has 'data' attribute with traces - - Number of traces > 0 (something was plotted) - - Visual output actually created - """ - fig, _ = mock_suite_with_pp.posterior_predictive( - backend="plotly", return_as_pc=False - ) - - # Plotly figures have .data attribute with traces - assert hasattr(fig, "data"), "Plotly figure should have 'data' attribute" - assert len(fig.data) > 0, f"Plotly figure should have traces, found {len(fig.data)}" - - -def test_posterior_predictive_bokeh_has_renderers(mock_suite_with_pp): - """ - Test that bokeh output contains actual renderers (plot elements). - - This test verifies: - - Bokeh figure has renderers - - Number of renderers > 0 (something was plotted) - - Visual output actually created - """ - fig, _ = mock_suite_with_pp.posterior_predictive( - backend="bokeh", return_as_pc=False - ) - - # Bokeh figures have .renderers attribute - assert hasattr(fig, "renderers"), "Bokeh figure should have 'renderers' attribute" - assert len(fig.renderers) > 0, ( - f"Bokeh figure should have renderers, found {len(fig.renderers)}" - ) diff --git a/tests/mmm/test_plot_compatibility.py b/tests/mmm/test_plot_compatibility.py new file mode 100644 index 000000000..95b059471 --- /dev/null +++ b/tests/mmm/test_plot_compatibility.py @@ -0,0 +1,375 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compatibility tests for plot suite version switching.""" + +import warnings + +import numpy as np +import pytest +from arviz_plots import PlotCollection +from matplotlib.axes import Axes +from matplotlib.figure import Figure + + +class TestVersionSwitching: + """Test that mmm_config['plot.use_v2'] controls which suite is returned.""" + + def test_use_v2_false_returns_legacy_suite(self, mock_mmm): + """Test that use_v2=False returns LegacyMMMPlotSuite.""" + from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + + original = mmm_config.get("plot.use_v2", False) + try: + mmm_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning, match="deprecated in v0.20.0"): + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, LegacyMMMPlotSuite) + assert plot_suite.__class__.__name__ == "LegacyMMMPlotSuite" + finally: + mmm_config["plot.use_v2"] = original + + def test_use_v2_true_returns_new_suite(self, mock_mmm): + """Test that use_v2=True returns MMMPlotSuite.""" + from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + from pymc_marketing.mmm.plot import MMMPlotSuite + + original = mmm_config.get("plot.use_v2", False) + try: + mmm_config["plot.use_v2"] = True + + # Should not warn + with warnings.catch_warnings(): + warnings.simplefilter("error") + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, MMMPlotSuite) + assert not isinstance(plot_suite, LegacyMMMPlotSuite) + assert plot_suite.__class__.__name__ == "MMMPlotSuite" + finally: + mmm_config["plot.use_v2"] = original + + def test_default_is_legacy_suite(self, mock_mmm): + """Test that default behavior uses legacy suite (backward compatible).""" + from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + + # Ensure default state + if "plot.use_v2" in mmm_config: + del mmm_config["plot.use_v2"] + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm.plot + + assert isinstance(plot_suite, LegacyMMMPlotSuite) + + def test_config_flag_persists_across_calls(self, mock_mmm): + """Test that setting config flag affects all subsequent calls.""" + from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm.plot import MMMPlotSuite + + original = mmm_config.get("plot.use_v2", False) + try: + # Set once + mmm_config["plot.use_v2"] = True + + # Multiple calls should all use new suite + plot_suite1 = mock_mmm.plot + plot_suite2 = mock_mmm.plot + plot_suite3 = mock_mmm.plot + + assert isinstance(plot_suite1, MMMPlotSuite) + assert isinstance(plot_suite2, MMMPlotSuite) + assert isinstance(plot_suite3, MMMPlotSuite) + finally: + mmm_config["plot.use_v2"] = original + + +class TestDeprecationWarnings: + """Test deprecation warnings shown correctly with helpful information.""" + + def test_deprecation_warning_shown_by_default(self, mock_mmm): + """Test that deprecation warning is shown when using legacy suite.""" + from pymc_marketing.mmm import mmm_config + + original_use_v2 = mmm_config.get("plot.use_v2", False) + original_warnings = mmm_config.get("plot.show_warnings", True) + + try: + mmm_config["plot.use_v2"] = False + mmm_config["plot.show_warnings"] = True + + with pytest.warns(FutureWarning, match=r"deprecated in v0\.20\.0"): + plot_suite = mock_mmm.plot + + assert plot_suite is not None + finally: + mmm_config["plot.use_v2"] = original_use_v2 + mmm_config["plot.show_warnings"] = original_warnings + + def test_deprecation_warning_suppressible(self, mock_mmm): + """Test that deprecation warning can be suppressed.""" + from pymc_marketing.mmm import mmm_config + + original_use_v2 = mmm_config.get("plot.use_v2", False) + original_warnings = mmm_config.get("plot.show_warnings", True) + + try: + mmm_config["plot.use_v2"] = False + mmm_config["plot.show_warnings"] = False + + # Should not warn + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + plot_suite = mock_mmm.plot + + assert plot_suite is not None + finally: + mmm_config["plot.use_v2"] = original_use_v2 + mmm_config["plot.show_warnings"] = original_warnings + + def test_warning_message_includes_migration_info(self, mock_mmm): + """Test that warning provides clear migration instructions.""" + from pymc_marketing.mmm import mmm_config + + original_use_v2 = mmm_config.get("plot.use_v2", False) + + try: + mmm_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning) as warning_list: + _ = mock_mmm.plot + + warning_msg = str(warning_list[0].message) + + # Check for key information + assert "v0.20.0" in warning_msg, "Should mention removal version" + assert "plot.use_v2" in warning_msg, "Should show how to enable v2" + assert "True" in warning_msg, "Should show value to set" + assert any( + word in warning_msg.lower() + for word in ["migration", "guide", "documentation", "docs"] + ), "Should reference migration guide" + finally: + mmm_config["plot.use_v2"] = original_use_v2 + + def test_no_warning_when_using_new_suite(self, mock_mmm): + """Test that no warning shown when using new suite.""" + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.use_v2", False) + + try: + mmm_config["plot.use_v2"] = True + + with warnings.catch_warnings(): + warnings.simplefilter("error") + plot_suite = mock_mmm.plot + + assert plot_suite is not None + finally: + mmm_config["plot.use_v2"] = original + + +class TestReturnTypeCompatibility: + """Test both suites return correct, expected types.""" + + def test_legacy_suite_returns_tuple(self, mock_mmm_fitted): + """Test legacy suite returns (Figure, Axes) tuple.""" + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.use_v2", False) + + try: + mmm_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + result = plot_suite.posterior_predictive() + + assert isinstance(result, tuple), f"Expected tuple, got {type(result)}" + assert len(result) == 2, f"Expected 2-tuple, got length {len(result)}" + assert isinstance(result[0], Figure), ( + f"Expected Figure, got {type(result[0])}" + ) + + # result[1] can be Axes or ndarray of Axes + if isinstance(result[1], np.ndarray): + assert all(isinstance(ax, Axes) for ax in result[1].flat) + else: + assert isinstance(result[1], Axes) + finally: + mmm_config["plot.use_v2"] = original + + def test_new_suite_returns_plot_collection(self, mock_mmm_fitted): + """Test new suite returns PlotCollection.""" + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.use_v2", False) + + try: + mmm_config["plot.use_v2"] = True + + plot_suite = mock_mmm_fitted.plot + result = plot_suite.posterior_predictive() + + assert isinstance(result, PlotCollection), ( + f"Expected PlotCollection, got {type(result)}" + ) + assert hasattr(result, "backend"), ( + "PlotCollection should have backend attribute" + ) + assert hasattr(result, "show"), "PlotCollection should have show method" + finally: + mmm_config["plot.use_v2"] = original + + def test_both_suites_produce_valid_plots(self, mock_mmm_fitted): + """Test that both suites can successfully create plots.""" + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.use_v2", False) + + try: + # Legacy suite + mmm_config["plot.use_v2"] = False + with pytest.warns(FutureWarning): + legacy_result = mock_mmm_fitted.plot.contributions_over_time( + var=["intercept"] + ) + assert legacy_result is not None + + # New suite + mmm_config["plot.use_v2"] = True + new_result = mock_mmm_fitted.plot.contributions_over_time(var=["intercept"]) + assert new_result is not None + finally: + mmm_config["plot.use_v2"] = original + + +class TestDeprecatedMethodRemoval: + """Test saturation_curves_scatter() removed from new suite but kept in legacy.""" + + def test_saturation_curves_scatter_removed_from_new_suite(self, mock_mmm_fitted): + """Test saturation_curves_scatter removed from new MMMPlotSuite.""" + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.use_v2", False) + + try: + mmm_config["plot.use_v2"] = True + plot_suite = mock_mmm_fitted.plot + + assert not hasattr(plot_suite, "saturation_curves_scatter"), ( + "saturation_curves_scatter should not exist in new MMMPlotSuite" + ) + finally: + mmm_config["plot.use_v2"] = original + + def test_saturation_curves_scatter_exists_in_legacy_suite(self, mock_mmm_fitted): + """Test saturation_curves_scatter still exists in LegacyMMMPlotSuite.""" + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.use_v2", False) + + try: + mmm_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + + assert hasattr(plot_suite, "saturation_curves_scatter"), ( + "saturation_curves_scatter should exist in LegacyMMMPlotSuite" + ) + finally: + mmm_config["plot.use_v2"] = original + + +class TestMissingMethods: + """Test methods that exist in one suite but not the other handle gracefully.""" + + def test_budget_allocation_exists_in_legacy_suite( + self, mock_mmm_fitted, mock_allocation_samples + ): + """Test that budget_allocation() works in legacy suite.""" + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.use_v2", False) + + try: + mmm_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + + # Should work (not raise AttributeError) + result = plot_suite.budget_allocation(samples=mock_allocation_samples) + assert isinstance(result, tuple) + assert len(result) == 2 + finally: + mmm_config["plot.use_v2"] = original + + def test_budget_allocation_raises_in_new_suite(self, mock_mmm_fitted): + """Test that budget_allocation() raises helpful error in new suite.""" + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.use_v2", False) + + try: + mmm_config["plot.use_v2"] = True + plot_suite = mock_mmm_fitted.plot + + with pytest.raises(NotImplementedError, match="removed in MMMPlotSuite v2"): + plot_suite.budget_allocation(samples=None) + finally: + mmm_config["plot.use_v2"] = original + + def test_budget_allocation_roas_exists_in_new_suite(self, mock_mmm_fitted): + """Test that budget_allocation_roas() exists in new suite.""" + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.use_v2", False) + + try: + mmm_config["plot.use_v2"] = True + plot_suite = mock_mmm_fitted.plot + + # Just check that the method exists (not AttributeError) + assert hasattr(plot_suite, "budget_allocation_roas"), ( + "budget_allocation_roas should exist in new MMMPlotSuite" + ) + assert callable(plot_suite.budget_allocation_roas), ( + "budget_allocation_roas should be callable" + ) + finally: + mmm_config["plot.use_v2"] = original + + def test_budget_allocation_roas_missing_in_legacy_suite(self, mock_mmm_fitted): + """Test that budget_allocation_roas() doesn't exist in legacy suite.""" + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.use_v2", False) + + try: + mmm_config["plot.use_v2"] = False + + with pytest.warns(FutureWarning): + plot_suite = mock_mmm_fitted.plot + + with pytest.raises(AttributeError): + plot_suite.budget_allocation_roas(samples=None) + finally: + mmm_config["plot.use_v2"] = original From dc67ab0f421ed19a2f337a3097f63e02b57663b4 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Fri, 21 Nov 2025 10:18:58 -0500 Subject: [PATCH 11/29] working on 3.3 --- tests/mmm/test_plot.py | 895 +++-------------------------------------- 1 file changed, 59 insertions(+), 836 deletions(-) diff --git a/tests/mmm/test_plot.py b/tests/mmm/test_plot.py index f162fde86..da6a81e8d 100644 --- a/tests/mmm/test_plot.py +++ b/tests/mmm/test_plot.py @@ -12,182 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# NOTE: This file may be consolidated with test_plot_backends.py in the future -# once the backend migration is complete and stable. -import warnings +"""Tests for new MMMPlotSuite with multi-backend support (arviz_plots-based). + +This file tests the new arviz_plots-based MMMPlotSuite that supports +matplotlib, plotly, and bokeh backends. + +For tests of the legacy matplotlib-only suite, see test_legacy_plot.py. + +Test Organization: +- Parametrized backend tests: Each plotting method tested with all backends +- Backend behavior tests: Config override, invalid backends +- Data parameter tests: Explicit data parameter functionality +- Integration tests: Multiple plots, backend switching + +.. versionadded:: 0.18.0 + New test suite for arviz_plots-based MMMPlotSuite. +""" import arviz as az import numpy as np import pandas as pd import pytest import xarray as xr -from matplotlib.axes import Axes -from matplotlib.figure import Figure - -from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation - -with warnings.catch_warnings(): - warnings.simplefilter("ignore", FutureWarning) - from pymc_marketing.mmm.multidimensional import MMM from pymc_marketing.mmm.plot import MMMPlotSuite -@pytest.fixture -def mmm(): - return MMM( - date_column="date", - channel_columns=["C1", "C2"], - dims=("country",), - target_column="y", - adstock=GeometricAdstock(l_max=10), - saturation=LogisticSaturation(), - ) - - -@pytest.fixture -def df() -> pd.DataFrame: - dates = pd.date_range("2025-01-01", periods=3, freq="W-MON").rename("date") - df = pd.DataFrame( - { - ("A", "C1"): [1, 2, 3], - ("B", "C1"): [4, 5, 6], - ("A", "C2"): [7, 8, 9], - ("B", "C2"): [10, 11, 12], - }, - index=dates, - ) - df.columns.names = ["country", "channel"] - - y = pd.DataFrame( - { - ("A", "y"): [1, 2, 3], - ("B", "y"): [4, 5, 6], - }, - index=dates, - ) - y.columns.names = ["country", "channel"] - - return pd.concat( - [ - df.stack("country", future_stack=True), - y.stack("country", future_stack=True), - ], - axis=1, - ).reset_index() - - -@pytest.fixture -def fit_mmm_with_channel_original_scale(df, mmm, mock_pymc_sample): - X = df.drop(columns=["y"]) - y = df["y"] - - mmm.build_model(X, y) - mmm.add_original_scale_contribution_variable( - var=[ - "channel_contribution", - ] - ) - - mmm.fit(X, y) - - return mmm - - -@pytest.fixture -def fit_mmm_without_channel_original_scale(df, mmm, mock_pymc_sample): - X = df.drop(columns=["y"]) - y = df["y"] - - mmm.fit(X, y) - - return mmm - - -def test_saturation_curves_scatter_original_scale(fit_mmm_with_channel_original_scale): - fig, ax = fit_mmm_with_channel_original_scale.plot.saturation_curves_scatter( - original_scale=True - ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - assert all(isinstance(a, Axes) for a in ax.flat) - - -def test_saturation_curves_scatter_original_scale_fails_if_no_deterministic( - fit_mmm_without_channel_original_scale, -): - with pytest.raises(ValueError): - fit_mmm_without_channel_original_scale.plot.saturation_curves_scatter( - original_scale=True - ) - - -def test_contributions_over_time(fit_mmm_with_channel_original_scale): - fig, ax = fit_mmm_with_channel_original_scale.plot.contributions_over_time( - var=["channel_contribution"], - hdi_prob=0.95, - ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - assert all(isinstance(a, Axes) for a in ax.flat) - - -def test_contributions_over_time_with_dim(mock_suite: MMMPlotSuite): - # Test with explicit dim argument - from arviz_plots import PlotCollection - - pc = mock_suite.contributions_over_time( - var=["intercept", "linear_trend"], - dims={"country": "A"}, - ) - assert isinstance(pc, PlotCollection) - # Verify plot was created successfully - assert hasattr(pc, "backend") - assert hasattr(pc, "show") - - -def test_contributions_over_time_with_dims_list(mock_suite: MMMPlotSuite): - """Test that passing a list to dims creates a subplot for each value.""" - from arviz_plots import PlotCollection - - pc = mock_suite.contributions_over_time( - var=["intercept"], - dims={"country": ["A", "B"]}, - ) - assert isinstance(pc, PlotCollection) - assert hasattr(pc, "backend") - assert hasattr(pc, "show") - - -def test_contributions_over_time_with_multiple_dims_lists(mock_suite: MMMPlotSuite): - """Test that passing multiple lists to dims creates a subplot for each combination.""" - from arviz_plots import PlotCollection - - # Add a fake 'region' dim to the mock posterior for this test if not present - idata = mock_suite.idata - if "region" not in idata.posterior["intercept"].dims: - idata.posterior["intercept"] = idata.posterior["intercept"].expand_dims( - region=["X", "Y"] - ) - pc = mock_suite.contributions_over_time( - var=["intercept"], - dims={"country": ["A", "B"], "region": ["X", "Y"]}, - ) - assert isinstance(pc, PlotCollection) - assert hasattr(pc, "backend") - assert hasattr(pc, "show") - - -def test_posterior_predictive(fit_mmm_with_channel_original_scale, df): - fit_mmm_with_channel_original_scale.sample_posterior_predictive( - df.drop(columns=["y"]) - ) - fig, ax = fit_mmm_with_channel_original_scale.plot.posterior_predictive( - hdi_prob=0.95, - ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - assert all(isinstance(a, Axes) for a in ax.flat) +# ============================================================================= +# Fixtures +# ============================================================================= @pytest.fixture(scope="module") @@ -211,12 +64,13 @@ def mock_idata() -> az.InferenceData: }, ), "linear_trend": xr.DataArray( - normal(size=(4, 100, 52)), - dims=("chain", "draw", "date"), + normal(size=(4, 100, 52, 3)), + dims=("chain", "draw", "date", "country"), coords={ "chain": np.arange(4), "draw": np.arange(100), "date": dates, + "country": ["A", "B", "C"], }, ), } @@ -291,21 +145,6 @@ def mock_suite_with_sensitivity(mock_idata_with_sensitivity): return MMMPlotSuite(idata=mock_idata_with_sensitivity) -def test_contributions_over_time_expand_dims(mock_suite: MMMPlotSuite): - from arviz_plots import PlotCollection - - pc = mock_suite.contributions_over_time( - var=[ - "intercept", - "linear_trend", - ] - ) - - assert isinstance(pc, PlotCollection) - assert hasattr(pc, "backend") - assert hasattr(pc, "show") - - @pytest.fixture(scope="module") def mock_idata_with_constant_data() -> az.InferenceData: """Create mock InferenceData with constant_data and posterior for saturation tests.""" @@ -418,636 +257,24 @@ def mock_saturation_curve() -> xr.DataArray: ) -class TestSaturationScatterplot: - def test_saturation_scatterplot_basic(self, mock_suite_with_constant_data): - """Test basic functionality of saturation_scatterplot.""" - fig, axes = mock_suite_with_constant_data.saturation_scatterplot() - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - - def test_saturation_scatterplot_original_scale(self, mock_suite_with_constant_data): - """Test saturation_scatterplot with original_scale=True.""" - fig, axes = mock_suite_with_constant_data.saturation_scatterplot( - original_scale=True - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - - def test_saturation_scatterplot_custom_kwargs(self, mock_suite_with_constant_data): - """Test saturation_scatterplot with custom kwargs.""" - fig, axes = mock_suite_with_constant_data.saturation_scatterplot( - width_per_col=8.0, height_per_row=5.0 - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - - def test_saturation_scatterplot_no_constant_data(self, mock_suite): - """Test that saturation_scatterplot raises error without constant_data.""" - with pytest.raises(ValueError, match=r"No 'constant_data' found"): - mock_suite.saturation_scatterplot() - - def test_saturation_scatterplot_no_original_scale_contribution( - self, mock_suite_with_constant_data - ): - """Test that saturation_scatterplot raises error when original_scale=True but no original scale data.""" - # Remove the original scale contribution from the mock data - idata_copy = mock_suite_with_constant_data.idata.copy() - idata_copy.posterior = idata_copy.posterior.drop_vars( - "channel_contribution_original_scale" - ) - suite_without_original_scale = MMMPlotSuite(idata=idata_copy) - - with pytest.raises( - ValueError, match=r"No posterior.channel_contribution_original_scale" - ): - suite_without_original_scale.saturation_scatterplot(original_scale=True) - - -class TestSaturationScatterplotDims: - def test_saturation_scatterplot_with_dim(self, mock_suite_with_constant_data): - """Test saturation_scatterplot with a single value in dims.""" - fig, axes = mock_suite_with_constant_data.saturation_scatterplot( - dims={"country": "A"} - ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - # Should create one column (n_channels, 1) - assert axes.shape[1] == 1 - for row in range(axes.shape[0]): - assert "country=A" in axes[row, 0].get_title() - - def test_saturation_scatterplot_with_dims_list(self, mock_suite_with_constant_data): - """Test saturation_scatterplot with a list in dims (should create subplots for each value).""" - fig, axes = mock_suite_with_constant_data.saturation_scatterplot( - dims={"country": ["A", "B"]} - ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - # Should create two columns (n_channels, 2) - assert axes.shape[1] == 2 - for col, country in enumerate(["A", "B"]): - for row in range(axes.shape[0]): - assert f"country={country}" in axes[row, col].get_title() - - def test_saturation_scatterplot_with_multiple_dims_lists( - self, mock_suite_with_constant_data - ): - """Test saturation_scatterplot with multiple lists in dims (should create subplots for each combination).""" - # Add a fake 'region' dim to the mock constant_data for this test if not present - idata = mock_suite_with_constant_data.idata - if "region" not in idata.constant_data.channel_data.dims: - # Expand channel_data and posterior to add region - new_regions = ["X", "Y"] - channel_data = idata.constant_data.channel_data.expand_dims( - region=new_regions - ) - idata.constant_data["channel_data"] = channel_data - for var in ["channel_contribution", "channel_contribution_original_scale"]: - if var in idata.posterior: - idata.posterior[var] = idata.posterior[var].expand_dims( - region=new_regions - ) - fig, axes = mock_suite_with_constant_data.saturation_scatterplot( - dims={"country": ["A", "B"], "region": ["X", "Y"]} - ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - # Should create 4 columns (n_channels, 4) - assert axes.shape[1] == 4 - combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] - for col, (country, region) in enumerate(combos): - for row in range(axes.shape[0]): - title = axes[row, col].get_title() - assert f"country={country}" in title - assert f"region={region}" in title - - -class TestSaturationCurves: - def test_saturation_curves_basic( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test basic functionality of saturation_curves.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=5 - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - - def test_saturation_curves_original_scale( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with original_scale=True.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, original_scale=True, n_samples=3 - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - - def test_saturation_curves_with_hdi( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with HDI intervals.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, hdi_probs=[0.5, 0.94] - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - - def test_saturation_curves_single_hdi( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with single HDI probability.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, hdi_probs=0.85 - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - - def test_saturation_curves_custom_colors( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with custom colors.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, colors=["red", "blue"] - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - - def test_saturation_curves_subplot_kwargs( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with custom subplot_kwargs.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, - n_samples=3, - subplot_kwargs={"figsize": (12, 8)}, - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - # Check that figsize was applied - assert fig.get_size_inches()[0] == 12 - assert fig.get_size_inches()[1] == 8 - - def test_saturation_curves_rc_params( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with rc_params.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, rc_params={"font.size": 14} - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - - def test_saturation_curves_no_samples( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with n_samples=0.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=0, hdi_probs=0.85 - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - - def test_saturation_curves_no_constant_data( - self, mock_suite, mock_saturation_curve - ): - """Test that saturation_curves raises error without constant_data.""" - with pytest.raises(ValueError, match=r"No 'constant_data' found"): - mock_suite.saturation_curves(curve=mock_saturation_curve) - - def test_saturation_curves_no_original_scale_contribution( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test that saturation_curves raises error when original_scale=True but no original scale data.""" - # Remove the original scale contribution from the mock data - idata_copy = mock_suite_with_constant_data.idata.copy() - idata_copy.posterior = idata_copy.posterior.drop_vars( - "channel_contribution_original_scale" - ) - suite_without_original_scale = MMMPlotSuite(idata=idata_copy) - - with pytest.raises( - ValueError, match=r"No posterior.channel_contribution_original_scale" - ): - suite_without_original_scale.saturation_curves( - curve=mock_saturation_curve, original_scale=True - ) - - -class TestSaturationCurvesDims: - def test_saturation_curves_with_dim( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with a single value in dims.""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, dims={"country": "A"} - ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - - for row in range(axes.shape[0]): - assert "country=A" in axes[row, 0].get_title() - - def test_saturation_curves_with_dims_list( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with a list in dims (should create subplots for each value).""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=3, dims={"country": ["A", "B"]} - ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - - def test_saturation_curves_with_multiple_dims_lists( - self, mock_suite_with_constant_data, mock_saturation_curve - ): - """Test saturation_curves with multiple lists in dims (should create subplots for each combination).""" - # Add a fake 'region' dim to the mock constant_data for this test if not present - idata = mock_suite_with_constant_data.idata - if "region" not in idata.constant_data.channel_data.dims: - # Expand channel_data and posterior to add region - new_regions = ["X", "Y"] - channel_data = idata.constant_data.channel_data.expand_dims( - region=new_regions - ) - idata.constant_data["channel_data"] = channel_data - for var in ["channel_contribution", "channel_contribution_original_scale"]: - if var in idata.posterior: - idata.posterior[var] = idata.posterior[var].expand_dims( - region=new_regions - ) - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, - n_samples=3, - dims={"country": ["A", "B"], "region": ["X", "Y"]}, - ) - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] - - for col, (country, region) in enumerate(combos): - for row in range(axes.shape[0]): - title = axes[row, col].get_title() - assert f"country={country}" in title - assert f"region={region}" in title - - -def test_saturation_curves_scatter_deprecation_warning(mock_suite_with_constant_data): - """Test that saturation_curves_scatter shows deprecation warning.""" - with pytest.warns( - DeprecationWarning, match=r"saturation_curves_scatter is deprecated" - ): - fig, axes = mock_suite_with_constant_data.saturation_curves_scatter() - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert all(isinstance(ax, Axes) for ax in axes.flat) - - -@pytest.fixture(scope="module") -def mock_idata_with_constant_data_single_dim() -> az.InferenceData: - """Mock InferenceData where channel_data has only ('date','channel') dims.""" - seed = sum(map(ord, "Saturation single-dim tests")) - rng = np.random.default_rng(seed) - normal = rng.normal - - dates = pd.date_range("2025-01-01", periods=12, freq="W-MON") - channels = ["channel_1", "channel_2", "channel_3"] - - posterior = xr.Dataset( - { - "channel_contribution": xr.DataArray( - normal(size=(2, 10, 12, 3)), - dims=("chain", "draw", "date", "channel"), - coords={ - "chain": np.arange(2), - "draw": np.arange(10), - "date": dates, - "channel": channels, - }, - ), - "channel_contribution_original_scale": xr.DataArray( - normal(size=(2, 10, 12, 3)) * 100.0, - dims=("chain", "draw", "date", "channel"), - coords={ - "chain": np.arange(2), - "draw": np.arange(10), - "date": dates, - "channel": channels, - }, - ), - } - ) - - constant_data = xr.Dataset( - { - "channel_data": xr.DataArray( - rng.uniform(0, 10, size=(12, 3)), - dims=("date", "channel"), - coords={"date": dates, "channel": channels}, - ), - "channel_scale": xr.DataArray( - [100.0, 150.0, 200.0], dims=("channel",), coords={"channel": channels} - ), - "target_scale": xr.DataArray( - [1000.0], dims="target", coords={"target": ["y"]} - ), - } - ) - - return az.InferenceData(posterior=posterior, constant_data=constant_data) - - -@pytest.fixture(scope="module") -def mock_suite_with_constant_data_single_dim(mock_idata_with_constant_data_single_dim): - return MMMPlotSuite(idata=mock_idata_with_constant_data_single_dim) - - -@pytest.fixture(scope="module") -def mock_saturation_curve_single_dim() -> xr.DataArray: - """Saturation curve with dims ('chain','draw','channel','x').""" - seed = sum(map(ord, "Saturation curve single-dim")) - rng = np.random.default_rng(seed) - x_values = np.linspace(0, 1, 50) - channels = ["channel_1", "channel_2", "channel_3"] - - # shape: (chains=2, draws=10, channel=3, x=50) - curve_array = np.empty((2, 10, len(channels), len(x_values))) - for ci in range(2): - for di in range(10): - for c in range(len(channels)): - curve_array[ci, di, c, :] = x_values / (1 + x_values) + rng.normal( - 0, 0.02, size=x_values.shape - ) - - return xr.DataArray( - curve_array, - dims=("chain", "draw", "channel", "x"), - coords={ - "chain": np.arange(2), - "draw": np.arange(10), - "channel": channels, - "x": x_values, - }, - name="saturation_curve", - ) - - -def test_saturation_curves_single_dim_axes_shape( - mock_suite_with_constant_data_single_dim, mock_saturation_curve_single_dim -): - """When there are no extra dims, columns should default to 1 (no ncols=0).""" - fig, axes = mock_suite_with_constant_data_single_dim.saturation_curves( - curve=mock_saturation_curve_single_dim, n_samples=3 - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - # Expect (n_channels, 1) - assert axes.shape[1] == 1 - assert axes.shape[0] == mock_saturation_curve_single_dim.sizes["channel"] - - -def test_saturation_curves_multi_dim_axes_shape( - mock_suite_with_constant_data, mock_saturation_curve -): - """With an extra dim (e.g., 'country'), expect (n_channels, n_countries).""" - fig, axes = mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, n_samples=2 - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - - -def test_sensitivity_analysis_basic(mock_suite_with_sensitivity): - fig, axes = mock_suite_with_sensitivity.sensitivity_analysis() - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert axes.ndim == 2 - expected_panels = len( - mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] - ) # type: ignore - assert axes.size >= expected_panels - assert all(isinstance(ax, Axes) for ax in axes.flat[:expected_panels]) - - -def test_sensitivity_analysis_with_aggregation(mock_suite_with_sensitivity): - ax = mock_suite_with_sensitivity.sensitivity_analysis( - aggregation={"sum": ("region",)} - ) - assert isinstance(ax, Axes) - - -def test_marginal_curve(mock_suite_with_sensitivity): - fig, axes = mock_suite_with_sensitivity.marginal_curve() - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert axes.ndim == 2 - regions = mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] # type: ignore - assert axes.size >= len(regions) - assert all(isinstance(ax, Axes) for ax in axes.flat[: len(regions)]) - - -def test_uplift_curve(mock_suite_with_sensitivity): - fig, axes = mock_suite_with_sensitivity.uplift_curve() - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert axes.ndim == 2 - regions = mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] # type: ignore - assert axes.size >= len(regions) - assert all(isinstance(ax, Axes) for ax in axes.flat[: len(regions)]) - - -def test_sensitivity_analysis_multi_panel(mock_suite_with_sensitivity): - # The fixture provides an extra 'region' dimension, so multiple panels should be produced - fig, axes = mock_suite_with_sensitivity.sensitivity_analysis( - subplot_kwargs={"ncols": 2} - ) - - assert isinstance(fig, Figure) - assert isinstance(axes, np.ndarray) - assert axes.ndim == 2 - # There should be two regions, therefore exactly two panels - expected_panels = len( - mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] - ) # type: ignore - assert axes.size >= expected_panels - assert all(isinstance(ax, Axes) for ax in axes.flat[:expected_panels]) - - -def test_sensitivity_analysis_error_on_missing_results(mock_idata): - suite = MMMPlotSuite(idata=mock_idata) - with pytest.raises(ValueError, match=r"No sensitivity analysis results found"): - suite.sensitivity_analysis() - suite.plot_sensitivity_analysis() - - -def test_budget_allocation_with_dims(mock_suite_with_constant_data): - # Use dims to filter to a single country - samples = mock_suite_with_constant_data.idata.posterior - # Add a fake 'allocation' variable for testing - samples = samples.copy() - samples["allocation"] = ( - samples["channel_contribution"].dims, - np.abs(samples["channel_contribution"].values), - ) - plot_suite = mock_suite_with_constant_data - fig, _ax = plot_suite.budget_allocation( - samples=samples, - dims={"country": "A"}, - ) - assert isinstance(fig, Figure) - - -def test_budget_allocation_with_dims_list(mock_suite_with_constant_data): - """Test that passing a list to dims creates a subplot for each value.""" - samples = mock_suite_with_constant_data.idata.posterior.copy() - # Add a fake 'allocation' variable for testing - samples["allocation"] = ( - samples["channel_contribution"].dims, - np.abs(samples["channel_contribution"].values), - ) - plot_suite = mock_suite_with_constant_data - fig, ax = plot_suite.budget_allocation( - samples=samples, - dims={"country": ["A", "B"]}, - ) - assert isinstance(fig, Figure) - assert isinstance(ax, np.ndarray) - - -def test__validate_dims_valid(): - """Test _validate_dims with valid dims and values.""" - suite = MMMPlotSuite(idata=None) - - # Patch suite.idata.posterior.coords to simulate valid dims - class DummyCoord: - def __init__(self, values): - self.values = values - - class DummyCoords: - def __init__(self): - self._coords = { - "country": DummyCoord(["A", "B"]), - "region": DummyCoord(["X", "Y"]), - } - - def __getitem__(self, key): - return self._coords[key] - - class DummyPosterior: - coords = DummyCoords() - - suite.idata = type("idata", (), {"posterior": DummyPosterior()})() - # Should not raise - suite._validate_dims({"country": "A", "region": "X"}, ["country", "region"]) - suite._validate_dims({"country": ["A", "B"]}, ["country", "region"]) - - -def test__validate_dims_invalid_dim(): - """Test _validate_dims raises for invalid dim name.""" - suite = MMMPlotSuite(idata=None) - - class DummyCoord: - def __init__(self, values): - self.values = values - - class DummyCoords: - def __init__(self): - self.country = DummyCoord(["A", "B"]) - - def __getitem__(self, key): - return getattr(self, key) - - class DummyPosterior: - coords = DummyCoords() - - suite.idata = type("idata", (), {"posterior": DummyPosterior()})() - with pytest.raises(ValueError, match=r"Dimension 'region' not found"): - suite._validate_dims({"region": "X"}, ["country"]) - - -def test__validate_dims_invalid_value(): - """Test _validate_dims raises for invalid value.""" - suite = MMMPlotSuite(idata=None) - - class DummyCoord: - def __init__(self, values): - self.values = values - - class DummyCoords: - def __init__(self): - self.country = DummyCoord(["A", "B"]) - - def __getitem__(self, key): - return getattr(self, key) - - class DummyPosterior: - coords = DummyCoords() - - suite.idata = type("idata", (), {"posterior": DummyPosterior()})() - with pytest.raises(ValueError, match=r"Value 'C' not found in dimension 'country'"): - suite._validate_dims({"country": "C"}, ["country"]) - - -def test__dim_list_handler_none(): - """Test _dim_list_handler with None input.""" - suite = MMMPlotSuite(idata=None) - keys, combos = suite._dim_list_handler(None) - assert keys == [] - assert combos == [()] - +# ============================================================================= +# Basic Functionality Tests +# ============================================================================= -def test__dim_list_handler_single(): - """Test _dim_list_handler with a single list-valued dim.""" - suite = MMMPlotSuite(idata=None) - keys, combos = suite._dim_list_handler({"country": ["A", "B"]}) - assert keys == ["country"] - assert set(combos) == {("A",), ("B",)} +def test_contributions_over_time_expand_dims(mock_suite: MMMPlotSuite): + from arviz_plots import PlotCollection -def test__dim_list_handler_multiple(): - """Test _dim_list_handler with multiple list-valued dims.""" - suite = MMMPlotSuite(idata=None) - keys, combos = suite._dim_list_handler( - {"country": ["A", "B"], "region": ["X", "Y"]} + pc = mock_suite.contributions_over_time( + var=[ + "intercept", + "linear_trend", + ] ) - assert set(keys) == {"country", "region"} - assert set(combos) == {("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")} - -def test__dim_list_handler_mixed(): - """Test _dim_list_handler with mixed single and list values.""" - suite = MMMPlotSuite(idata=None) - keys, combos = suite._dim_list_handler({"country": ["A", "B"], "region": "X"}) - assert keys == ["country"] - assert set(combos) == {("A",), ("B",)} + assert isinstance(pc, PlotCollection) + assert hasattr(pc, "backend") + assert hasattr(pc, "show") # ============================================================================= @@ -1290,29 +517,34 @@ def test_backend_parameter_none_uses_config(self, mock_suite, config_backend): mmm_config["plot.backend"] = config_backend pc = mock_suite.contributions_over_time( - var=["intercept"], - backend=None, # Explicitly None + var=["intercept"], backend=None # Explicitly None ) assert isinstance(pc, PlotCollection) - # PlotCollection should be created with config_backend finally: mmm_config["plot.backend"] = original def test_invalid_backend_warning(self, mock_suite): - """Test that invalid backend either shows warning or raises error.""" - # Invalid backend should either warn or raise an error - # arviz_plots may accept the invalid backend string without warning - # This test just verifies the code doesn't crash in unexpected ways - try: - mock_suite.contributions_over_time( - var=["intercept"], backend="invalid_backend" - ) - # If it succeeds, the backend was accepted (implementation dependent) - except (ValueError, TypeError, NotImplementedError, ModuleNotFoundError): - # These are expected exceptions for invalid backends - pass + """Test that invalid backend shows warning.""" + import warnings + + # Invalid backend should warn but still attempt to create plot + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # This might fail or succeed depending on arviz_plots behavior + # The important thing is that a warning was issued + try: + _pc = mock_suite.contributions_over_time( + var=["intercept"], backend="invalid_backend" + ) + # If it succeeds, just check warning was issued + assert any("backend" in str(warning.message).lower() for warning in w) + except Exception: + # If it fails, that's also acceptable + # The warning should have been issued before the error + assert any("backend" in str(warning.message).lower() for warning in w) class TestDataParameters: @@ -1338,19 +570,11 @@ def test_saturation_scatterplot_with_explicit_data( suite = MMMPlotSuite(idata=None) - # Create proper posterior data with channel_contribution - rng = np.random.default_rng(42) + # Create a small posterior for testing posterior_data = xr.Dataset( { - "channel_contribution": xr.DataArray( - rng.normal(size=(4, 100, 52, 3)), - dims=("chain", "draw", "date", "channel"), - coords={ - "chain": np.arange(4), - "draw": np.arange(100), - "date": pd.date_range("2025-01-01", periods=52, freq="W"), - "channel": ["TV", "Radio", "Digital"], - }, + "channel_contribution": mock_posterior_data["intercept"].isel( + country=0, drop=True ) } ) @@ -1363,7 +587,7 @@ def test_saturation_scatterplot_with_explicit_data( class TestIntegration: - """Test complete workflows and method interactions.""" + """Integration tests for multiple plots and backend switching.""" def test_multiple_plots_same_suite_instance(self, mock_suite_with_constant_data): """Test that same suite instance can create multiple plots.""" @@ -1372,7 +596,6 @@ def test_multiple_plots_same_suite_instance(self, mock_suite_with_constant_data) suite = mock_suite_with_constant_data # Create multiple different plots - # Use channel_contribution which exists in the fixture pc1 = suite.contributions_over_time(var=["channel_contribution"]) pc2 = suite.saturation_scatterplot() From 84eadc988675f2c143032c755c51ff15f6880076 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Fri, 21 Nov 2025 12:50:25 -0500 Subject: [PATCH 12/29] done milestone 4 --- pymc_marketing/mmm/config.py | 92 +++- pymc_marketing/mmm/plot.py | 899 ++++++++++++++++++++++++++++++----- tests/mmm/test_plot.py | 896 +++++++++++++++++++++++++++++++++- 3 files changed, 1736 insertions(+), 151 deletions(-) diff --git a/pymc_marketing/mmm/config.py b/pymc_marketing/mmm/config.py index 889a15377..7f70af925 100644 --- a/pymc_marketing/mmm/config.py +++ b/pymc_marketing/mmm/config.py @@ -19,21 +19,101 @@ class MMMConfig(dict): - """ - Configuration dictionary for MMM plotting settings. + """Configuration dictionary for MMM plotting settings. + + Global configuration object that controls MMM plotting behavior including + backend selection and version control. Modeled after ArviZ's rcParams pattern. + + Available Configuration Keys + ---------------------------- + + **plot.backend** : str, default="matplotlib" + Plotting backend to use for all plots in MMMPlotSuite. Options: + + * ``"matplotlib"`` - Static plots, publication-quality, widest compatibility + * ``"plotly"`` - Interactive plots with hover tooltips and zoom + * ``"bokeh"`` - Interactive plots with rich interactions + + Can be overridden per-method using the ``backend`` parameter. + + .. versionadded:: 0.18.0 + + **plot.show_warnings** : bool, default=True + Whether to show deprecation and other warnings from the plotting suite. + + .. versionadded:: 0.18.0 + + **plot.use_v2** : bool, default=False + Whether to use new arviz_plots-based plotting suite vs legacy suite. + + * ``False`` (default in v0.18.0): Use legacy matplotlib-only suite + * ``True``: Use new multi-backend arviz_plots-based suite + + This flag controls which suite is returned by ``MMM.plot`` property. + + .. versionadded:: 0.18.0 + + .. versionchanged:: 0.19.0 + Default will change to True (new suite becomes default). - Provides backend configuration with validation and reset functionality. - Modeled after ArviZ's rcParams pattern. + .. deprecated:: 0.20.0 + This flag will be removed as legacy suite is removed. Examples -------- + Set plotting backend globally: + >>> from pymc_marketing.mmm import mmm_config >>> mmm_config["plot.backend"] = "plotly" - >>> mmm_config["plot.backend"] - 'plotly' + >>> # All plots now use plotly by default + >>> mmm = MMM(...) + >>> mmm.fit(X, y) + >>> pc = mmm.plot.posterior_predictive() # Uses plotly + >>> pc.show() + + Enable new plotting suite (v2): + + >>> mmm_config["plot.use_v2"] = True + >>> # Now using arviz_plots-based multi-backend suite + >>> mmm = MMM(...) + >>> mmm.fit(X, y) + >>> pc = mmm.plot.contributions_over_time(var=["intercept"]) + >>> pc.show() + + Suppress warnings: + + >>> mmm_config["plot.show_warnings"] = False + + Reset to defaults: + >>> mmm_config.reset() >>> mmm_config["plot.backend"] 'matplotlib' + + Context manager pattern for temporary config changes: + + >>> original = mmm_config["plot.backend"] + >>> try: + ... mmm_config["plot.backend"] = "plotly" + ... # Use plotly for this section + ... pc = mmm.plot.posterior_predictive() + ... pc.show() + ... finally: + ... mmm_config["plot.backend"] = original + + See Also + -------- + MMM.plot : Property that returns appropriate plot suite based on config + MMMPlotSuite : New multi-backend plotting suite + LegacyMMMPlotSuite : Legacy matplotlib-only suite + + Notes + ----- + Configuration changes affect all subsequent plot calls globally unless + overridden at the method level using the ``backend`` parameter. + + The configuration is a singleton - changes affect all MMM instances in + the current Python session. """ _defaults = { diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index c93abcf13..12609ada6 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -338,27 +338,96 @@ def posterior_predictive( hdi_prob: float = 0.85, backend: str | None = None, ) -> PlotCollection: - """ - Plot posterior predictive distributions over time. + """Plot posterior predictive distributions over time. + + Visualizes posterior predictive samples as time series, showing the median + line and highest density interval (HDI) bands. Useful for checking model fit + and understanding prediction uncertainty. + + .. versionadded:: 0.18.0 + New arviz_plots-based implementation supporting multiple backends. Parameters ---------- var : str, optional - Variable name to plot. If None, uses "y". + Variable name to plot from posterior_predictive group. If None, uses "y". idata : xr.Dataset, optional - Dataset containing posterior predictive samples. + Dataset containing posterior predictive samples with a "date" coordinate. If None, uses self.idata.posterior_predictive. + + This parameter allows: + - Testing with mock data without modifying self.idata + - Plotting external posterior predictive samples + - Comparing different model fits side-by-side hdi_prob : float, default 0.85 - Probability mass for HDI interval. + Probability mass for HDI interval (between 0 and 1). backend : str, optional Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". If None, uses global config via mmm_config["plot.backend"]. - Default (via config) is "matplotlib". + Default is "matplotlib". Returns ------- PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)``, + this provides a unified interface across all backends. + + Raises + ------ + ValueError + If no posterior_predictive data found in self.idata and no idata provided. + ValueError + If hdi_prob is not between 0 and 1. + + See Also + -------- + LegacyMMMPlotSuite.posterior_predictive : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) + - Different interface for saving and displaying plots + Examples + -------- + Basic usage: + + >>> mmm.sample_posterior_predictive(X) + >>> pc = mmm.plot.posterior_predictive() + >>> pc.show() + + Plot with different HDI probability: + + >>> pc = mmm.plot.posterior_predictive(hdi_prob=0.94) + >>> pc.show() + + Save to file: + + >>> pc = mmm.plot.posterior_predictive() + >>> pc.save("posterior_predictive.png") + + Use different backend: + + >>> pc = mmm.plot.posterior_predictive(backend="plotly") + >>> pc.show() + + Provide explicit data: + + >>> external_pp = xr.Dataset(...) # Custom posterior predictive + >>> pc = mmm.plot.posterior_predictive(idata=external_pp) + >>> pc.show() + + Direct instantiation pattern: + + >>> from pymc_marketing.mmm.plot import MMMPlotSuite + >>> mps = MMMPlotSuite(custom_idata) + >>> pc = mps.posterior_predictive() + >>> pc.show() """ if not 0 < hdi_prob < 1: raise ValueError("HDI probability must be between 0 and 1.") @@ -428,44 +497,124 @@ def contributions_over_time( dims: dict[str, str | int | list] | None = None, backend: str | None = None, ) -> PlotCollection: - """Plot the time-series contributions for each variable in `var`. + """Plot time-series contributions for specified variables. + + Visualizes how variables contribute over time, showing the median line and + HDI bands. Useful for understanding channel contributions, intercepts, or + other time-varying effects in your model. - showing the median and the credible interval (default 85%). - Creates one subplot per combination of non-(chain/draw/date) dimensions - and places all variables on the same subplot. + .. versionadded:: 0.18.0 + New arviz_plots-based implementation supporting multiple backends. Parameters ---------- var : list of str - A list of variable names to plot from the posterior. + Variable names to plot from the posterior group. Must have a "date" dimension. + Examples: ["channel_contribution"], ["intercept"], ["channel_contribution", "intercept"]. data : xr.Dataset, optional - Dataset containing posterior data. If None, uses self.idata.posterior. + Dataset containing posterior data with variables in `var`. + If None, uses self.idata.posterior. + + .. versionadded:: 0.18.0 + Added data parameter for explicit data passing. This parameter allows: - Testing with mock data without modifying self.idata - Plotting external results not stored in self.idata - Comparing different posterior samples side-by-side - Avoiding unintended side effects on self.idata - hdi_prob: float, optional - The probability mass of the highest density interval to be displayed. Default is 0.85. + hdi_prob : float, default 0.85 + Probability mass for HDI interval (between 0 and 1). dims : dict[str, str | int | list], optional - Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. + Dimension filters to apply. Keys are dimension names, values are either: + - Single value: {"country": "US", "user_type": "new"} + - List of values: {"country": ["US", "UK"]} + If provided, only the selected slice(s) will be plotted. backend : str, optional Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". If None, uses global config via mmm_config["plot.backend"]. - Default (via config) is "matplotlib". + Default is "matplotlib". Returns ------- PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)``, + this provides a unified interface across all backends. Raises ------ ValueError - If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value. + If hdi_prob is not between 0 and 1. ValueError If no posterior data found in self.idata and no data argument provided. + ValueError + If any variable in `var` not found in data. + + See Also + -------- + LegacyMMMPlotSuite.contributions_over_time : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) + - Variable names must be passed in a list (was already list in legacy) + + Examples + -------- + Basic usage - plot channel contributions: + + >>> mmm.fit(X, y) + >>> pc = mmm.plot.contributions_over_time(var=["channel_contribution"]) + >>> pc.show() + + Plot multiple variables together: + + >>> pc = mmm.plot.contributions_over_time( + ... var=["channel_contribution", "intercept"] + ... ) + >>> pc.show() + + Filter by dimension: + + >>> pc = mmm.plot.contributions_over_time( + ... var=["channel_contribution"], dims={"geo": "US"} + ... ) + >>> pc.show() + + Filter with multiple dimension values: + + >>> pc = mmm.plot.contributions_over_time( + ... var=["channel_contribution"], dims={"geo": ["US", "UK"]} + ... ) + >>> pc.show() + + Use different backend: + + >>> pc = mmm.plot.contributions_over_time( + ... var=["channel_contribution"], backend="plotly" + ... ) + >>> pc.show() + + Provide explicit data (option 1 - via data parameter): + + >>> custom_posterior = xr.Dataset(...) + >>> pc = mmm.plot.contributions_over_time( + ... var=["my_contribution"], data=custom_posterior + ... ) + >>> pc.show() + + Provide explicit data (option 2 - direct instantiation): + + >>> from pymc_marketing.mmm.plot import MMMPlotSuite + >>> mps = MMMPlotSuite(custom_idata) + >>> pc = mps.contributions_over_time(var=["my_contribution"]) + >>> pc.show() """ if not 0 < hdi_prob < 1: raise ValueError("HDI probability must be between 0 and 1.") @@ -550,48 +699,121 @@ def saturation_scatterplot( dims: dict[str, str | int | list] | None = None, backend: str | None = None, ) -> PlotCollection: - """Plot the saturation curves for each channel. + """Plot saturation scatter plot showing channel spend vs contributions. + + Creates scatter plots of actual channel spend (X-axis) against channel + contributions (Y-axis), one subplot per channel. Useful for understanding + the saturation behavior and diminishing returns of each marketing channel. - Creates a grid of subplots for each combination of channel and non-(date/channel) dimensions. - Optionally, subset by dims (single values or lists). - Each channel will have a consistent color across all subplots. + .. versionadded:: 0.18.0 + New arviz_plots-based implementation supporting multiple backends. Parameters ---------- - original_scale: bool, optional - Whether to plot the original scale contributions. Default is False. + original_scale : bool, default False + Whether to plot in original scale (True) or scaled space (False). + If True, requires channel_contribution_original_scale in posterior. constant_data : xr.Dataset, optional - Dataset containing constant_data group with 'channel_data' variable. + Dataset containing constant_data group with required variables: + - 'channel_data': Channel spend data (dims include "date", "channel") + - 'channel_scale': Scaling factor per channel (if original_scale=True) + - 'target_scale': Target scaling factor (if original_scale=True) + If None, uses self.idata.constant_data. + .. versionadded:: 0.18.0 + Added constant_data parameter for explicit data passing. + This parameter allows: - Testing with mock constant data - Plotting with alternative scaling factors - Comparing different data scenarios posterior_data : xr.Dataset, optional Dataset containing posterior group with channel contribution variables. + Must contain 'channel_contribution' or 'channel_contribution_original_scale'. If None, uses self.idata.posterior. + .. versionadded:: 0.18.0 + Added posterior_data parameter for explicit data passing. + This parameter allows: - Testing with mock posterior samples - Plotting external inference results - Comparing different model fits - dims: dict[str, str | int | list], optional - Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. + dims : dict[str, str | int | list], optional + Dimension filters to apply. Examples: + - {"geo": "US"} - Single value + - {"geo": ["US", "UK"]} - Multiple values + If provided, only the selected slice(s) will be plotted. - backend: str, optional + backend : str, optional Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". If None, uses global config via mmm_config["plot.backend"]. - Default (via config) is "matplotlib". + Default is "matplotlib". Returns ------- PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)``, + this provides a unified interface across all backends. Raises ------ ValueError If required data not found in self.idata and not provided explicitly. + ValueError + If 'channel_data' not found in constant_data. + ValueError + If original_scale=True but channel_contribution_original_scale not in posterior. + + See Also + -------- + saturation_curves : Add posterior predictive curves to this scatter plot + LegacyMMMPlotSuite.saturation_scatterplot : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) + - Lost **kwargs for matplotlib customization (use backend-specific methods) + - Different grid layout algorithm + + Examples + -------- + Basic usage (scaled space): + + >>> mmm.fit(X, y) + >>> pc = mmm.plot.saturation_scatterplot() + >>> pc.show() + + Plot in original scale: + + >>> mmm.add_original_scale_contribution_variable(var=["channel_contribution"]) + >>> pc = mmm.plot.saturation_scatterplot(original_scale=True) + >>> pc.show() + + Filter by dimension: + + >>> pc = mmm.plot.saturation_scatterplot(dims={"geo": "US"}) + >>> pc.show() + + Use different backend: + + >>> pc = mmm.plot.saturation_scatterplot(backend="plotly") + >>> pc.show() + + Provide explicit data: + + >>> custom_constant = xr.Dataset(...) + >>> custom_posterior = xr.Dataset(...) + >>> pc = mmm.plot.saturation_scatterplot( + ... constant_data=custom_constant, posterior_data=custom_posterior + ... ) + >>> pc.show() """ # Resolve backend backend = self._resolve_backend(backend) @@ -690,50 +912,120 @@ def saturation_curves( dims: dict[str, str | int | list] | None = None, backend: str | None = None, ) -> PlotCollection: - """ - Overlay saturation‑curve scatter‑plots with posterior‑predictive sample curves and HDI bands. + """Overlay saturation scatter plots with posterior predictive curves and HDI bands. + + Builds on saturation_scatterplot() by adding: + - Sample curves from the posterior distribution + - HDI bands showing uncertainty + - Smooth saturation curves over the scatter plot - **allowing** you to customize figsize and font sizes. + .. versionadded:: 0.18.0 + New arviz_plots-based implementation supporting multiple backends. Parameters ---------- curve : xr.DataArray - Posterior‑predictive curves (e.g. dims `("chain","draw","x","channel","geo")`). - original_scale : bool, default=False - Plot `channel_contribution_original_scale` if True, else `channel_contribution`. + Posterior predictive saturation curves with required dimensions: + - "chain", "draw": MCMC samples + - "x": Input values for curve evaluation + - "channel": Channel names + + Generate using: ``mmm.saturation.sample_curve(...)`` + original_scale : bool, default False + Plot in original scale (True) or scaled space (False). + If True, requires channel_contribution_original_scale in posterior. constant_data : xr.Dataset, optional Dataset containing constant_data group. If None, uses self.idata.constant_data. + .. versionadded:: 0.18.0 + Added constant_data parameter for explicit data passing. + This parameter allows testing with mock data and plotting alternative scenarios. posterior_data : xr.Dataset, optional Dataset containing posterior group. If None, uses self.idata.posterior. + .. versionadded:: 0.18.0 + Added posterior_data parameter for explicit data passing. + This parameter allows testing with mock posterior samples and comparing model fits. - n_samples : int, default=10 - Number of sample‑curves per subplot. + n_samples : int, default 10 + Number of sample curves to draw per subplot. + Set to 0 to show only HDI bands without individual samples. hdi_probs : float or list of float, optional - Credible interval probabilities (e.g. 0.94 or [0.5, 0.94]). - If None, uses ArviZ's default (0.94). + HDI probability levels for credible intervals. + Examples: 0.94 (single band), [0.5, 0.94] (multiple bands). + If None, no HDI bands are drawn. random_seed : np.random.Generator, optional - RNG for reproducible sampling. If None, uses `np.random.default_rng()`. + Random number generator for reproducible curve sampling. + If None, uses ``np.random.default_rng()``. dims : dict[str, str | int | list], optional - Dimension filters to apply. Example: {"country": ["US", "UK"], "region": "X"}. + Dimension filters to apply. Examples: + - {"geo": "US"} + - {"geo": ["US", "UK"]} + If provided, only the selected slice(s) will be plotted. - backend: str, optional + backend : str, optional Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". If None, uses global config via mmm_config["plot.backend"]. - Default (via config) is "matplotlib". + Default is "matplotlib". Returns ------- PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + + Raises + ------ + ValueError + If curve is missing required dimensions ("x" or "channel"). + ValueError + If original_scale=True but channel_contribution_original_scale not in posterior. - Example use: - >>> curve = model.saturation.sample_curve( - >>> model.idata.posterior[["saturation_beta", "saturation_lam"]], max_value=2 - >>> ) - >>> pc = model.plot.saturation_curves(curve, original_scale=True, n_samples=10, - >>> hdi_probs=[0.9, 0.7], random_seed=rng) + See Also + -------- + saturation_scatterplot : Base scatter plot without curves + LegacyMMMPlotSuite.saturation_curves : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) + - Lost colors, subplot_kwargs, rc_params parameters + - Different HDI calculation (uses arviz_plots instead of custom) + + Examples + -------- + Generate and plot saturation curves: + + >>> # Generate curves using saturation transformation + >>> curve = mmm.saturation.sample_curve( + ... idata=mmm.idata.posterior[["saturation_beta", "saturation_lam"]], + ... max_value=2.0, + ... ) + >>> pc = mmm.plot.saturation_curves(curve) + >>> pc.show() + + Add HDI bands: + + >>> pc = mmm.plot.saturation_curves(curve, hdi_probs=[0.5, 0.94]) + >>> pc.show() + + Original scale with custom seed: + + >>> import numpy as np + >>> rng = np.random.default_rng(42) + >>> mmm.add_original_scale_contribution_variable(var=["channel_contribution"]) + >>> pc = mmm.plot.saturation_curves( + ... curve, original_scale=True, n_samples=15, random_seed=rng + ... ) + >>> pc.show() + + Filter by dimension: + + >>> pc = mmm.plot.saturation_curves(curve, dims={"geo": "US"}) >>> pc.show() """ # Get constant_data and posterior_data with fallback @@ -841,29 +1133,93 @@ def budget_allocation_roas( dims_to_group_by: list[str] | str | None = None, backend: str | None = None, ) -> PlotCollection: - """Plot the ROI distribution of a given a response distribution and a budget allocation. + """Plot ROI (Return on Ad Spend) distributions for budget allocation scenarios. + + Visualizes the posterior distribution of ROI for each channel given a budget + allocation. Useful for comparing ROI across channels and understanding + optimization trade-offs. + + .. versionadded:: 0.18.0 + New method in MMMPlotSuite v2. This is different from the legacy + budget_allocation() method which showed bar charts. Parameters ---------- samples : xr.Dataset - The dataset containing the channel contributions and allocation values. - Expected to have 'channel_contribution' and 'allocation' variables. + Dataset from budget allocation optimization containing: + - 'channel_contribution_original_scale': Channel contributions + - 'allocation': Allocated budget per channel + - 'channel' dimension + + Typically obtained from: ``mmm.allocate_budget_to_maximize_response(...)`` dims : dict[str, str | int | list], optional - Dimension filters to apply. Example: {"country": ["US", "UK"], "user_type": "new"}. + Dimension filters to apply. Examples: + - {"geo": "US"} + - {"geo": ["US", "UK"]} + If provided, only the selected slice(s) will be plotted. dims_to_group_by : list[str] | str | None, optional - Dimension(s) to group by for plotting purposes. - When a dimension is specified, all the ROAs distributions for each coordinate of that dimension will be - plotted together in a single plot. This is useful for comparing the ROAs distributions. - If None, will not group by any dimensions (i.e. each distribution will be plotted separately). - If a single string, will group by that dimension. - If a list of strings, will group by each of those dimensions. + Dimension(s) to group by for overlaying distributions. + When specified, all ROI distributions for each coordinate of that + dimension will be plotted together for comparison. + + - None (default): Each distribution plotted separately + - Single string: Group by that dimension (e.g., "geo") + - List of strings: Group by multiple dimensions (e.g., ["geo", "segment"]) backend : str | None, optional - Backend to use for plotting. If None, will use the global backend configuration. + Backend to use for plotting. If None, uses global backend configuration. Returns ------- PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + + Raises + ------ + ValueError + If 'channel' dimension not found in samples. + ValueError + If required variables not found in samples. + + See Also + -------- + LegacyMMMPlotSuite.budget_allocation : Legacy bar chart method (different purpose) + + Notes + ----- + This method is NEW in MMMPlotSuite v2 and serves a different purpose + than the legacy ``budget_allocation()`` method: + + - **New method** (this): Shows ROI distributions (KDE plots) + - **Legacy method**: Shows bar charts comparing spend vs contributions + + To use the legacy method, set: ``mmm_config["plot.use_v2"] = False`` + + Examples + -------- + Basic usage with budget optimization results: + + >>> allocation_results = mmm.allocate_budget_to_maximize_response( + ... total_budget=100_000, budget_bounds={"lower": 0.5, "upper": 2.0} + ... ) + >>> pc = mmm.plot.budget_allocation_roas(allocation_results) + >>> pc.show() + + Group by geography to compare ROI across regions: + + >>> pc = mmm.plot.budget_allocation_roas( + ... allocation_results, dims_to_group_by="geo" + ... ) + >>> pc.show() + + Filter and group: + + >>> pc = mmm.plot.budget_allocation_roas( + ... allocation_results, dims={"segment": "premium"}, dims_to_group_by="geo" + ... ) + >>> pc.show() """ # Get the channels from samples if "channel" not in samples.dims: @@ -943,27 +1299,89 @@ def allocated_contribution_by_channel_over_time( hdi_prob: float = 0.85, backend: str | None = None, ) -> PlotCollection: - """Plot the allocated contribution by channel with uncertainty intervals. + """Plot channel contributions over time from budget allocation optimization. - This function visualizes the mean allocated contributions by channel along with - the uncertainty intervals defined by the lower and upper quantiles. - If additional dimensions besides 'channel', 'date', and 'sample' are present, - creates a subplot for each combination of these dimensions. + Visualizes how contributions from each channel evolve over time given an + optimized budget allocation. Shows mean contribution lines per channel with + HDI uncertainty bands. + + .. versionadded:: 0.18.0 + New arviz_plots-based implementation supporting multiple backends. Parameters ---------- samples : xr.Dataset - The dataset containing the samples of channel contributions. - Expected to have 'channel_contribution' variable with dimensions - 'channel', 'date', and 'sample'. - hdi_prob : float, optional - The probability mass of the highest density interval to be displayed. Default is 0.85. + Dataset from budget allocation optimization containing channel + contributions over time. Required dimensions: + - 'channel': Channel names + - 'date': Time dimension + - 'sample': MCMC samples + + Required variables: + - Variable containing 'channel_contribution' (e.g., 'channel_contribution' + or 'channel_contribution_original_scale') + + Typically obtained from: ``mmm.allocate_budget_to_maximize_response(...)`` + hdi_prob : float, default 0.85 + Probability mass for HDI interval (between 0 and 1). backend : str | None, optional - Backend to use for plotting. If None, will use the global backend configuration. + Backend to use for plotting. If None, uses global backend configuration. Returns ------- PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)``, + this provides a unified interface across all backends. + + Raises + ------ + ValueError + If required dimensions ('channel', 'date', 'sample') not found in samples. + ValueError + If no variable containing 'channel_contribution' found in samples. + + See Also + -------- + budget_allocation_roas : Plot ROI distributions from same allocation results + LegacyMMMPlotSuite.allocated_contribution_by_channel_over_time : Legacy implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) + - Lost scale_factor, lower_quantile, upper_quantile, figsize, ax parameters + - Now uses HDI instead of quantiles for uncertainty + - Automatic handling of extra dimensions (creates subplots) + + Examples + -------- + Basic usage with budget optimization results: + + >>> allocation_results = mmm.allocate_budget_to_maximize_response( + ... total_budget=100_000, budget_bounds={"lower": 0.5, "upper": 2.0} + ... ) + >>> pc = mmm.plot.allocated_contribution_by_channel_over_time( + ... allocation_results + ... ) + >>> pc.show() + + Custom HDI probability: + + >>> pc = mmm.plot.allocated_contribution_by_channel_over_time( + ... allocation_results, hdi_prob=0.94 + ... ) + >>> pc.show() + + Use different backend: + + >>> pc = mmm.plot.allocated_contribution_by_channel_over_time( + ... allocation_results, backend="plotly" + ... ) + >>> pc.show() """ # Check for expected dimensions and variables if "channel" not in samples.dims: @@ -1047,30 +1465,67 @@ def _sensitivity_analysis_plot( aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, backend: str | None = None, ) -> PlotCollection: - """Plot helper for sensitivity analysis results. + """Private helper for plotting sensitivity analysis results. + + This is an internal method that performs the core plotting logic for + sensitivity analysis visualizations. Public methods (sensitivity_analysis, + uplift_curve, marginal_curve) handle data retrieval and call this helper. - This is a private helper method that operates on provided data. - Public methods (sensitivity_analysis, uplift_curve, marginal_curve) - handle data retrieval from self.idata. + .. versionadded:: 0.18.0 + New arviz_plots-based implementation supporting multiple backends. Parameters ---------- data : xr.DataArray or xr.Dataset - Sensitivity analysis data to plot. Must have 'sample' and 'sweep' dimensions. - If Dataset, should contain 'x' variable. This parameter is REQUIRED with - no fallback to self.idata to maintain separation of concerns. + Sensitivity analysis data to plot. Must have required dimensions: + - 'sample': MCMC samples + - 'sweep': Sweep values (e.g., multipliers or input values) + + If Dataset, should contain 'x' variable. + + IMPORTANT: This parameter is REQUIRED with no fallback to self.idata. + This design maintains separation of concerns - public methods handle + data retrieval, this helper handles pure plotting. hdi_prob : float, default 0.94 - HDI probability mass. + HDI probability mass (between 0 and 1). aggregation : dict, optional - Aggregation to apply to the data. - E.g., {"sum": ("channel",)} to sum over the channel dimension. + Aggregations to apply before plotting. + Keys are operations ("sum", "mean", "median"), values are dimension tuples. + Example: {"sum": ("channel",)} sums over the channel dimension. backend : str | None, optional - Backend to use for plotting. If None, will use the global backend configuration. + Backend to use for plotting. If None, uses global backend configuration. Returns ------- PlotCollection + arviz_plots PlotCollection object containing the plot. + + Note: Y-axis label is NOT set by this helper. Public methods calling + this helper should set appropriate labels (e.g., "Contribution", + "Uplift (%)", "Marginal Effect"). + Raises + ------ + ValueError + If data is missing required dimensions ('sample', 'sweep'). + + Notes + ----- + Design rationale for REQUIRED data parameter: + + - **Separation of concerns**: Public methods handle data location/retrieval + (from self.idata.sensitivity_analysis, self.idata.posterior, etc.), + this helper handles pure visualization logic. + - **Testability**: Easy to test plotting logic with mock data. + - **Cleaner implementation**: No monkey-patching or state manipulation. + - **Flexibility**: Can be reused for different data sources without + coupling to self.idata structure. + + This is a PRIVATE method (starts with _) and should not be called directly + by users. Use public methods instead: + - sensitivity_analysis(): General sensitivity analysis plots + - uplift_curve(): Uplift percentage plots + - marginal_curve(): Marginal effects plots """ # Handle Dataset or DataArray x = data["x"] if isinstance(data, xr.Dataset) else data @@ -1151,52 +1606,118 @@ def sensitivity_analysis( aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, backend: str | None = None, ) -> PlotCollection: - """Plot sensitivity analysis results. + """Plot sensitivity analysis results showing response to input changes. + + Visualizes how model outputs (e.g., channel contributions) change as inputs + (e.g., channel spend) are varied. Shows mean response line and HDI bands + across sweep values. + + .. versionadded:: 0.18.0 + New arviz_plots-based implementation supporting multiple backends. Parameters ---------- data : xr.DataArray or xr.Dataset, optional - Sensitivity analysis data to plot. If None, uses self.idata.sensitivity_analysis. + Sensitivity analysis data with required dimensions: + - 'sample': MCMC samples + - 'sweep': Sweep values (e.g., multipliers) + + If Dataset, should contain 'x' variable. + If None, uses self.idata.sensitivity_analysis. + + .. versionadded:: 0.18.0 + Added data parameter for explicit data passing. This parameter allows: - Testing with mock sensitivity analysis results - Plotting external sweep results - Comparing different sensitivity analyses hdi_prob : float, default 0.94 - HDI probability mass. + HDI probability mass (between 0 and 1). aggregation : dict, optional - Aggregation to apply to the data. - E.g., {"sum": ("channel",)} to sum over the channel dimension. + Aggregations to apply before plotting. + Keys: "sum", "mean", or "median" + Values: tuple of dimension names + + Example: ``{"sum": ("channel",)}`` sums over channels before plotting. backend : str | None, optional - Backend to use for plotting. If None, will use the global backend configuration. + Backend to use for plotting. If None, uses global backend configuration. Returns ------- PlotCollection + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)`` or ``Axes``, + this provides a unified interface across all backends. Raises ------ ValueError - If no sensitivity analysis data found and no data provided. + If no sensitivity analysis data found in self.idata and no data provided. + + See Also + -------- + uplift_curve : Plot uplift percentages (derived from sensitivity analysis) + marginal_curve : Plot marginal effects (derived from sensitivity analysis) + LegacyMMMPlotSuite.sensitivity_analysis : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) or Axes + - Lost ax, subplot_kwargs, plot_kwargs parameters (use backend methods) + - Cleaner implementation without monkey-patching + - Data parameter for explicit data passing (no side effects on self.idata) Examples -------- - Basic run using stored results in `idata`: + Run sweep and plot results: .. code-block:: python - # Assuming you already ran a sweep and stored results - # under idata.sensitivity_analysis via SensitivityAnalysis.run_sweep(..., extend_idata=True) - mmm.plot.sensitivity_analysis(hdi_prob=0.9) + from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis + + # Run sensitivity sweep + sweeps = np.linspace(0.5, 1.5, 11) + sa = SensitivityAnalysis(mmm.model, mmm.idata) + results = sa.run_sweep( + var_input="channel_data", + sweep_values=sweeps, + var_names="channel_contribution", + sweep_type="multiplicative", + extend_idata=True, # Store in idata + ) + + # Plot stored results + pc = mmm.plot.sensitivity_analysis(hdi_prob=0.9) + pc.show() - With aggregation over dimensions (e.g., sum over channels): + Aggregate over channels: .. code-block:: python - mmm.plot.sensitivity_analysis( - hdi_prob=0.9, - aggregation={"sum": ("channel",)}, + pc = mmm.plot.sensitivity_analysis( + hdi_prob=0.9, aggregation={"sum": ("channel",)} ) + pc.show() + + Use different backend: + + .. code-block:: python + + pc = mmm.plot.sensitivity_analysis(backend="plotly") + pc.show() + + Provide explicit data: + + .. code-block:: python + + external_results = sa.run_sweep(...) # Not stored in idata + pc = mmm.plot.sensitivity_analysis(data=external_results) + pc.show() """ # Retrieve data if not provided data = self._get_data_or_fallback( @@ -1216,45 +1737,83 @@ def uplift_curve( aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, backend: str | None = None, ) -> PlotCollection: - """ - Plot precomputed uplift curves stored under `idata.sensitivity_analysis['uplift_curve']`. + """Plot uplift curves showing percentage change relative to baseline. + + Visualizes relative percentage changes in model outputs (e.g., channel + contributions) as inputs are varied, compared to a reference point. + Shows mean uplift line and HDI bands. + + .. versionadded:: 0.18.0 + New arviz_plots-based implementation supporting multiple backends. Parameters ---------- data : xr.DataArray or xr.Dataset, optional - Uplift curve data to plot. If Dataset, should contain 'uplift_curve' variable. + Uplift curve data computed from sensitivity analysis. + If Dataset, should contain 'uplift_curve' variable. If None, uses self.idata.sensitivity_analysis['uplift_curve']. + Must be precomputed using: + ``SensitivityAnalysis.compute_uplift_curve_respect_to_base(...)`` + + .. versionadded:: 0.18.0 + Added data parameter for explicit data passing. + This parameter allows: - Testing with mock uplift curve data - Plotting externally computed uplift curves - Comparing uplift curves from different models hdi_prob : float, default 0.94 - HDI probability mass. + HDI probability mass (between 0 and 1). aggregation : dict, optional - Aggregation to apply to the data. - E.g., {"sum": ("channel",)} to sum over the channel dimension. + Aggregations to apply before plotting. + Keys: "sum", "mean", or "median" + Values: tuple of dimension names + + Example: ``{"sum": ("channel",)}`` sums over channels before plotting. backend : str | None, optional - Backend to use for plotting. If None, will use the global backend configuration. + Backend to use for plotting. If None, uses global backend configuration. Returns ------- PlotCollection - arviz_plots PlotCollection object. + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)`` or ``Axes``, + this provides a unified interface across all backends. Raises ------ ValueError - If no uplift curve data found and no data provided. + If no uplift curve data found in self.idata and no data provided. + ValueError + If 'uplift_curve' variable not found in sensitivity_analysis group. + + See Also + -------- + sensitivity_analysis : Plot raw sensitivity analysis results + marginal_curve : Plot marginal effects (absolute changes) + LegacyMMMPlotSuite.uplift_curve : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) or Axes + - Cleaner implementation without monkey-patching + - No longer modifies self.idata.sensitivity_analysis temporarily + - Data parameter for explicit data passing Examples -------- - Persist uplift curve and plot: + Compute and plot uplift curve: .. code-block:: python from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis + # Run sensitivity sweep sweeps = np.linspace(0.5, 1.5, 11) sa = SensitivityAnalysis(mmm.model, mmm.idata) results = sa.run_sweep( @@ -1263,10 +1822,39 @@ def uplift_curve( var_names="channel_contribution", sweep_type="multiplicative", ) + + # Compute uplift relative to baseline (ref=1.0) uplift = sa.compute_uplift_curve_respect_to_base( - results, ref=1.0, extend_idata=True + results, + ref=1.0, + extend_idata=True, # Store in idata ) - mmm.plot.uplift_curve(hdi_prob=0.9) + + # Plot stored uplift curve + pc = mmm.plot.uplift_curve(hdi_prob=0.9) + pc.show() + + Aggregate over channels: + + .. code-block:: python + + pc = mmm.plot.uplift_curve(aggregation={"sum": ("channel",)}) + pc.show() + + Use different backend: + + .. code-block:: python + + pc = mmm.plot.uplift_curve(backend="plotly") + pc.show() + + Provide explicit data: + + .. code-block:: python + + uplift_data = sa.compute_uplift_curve_respect_to_base(results, ref=1.0) + pc = mmm.plot.uplift_curve(data=uplift_data) + pc.show() """ # Retrieve data if not provided if data is None: @@ -1311,45 +1899,86 @@ def marginal_curve( aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, backend: str | None = None, ) -> PlotCollection: - """ - Plot precomputed marginal effects stored under `idata.sensitivity_analysis['marginal_effects']`. + """Plot marginal effects showing absolute rate of change. + + Visualizes the instantaneous rate of change (derivative) of model outputs + with respect to inputs. Shows how much output changes per unit change in + input at each sweep value. + + .. versionadded:: 0.18.0 + New arviz_plots-based implementation supporting multiple backends. Parameters ---------- data : xr.DataArray or xr.Dataset, optional - Marginal effects data to plot. If Dataset, should contain 'marginal_effects' variable. + Marginal effects data computed from sensitivity analysis. + If Dataset, should contain 'marginal_effects' variable. If None, uses self.idata.sensitivity_analysis['marginal_effects']. + Must be precomputed using: + ``SensitivityAnalysis.compute_marginal_effects(...)`` + + .. versionadded:: 0.18.0 + Added data parameter for explicit data passing. + This parameter allows: - Testing with mock marginal effects data - Plotting externally computed marginal effects - Comparing marginal effects from different models hdi_prob : float, default 0.94 - HDI probability mass. + HDI probability mass (between 0 and 1). aggregation : dict, optional - Aggregation to apply to the data. - E.g., {"sum": ("channel",)} to sum over the channel dimension. + Aggregations to apply before plotting. + Keys: "sum", "mean", or "median" + Values: tuple of dimension names + + Example: ``{"sum": ("channel",)}`` sums over channels before plotting. backend : str | None, optional - Backend to use for plotting. If None, will use the global backend configuration. + Backend to use for plotting. If None, uses global backend configuration. Returns ------- PlotCollection - arviz_plots PlotCollection object. + arviz_plots PlotCollection object containing the plot. + + Use ``.show()`` to display or ``.save("filename")`` to save. + Unlike the legacy suite which returned ``(Figure, Axes)`` or ``Axes``, + this provides a unified interface across all backends. Raises ------ ValueError - If no marginal effects data found and no data provided. + If no marginal effects data found in self.idata and no data provided. + ValueError + If 'marginal_effects' variable not found in sensitivity_analysis group. + + See Also + -------- + sensitivity_analysis : Plot raw sensitivity analysis results + uplift_curve : Plot uplift percentages (relative changes) + LegacyMMMPlotSuite.marginal_curve : Legacy matplotlib-only implementation + + Notes + ----- + Breaking changes from legacy implementation: + + - Returns PlotCollection instead of (Figure, Axes) or Axes + - Cleaner implementation without monkey-patching + - No longer modifies self.idata.sensitivity_analysis temporarily + - Data parameter for explicit data passing + + Marginal effects show the **slope** of the sensitivity curve, helping + identify where returns are diminishing most rapidly. Examples -------- - Persist marginal effects and plot: + Compute and plot marginal effects: .. code-block:: python from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis + # Run sensitivity sweep sweeps = np.linspace(0.5, 1.5, 11) sa = SensitivityAnalysis(mmm.model, mmm.idata) results = sa.run_sweep( @@ -1358,8 +1987,38 @@ def marginal_curve( var_names="channel_contribution", sweep_type="multiplicative", ) - me = sa.compute_marginal_effects(results, extend_idata=True) - mmm.plot.marginal_curve(hdi_prob=0.9) + + # Compute marginal effects (derivatives) + me = sa.compute_marginal_effects( + results, + extend_idata=True, # Store in idata + ) + + # Plot stored marginal effects + pc = mmm.plot.marginal_curve(hdi_prob=0.9) + pc.show() + + Aggregate over channels: + + .. code-block:: python + + pc = mmm.plot.marginal_curve(aggregation={"sum": ("channel",)}) + pc.show() + + Use different backend: + + .. code-block:: python + + pc = mmm.plot.marginal_curve(backend="plotly") + pc.show() + + Provide explicit data: + + .. code-block:: python + + marginal_data = sa.compute_marginal_effects(results) + pc = mmm.plot.marginal_curve(data=marginal_data) + pc.show() """ # Retrieve data if not provided if data is None: diff --git a/tests/mmm/test_plot.py b/tests/mmm/test_plot.py index da6a81e8d..b0a94e48f 100644 --- a/tests/mmm/test_plot.py +++ b/tests/mmm/test_plot.py @@ -37,7 +37,6 @@ from pymc_marketing.mmm.plot import MMMPlotSuite - # ============================================================================= # Fixtures # ============================================================================= @@ -517,7 +516,8 @@ def test_backend_parameter_none_uses_config(self, mock_suite, config_backend): mmm_config["plot.backend"] = config_backend pc = mock_suite.contributions_over_time( - var=["intercept"], backend=None # Explicitly None + var=["intercept"], + backend=None, # Explicitly None ) assert isinstance(pc, PlotCollection) @@ -525,26 +525,13 @@ def test_backend_parameter_none_uses_config(self, mock_suite, config_backend): finally: mmm_config["plot.backend"] = original - def test_invalid_backend_warning(self, mock_suite): - """Test that invalid backend shows warning.""" - import warnings - - # Invalid backend should warn but still attempt to create plot - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - # This might fail or succeed depending on arviz_plots behavior - # The important thing is that a warning was issued - try: - _pc = mock_suite.contributions_over_time( - var=["intercept"], backend="invalid_backend" - ) - # If it succeeds, just check warning was issued - assert any("backend" in str(warning.message).lower() for warning in w) - except Exception: - # If it fails, that's also acceptable - # The warning should have been issued before the error - assert any("backend" in str(warning.message).lower() for warning in w) + def test_invalid_backend_raises_error(self, mock_suite): + """Test that invalid backend raises an appropriate error.""" + # Invalid backend should raise an error (arviz_plots behavior) + with pytest.raises((ModuleNotFoundError, ImportError, ValueError)): + _pc = mock_suite.contributions_over_time( + var=["intercept"], backend="invalid_backend" + ) class TestDataParameters: @@ -570,11 +557,20 @@ def test_saturation_scatterplot_with_explicit_data( suite = MMMPlotSuite(idata=None) - # Create a small posterior for testing + # Create posterior data with channel_contribution matching constant_data channels + rng = np.random.default_rng(42) + n_channels = len(mock_constant_data.coords["channel"]) posterior_data = xr.Dataset( { - "channel_contribution": mock_posterior_data["intercept"].isel( - country=0, drop=True + "channel_contribution": xr.DataArray( + rng.normal(size=(4, 100, 52, n_channels)), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": mock_constant_data.coords["date"], + "channel": mock_constant_data.coords["channel"], + }, ) } ) @@ -619,3 +615,853 @@ def test_backend_switching_same_method(self, mock_suite): assert isinstance(pc_mpl, PlotCollection) assert isinstance(pc_plotly, PlotCollection) assert isinstance(pc_bokeh, PlotCollection) + + +# ============================================================================= +# Validation Error Tests +# ============================================================================= + + +class TestValidationErrors: + """Test validation and error handling.""" + + def test_posterior_predictive_invalid_hdi_prob(self, mock_suite): + """Test that invalid hdi_prob raises ValueError.""" + # Create idata with posterior_predictive + idata = mock_suite.idata.copy() + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + idata.posterior_predictive = xr.Dataset( + { + "y": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ) + } + ) + suite = MMMPlotSuite(idata=idata) + + with pytest.raises(ValueError, match="HDI probability must be between 0 and 1"): + suite.posterior_predictive(hdi_prob=1.5) + + with pytest.raises(ValueError, match="HDI probability must be between 0 and 1"): + suite.posterior_predictive(hdi_prob=0.0) + + def test_contributions_over_time_invalid_hdi_prob(self, mock_suite): + """Test that invalid hdi_prob raises ValueError.""" + with pytest.raises(ValueError, match="HDI probability must be between 0 and 1"): + mock_suite.contributions_over_time(var=["intercept"], hdi_prob=2.0) + + def test_contributions_over_time_missing_variable(self, mock_suite): + """Test that missing variable raises ValueError.""" + with pytest.raises(ValueError, match="not found in data"): + mock_suite.contributions_over_time(var=["nonexistent_var"]) + + def test_posterior_predictive_no_data(self): + """Test that missing posterior_predictive data raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + with pytest.raises(ValueError, match="No posterior_predictive data found"): + suite.posterior_predictive() + + def test_contributions_over_time_no_posterior(self): + """Test that missing posterior data raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + with pytest.raises(ValueError, match="No posterior data found"): + suite.contributions_over_time(var=["intercept"]) + + def test_saturation_scatterplot_no_constant_data(self): + """Test that missing constant_data raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + with pytest.raises(ValueError, match="No constant data found"): + suite.saturation_scatterplot() + + def test_saturation_scatterplot_missing_channel_data(self, mock_posterior_data): + """Test that missing channel_data variable raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + # Create constant_data without channel_data + constant_data = xr.Dataset({"other_var": xr.DataArray([1, 2, 3])}) + + with pytest.raises(ValueError, match="'channel_data' variable not found"): + suite.saturation_scatterplot( + constant_data=constant_data, posterior_data=mock_posterior_data + ) + + def test_saturation_scatterplot_missing_channel_contribution( + self, mock_constant_data + ): + """Test that missing channel_contribution raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + # Create posterior without channel_contribution + posterior = xr.Dataset({"other_var": xr.DataArray([1, 2, 3])}) + + with pytest.raises(ValueError, match=r"No posterior\.channel_contribution"): + suite.saturation_scatterplot( + constant_data=mock_constant_data, posterior_data=posterior + ) + + def test_saturation_curves_missing_x_dimension(self, mock_suite_with_constant_data): + """Test that curve without 'x' dimension raises ValueError.""" + # Create curve without 'x' dimension + bad_curve = xr.DataArray( + np.random.rand(10, 2), + dims=("time", "channel"), + coords={"time": np.arange(10), "channel": ["A", "B"]}, + ) + + with pytest.raises(ValueError, match="curve must have an 'x' dimension"): + mock_suite_with_constant_data.saturation_curves(curve=bad_curve) + + def test_saturation_curves_missing_channel_dimension( + self, mock_suite_with_constant_data + ): + """Test that curve without 'channel' dimension raises ValueError.""" + # Create curve without 'channel' dimension + bad_curve = xr.DataArray( + np.random.rand(10, 20), + dims=("time", "x"), + coords={"time": np.arange(10), "x": np.linspace(0, 1, 20)}, + ) + + with pytest.raises(ValueError, match="curve must have a 'channel' dimension"): + mock_suite_with_constant_data.saturation_curves(curve=bad_curve) + + def test_budget_allocation_roas_missing_channel_dim(self, mock_suite): + """Test that samples without channel dimension raises ValueError.""" + # Create samples without channel dimension + samples = xr.Dataset({"some_var": xr.DataArray([1, 2, 3])}) + + with pytest.raises(ValueError, match="Expected 'channel' dimension"): + mock_suite.budget_allocation_roas(samples=samples) + + def test_budget_allocation_roas_missing_contribution(self, mock_suite): + """Test that samples without contribution variable raises ValueError.""" + # Create samples with channel but missing contribution + samples = xr.Dataset( + { + "other_var": xr.DataArray( + [1, 2, 3], dims=("channel",), coords={"channel": ["A", "B", "C"]} + ) + } + ) + + with pytest.raises( + ValueError, + match="Expected a variable containing 'channel_contribution_original_scale'", + ): + mock_suite.budget_allocation_roas(samples=samples) + + def test_budget_allocation_roas_missing_allocation(self, mock_suite): + """Test that samples without allocation raises ValueError.""" + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + channels = ["A", "B", "C"] + + # Create samples with contribution but missing allocation + samples = xr.Dataset( + { + "channel_contribution_original_scale": xr.DataArray( + rng.normal(size=(100, 52, 3)), + dims=("sample", "date", "channel"), + coords={ + "sample": np.arange(100), + "date": dates, + "channel": channels, + }, + ) + } + ) + + with pytest.raises(ValueError, match="Expected 'allocation' variable"): + mock_suite.budget_allocation_roas(samples=samples) + + def test_allocated_contribution_missing_channel(self, mock_suite): + """Test that samples without channel dimension raises ValueError.""" + samples = xr.Dataset({"some_var": xr.DataArray([1, 2, 3])}) + + with pytest.raises(ValueError, match="Expected 'channel' dimension"): + mock_suite.allocated_contribution_by_channel_over_time(samples=samples) + + def test_allocated_contribution_missing_date(self, mock_suite): + """Test that samples without date dimension raises ValueError.""" + samples = xr.Dataset( + { + "channel_contribution": xr.DataArray( + [[1, 2], [3, 4]], + dims=("sample", "channel"), + coords={"sample": [0, 1], "channel": ["A", "B"]}, + ) + } + ) + + with pytest.raises(ValueError, match="Expected 'date' dimension"): + mock_suite.allocated_contribution_by_channel_over_time(samples=samples) + + def test_allocated_contribution_missing_sample(self, mock_suite): + """Test that samples without sample dimension raises ValueError.""" + dates = pd.date_range("2025-01-01", periods=10, freq="W") + samples = xr.Dataset( + { + "channel_contribution": xr.DataArray( + [[1, 2], [3, 4]], + dims=("date", "channel"), + coords={"date": dates[:2], "channel": ["A", "B"]}, + ) + } + ) + + with pytest.raises(ValueError, match="Expected 'sample' dimension"): + mock_suite.allocated_contribution_by_channel_over_time(samples=samples) + + def test_allocated_contribution_missing_contribution_var(self, mock_suite): + """Test that samples without channel_contribution variable raises ValueError.""" + dates = pd.date_range("2025-01-01", periods=10, freq="W") + samples = xr.Dataset( + { + "other_var": xr.DataArray( + [[[1, 2]]], + dims=("sample", "date", "channel"), + coords={"sample": [0], "date": dates[:1], "channel": ["A", "B"]}, + ) + } + ) + + with pytest.raises( + ValueError, match="Expected a variable containing 'channel_contribution'" + ): + mock_suite.allocated_contribution_by_channel_over_time(samples=samples) + + def test_sensitivity_analysis_invalid_dimensions(self, mock_suite): + """Test that data without required dimensions raises ValueError.""" + # Create data without required dimensions + bad_data = xr.DataArray( + np.random.rand(10, 20), dims=("time", "space"), name="x" + ) + + with pytest.raises(ValueError, match="Data must have dimensions"): + mock_suite._sensitivity_analysis_plot(data=bad_data) + + def test_sensitivity_analysis_no_data(self): + """Test that missing sensitivity_analysis group raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + with pytest.raises(ValueError, match="No sensitivity analysis results found"): + suite.sensitivity_analysis() + + def test_uplift_curve_missing_data(self): + """Test that missing uplift_curve raises ValueError.""" + # Create idata with sensitivity_analysis but without uplift_curve + rng = np.random.default_rng(42) + idata = az.InferenceData( + posterior=xr.Dataset( + {"intercept": xr.DataArray(rng.normal(size=(4, 100)))} + ), + sensitivity_analysis=xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ), + ) + suite = MMMPlotSuite(idata=idata) + + with pytest.raises(ValueError, match="Expected 'uplift_curve'"): + suite.uplift_curve() + + def test_marginal_curve_missing_data(self): + """Test that missing marginal_effects raises ValueError.""" + # Create idata with sensitivity_analysis but without marginal_effects + rng = np.random.default_rng(42) + idata = az.InferenceData( + posterior=xr.Dataset( + {"intercept": xr.DataArray(rng.normal(size=(4, 100)))} + ), + sensitivity_analysis=xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ), + ) + suite = MMMPlotSuite(idata=idata) + + with pytest.raises(ValueError, match="Expected 'marginal_effects'"): + suite.marginal_curve() + + def test_get_additional_dim_combinations_missing_variable(self, mock_suite): + """Test that missing variable in dataset raises ValueError.""" + with pytest.raises(ValueError, match="Variable 'nonexistent' not found"): + mock_suite._get_additional_dim_combinations( + data=mock_suite.idata.posterior, + variable="nonexistent", + ignored_dims={"chain", "draw"}, + ) + + def test_validate_dims_invalid_dimension(self, mock_suite): + """Test that invalid dimension raises ValueError.""" + with pytest.raises(ValueError, match="Dimension 'invalid_dim' not found"): + mock_suite._validate_dims( + dims={"invalid_dim": "A"}, all_dims=["chain", "draw", "country"] + ) + + def test_validate_dims_invalid_value(self, mock_suite): + """Test that invalid dimension value raises ValueError.""" + with pytest.raises(ValueError, match="Value 'Z' not found in dimension"): + mock_suite._validate_dims( + dims={"country": "Z"}, all_dims=["chain", "draw", "country"] + ) + + def test_validate_dims_invalid_list_value(self, mock_suite): + """Test that invalid value in list raises ValueError.""" + with pytest.raises(ValueError, match="Value 'Z' not found in dimension"): + mock_suite._validate_dims( + dims={"country": ["A", "Z"]}, all_dims=["chain", "draw", "country"] + ) + + +# ============================================================================= +# Edge Case Tests +# ============================================================================= + + +class TestEdgeCases: + """Test edge cases and special scenarios.""" + + def test_contributions_over_time_with_dims_filtering(self, mock_suite): + """Test contributions_over_time with dims parameter.""" + from arviz_plots import PlotCollection + + # Filter to specific country + pc = mock_suite.contributions_over_time( + var=["intercept"], dims={"country": "A"} + ) + assert isinstance(pc, PlotCollection) + + def test_contributions_over_time_with_list_dims(self, mock_suite): + """Test contributions_over_time with list-valued dims.""" + from arviz_plots import PlotCollection + + # Filter to multiple countries + pc = mock_suite.contributions_over_time( + var=["intercept"], dims={"country": ["A", "B"]} + ) + assert isinstance(pc, PlotCollection) + + def test_saturation_scatterplot_with_dims_single_value( + self, mock_suite_with_constant_data + ): + """Test saturation_scatterplot with single-value dims.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_scatterplot(dims={"country": "A"}) + assert isinstance(pc, PlotCollection) + + def test_saturation_scatterplot_with_dims_list(self, mock_suite_with_constant_data): + """Test saturation_scatterplot with list-valued dims.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_scatterplot( + dims={"country": ["A", "B"]} + ) + assert isinstance(pc, PlotCollection) + + def test_saturation_curves_with_hdi_probs_float( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with float hdi_probs.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, hdi_probs=0.9, n_samples=3 + ) + assert isinstance(pc, PlotCollection) + + def test_saturation_curves_with_hdi_probs_list( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with list of hdi_probs.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, hdi_probs=[0.5, 0.9], n_samples=3 + ) + assert isinstance(pc, PlotCollection) + + def test_saturation_curves_with_hdi_probs_tuple( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with tuple of hdi_probs.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, hdi_probs=(0.5, 0.9), n_samples=3 + ) + assert isinstance(pc, PlotCollection) + + def test_saturation_curves_with_hdi_probs_array( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with numpy array of hdi_probs.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, + hdi_probs=np.array([0.5, 0.9]), + n_samples=3, + ) + assert isinstance(pc, PlotCollection) + + def test_budget_allocation_roas_with_dims_to_group_by_string(self, mock_suite): + """Test budget_allocation_roas with dims_to_group_by as string.""" + from arviz_plots import PlotCollection + + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + channels = ["TV", "Radio", "Digital"] + regions = ["East", "West"] + + samples = xr.Dataset( + { + "channel_contribution_original_scale": xr.DataArray( + rng.normal(loc=1000, scale=100, size=(100, 52, 3, 2)), + dims=("sample", "date", "channel", "region"), + coords={ + "sample": np.arange(100), + "date": dates, + "channel": channels, + "region": regions, + }, + ), + "allocation": xr.DataArray( + rng.uniform(100, 1000, size=(3, 2)), + dims=("channel", "region"), + coords={"channel": channels, "region": regions}, + ), + } + ) + + pc = mock_suite.budget_allocation_roas( + samples=samples, dims_to_group_by="region" + ) + assert isinstance(pc, PlotCollection) + + def test_budget_allocation_roas_with_dims_to_group_by_list(self, mock_suite): + """Test budget_allocation_roas with dims_to_group_by as list.""" + from arviz_plots import PlotCollection + + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + channels = ["TV", "Radio"] + regions = ["East", "West"] + + samples = xr.Dataset( + { + "channel_contribution_original_scale": xr.DataArray( + rng.normal(loc=1000, scale=100, size=(100, 52, 2, 2)), + dims=("sample", "date", "channel", "region"), + coords={ + "sample": np.arange(100), + "date": dates, + "channel": channels, + "region": regions, + }, + ), + "allocation": xr.DataArray( + rng.uniform(100, 1000, size=(2, 2)), + dims=("channel", "region"), + coords={"channel": channels, "region": regions}, + ), + } + ) + + pc = mock_suite.budget_allocation_roas( + samples=samples, dims_to_group_by=["channel", "region"] + ) + assert isinstance(pc, PlotCollection) + + def test_sensitivity_analysis_with_aggregation_sum(self, mock_sensitivity_data): + """Test sensitivity_analysis_plot with sum aggregation.""" + from arviz_plots import PlotCollection + + # Add a dimension to aggregate over + data_with_dim = xr.Dataset( + { + "x": xr.DataArray( + np.random.rand(100, 20, 3), + dims=("sample", "sweep", "channel"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + "channel": ["A", "B", "C"], + }, + ) + } + ) + + suite = MMMPlotSuite(idata=None) + pc = suite._sensitivity_analysis_plot( + data=data_with_dim, aggregation={"sum": ("channel",)} + ) + assert isinstance(pc, PlotCollection) + + def test_sensitivity_analysis_with_aggregation_mean(self, mock_sensitivity_data): + """Test sensitivity_analysis_plot with mean aggregation.""" + from arviz_plots import PlotCollection + + data_with_dim = xr.Dataset( + { + "x": xr.DataArray( + np.random.rand(100, 20, 3), + dims=("sample", "sweep", "channel"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + "channel": ["A", "B", "C"], + }, + ) + } + ) + + suite = MMMPlotSuite(idata=None) + pc = suite._sensitivity_analysis_plot( + data=data_with_dim, aggregation={"mean": ("channel",)} + ) + assert isinstance(pc, PlotCollection) + + def test_sensitivity_analysis_with_aggregation_median(self, mock_sensitivity_data): + """Test sensitivity_analysis_plot with median aggregation.""" + from arviz_plots import PlotCollection + + data_with_dim = xr.Dataset( + { + "x": xr.DataArray( + np.random.rand(100, 20, 3), + dims=("sample", "sweep", "channel"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + "channel": ["A", "B", "C"], + }, + ) + } + ) + + suite = MMMPlotSuite(idata=None) + pc = suite._sensitivity_analysis_plot( + data=data_with_dim, aggregation={"median": ("channel",)} + ) + assert isinstance(pc, PlotCollection) + + def test_uplift_curve_with_dataset_containing_uplift_curve(self): + """Test uplift_curve when data is Dataset with uplift_curve variable.""" + from arviz_plots import PlotCollection + + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "uplift_curve": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + + suite = MMMPlotSuite(idata=None) + pc = suite.uplift_curve(data=data) + assert isinstance(pc, PlotCollection) + + def test_uplift_curve_with_dataset_containing_x(self): + """Test uplift_curve when data is Dataset with x variable.""" + from arviz_plots import PlotCollection + + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + + suite = MMMPlotSuite(idata=None) + pc = suite.uplift_curve(data=data) + assert isinstance(pc, PlotCollection) + + def test_marginal_curve_with_dataset_containing_marginal_effects(self): + """Test marginal_curve when data is Dataset with marginal_effects variable.""" + from arviz_plots import PlotCollection + + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "marginal_effects": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + + suite = MMMPlotSuite(idata=None) + pc = suite.marginal_curve(data=data) + assert isinstance(pc, PlotCollection) + + def test_marginal_curve_with_dataset_containing_x(self): + """Test marginal_curve when data is Dataset with x variable.""" + from arviz_plots import PlotCollection + + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "x": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + + suite = MMMPlotSuite(idata=None) + pc = suite.marginal_curve(data=data) + assert isinstance(pc, PlotCollection) + + +# ============================================================================= +# Original Scale Tests +# ============================================================================= + + +class TestOriginalScale: + """Test original_scale parameter functionality.""" + + def test_saturation_scatterplot_original_scale_true( + self, mock_suite_with_constant_data + ): + """Test saturation_scatterplot with original_scale=True.""" + from arviz_plots import PlotCollection + + pc = mock_suite_with_constant_data.saturation_scatterplot(original_scale=True) + assert isinstance(pc, PlotCollection) + + def test_saturation_scatterplot_original_scale_missing_variable( + self, mock_constant_data + ): + """Test that original_scale=True without variable raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + # Create posterior without channel_contribution_original_scale + posterior = xr.Dataset( + { + "channel_contribution": xr.DataArray( + np.random.rand(4, 100, 52, 3), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": mock_constant_data.coords["date"], + "channel": mock_constant_data.coords["channel"], + }, + ) + } + ) + + with pytest.raises( + ValueError, match=r"No posterior\.channel_contribution_original_scale" + ): + suite.saturation_scatterplot( + original_scale=True, + constant_data=mock_constant_data, + posterior_data=posterior, + ) + + def test_saturation_curves_original_scale_missing_variable( + self, mock_constant_data, mock_saturation_curve + ): + """Test that original_scale=True without variable raises ValueError.""" + suite = MMMPlotSuite(idata=None) + + # Create posterior without channel_contribution_original_scale + posterior = xr.Dataset( + { + "channel_contribution": xr.DataArray( + np.random.rand(4, 100, 52, 3), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": mock_constant_data.coords["date"], + "channel": mock_constant_data.coords["channel"], + }, + ) + } + ) + + with pytest.raises(ValueError, match=r"No posterior\.channel_contribution"): + suite.saturation_curves( + curve=mock_saturation_curve, + original_scale=True, + constant_data=mock_constant_data, + posterior_data=posterior, + ) + + +# ============================================================================= +# Deprecated Method Tests +# ============================================================================= + + +class TestDeprecatedMethods: + """Test deprecated methods raise appropriate errors.""" + + def test_budget_allocation_raises_not_implemented(self, mock_suite): + """Test that budget_allocation() raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match=r"budget_allocation.*removed"): + mock_suite.budget_allocation() + + +# ============================================================================= +# Additional Coverage Tests +# ============================================================================= + + +class TestAdditionalCoverage: + """Additional tests to reach >95% coverage.""" + + def test_posterior_predictive_with_explicit_idata(self): + """Test posterior_predictive with explicit idata parameter.""" + from arviz_plots import PlotCollection + + rng = np.random.default_rng(42) + dates = pd.date_range("2025-01-01", periods=52, freq="W") + + # Create posterior_predictive dataset + pp_data = xr.Dataset( + { + "y": xr.DataArray( + rng.normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ) + } + ) + + # Create suite without idata + suite = MMMPlotSuite(idata=None) + + # Should work with explicit idata parameter + pc = suite.posterior_predictive(var="y", idata=pp_data) + assert isinstance(pc, PlotCollection) + + def test_saturation_curves_with_invalid_hdi_probs_type( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test that invalid hdi_probs type raises TypeError.""" + with pytest.raises(TypeError, match="hdi_probs must be a float"): + mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, hdi_probs={"invalid": "type"} + ) + + def test_uplift_curve_with_dataset_missing_both_variables(self): + """Test uplift_curve when Dataset has neither uplift_curve nor x.""" + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "other_var": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + + suite = MMMPlotSuite(idata=None) + with pytest.raises(ValueError, match="must contain 'uplift_curve' or 'x'"): + suite.uplift_curve(data=data) + + def test_marginal_curve_with_dataset_missing_both_variables(self): + """Test marginal_curve when Dataset has neither marginal_effects nor x.""" + rng = np.random.default_rng(42) + data = xr.Dataset( + { + "other_var": xr.DataArray( + rng.normal(size=(100, 20)), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + + suite = MMMPlotSuite(idata=None) + with pytest.raises(ValueError, match="must contain 'marginal_effects' or 'x'"): + suite.marginal_curve(data=data) + + def test_sensitivity_analysis_with_aggregation_no_matching_dims(self): + """Test sensitivity_analysis_plot with aggregation but no matching dims.""" + from arviz_plots import PlotCollection + + # Create data without the dimension to aggregate + data = xr.Dataset( + { + "x": xr.DataArray( + np.random.rand(100, 20), + dims=("sample", "sweep"), + coords={ + "sample": np.arange(100), + "sweep": np.linspace(0, 1, 20), + }, + ) + } + ) + + suite = MMMPlotSuite(idata=None) + # Should work even though "channel" doesn't exist in data + pc = suite._sensitivity_analysis_plot( + data=data, aggregation={"sum": ("channel",)} + ) + assert isinstance(pc, PlotCollection) From ddce5b9578ad0fcbe0eb335c03bf22396b547e9a Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Fri, 21 Nov 2025 12:58:16 -0500 Subject: [PATCH 13/29] add check for valid keys in mmm_config --- pymc_marketing/mmm/config.py | 11 ++++- tests/mmm/test_plot_compatibility.py | 61 ++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/pymc_marketing/mmm/config.py b/pymc_marketing/mmm/config.py index 7f70af925..487545f89 100644 --- a/pymc_marketing/mmm/config.py +++ b/pymc_marketing/mmm/config.py @@ -122,11 +122,20 @@ class MMMConfig(dict): "plot.use_v2": False, # Use new arviz_plots-based suite (False = legacy suite for backward compatibility) } + VALID_KEYS = set(_defaults.keys()) + def __init__(self): super().__init__(self._defaults) def __setitem__(self, key, value): - """Set config value with validation for backend.""" + """Set config value with validation for key and backend.""" + if key not in self.VALID_KEYS: + warnings.warn( + f"Invalid config key '{key}'. Valid keys are: {sorted(self.VALID_KEYS)}. " + f"Setting anyway, but this key may not be recognized.", + UserWarning, + stacklevel=2, + ) if key == "plot.backend": if value not in VALID_BACKENDS: warnings.warn( diff --git a/tests/mmm/test_plot_compatibility.py b/tests/mmm/test_plot_compatibility.py index 95b059471..6e71c3166 100644 --- a/tests/mmm/test_plot_compatibility.py +++ b/tests/mmm/test_plot_compatibility.py @@ -373,3 +373,64 @@ def test_budget_allocation_roas_missing_in_legacy_suite(self, mock_mmm_fitted): plot_suite.budget_allocation_roas(samples=None) finally: mmm_config["plot.use_v2"] = original + + +class TestConfigValidation: + """Test MMMConfig key validation.""" + + def test_invalid_key_warns_but_allows_setting(self): + """Test that setting an invalid config key warns but still sets the value.""" + from pymc_marketing.mmm import mmm_config + + # Store original state + original_invalid = mmm_config.get("invalid.key", None) + try: + # Try to set an invalid key + with pytest.warns(UserWarning, match="Invalid config key"): + mmm_config["invalid.key"] = "some_value" + + # Verify the warning message contains valid keys + with pytest.warns(UserWarning) as warning_list: + mmm_config["another.invalid.key"] = "another_value" + + warning_msg = str(warning_list[0].message) + assert "Invalid config key" in warning_msg + assert "another.invalid.key" in warning_msg + assert "plot.backend" in warning_msg or "plot.show_warnings" in warning_msg + + # Verify the invalid key was still set (allows setting but warns) + assert mmm_config["invalid.key"] == "some_value" + assert mmm_config["another.invalid.key"] == "another_value" + finally: + # Clean up invalid keys + if "invalid.key" in mmm_config: + del mmm_config["invalid.key"] + if "another.invalid.key" in mmm_config: + del mmm_config["another.invalid.key"] + if original_invalid is not None: + mmm_config["invalid.key"] = original_invalid + + def test_valid_keys_do_not_warn(self): + """Test that setting valid config keys does not warn.""" + from pymc_marketing.mmm import mmm_config + + original_backend = mmm_config.get("plot.backend", "matplotlib") + original_use_v2 = mmm_config.get("plot.use_v2", False) + original_warnings = mmm_config.get("plot.show_warnings", True) + + try: + # Setting valid keys should not warn + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + mmm_config["plot.backend"] = "plotly" + mmm_config["plot.use_v2"] = True + mmm_config["plot.show_warnings"] = False + + # Verify values were set + assert mmm_config["plot.backend"] == "plotly" + assert mmm_config["plot.use_v2"] is True + assert mmm_config["plot.show_warnings"] is False + finally: + mmm_config["plot.backend"] = original_backend + mmm_config["plot.use_v2"] = original_use_v2 + mmm_config["plot.show_warnings"] = original_warnings From f354005351e892451bdb067a1806fe933ef17e01 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Fri, 21 Nov 2025 13:02:38 -0500 Subject: [PATCH 14/29] small change to commit --- pymc_marketing/mmm/multidimensional.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc_marketing/mmm/multidimensional.py b/pymc_marketing/mmm/multidimensional.py index 67def0eb6..5657904e1 100644 --- a/pymc_marketing/mmm/multidimensional.py +++ b/pymc_marketing/mmm/multidimensional.py @@ -649,7 +649,9 @@ def plot(self): "The current MMMPlotSuite will be deprecated in v0.20.0. " "The new version uses arviz_plots and supports multiple backends " "(matplotlib, plotly, bokeh). " - "To use the new version: mmm_config['plot.use_v2'] = True\n" + "To use the new version:\n" + " from pymc_marketing.mmm import mmm_config\n" + " mmm_config['plot.use_v2'] = True\n" "To suppress this warning: mmm_config['plot.show_warnings'] = False\n" "See migration guide: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html", FutureWarning, From 26da1d41c6a1206809439abe854f8d5cef1603b6 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Fri, 21 Nov 2025 15:27:06 -0500 Subject: [PATCH 15/29] small update --- pymc_marketing/mmm/multidimensional.py | 9 ++---- tests/mmm/test_plot_compatibility.py | 39 ++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/pymc_marketing/mmm/multidimensional.py b/pymc_marketing/mmm/multidimensional.py index 5657904e1..5fccfa307 100644 --- a/pymc_marketing/mmm/multidimensional.py +++ b/pymc_marketing/mmm/multidimensional.py @@ -183,14 +183,17 @@ SaturationTransformation, saturation_from_dict, ) +from pymc_marketing.mmm.config import mmm_config from pymc_marketing.mmm.events import EventEffect from pymc_marketing.mmm.fourier import YearlyFourier from pymc_marketing.mmm.hsgp import HSGPBase +from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite from pymc_marketing.mmm.lift_test import ( add_cost_per_target_potentials, add_lift_measurements_to_likelihood_from_saturation, scale_lift_measurements, ) +from pymc_marketing.mmm.plot import MMMPlotSuite from pymc_marketing.mmm.scaling import Scaling, VariableScaling from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis from pymc_marketing.mmm.tvp import infer_time_index @@ -630,12 +633,6 @@ def plot(self): MMMPlotSuite or LegacyMMMPlotSuite Plot suite instance for creating MMM visualizations. """ - import warnings - - from pymc_marketing.mmm.config import mmm_config - from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite - from pymc_marketing.mmm.plot import MMMPlotSuite - self._validate_model_was_built() self._validate_idata_exists() diff --git a/tests/mmm/test_plot_compatibility.py b/tests/mmm/test_plot_compatibility.py index 6e71c3166..f3a81f6b4 100644 --- a/tests/mmm/test_plot_compatibility.py +++ b/tests/mmm/test_plot_compatibility.py @@ -98,6 +98,45 @@ def test_config_flag_persists_across_calls(self, mock_mmm): finally: mmm_config["plot.use_v2"] = original + def test_switching_between_v2_true_and_false(self, mock_mmm): + """Test that switching from use_v2=True to False and back works correctly.""" + from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + from pymc_marketing.mmm.plot import MMMPlotSuite + + original = mmm_config.get("plot.use_v2", False) + try: + # Start with use_v2 = True + mmm_config["plot.use_v2"] = True + + # Should return new suite without warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") + plot_suite_v2 = mock_mmm.plot + + assert isinstance(plot_suite_v2, MMMPlotSuite) + + # Switch to use_v2 = False + mmm_config["plot.use_v2"] = False + + # Should return legacy suite with deprecation warning + with pytest.warns(FutureWarning, match="deprecated in v0.20.0"): + plot_suite_legacy = mock_mmm.plot + + assert isinstance(plot_suite_legacy, LegacyMMMPlotSuite) + + # Switch back to use_v2 = True + mmm_config["plot.use_v2"] = True + + # Should return new suite again without warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") + plot_suite_v2_again = mock_mmm.plot + + assert isinstance(plot_suite_v2_again, MMMPlotSuite) + finally: + mmm_config["plot.use_v2"] = original + class TestDeprecationWarnings: """Test deprecation warnings shown correctly with helpful information.""" From d5d4bae7708393d6bd85cdc92795b9e9e7834fad Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Fri, 21 Nov 2025 15:34:06 -0500 Subject: [PATCH 16/29] adds test_legacy_plot --- tests/mmm/test_legacy_plot.py | 1054 +++++++++++++++++++++++++++++++++ 1 file changed, 1054 insertions(+) create mode 100644 tests/mmm/test_legacy_plot.py diff --git a/tests/mmm/test_legacy_plot.py b/tests/mmm/test_legacy_plot.py new file mode 100644 index 000000000..6a879a9c5 --- /dev/null +++ b/tests/mmm/test_legacy_plot.py @@ -0,0 +1,1054 @@ +# Copyright 2022 - 2025 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings + +import arviz as az +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation + +with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + from pymc_marketing.mmm.multidimensional import MMM + +from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite + + +@pytest.fixture +def mmm(): + return MMM( + date_column="date", + channel_columns=["C1", "C2"], + dims=("country",), + target_column="y", + adstock=GeometricAdstock(l_max=10), + saturation=LogisticSaturation(), + ) + + +@pytest.fixture +def df() -> pd.DataFrame: + dates = pd.date_range("2025-01-01", periods=3, freq="W-MON").rename("date") + df = pd.DataFrame( + { + ("A", "C1"): [1, 2, 3], + ("B", "C1"): [4, 5, 6], + ("A", "C2"): [7, 8, 9], + ("B", "C2"): [10, 11, 12], + }, + index=dates, + ) + df.columns.names = ["country", "channel"] + + y = pd.DataFrame( + { + ("A", "y"): [1, 2, 3], + ("B", "y"): [4, 5, 6], + }, + index=dates, + ) + y.columns.names = ["country", "channel"] + + return pd.concat( + [ + df.stack("country", future_stack=True), + y.stack("country", future_stack=True), + ], + axis=1, + ).reset_index() + + +@pytest.fixture +def fit_mmm_with_channel_original_scale(df, mmm, mock_pymc_sample): + X = df.drop(columns=["y"]) + y = df["y"] + + mmm.build_model(X, y) + mmm.add_original_scale_contribution_variable( + var=[ + "channel_contribution", + ] + ) + + mmm.fit(X, y) + + return mmm + + +@pytest.fixture +def fit_mmm_without_channel_original_scale(df, mmm, mock_pymc_sample): + X = df.drop(columns=["y"]) + y = df["y"] + + mmm.fit(X, y) + + return mmm + + +def test_saturation_curves_scatter_original_scale(fit_mmm_with_channel_original_scale): + fig, ax = fit_mmm_with_channel_original_scale.plot.saturation_curves_scatter( + original_scale=True + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert all(isinstance(a, Axes) for a in ax.flat) + + +def test_saturation_curves_scatter_original_scale_fails_if_no_deterministic( + fit_mmm_without_channel_original_scale, +): + with pytest.raises(ValueError): + fit_mmm_without_channel_original_scale.plot.saturation_curves_scatter( + original_scale=True + ) + + +def test_contributions_over_time(fit_mmm_with_channel_original_scale): + fig, ax = fit_mmm_with_channel_original_scale.plot.contributions_over_time( + var=["channel_contribution"], + hdi_prob=0.95, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert all(isinstance(a, Axes) for a in ax.flat) + + +def test_contributions_over_time_with_dim(mock_suite: LegacyMMMPlotSuite): + # Test with explicit dim argument + fig, ax = mock_suite.contributions_over_time( + var=["intercept", "linear_trend"], + dims={"country": "A"}, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert all(isinstance(a, Axes) for a in ax.flat) + # Optionally, check axes shape if known + if hasattr(ax, "shape"): + # When filtering to a single country, shape[-1] should be 1 + assert ax.shape[-1] == 1 + + +def test_contributions_over_time_with_dims_list(mock_suite: LegacyMMMPlotSuite): + """Test that passing a list to dims creates a subplot for each value.""" + fig, ax = mock_suite.contributions_over_time( + var=["intercept"], + dims={"country": ["A", "B"]}, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + # Should create one subplot per value in the list (here: 2 countries) + assert ax.shape[0] == 2 + # Optionally, check subplot titles contain the correct country + for i, country in enumerate(["A", "B"]): + assert country in ax[i, 0].get_title() + + +def test_contributions_over_time_with_multiple_dims_lists( + mock_suite: LegacyMMMPlotSuite, +): + """Test that passing multiple lists to dims creates a subplot for each combination.""" + # Add a fake 'region' dim to the mock posterior for this test if not present + idata = mock_suite.idata + if "region" not in idata.posterior["intercept"].dims: + idata.posterior["intercept"] = idata.posterior["intercept"].expand_dims( + region=["X", "Y"] + ) + fig, ax = mock_suite.contributions_over_time( + var=["intercept"], + dims={"country": ["A", "B"], "region": ["X", "Y"]}, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + # Should create one subplot per combination (2 countries x 2 regions = 4) + assert ax.shape[0] == 4 + combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] + for i, (country, region) in enumerate(combos): + title = ax[i, 0].get_title() + assert country in title + assert region in title + + +def test_posterior_predictive(fit_mmm_with_channel_original_scale, df): + fit_mmm_with_channel_original_scale.sample_posterior_predictive( + df.drop(columns=["y"]) + ) + fig, ax = fit_mmm_with_channel_original_scale.plot.posterior_predictive( + hdi_prob=0.95, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert all(isinstance(a, Axes) for a in ax.flat) + + +@pytest.fixture(scope="module") +def mock_idata() -> az.InferenceData: + seed = sum(map(ord, "Fake posterior")) + rng = np.random.default_rng(seed) + normal = rng.normal + + dates = pd.date_range("2025-01-01", periods=52, freq="W-MON") + return az.InferenceData( + posterior=xr.Dataset( + { + "intercept": xr.DataArray( + normal(size=(4, 100, 52, 3)), + dims=("chain", "draw", "date", "country"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + "country": ["A", "B", "C"], + }, + ), + "linear_trend": xr.DataArray( + normal(size=(4, 100, 52)), + dims=("chain", "draw", "date"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + }, + ), + } + ) + ) + + +@pytest.fixture(scope="module") +def mock_idata_with_sensitivity(mock_idata): + # Copy the mock_idata so we don't mutate the shared fixture + idata = mock_idata.copy() + n_sample, n_sweep = 40, 5 + sweep = np.linspace(0.5, 1.5, n_sweep) + regions = ["A", "B"] + + samples = xr.DataArray( + np.random.normal(0, 1, size=(n_sample, n_sweep, len(regions))), + dims=("sample", "sweep", "region"), + coords={ + "sample": np.arange(n_sample), + "sweep": sweep, + "region": regions, + }, + name="x", + ) + + marginal_effects = xr.DataArray( + np.random.normal(0, 1, size=(n_sample, n_sweep, len(regions))), + dims=("sample", "sweep", "region"), + coords={ + "sample": np.arange(n_sample), + "sweep": sweep, + "region": regions, + }, + name="marginal_effects", + ) + + uplift_curve = xr.DataArray( + np.random.normal(0, 1, size=(n_sample, n_sweep, len(regions))), + dims=("sample", "sweep", "region"), + coords={ + "sample": np.arange(n_sample), + "sweep": sweep, + "region": regions, + }, + name="uplift_curve", + ) + + sensitivity_analysis = xr.Dataset( + { + "x": samples, + "marginal_effects": marginal_effects, + "uplift_curve": uplift_curve, + }, + coords={"sweep": sweep, "region": regions}, + attrs={"sweep_type": "multiplicative", "var_names": "test_var"}, + ) + + idata.sensitivity_analysis = sensitivity_analysis + return idata + + +@pytest.fixture(scope="module") +def mock_suite(mock_idata): + """Fixture to create a mock LegacyMMMPlotSuite with a mocked posterior.""" + return LegacyMMMPlotSuite(idata=mock_idata) + + +@pytest.fixture(scope="module") +def mock_suite_with_sensitivity(mock_idata_with_sensitivity): + """Fixture to create a mock LegacyMMMPlotSuite with sensitivity analysis.""" + return LegacyMMMPlotSuite(idata=mock_idata_with_sensitivity) + + +def test_contributions_over_time_expand_dims(mock_suite: LegacyMMMPlotSuite): + fig, ax = mock_suite.contributions_over_time( + var=[ + "intercept", + "linear_trend", + ] + ) + + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + assert all(isinstance(a, Axes) for a in ax.flat) + + +@pytest.fixture(scope="module") +def mock_idata_with_constant_data() -> az.InferenceData: + """Create mock InferenceData with constant_data and posterior for saturation tests.""" + seed = sum(map(ord, "Saturation tests")) + rng = np.random.default_rng(seed) + normal = rng.normal + + dates = pd.date_range("2025-01-01", periods=52, freq="W-MON") + channels = ["channel_1", "channel_2"] + countries = ["A", "B"] + + # Create posterior data + posterior = xr.Dataset( + { + "channel_contribution": xr.DataArray( + normal(size=(4, 100, 52, 2, 2)), + dims=("chain", "draw", "date", "channel", "country"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + "channel": channels, + "country": countries, + }, + ), + "channel_contribution_original_scale": xr.DataArray( + normal(size=(4, 100, 52, 2, 2)) * 100, # scaled up for original scale + dims=("chain", "draw", "date", "channel", "country"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "date": dates, + "channel": channels, + "country": countries, + }, + ), + } + ) + + # Create constant_data + constant_data = xr.Dataset( + { + "channel_data": xr.DataArray( + rng.uniform(0, 10, size=(52, 2, 2)), + dims=("date", "channel", "country"), + coords={ + "date": dates, + "channel": channels, + "country": countries, + }, + ), + "channel_scale": xr.DataArray( + [[100.0, 200.0], [150.0, 250.0]], + dims=("country", "channel"), + coords={"country": countries, "channel": channels}, + ), + "target_scale": xr.DataArray( + [1000.0], + dims="target", + coords={"target": ["y"]}, + ), + } + ) + + return az.InferenceData(posterior=posterior, constant_data=constant_data) + + +@pytest.fixture(scope="module") +def mock_suite_with_constant_data(mock_idata_with_constant_data): + """Fixture to create a LegacyMMMPlotSuite with constant_data for saturation tests.""" + return LegacyMMMPlotSuite(idata=mock_idata_with_constant_data) + + +@pytest.fixture(scope="module") +def mock_saturation_curve() -> xr.DataArray: + """Create mock saturation curve data for testing saturation_curves method.""" + seed = sum(map(ord, "Saturation curve")) + rng = np.random.default_rng(seed) + + # Create curve data with typical saturation curve shape + x_values = np.linspace(0, 1, 100) + channels = ["channel_1", "channel_2"] + countries = ["A", "B"] + + curve_data = [] + for _ in range(4): # chains + for _ in range(100): # draws + for _ in channels: + for _ in countries: + # Simple saturation curve: y = x / (1 + x) + y_values = x_values / (1 + x_values) + rng.normal( + 0, 0.01, size=x_values.shape + ) + curve_data.append(y_values) + + curve_array = np.array(curve_data).reshape( + 4, 100, len(channels), len(countries), len(x_values) + ) + + return xr.DataArray( + curve_array, + dims=("chain", "draw", "channel", "country", "x"), + coords={ + "chain": np.arange(4), + "draw": np.arange(100), + "channel": channels, + "country": countries, + "x": x_values, + }, + ) + + +class TestSaturationScatterplot: + def test_saturation_scatterplot_basic(self, mock_suite_with_constant_data): + """Test basic functionality of saturation_scatterplot.""" + fig, axes = mock_suite_with_constant_data.saturation_scatterplot() + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_scatterplot_original_scale(self, mock_suite_with_constant_data): + """Test saturation_scatterplot with original_scale=True.""" + fig, axes = mock_suite_with_constant_data.saturation_scatterplot( + original_scale=True + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_scatterplot_custom_kwargs(self, mock_suite_with_constant_data): + """Test saturation_scatterplot with custom kwargs.""" + fig, axes = mock_suite_with_constant_data.saturation_scatterplot( + width_per_col=8.0, height_per_row=5.0 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_scatterplot_no_constant_data(self, mock_suite): + """Test that saturation_scatterplot raises error without constant_data.""" + with pytest.raises(ValueError, match=r"No 'constant_data' found"): + mock_suite.saturation_scatterplot() + + def test_saturation_scatterplot_no_original_scale_contribution( + self, mock_suite_with_constant_data + ): + """Test that saturation_scatterplot raises error when original_scale=True but no original scale data.""" + # Remove the original scale contribution from the mock data + idata_copy = mock_suite_with_constant_data.idata.copy() + idata_copy.posterior = idata_copy.posterior.drop_vars( + "channel_contribution_original_scale" + ) + suite_without_original_scale = LegacyMMMPlotSuite(idata=idata_copy) + + with pytest.raises( + ValueError, match=r"No posterior.channel_contribution_original_scale" + ): + suite_without_original_scale.saturation_scatterplot(original_scale=True) + + +class TestSaturationScatterplotDims: + def test_saturation_scatterplot_with_dim(self, mock_suite_with_constant_data): + """Test saturation_scatterplot with a single value in dims.""" + fig, axes = mock_suite_with_constant_data.saturation_scatterplot( + dims={"country": "A"} + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + # Should create one column (n_channels, 1) + assert axes.shape[1] == 1 + for row in range(axes.shape[0]): + assert "country=A" in axes[row, 0].get_title() + + def test_saturation_scatterplot_with_dims_list(self, mock_suite_with_constant_data): + """Test saturation_scatterplot with a list in dims (should create subplots for each value).""" + fig, axes = mock_suite_with_constant_data.saturation_scatterplot( + dims={"country": ["A", "B"]} + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + # Should create two columns (n_channels, 2) + assert axes.shape[1] == 2 + for col, country in enumerate(["A", "B"]): + for row in range(axes.shape[0]): + assert f"country={country}" in axes[row, col].get_title() + + def test_saturation_scatterplot_with_multiple_dims_lists( + self, mock_suite_with_constant_data + ): + """Test saturation_scatterplot with multiple lists in dims (should create subplots for each combination).""" + # Add a fake 'region' dim to the mock constant_data for this test if not present + idata = mock_suite_with_constant_data.idata + if "region" not in idata.constant_data.channel_data.dims: + # Expand channel_data and posterior to add region + new_regions = ["X", "Y"] + channel_data = idata.constant_data.channel_data.expand_dims( + region=new_regions + ) + idata.constant_data["channel_data"] = channel_data + for var in ["channel_contribution", "channel_contribution_original_scale"]: + if var in idata.posterior: + idata.posterior[var] = idata.posterior[var].expand_dims( + region=new_regions + ) + fig, axes = mock_suite_with_constant_data.saturation_scatterplot( + dims={"country": ["A", "B"], "region": ["X", "Y"]} + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + # Should create 4 columns (n_channels, 4) + assert axes.shape[1] == 4 + combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] + for col, (country, region) in enumerate(combos): + for row in range(axes.shape[0]): + title = axes[row, col].get_title() + assert f"country={country}" in title + assert f"region={region}" in title + + +class TestSaturationCurves: + def test_saturation_curves_basic( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test basic functionality of saturation_curves.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=5 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_original_scale( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with original_scale=True.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, original_scale=True, n_samples=3 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_with_hdi( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with HDI intervals.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, hdi_probs=[0.5, 0.94] + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_single_hdi( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with single HDI probability.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, hdi_probs=0.85 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_custom_colors( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with custom colors.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, colors=["red", "blue"] + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_subplot_kwargs( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with custom subplot_kwargs.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, + n_samples=3, + subplot_kwargs={"figsize": (12, 8)}, + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + # Check that figsize was applied + assert fig.get_size_inches()[0] == 12 + assert fig.get_size_inches()[1] == 8 + + def test_saturation_curves_rc_params( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with rc_params.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, rc_params={"font.size": 14} + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_no_samples( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with n_samples=0.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=0, hdi_probs=0.85 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + def test_saturation_curves_no_constant_data( + self, mock_suite, mock_saturation_curve + ): + """Test that saturation_curves raises error without constant_data.""" + with pytest.raises(ValueError, match=r"No 'constant_data' found"): + mock_suite.saturation_curves(curve=mock_saturation_curve) + + def test_saturation_curves_no_original_scale_contribution( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test that saturation_curves raises error when original_scale=True but no original scale data.""" + # Remove the original scale contribution from the mock data + idata_copy = mock_suite_with_constant_data.idata.copy() + idata_copy.posterior = idata_copy.posterior.drop_vars( + "channel_contribution_original_scale" + ) + suite_without_original_scale = LegacyMMMPlotSuite(idata=idata_copy) + + with pytest.raises( + ValueError, match=r"No posterior.channel_contribution_original_scale" + ): + suite_without_original_scale.saturation_curves( + curve=mock_saturation_curve, original_scale=True + ) + + +class TestSaturationCurvesDims: + def test_saturation_curves_with_dim( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with a single value in dims.""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, dims={"country": "A"} + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + + for row in range(axes.shape[0]): + assert "country=A" in axes[row, 0].get_title() + + def test_saturation_curves_with_dims_list( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with a list in dims (should create subplots for each value).""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=3, dims={"country": ["A", "B"]} + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + + def test_saturation_curves_with_multiple_dims_lists( + self, mock_suite_with_constant_data, mock_saturation_curve + ): + """Test saturation_curves with multiple lists in dims (should create subplots for each combination).""" + # Add a fake 'region' dim to the mock constant_data for this test if not present + idata = mock_suite_with_constant_data.idata + if "region" not in idata.constant_data.channel_data.dims: + # Expand channel_data and posterior to add region + new_regions = ["X", "Y"] + channel_data = idata.constant_data.channel_data.expand_dims( + region=new_regions + ) + idata.constant_data["channel_data"] = channel_data + for var in ["channel_contribution", "channel_contribution_original_scale"]: + if var in idata.posterior: + idata.posterior[var] = idata.posterior[var].expand_dims( + region=new_regions + ) + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, + n_samples=3, + dims={"country": ["A", "B"], "region": ["X", "Y"]}, + ) + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + combos = [("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")] + + for col, (country, region) in enumerate(combos): + for row in range(axes.shape[0]): + title = axes[row, col].get_title() + assert f"country={country}" in title + assert f"region={region}" in title + + +def test_saturation_curves_scatter_deprecation_warning(mock_suite_with_constant_data): + """Test that saturation_curves_scatter shows deprecation warning.""" + with pytest.warns( + DeprecationWarning, match=r"saturation_curves_scatter is deprecated" + ): + fig, axes = mock_suite_with_constant_data.saturation_curves_scatter() + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert all(isinstance(ax, Axes) for ax in axes.flat) + + +@pytest.fixture(scope="module") +def mock_idata_with_constant_data_single_dim() -> az.InferenceData: + """Mock InferenceData where channel_data has only ('date','channel') dims.""" + seed = sum(map(ord, "Saturation single-dim tests")) + rng = np.random.default_rng(seed) + normal = rng.normal + + dates = pd.date_range("2025-01-01", periods=12, freq="W-MON") + channels = ["channel_1", "channel_2", "channel_3"] + + posterior = xr.Dataset( + { + "channel_contribution": xr.DataArray( + normal(size=(2, 10, 12, 3)), + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(2), + "draw": np.arange(10), + "date": dates, + "channel": channels, + }, + ), + "channel_contribution_original_scale": xr.DataArray( + normal(size=(2, 10, 12, 3)) * 100.0, + dims=("chain", "draw", "date", "channel"), + coords={ + "chain": np.arange(2), + "draw": np.arange(10), + "date": dates, + "channel": channels, + }, + ), + } + ) + + constant_data = xr.Dataset( + { + "channel_data": xr.DataArray( + rng.uniform(0, 10, size=(12, 3)), + dims=("date", "channel"), + coords={"date": dates, "channel": channels}, + ), + "channel_scale": xr.DataArray( + [100.0, 150.0, 200.0], dims=("channel",), coords={"channel": channels} + ), + "target_scale": xr.DataArray( + [1000.0], dims="target", coords={"target": ["y"]} + ), + } + ) + + return az.InferenceData(posterior=posterior, constant_data=constant_data) + + +@pytest.fixture(scope="module") +def mock_suite_with_constant_data_single_dim(mock_idata_with_constant_data_single_dim): + return LegacyMMMPlotSuite(idata=mock_idata_with_constant_data_single_dim) + + +@pytest.fixture(scope="module") +def mock_saturation_curve_single_dim() -> xr.DataArray: + """Saturation curve with dims ('chain','draw','channel','x').""" + seed = sum(map(ord, "Saturation curve single-dim")) + rng = np.random.default_rng(seed) + x_values = np.linspace(0, 1, 50) + channels = ["channel_1", "channel_2", "channel_3"] + + # shape: (chains=2, draws=10, channel=3, x=50) + curve_array = np.empty((2, 10, len(channels), len(x_values))) + for ci in range(2): + for di in range(10): + for c in range(len(channels)): + curve_array[ci, di, c, :] = x_values / (1 + x_values) + rng.normal( + 0, 0.02, size=x_values.shape + ) + + return xr.DataArray( + curve_array, + dims=("chain", "draw", "channel", "x"), + coords={ + "chain": np.arange(2), + "draw": np.arange(10), + "channel": channels, + "x": x_values, + }, + name="saturation_curve", + ) + + +def test_saturation_curves_single_dim_axes_shape( + mock_suite_with_constant_data_single_dim, mock_saturation_curve_single_dim +): + """When there are no extra dims, columns should default to 1 (no ncols=0).""" + fig, axes = mock_suite_with_constant_data_single_dim.saturation_curves( + curve=mock_saturation_curve_single_dim, n_samples=3 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + # Expect (n_channels, 1) + assert axes.shape[1] == 1 + assert axes.shape[0] == mock_saturation_curve_single_dim.sizes["channel"] + + +def test_saturation_curves_multi_dim_axes_shape( + mock_suite_with_constant_data, mock_saturation_curve +): + """With an extra dim (e.g., 'country'), expect (n_channels, n_countries).""" + fig, axes = mock_suite_with_constant_data.saturation_curves( + curve=mock_saturation_curve, n_samples=2 + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + + +def test_sensitivity_analysis_basic(mock_suite_with_sensitivity): + fig, axes = mock_suite_with_sensitivity.sensitivity_analysis() + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert axes.ndim == 2 + expected_panels = len( + mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] + ) # type: ignore + assert axes.size >= expected_panels + assert all(isinstance(ax, Axes) for ax in axes.flat[:expected_panels]) + + +def test_sensitivity_analysis_with_aggregation(mock_suite_with_sensitivity): + ax = mock_suite_with_sensitivity.sensitivity_analysis( + aggregation={"sum": ("region",)} + ) + assert isinstance(ax, Axes) + + +def test_marginal_curve(mock_suite_with_sensitivity): + fig, axes = mock_suite_with_sensitivity.marginal_curve() + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert axes.ndim == 2 + regions = mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] # type: ignore + assert axes.size >= len(regions) + assert all(isinstance(ax, Axes) for ax in axes.flat[: len(regions)]) + + +def test_uplift_curve(mock_suite_with_sensitivity): + fig, axes = mock_suite_with_sensitivity.uplift_curve() + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert axes.ndim == 2 + regions = mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] # type: ignore + assert axes.size >= len(regions) + assert all(isinstance(ax, Axes) for ax in axes.flat[: len(regions)]) + + +def test_sensitivity_analysis_multi_panel(mock_suite_with_sensitivity): + # The fixture provides an extra 'region' dimension, so multiple panels should be produced + fig, axes = mock_suite_with_sensitivity.sensitivity_analysis( + subplot_kwargs={"ncols": 2} + ) + + assert isinstance(fig, Figure) + assert isinstance(axes, np.ndarray) + assert axes.ndim == 2 + # There should be two regions, therefore exactly two panels + expected_panels = len( + mock_suite_with_sensitivity.idata.sensitivity_analysis.coords["region"] + ) # type: ignore + assert axes.size >= expected_panels + assert all(isinstance(ax, Axes) for ax in axes.flat[:expected_panels]) + + +def test_sensitivity_analysis_error_on_missing_results(mock_idata): + suite = LegacyMMMPlotSuite(idata=mock_idata) + with pytest.raises(ValueError, match=r"No sensitivity analysis results found"): + suite.sensitivity_analysis() + suite.plot_sensitivity_analysis() + + +def test_budget_allocation_with_dims(mock_suite_with_constant_data): + # Use dims to filter to a single country + samples = mock_suite_with_constant_data.idata.posterior + # Add a fake 'allocation' variable for testing + samples = samples.copy() + samples["allocation"] = ( + samples["channel_contribution"].dims, + np.abs(samples["channel_contribution"].values), + ) + plot_suite = mock_suite_with_constant_data + fig, _ax = plot_suite.budget_allocation( + samples=samples, + dims={"country": "A"}, + ) + assert isinstance(fig, Figure) + + +def test_budget_allocation_with_dims_list(mock_suite_with_constant_data): + """Test that passing a list to dims creates a subplot for each value.""" + samples = mock_suite_with_constant_data.idata.posterior.copy() + # Add a fake 'allocation' variable for testing + samples["allocation"] = ( + samples["channel_contribution"].dims, + np.abs(samples["channel_contribution"].values), + ) + plot_suite = mock_suite_with_constant_data + fig, ax = plot_suite.budget_allocation( + samples=samples, + dims={"country": ["A", "B"]}, + ) + assert isinstance(fig, Figure) + assert isinstance(ax, np.ndarray) + + +def test__validate_dims_valid(): + """Test _validate_dims with valid dims and values.""" + suite = LegacyMMMPlotSuite(idata=None) + + # Patch suite.idata.posterior.coords to simulate valid dims + class DummyCoord: + def __init__(self, values): + self.values = values + + class DummyCoords: + def __init__(self): + self._coords = { + "country": DummyCoord(["A", "B"]), + "region": DummyCoord(["X", "Y"]), + } + + def __getitem__(self, key): + return self._coords[key] + + class DummyPosterior: + coords = DummyCoords() + + suite.idata = type("idata", (), {"posterior": DummyPosterior()})() + # Should not raise + suite._validate_dims({"country": "A", "region": "X"}, ["country", "region"]) + suite._validate_dims({"country": ["A", "B"]}, ["country", "region"]) + + +def test__validate_dims_invalid_dim(): + """Test _validate_dims raises for invalid dim name.""" + suite = LegacyMMMPlotSuite(idata=None) + + class DummyCoord: + def __init__(self, values): + self.values = values + + class DummyCoords: + def __init__(self): + self.country = DummyCoord(["A", "B"]) + + def __getitem__(self, key): + return getattr(self, key) + + class DummyPosterior: + coords = DummyCoords() + + suite.idata = type("idata", (), {"posterior": DummyPosterior()})() + with pytest.raises(ValueError, match=r"Dimension 'region' not found"): + suite._validate_dims({"region": "X"}, ["country"]) + + +def test__validate_dims_invalid_value(): + """Test _validate_dims raises for invalid value.""" + suite = LegacyMMMPlotSuite(idata=None) + + class DummyCoord: + def __init__(self, values): + self.values = values + + class DummyCoords: + def __init__(self): + self.country = DummyCoord(["A", "B"]) + + def __getitem__(self, key): + return getattr(self, key) + + class DummyPosterior: + coords = DummyCoords() + + suite.idata = type("idata", (), {"posterior": DummyPosterior()})() + with pytest.raises(ValueError, match=r"Value 'C' not found in dimension 'country'"): + suite._validate_dims({"country": "C"}, ["country"]) + + +def test__dim_list_handler_none(): + """Test _dim_list_handler with None input.""" + suite = LegacyMMMPlotSuite(idata=None) + keys, combos = suite._dim_list_handler(None) + assert keys == [] + assert combos == [()] + + +def test__dim_list_handler_single(): + """Test _dim_list_handler with a single list-valued dim.""" + suite = LegacyMMMPlotSuite(idata=None) + keys, combos = suite._dim_list_handler({"country": ["A", "B"]}) + assert keys == ["country"] + assert set(combos) == {("A",), ("B",)} + + +def test__dim_list_handler_multiple(): + """Test _dim_list_handler with multiple list-valued dims.""" + suite = LegacyMMMPlotSuite(idata=None) + keys, combos = suite._dim_list_handler( + {"country": ["A", "B"], "region": ["X", "Y"]} + ) + assert set(keys) == {"country", "region"} + assert set(combos) == {("A", "X"), ("A", "Y"), ("B", "X"), ("B", "Y")} + + +def test__dim_list_handler_mixed(): + """Test _dim_list_handler with mixed single and list values.""" + suite = LegacyMMMPlotSuite(idata=None) + keys, combos = suite._dim_list_handler({"country": ["A", "B"], "region": "X"}) + assert keys == ["country"] + assert set(combos) == {("A",), ("B",)} From c3210398675987244faec024e908c5e4c6bbaa57 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Fri, 21 Nov 2025 15:58:06 -0500 Subject: [PATCH 17/29] remove old test --- tests/mmm/test_legacy_plot_imports.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/mmm/test_legacy_plot_imports.py b/tests/mmm/test_legacy_plot_imports.py index 5aeefde76..7b80556ff 100644 --- a/tests/mmm/test_legacy_plot_imports.py +++ b/tests/mmm/test_legacy_plot_imports.py @@ -31,11 +31,3 @@ def test_legacy_class_name(): from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite assert LegacyMMMPlotSuite.__name__ == "LegacyMMMPlotSuite" - - -def test_old_plot_module_not_exists(): - """Test that old_plot module has been removed.""" - with pytest.raises( - ImportError, match=r"(No module named.*old_plot|cannot import name 'old_plot')" - ): - from pymc_marketing.mmm import old_plot # noqa: F401 From 2d4723f20f83e8e50a50d3fd9231925ced43ea4c Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Fri, 21 Nov 2025 16:11:31 -0500 Subject: [PATCH 18/29] small changes --- .gitignore | 3 - ...mmplotsuite-migration-complete-analysis.md | 2130 ----------------- 2 files changed, 2133 deletions(-) delete mode 100644 thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md diff --git a/.gitignore b/.gitignore index 54cc61220..ac4503fde 100644 --- a/.gitignore +++ b/.gitignore @@ -147,6 +147,3 @@ dmypy.json # Gallery images docs/source/gallery/images/ docs/gettext/ - -# ignore Claude -.claude/* diff --git a/thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md b/thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md deleted file mode 100644 index d631d7038..000000000 --- a/thoughts/shared/research/2025-11-19-mmmplotsuite-migration-complete-analysis.md +++ /dev/null @@ -1,2130 +0,0 @@ ---- -date: 2025-11-19T14:04:21+0000 -researcher: Claude -git_commit: d6331a03727aa9c78ad16690aca25ce9cb869129 -branch: feature/mmmplotsuite-arviz -repository: pymc-labs/pymc-marketing -topic: "MMMPlotSuite Migration - Complete Implementation Analysis and Requirements" -tags: [research, codebase, mmm, plotting, migration, backward-compatibility, testing, arviz-plots] -status: complete -last_updated: 2025-11-19 -last_updated_by: Claude ---- - -# Research: MMMPlotSuite Migration - Complete Implementation Analysis and Requirements - -**Date**: 2025-11-19T14:04:21+0000 -**Researcher**: Claude -**Git Commit**: d6331a03727aa9c78ad16690aca25ce9cb869129 -**Branch**: feature/mmmplotsuite-arviz -**Repository**: pymc-labs/pymc-marketing - -## Research Question - -The user is migrating MMMPlotSuite from matplotlib-based plotting to arviz_plots with multi-backend support. The legacy implementation is currently in `mmm/old_plot.py` and should be renamed to `mmm/legacy_plot.py`. To complete this migration, they need to: - -1. Rename `old_plot.py` to `legacy_plot.py` and `OldMMMPlotSuite` to `LegacyMMMPlotSuite` -2. Support global backend configuration with per-function override capability -3. Implement backward compatibility with a flag to control legacy vs new behavior (default: legacy) -4. Add deprecation warning pointing to v0.20.0 removal -5. Review the new code implementation for quality issues -6. Create comprehensive tests for matplotlib, bokeh, and plotly backends - -## Summary - -Based on comprehensive codebase analysis, the migration is **75% complete** with critical gaps identified: - -**✅ Already Implemented:** -- Backend configuration system with `mmm_config["plot.backend"]` supporting matplotlib/plotly/bokeh -- Complete new arviz_plots-based implementation returning `PlotCollection` objects -- Legacy matplotlib-based implementation preserved in `old_plot.py` (to be renamed `legacy_plot.py`) -- Per-method backend override via `backend` parameter on all plot methods - -**❌ Missing Critical Components:** -- Rename `old_plot.py` to `legacy_plot.py` and `OldMMMPlotSuite` to `LegacyMMMPlotSuite` -- **Data Parameter Standardization**: All plotting methods should accept data as input parameters (some with fallback to `self.idata`, some without). Currently inconsistent across methods. -- **`_sensitivity_analysis_plot()` refactoring**: Must accept `data` as REQUIRED parameter (no fallback), and all callers (`sensitivity_analysis()`, `uplift_curve()`, `marginal_curve()`) must be updated to pass data explicitly. -- Backward compatibility flag (`use_v2`) to toggle between legacy/new suite -- Deprecation warning system for users -- Comprehensive backend testing for the new suite -- Compatibility test suite -- Documentation of breaking changes - -**⚠️ Code Review Issues Found:** -- Return type documentation incomplete -- Breaking parameter type changes across all methods (intentional, no backward compatibility needed) -- Lost customization parameters (colors, subplot_kwargs, rc_params) - handled by arviz_plots -- **Deprecated method carried forward**: `saturation_curves_scatter()` is implemented in v2 but should be removed (already deprecated in v0.1.0) - -## Detailed Findings - -### 1. Current Architecture - -#### 1.1 Class Definitions and Locations - -**New Implementation:** -- **File**: [pymc_marketing/mmm/plot.py:187-1272](pymc_marketing/mmm/plot.py#L187-L1272) -- **Class**: `MMMPlotSuite` -- **Export**: `__all__ = ["MMMPlotSuite"]` at line 181 -- **Technology**: arviz_plots library -- **Return Type**: `PlotCollection` (unified across all backends) - -**Legacy Implementation:** -- **File**: [pymc_marketing/mmm/old_plot.py:191-1936](pymc_marketing/mmm/old_plot.py#L191-L1936) (to be renamed to `legacy_plot.py`) -- **Class**: `OldMMMPlotSuite` (to be renamed to `LegacyMMMPlotSuite`) -- **Export**: Not exported in any `__all__` -- **Technology**: matplotlib only -- **Return Type**: `tuple[Figure, NDArray[Axes]]` or `tuple[Figure, plt.Axes]` - -**Integration Point:** -- **File**: [pymc_marketing/mmm/multidimensional.py:602-607](pymc_marketing/mmm/multidimensional.py#L602-L607) -- **Property**: `MMM.plot` returns `MMMPlotSuite(idata=self.idata)` -- **Issue**: Hardcoded to only return new suite, no version control - -#### 1.2 Method Comparison Matrix - -| Method | New Suite | Legacy Suite | API Compatible | Breaking Changes | -|--------|-----------|--------------|----------------|------------------| -| `__init__` | ✅ | ✅ | ✅ | None | -| `posterior_predictive()` | ✅ | ✅ | ❌ | `var: str` vs `list[str]`, return type | -| `contributions_over_time()` | ✅ | ✅ | ⚠️ | Return type only | -| `saturation_scatterplot()` | ✅ | ✅ | ⚠️ | Lost `**kwargs`, return type | -| `saturation_curves()` | ✅ | ✅ | ❌ | Lost colors, subplot_kwargs, rc_params | -| `saturation_curves_scatter()` | ⚠️ | ✅ | ⚠️ | **SHOULD BE REMOVED** - Currently in v2 but deprecated, delegates to saturation_scatterplot | -| `budget_allocation()` | ❌ | ✅ | ❌ | **REMOVED** - no replacement | -| `budget_allocation_roas()` | ✅ | ❌ | N/A | New method, different purpose | -| `allocated_contribution_by_channel_over_time()` | ✅ | ✅ | ❌ | Lost scale_factor, quantiles, figsize, ax | -| `sensitivity_analysis()` | ✅ | ✅ | ❌ | Lost ax, subplot_kwargs, plot_kwargs | -| `uplift_curve()` | ✅ | ✅ | ❌ | Lost ax, subplot_kwargs, plot_kwargs | -| `marginal_curve()` | ✅ | ✅ | ❌ | Lost ax, subplot_kwargs, plot_kwargs | - -**Helper Methods:** -- New Suite: `_get_additional_dim_combinations()`, `_get_posterior_predictive_data()`, `_validate_dims()`, `_dim_list_handler()`, `_resolve_backend()`, `_sensitivity_analysis_plot()` -- Legacy Suite: `_init_subplots()`, `_build_subplot_title()`, `_reduce_and_stack()`, `_add_median_and_hdi()`, `_plot_budget_allocation_bars()` + shared helpers - -### 2. Backend Configuration System ✅ **COMPLETE** - -#### 2.1 Implementation - -**File**: [pymc_marketing/mmm/config.py:21-66](pymc_marketing/mmm/config.py#L21-L66) - -```python -VALID_BACKENDS = {"matplotlib", "plotly", "bokeh"} - -class MMMConfig(dict): - """Configuration dictionary for MMM plotting settings.""" - - _defaults = { - "plot.backend": "matplotlib", - "plot.show_warnings": True, - } - - def __setitem__(self, key, value): - """Set config value with validation for backend.""" - if key == "plot.backend": - if value not in VALID_BACKENDS: - warnings.warn( - f"Invalid backend '{value}'. Valid backends are: {VALID_BACKENDS}. " - f"Setting anyway, but plotting may fail.", - UserWarning, - stacklevel=2, - ) - super().__setitem__(key, value) - -# Global config instance -mmm_config = MMMConfig() -``` - -#### 2.2 Backend Resolution - -**File**: [pymc_marketing/mmm/plot.py:288-292](pymc_marketing/mmm/plot.py#L288-L292) - -```python -def _resolve_backend(self, backend: str | None) -> str: - """Resolve backend parameter to actual backend string.""" - from pymc_marketing.mmm.config import mmm_config - return backend or mmm_config["plot.backend"] -``` - -#### 2.3 Usage Pattern - -```python -from pymc_marketing.mmm import mmm_config - -# Set global backend -mmm_config["plot.backend"] = "plotly" - -# All plots use plotly -mmm.plot.posterior_predictive() - -# Override for specific plot -mmm.plot.posterior_predictive(backend="matplotlib") -``` - -**Status**: ✅ No action needed - fully functional - -### 3. Backward Compatibility ❌ **MISSING - CRITICAL** - -#### 3.1 Current Gap - -The `.plot` property currently only returns the new suite: - -```python -# Current implementation in multidimensional.py:602-607 -@property -def plot(self) -> MMMPlotSuite: - """Use the MMMPlotSuite to plot the results.""" - self._validate_model_was_built() - self._validate_idata_exists() - return MMMPlotSuite(idata=self.idata) -``` - -#### 3.2 Required Implementation - -**Step 1: Add flag to config.py** - -```python -# File: pymc_marketing/mmm/config.py -_defaults = { - "plot.backend": "matplotlib", - "plot.show_warnings": True, - "plot.use_v2": False, # ← ADD THIS LINE -} -``` - -**Step 2: Implement version switching in multidimensional.py** - -```python -# File: pymc_marketing/mmm/multidimensional.py:602-607 -@property -def plot(self) -> MMMPlotSuite | LegacyMMMPlotSuite: - """Use the MMMPlotSuite to plot the results.""" - from pymc_marketing.mmm.config import mmm_config - from pymc_marketing.mmm.plot import MMMPlotSuite - from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite - import warnings - - self._validate_model_was_built() - self._validate_idata_exists() - - # Check version flag - if mmm_config.get("plot.use_v2", False): - return MMMPlotSuite(idata=self.idata) - else: - # Show deprecation warning for legacy suite - if mmm_config.get("plot.show_warnings", True): - warnings.warn( - "The current MMMPlotSuite will be deprecated in v0.20.0. " - "The new version uses arviz_plots and supports multiple backends (matplotlib, plotly, bokeh). " - "To use the new version: " - ">>> from pymc_marketing.mmm.config import mmm_config\n" - ">>> mmm_config['plot.use_v2'] = True\n" - "To suppress this warning: mmm_config['plot.show_warnings'] = False\n" - "See migration guide: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html", - FutureWarning, - stacklevel=2, - ) - return LegacyMMMPlotSuite(idata=self.idata) -``` - -#### 3.3 Design Rationale - -**Why `FutureWarning` instead of `DeprecationWarning`?** -- `DeprecationWarning` is for library developers (hidden by default in Python) -- `FutureWarning` is for end users (always shown) -- Our users are data scientists/analysts, not library developers -- Pattern found in [pymc_marketing/mlflow.py:180-185](pymc_marketing/mlflow.py#L180-L185) - -**Why config flag instead of function parameter?** -- Consistent with existing backend configuration pattern -- Allows global setting affecting all plot calls -- Can be overridden per-session -- Pattern found throughout codebase (e.g., `plot.backend`) - -**Why default to `False` (legacy suite)?** -- Non-breaking change in initial release -- Gives users time to migrate (1-2 releases) -- Prevents surprise breakage for existing code - -### 4. Deprecation Patterns Research - -Found **10 distinct patterns** used across the codebase: - -#### Pattern 1: Parameter Name Deprecation with Helper -**Location**: [pymc_marketing/model_builder.py:60-77](pymc_marketing/model_builder.py#L60-L77) -**Test**: [tests/test_model_builder.py:530-554](tests/test_model_builder.py#L530-L554) - -```python -def _handle_deprecate_pred_argument(value, name: str, kwargs: dict): - name_pred = f"{name}_pred" - - if name_pred in kwargs and value is not None: - raise ValueError(f"Both {name} and {name_pred} cannot be provided.") - - if name_pred in kwargs: - warnings.warn( - f"{name_pred} is deprecated, use {name} instead", - DeprecationWarning, - stacklevel=2, - ) - return kwargs.pop(name_pred) - - return value -``` - -#### Pattern 2: Method Deprecation with Delegation -**Location**: [pymc_marketing/mmm/plot.py:737-771](pymc_marketing/mmm/plot.py#L737-L771) -**Test**: [tests/mmm/test_plot.py:722-731](tests/mmm/test_plot.py#L722-L731) - -```python -def saturation_curves_scatter(self, original_scale: bool = False, **kwargs) -> PlotCollection: - """ - .. deprecated:: 0.1.0 - Will be removed in version 0.20.0. Use :meth:`saturation_scatterplot` instead. - """ - import warnings - warnings.warn( - "saturation_curves_scatter is deprecated and will be removed in version 0.2.0. " - "Use saturation_scatterplot instead.", - DeprecationWarning, - stacklevel=2, - ) - return self.saturation_scatterplot(original_scale=original_scale, **kwargs) -``` - -#### Pattern 3: Config Key Renaming -**Location**: [pymc_marketing/clv/models/basic.py:49-59](pymc_marketing/clv/models/basic.py#L49-L59) - -```python -deprecated_keys = [key for key in model_config if key.endswith("_prior")] -for key in deprecated_keys: - new_key = key.replace("_prior", "") - warnings.warn( - f"The key '{key}' in model_config is deprecated. Use '{new_key}' instead.", - DeprecationWarning, - stacklevel=2, - ) - model_config[new_key] = model_config.pop(key) -``` - -#### Pattern 4: Module-Level Deprecation -**Location**: [pymc_marketing/deserialize.py:14-40](pymc_marketing/deserialize.py#L14-L40) - -```python -warnings.warn( - "The pymc_marketing.deserialize module is deprecated. " - "Please use pymc_extras.deserialize instead.", - DeprecationWarning, - stacklevel=2, -) -``` - -**Key Testing Pattern**: All deprecation warnings tested with `pytest.warns()`: - -```python -def test_deprecation(): - with pytest.warns(DeprecationWarning, match=r"is deprecated"): - result = deprecated_function() - - # Verify functionality still works - assert isinstance(result, ExpectedType) -``` - -### 5. Code Review: Issues Found in New Implementation - -#### Issue 1: Return Type Documentation ⚠️ **MINOR** - -**Problem**: Method docstrings don't clearly state `PlotCollection` return type vs old `(Figure, Axes)` tuple. - -**Location**: All methods in [plot.py:298-1272](pymc_marketing/mmm/plot.py#L298-L1272) - -**Example** - `posterior_predictive()` docstring: -```python -def posterior_predictive(...) -> PlotCollection: - """ - Plot posterior predictive distributions over time. - - Returns - ------- - PlotCollection # ← States type but doesn't explain what it is - """ -``` - -**Fix**: Add explanatory text: -```python - Returns - ------- - PlotCollection - arviz_plots PlotCollection object containing the plot. - Use .show() to display or .save("filename") to save. - Unlike the old implementation which returned (Figure, Axes), - this provides a unified interface across matplotlib, plotly, and bokeh backends. -``` - -#### Issue 2: Breaking Parameter Type Changes ✅ **INTENTIONAL - NO ACTION NEEDED** - -**Status**: Many parameters have changed across all methods. Since this is a comprehensive migration to a new architecture (arviz_plots), these breaking changes are expected and documented. - -**Examples of parameter changes**: - -```python -# LEGACY (old_plot.py:387 - to be renamed to legacy_plot.py) -def posterior_predictive( - self, - var: list[str] | None = None, # ← Accepts list - ... -) -> tuple[Figure, NDArray[Axes]]: - -# NEW (plot.py:300) -def posterior_predictive( - self, - var: str | None = None, # ← Only accepts string - ... -) -> PlotCollection: -``` - -**Rationale for no backward compatibility**: -- The entire API is changing (return types, parameters, behavior) -- Users switch to new suite explicitly via `mmm_config["plot.use_v2"] = True` -- Legacy suite remains available for those who need legacy parameter behavior -- Attempting to handle all parameter changes would add significant complexity for minimal benefit -- Migration guide will document all parameter changes with examples - -**Action**: Document parameter changes in migration guide, let users adapt code when they opt into v2. - -#### Issue 3: Missing Method ⚠️ **MAJOR** - -**Problem**: `budget_allocation()` completely removed with no replacement. - -**Legacy Method**: [old_plot.py:1049-1224](pymc_marketing/mmm/old_plot.py#L1049-L1224) (to be renamed to legacy_plot.py) -- Creates bar chart comparing allocated spend vs channel contributions -- Dual y-axis visualization - -**New Method**: `budget_allocation_roas()` at [plot.py:773-874](pymc_marketing/mmm/plot.py#L773-L874) -- Completely different purpose (ROI distributions) -- Different parameters and output - -**Impact**: Code using `mmm.plot.budget_allocation()` will fail with `AttributeError`. - -**Recommendation**: Add stub method that raises helpful error: - -```python -def budget_allocation(self, *args, **kwargs): - """ - .. deprecated:: 0.18.0 - Removed in version 2.0. See budget_allocation_roas() for ROI distributions. - - Raises - ------ - NotImplementedError - This method was removed in MMMPlotSuite v2. - For ROI distributions, use budget_allocation_roas(). - To use the old budget_allocation(), set mmm_config['plot.use_v2'] = False. - """ - raise NotImplementedError( - "budget_allocation() was removed in MMMPlotSuite v2. " - "The new version uses arviz_plots which doesn't support this chart type. " - "Options:\n" - " 1. For ROI distributions: use budget_allocation_roas()\n" - " 2. To use old method: set mmm_config['plot.use_v2'] = False\n" - " 3. Implement custom bar chart using samples data" - ) -``` - -#### Issue 4: Backend Parameter Coverage ✅ **GOOD** - -**Status**: All public methods have `backend` parameter: -- `posterior_predictive()` ✅ -- `contributions_over_time()` ✅ -- `saturation_scatterplot()` ✅ -- `saturation_curves()` ✅ -- `budget_allocation_roas()` ✅ -- `allocated_contribution_by_channel_over_time()` ✅ -- `sensitivity_analysis()` ✅ -- `uplift_curve()` ✅ -- `marginal_curve()` ✅ -- ~~`saturation_curves_scatter()`~~ - **TO BE REMOVED** (deprecated method, see Issue 5) - -**Pattern**: Consistent across all methods, properly resolves via `_resolve_backend()`. - -#### Issue 5: Deprecated Method Should Be Removed ⚠️ **MINOR BUT IMPORTANT** - -**Problem**: `saturation_curves_scatter()` is currently implemented in MMMPlotSuite v2 but is deprecated and just delegates to `saturation_scatterplot()`. - -**Current implementation** (lines 737-771 in [plot.py](pymc_marketing/mmm/plot.py#L737-L771)): -```python -def saturation_curves_scatter(self, original_scale: bool = False, **kwargs) -> PlotCollection: - """ - .. deprecated:: 0.1.0 - Will be removed in version 0.20.0. Use :meth:`saturation_scatterplot` instead. - """ - import warnings - warnings.warn( - "saturation_curves_scatter is deprecated and will be removed in version 0.2.0. " - "Use saturation_scatterplot instead.", - DeprecationWarning, - stacklevel=2, - ) - return self.saturation_scatterplot(original_scale=original_scale, **kwargs) -``` - -**Rationale for removal**: -- Since MMMPlotSuite v2 is a completely new implementation, we should NOT carry forward deprecated methods -- The legacy suite (LegacyMMMPlotSuite) already has this method for users who need it -- Users who opt into v2 (`mmm_config["plot.use_v2"] = True`) should use the new, correct method name -- Keeping deprecated methods in v2 defeats the purpose of a clean migration -- The method was deprecated in v0.1.0, giving users ample time to migrate - -**Recommendation**: **REMOVE** `saturation_curves_scatter()` from MMMPlotSuite (plot.py) entirely. - -**Implementation**: -1. Delete the method from [pymc_marketing/mmm/plot.py:737-771](pymc_marketing/mmm/plot.py#L737-L771) -2. Keep it in LegacyMMMPlotSuite (legacy_plot.py) for backward compatibility -3. Document the removal in migration guide - -**Alternative** (if keeping for one more release): -Add a note in the deprecation warning that it won't be available in v2 by default: -```python -warnings.warn( - "saturation_curves_scatter is deprecated and will be removed in version 0.20.0. " - "Use saturation_scatterplot instead. " - "Note: This method is not available when using mmm_config['plot.use_v2'] = True.", - DeprecationWarning, - stacklevel=2, -) -``` - -**Preferred approach**: Clean removal from v2, keep only in legacy suite. - -### 6. Testing Infrastructure ⚠️ **MAJOR GAPS** - -#### 6.1 Current Test Coverage - -**Test Files Found:** -1. [tests/mmm/test_plot.py](tests/mmm/test_plot.py) - 800+ lines - - Contains ~28 test functions - - Good fixture patterns - - **Tests for LegacyMMMPlotSuite only** (currently using `old_plot.py`) - - **NEW suite (plot.py) has NO test coverage** - - **Needs new tests for the new MMMPlotSuite with all backends** - -2. [tests/mmm/test_plot_backends.py](tests/mmm/test_plot_backends.py) - 255 lines - - **EXPERIMENTAL FILE - SHOULD BE REMOVED** - - Contains ~14 test functions - - Only tests `posterior_predictive()` with multiple backends - - Functionality should be merged into test_plot.py with parametrization - -3. [tests/mmm/test_plotting.py](tests/mmm/test_plotting.py) - Legacy tests - - Tests for old `BaseMMM` and `MMM` plotting - - Not for MMMPlotSuite - -**Test Coverage Analysis:** - -| Method | Legacy Suite Tests | New Suite Tests | All Backends | Compatibility Tests | -|--------|-------------------|----------------|--------------|---------------------| -| `posterior_predictive()` | ✅ (matplotlib only) | ⚠️ (test_plot_backends.py only) | ❌ | ❌ | -| `contributions_over_time()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | -| `saturation_scatterplot()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | -| `saturation_curves()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | -| `budget_allocation()` | ✅ (matplotlib only) | N/A (removed) | ❌ | ❌ | -| `budget_allocation_roas()` | N/A (doesn't exist) | ❌ | ❌ | ❌ | -| `allocated_contribution_by_channel_over_time()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | -| `sensitivity_analysis()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | -| `uplift_curve()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | -| `marginal_curve()` | ✅ (matplotlib only) | ❌ | ❌ | ❌ | -| Config flag switching | ❌ | ❌ | ❌ | ❌ | -| Deprecation warnings | ❌ | ❌ | ❌ | ❌ | - -**Coverage**: -- Legacy suite: ~80% (8 methods tested, matplotlib only) -- **New suite: ~1% (only 1 method partially tested in experimental file)** -- Compatibility tests: 0% - -**Critical Gap**: The new MMMPlotSuite (plot.py) has essentially NO test coverage! - -**Testing Strategy**: -- **Create new comprehensive tests for the new MMMPlotSuite** -- Parametrize all new tests to run against all backends (matplotlib, plotly, bokeh) -- Keep existing test_plot.py tests for legacy suite (will be removed in v0.20.0) -- Create separate compatibility test suite - -#### 6.2 Available Test Fixtures - -**From test_plot.py (for LegacyMMMPlotSuite):** -```python -@pytest.fixture(scope="module") -def mock_idata() -> az.InferenceData: - """Basic mock InferenceData with posterior.""" - # Line 201 - -@pytest.fixture(scope="module") -def mock_idata_with_constant_data() -> az.InferenceData: - """Mock InferenceData with constant_data for saturation plots.""" - # Line 315 - -@pytest.fixture(scope="module") -def mock_suite(mock_idata) -> LegacyMMMPlotSuite: - """LegacyMMMPlotSuite instance with basic mock data.""" - # Line 290 - currently creates from old_plot - -@pytest.fixture(scope="module") -def mock_suite_with_constant_data(mock_idata_with_constant_data) -> LegacyMMMPlotSuite: - """LegacyMMMPlotSuite with constant data for saturation plots.""" - # Line 382 - currently creates from old_plot - -@pytest.fixture -def mock_saturation_curve(mock_idata_with_constant_data) -> xr.DataArray: - """Mock saturation curve DataArray.""" - # Line 388 -``` - -**Pattern**: All fixtures use deterministic seeds for reproducibility. - -**Note**: These fixtures will need to be adapted/duplicated for testing the new MMMPlotSuite. - -#### 6.3 Required Test Implementation - -**Strategy**: Create NEW comprehensive tests for the new MMMPlotSuite with multi-backend support - -**Step 1: Keep existing test_plot.py for legacy suite** -- Rename test file to make it clear it's for legacy: `test_plot.py` → `test_legacy_plot.py` -- Update imports to use `legacy_plot.LegacyMMMPlotSuite` -- These tests will be removed in v0.20.0 along with the legacy suite - -**Step 2: Create new test_plot.py for new MMMPlotSuite** - -Create comprehensive tests with backend parametrization: - -```python -"""Tests for new MMMPlotSuite with multi-backend support.""" - -import pytest -from arviz_plots import PlotCollection -from pymc_marketing.mmm import mmm_config -from pymc_marketing.mmm.plot import MMMPlotSuite - -@pytest.fixture(scope="module") -def new_mock_suite(mock_idata) -> MMMPlotSuite: - """New MMMPlotSuite instance with basic mock data.""" - return MMMPlotSuite(idata=mock_idata) - -@pytest.fixture(scope="module") -def new_mock_suite_with_constant_data(mock_idata_with_constant_data) -> MMMPlotSuite: - """New MMMPlotSuite with constant data for saturation plots.""" - return MMMPlotSuite(idata=mock_idata_with_constant_data) - -# Parametrize all tests across all backends -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_posterior_predictive(new_mock_suite, backend): - """Test posterior_predictive works with all backends.""" - pc = new_mock_suite.posterior_predictive(backend=backend) - assert isinstance(pc, PlotCollection) - assert pc.backend == backend - -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_contributions_over_time(new_mock_suite, backend): - """Test contributions_over_time works with all backends.""" - pc = new_mock_suite.contributions_over_time( - var=["intercept"], - backend=backend - ) - assert isinstance(pc, PlotCollection) - assert pc.backend == backend - -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_saturation_scatterplot(new_mock_suite_with_constant_data, backend): - """Test saturation_scatterplot works with all backends.""" - pc = new_mock_suite_with_constant_data.saturation_scatterplot(backend=backend) - assert isinstance(pc, PlotCollection) - assert pc.backend == backend - -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_saturation_curves( - new_mock_suite_with_constant_data, mock_saturation_curve, backend -): - """Test saturation_curves works with all backends.""" - pc = new_mock_suite_with_constant_data.saturation_curves( - curve=mock_saturation_curve, - backend=backend - ) - assert isinstance(pc, PlotCollection) - assert pc.backend == backend - -@pytest.mark.parametrize("backend", ["matplotlib", "plotly", "bokeh"]) -def test_budget_allocation_roas(new_mock_suite, backend): - """Test budget_allocation_roas works with all backends.""" - # Note: This is a NEW method that doesn't exist in legacy suite - pc = new_mock_suite.budget_allocation_roas(backend=backend) - assert isinstance(pc, PlotCollection) - assert pc.backend == backend - -# ... Create tests for all 8 methods with 3 backends = 24 core tests ... -``` - -**Step 3: Remove experimental test_plot_backends.py** -```bash -rm tests/mmm/test_plot_backends.py -``` - -**Step 4: Add backend-specific tests** - -```python -def test_backend_overrides_global_config(mock_suite): - """Test that method backend parameter overrides global config.""" - original = mmm_config.get("plot.backend", "matplotlib") - try: - mmm_config["plot.backend"] = "matplotlib" - - # Override with plotly - pc = mock_suite.contributions_over_time( - var=["intercept"], - backend="plotly" - ) - assert pc.backend == "plotly" - - # Default should still be matplotlib - pc2 = mock_suite.contributions_over_time(var=["intercept"]) - assert pc2.backend == "matplotlib" - finally: - mmm_config["plot.backend"] = original - -def test_invalid_backend_warning(mock_suite): - """Test that invalid backend shows warning but attempts plot.""" - import warnings - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - mmm_config["plot.backend"] = "invalid_backend" - - assert len(w) == 1 - assert "Invalid backend" in str(w[0].message) -``` - -**Result**: -- New suite: ~8 methods × 3 backends = ~24 core test cases (plus backend-specific tests) - note: saturation_curves_scatter removed -- Legacy suite: ~28 existing test functions (matplotlib only, will be removed in v0.20.0) - -**File 2: Create tests/mmm/test_plot_compatibility.py** - -New file for backward compatibility: - -```python -"""Tests for MMMPlotSuite backward compatibility and version switching.""" - -import pytest -import warnings -import numpy as np -from matplotlib.figure import Figure -from matplotlib.axes import Axes -from arviz_plots import PlotCollection - -from pymc_marketing.mmm import mmm_config -from pymc_marketing.mmm.plot import MMMPlotSuite -from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite - - -class TestVersionSwitching: - """Test mmm_config['plot.use_v2'] flag controls suite version.""" - - def test_use_v2_false_returns_legacy_suite(self, mock_mmm): - """Test that use_v2=False returns LegacyMMMPlotSuite.""" - original = mmm_config.get("plot.use_v2", False) - try: - mmm_config["plot.use_v2"] = False - - with pytest.warns(FutureWarning, match="deprecated in v0.20.0"): - plot_suite = mock_mmm.plot - - assert isinstance(plot_suite, LegacyMMMPlotSuite) - assert not isinstance(plot_suite, MMMPlotSuite) - finally: - mmm_config["plot.use_v2"] = original - - def test_use_v2_true_returns_new_suite(self, mock_mmm): - """Test that use_v2=True returns MMMPlotSuite.""" - original = mmm_config.get("plot.use_v2", False) - try: - mmm_config["plot.use_v2"] = True - - # Should not warn - with warnings.catch_warnings(): - warnings.simplefilter("error") # Turn warnings into errors - plot_suite = mock_mmm.plot - - assert isinstance(plot_suite, MMMPlotSuite) - finally: - mmm_config["plot.use_v2"] = original - - def test_default_is_legacy_suite(self, mock_mmm): - """Test that default behavior uses legacy suite (backward compatible).""" - # Reset to defaults - mmm_config.reset() - - with pytest.warns(FutureWarning): - plot_suite = mock_mmm.plot - - assert isinstance(plot_suite, LegacyMMMPlotSuite) - - -class TestDeprecationWarnings: - """Test deprecation warning system.""" - - def test_deprecation_warning_shown_by_default(self, mock_mmm): - """Test that deprecation warning is shown when using legacy suite.""" - mmm_config["plot.use_v2"] = False - mmm_config["plot.show_warnings"] = True - - with pytest.warns(FutureWarning, match=r"deprecated in v0\.20\.0"): - plot_suite = mock_mmm.plot - - assert isinstance(plot_suite, LegacyMMMPlotSuite) - - def test_deprecation_warning_suppressible(self, mock_mmm): - """Test that deprecation warning can be suppressed.""" - original_use_v2 = mmm_config.get("plot.use_v2", False) - original_warnings = mmm_config.get("plot.show_warnings", True) - - try: - mmm_config["plot.use_v2"] = False - mmm_config["plot.show_warnings"] = False - - # Should not warn - with warnings.catch_warnings(): - warnings.simplefilter("error") # Turn warnings into errors - plot_suite = mock_mmm.plot - - assert isinstance(plot_suite, LegacyMMMPlotSuite) - finally: - mmm_config["plot.use_v2"] = original_use_v2 - mmm_config["plot.show_warnings"] = original_warnings - - def test_warning_message_includes_migration_info(self, mock_mmm): - """Test that warning provides clear migration instructions.""" - mmm_config["plot.use_v2"] = False - mmm_config["plot.show_warnings"] = True - - with pytest.warns(FutureWarning) as warning_list: - plot_suite = mock_mmm.plot - - warning_msg = str(warning_list[0].message) - assert "v0.20.0" in warning_msg - assert "mmm_config['plot.use_v2'] = True" in warning_msg - assert "migration guide" in warning_msg.lower() or "documentation" in warning_msg.lower() - - def test_no_warning_when_using_new_suite(self, mock_mmm): - """Test that no warning shown when using new suite.""" - mmm_config["plot.use_v2"] = True - - with warnings.catch_warnings(): - warnings.simplefilter("error") - plot_suite = mock_mmm.plot - - assert isinstance(plot_suite, MMMPlotSuite) - - -class TestReturnTypeCompatibility: - """Test that both suites return expected types.""" - - def test_legacy_suite_returns_tuple(self, mock_mmm_fitted): - """Test legacy suite returns (Figure, Axes) tuple.""" - mmm_config["plot.use_v2"] = False - - with pytest.warns(FutureWarning): - plot_suite = mock_mmm_fitted.plot - result = plot_suite.posterior_predictive() - - assert isinstance(result, tuple) - assert len(result) == 2 - assert isinstance(result[0], Figure) - # result[1] can be Axes or ndarray of Axes - if isinstance(result[1], np.ndarray): - assert all(isinstance(ax, Axes) for ax in result[1].flat) - else: - assert isinstance(result[1], Axes) - - def test_new_suite_returns_plot_collection(self, mock_mmm_fitted): - """Test new suite returns PlotCollection.""" - mmm_config["plot.use_v2"] = True - - plot_suite = mock_mmm_fitted.plot - result = plot_suite.posterior_predictive() - - assert isinstance(result, PlotCollection) - assert hasattr(result, 'backend') - assert hasattr(result, 'show') - - def test_both_suites_produce_valid_plots(self, mock_mmm_fitted): - """Test that both suites can successfully create plots.""" - # Legacy suite - mmm_config["plot.use_v2"] = False - with pytest.warns(FutureWarning): - legacy_result = mock_mmm_fitted.plot.contributions_over_time( - var=["intercept"] - ) - assert legacy_result is not None - - # New suite - mmm_config["plot.use_v2"] = True - new_result = mock_mmm_fitted.plot.contributions_over_time( - var=["intercept"] - ) - assert new_result is not None - - -class TestMissingMethods: - """Test handling of methods that exist in one suite but not the other.""" - - def test_budget_allocation_exists_in_legacy_suite(self, mock_mmm_fitted, mock_allocation_samples): - """Test that budget_allocation() works in legacy suite.""" - mmm_config["plot.use_v2"] = False - - with pytest.warns(FutureWarning): - plot_suite = mock_mmm_fitted.plot - - # Should work (not raise AttributeError) - result = plot_suite.budget_allocation(samples=mock_allocation_samples) - assert isinstance(result, tuple) - - def test_budget_allocation_raises_in_new_suite(self, mock_mmm_fitted): - """Test that budget_allocation() raises helpful error in new suite.""" - mmm_config["plot.use_v2"] = True - plot_suite = mock_mmm_fitted.plot - - with pytest.raises(NotImplementedError, match="removed in MMMPlotSuite v2"): - plot_suite.budget_allocation(samples=None) - - def test_budget_allocation_roas_exists_in_new_suite( - self, mock_mmm_fitted, mock_allocation_samples - ): - """Test that budget_allocation_roas() works in new suite.""" - mmm_config["plot.use_v2"] = True - plot_suite = mock_mmm_fitted.plot - - result = plot_suite.budget_allocation_roas(samples=mock_allocation_samples) - assert isinstance(result, PlotCollection) - - def test_budget_allocation_roas_missing_in_legacy_suite(self, mock_mmm_fitted): - """Test that budget_allocation_roas() doesn't exist in legacy suite.""" - mmm_config["plot.use_v2"] = False - - with pytest.warns(FutureWarning): - plot_suite = mock_mmm_fitted.plot - - with pytest.raises(AttributeError): - plot_suite.budget_allocation_roas(samples=None) - - -class TestParameterCompatibility: - """Test parameter compatibility between suites.""" - - def test_var_parameter_list_in_legacy_suite(self, mock_mmm_fitted): - """Test that legacy suite accepts var as list.""" - mmm_config["plot.use_v2"] = False - - with pytest.warns(FutureWarning): - plot_suite = mock_mmm_fitted.plot - - # Should accept list - result = plot_suite.posterior_predictive(var=["y", "target"]) - assert isinstance(result, tuple) - - def test_var_parameter_list_warning_in_new_suite(self, mock_mmm_fitted): - """Test that new suite warns when given list for var.""" - mmm_config["plot.use_v2"] = True - plot_suite = mock_mmm_fitted.plot - - with pytest.warns(UserWarning, match="only supports single variable"): - result = plot_suite.posterior_predictive(var=["y"]) - - assert isinstance(result, PlotCollection) -``` - -**File 3: Additional fixtures in tests/conftest.py or tests/mmm/conftest.py** - -```python -@pytest.fixture -def mock_mmm(mock_idata): - """Mock MMM instance with idata.""" - from pymc_marketing.mmm.multidimensional import MMM - - mmm = Mock(spec=MMM) - mmm.idata = mock_idata - mmm._validate_model_was_built = Mock() - mmm._validate_idata_exists = Mock() - - # Make .plot property work - type(mmm).plot = MMM.plot - - return mmm - -@pytest.fixture -def mock_allocation_samples(): - """Mock samples dataset for budget allocation tests.""" - import xarray as xr - import numpy as np - - rng = np.random.default_rng(42) - - return xr.Dataset({ - "channel_contribution_original_scale": xr.DataArray( - rng.normal(size=(4, 100, 52, 3)), - dims=("chain", "draw", "date", "channel"), - coords={ - "chain": np.arange(4), - "draw": np.arange(100), - "date": pd.date_range("2025-01-01", periods=52, freq="W"), - "channel": ["TV", "Radio", "Digital"], - }, - ), - "allocation": xr.DataArray( - rng.uniform(100, 1000, size=(3,)), - dims=("channel",), - coords={"channel": ["TV", "Radio", "Digital"]}, - ), - }) -``` - -#### 6.4 Test Execution Checklist - -**Backend Testing:** -- [ ] Remove experimental test_plot_backends.py file -- [ ] Remove deprecated `saturation_curves_scatter()` from MMMPlotSuite -- [ ] Parametrize all tests in new test_plot.py with backend parameter (~8 methods) -- [ ] All ~24 parametrized tests pass (8 methods × 3 backends) -- [ ] Backend override test works correctly -- [ ] Invalid backend warning test passes - -**Compatibility Testing:** -- [ ] Create new test_plot_compatibility.py file -- [ ] All 15+ compatibility tests pass -- [ ] Config flag switching works -- [ ] Deprecation warnings show correctly -- [ ] Warnings are suppressible -- [ ] Both suites produce valid output -- [ ] Missing method raises helpful errors - -### 7. Import/Export Architecture - -#### 7.1 Current Import Chain - -``` -User Code - ↓ -from pymc_marketing.mmm.multidimensional import MMM - ↓ -MMM.plot property (multidimensional.py:602-607) - ↓ -Imports: from pymc_marketing.mmm.plot import MMMPlotSuite - ↓ -Returns: MMMPlotSuite(idata=self.idata) -``` - -#### 7.2 Required Imports for Compatibility - -**In multidimensional.py:** -```python -# Current (line 194) -from pymc_marketing.mmm.plot import MMMPlotSuite - -# Need to add in .plot property -from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite # Import locally in property -from pymc_marketing.mmm.config import mmm_config # Import locally in property -``` - -**NOT in mmm/__init__.py:** -- MMMPlotSuite is **not** exported in [pymc_marketing/mmm/__init__.py](pymc_marketing/mmm/__init__.py#L69-L119) -- Users access it via `mmm.plot.method()`, not by importing directly -- This is good - no need to modify `__all__` - -#### 7.3 User Usage Pattern - -```python -# Users do this: -from pymc_marketing.mmm.multidimensional import MMM - -mmm = MMM(...) -mmm.fit(...) - -# Access via property - this is where version switching happens -mmm.plot.posterior_predictive() # ← Property returns either old or new suite -``` - -### 8. Migration Timeline - -#### Phase 1: v0.18.0 (Current/Next Release) -**Goal**: Introduce new suite with safe fallback - -- ✅ Backend configuration (done) -- ✅ New suite implementation (done) -- ❌ Add `use_v2` flag to config (TODO) -- ❌ Implement version switching in `.plot` property (TODO) -- ❌ Add deprecation warning (TODO) -- ❌ Complete test coverage (TODO) -- ❌ Write migration guide documentation (TODO) - -**User Experience**: -- Default behavior: legacy suite with warning -- Opt-in to new: `mmm_config["plot.use_v2"] = True` -- Clear migration path provided - -#### Phase 2: v0.19.0 -**Goal**: Encourage migration to new suite - -- Change default: `"plot.use_v2": True` -- Keep legacy suite available via `use_v2=False` -- Strengthen warning when using legacy suite -- Monitor for issues - -**User Experience**: -- Default behavior: new suite -- Opt-out to legacy: `mmm_config["plot.use_v2"] = False` -- Legacy suite shows stronger deprecation warning - -#### Phase 3: v0.20.0 -**Goal**: Complete migration - -- Remove `LegacyMMMPlotSuite` class -- Remove `legacy_plot.py` file -- Remove `use_v2` flag -- Update all documentation -- Only new suite available - -**User Experience**: -- Only new suite available -- Legacy code must update to new API - -### 9. Breaking Changes Summary - -#### 9.1 Return Type - -**Legacy**: `tuple[Figure, NDArray[Axes]]` or `tuple[Figure, plt.Axes]` -```python -fig, axes = mmm.plot.posterior_predictive() -axes[0].set_title("Custom") -fig.savefig("plot.png") -``` - -**New**: `PlotCollection` -```python -pc = mmm.plot.posterior_predictive() -pc.show() # Display -pc.save("plot.png") # Save -``` - -#### 9.2 Parameter Changes - -| Method | Parameter | Legacy | New | Fix | -|--------|-----------|--------|-----|-----| -| `posterior_predictive()` | `var` | `list[str]` | `str` | Call multiple times or use list with warning | -| `saturation_scatterplot()` | `**kwargs` | Accepted | Removed | Customize PlotCollection after | -| `saturation_curves()` | `colors` | Supported | Removed | Use PlotCollection API | -| `saturation_curves()` | `subplot_kwargs` | Supported | Removed | Use PlotCollection API | -| `saturation_curves()` | `rc_params` | Supported | Removed | Set before calling | -| All methods | `ax` | Supported | Removed | Use PlotCollection | -| All methods | `figsize` | Supported | Removed | Use PlotCollection | -| All methods | `backend` | N/A | Added | Override global config | - -#### 9.3 Method Changes - -| Method | Status | Replacement | Notes | -|--------|--------|-------------|-------| -| `saturation_curves_scatter()` | **REMOVED in v2** | `saturation_scatterplot()` | Deprecated in v0.1.0, not carried forward to v2 | -| `budget_allocation()` | **REMOVED** | None exact | Use legacy suite or custom plot | -| `budget_allocation_roas()` | **NEW** | N/A | Different purpose (ROI dist) | - -### 10. Documentation Requirements - -#### 10.1 Migration Guide (docs/source/guides/mmm_plotting_migration.rst) - -Must include: - -1. **Overview** - - Why the change (arviz_plots benefits) - - Timeline (v0.18.0 intro, v0.19.0 default, v0.20.0 removal) - - How to opt-in/opt-out - -2. **Quick Start** - ```python - # Use new suite - from pymc_marketing.mmm import mmm_config - mmm_config["plot.use_v2"] = True - - # Set backend - mmm_config["plot.backend"] = "plotly" - ``` - -3. **Return Type Migration** - - Side-by-side examples - - How to work with PlotCollection - -4. **Method-by-Method Guide** - - API changes table - - Code examples for each method - - Common issues and solutions - -5. **Missing Features** - - `budget_allocation()` alternatives - - Lost customization parameters - - Workarounds - -6. **Backend Selection** - - Pros/cons of each backend - - When to use which - - Examples - -#### 10.2 Docstring Updates - -All methods in new suite need: -```python -def method_name(...) -> PlotCollection: - """ - Description. - - .. versionadded:: 0.18.0 - New arviz_plots-based implementation supporting multiple backends. - - Parameters - ---------- - backend : str, optional - Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". - If None, uses global config via mmm_config["plot.backend"]. - Default is "matplotlib". - - Returns - ------- - PlotCollection - arviz_plots PlotCollection object containing the plot. - Use .show() to display or .save("filename") to save. - Supports matplotlib, plotly, and bokeh backends. - - Unlike v1 which returned (Figure, Axes), this provides - a unified interface across all backends. - - Examples - -------- - Basic usage: - - >>> pc = mmm.plot.method_name() - >>> pc.show() - - Save to file: - - >>> pc.save("output.png") - - Use different backend: - - >>> pc = mmm.plot.method_name(backend="plotly") - >>> pc.show() - """ -``` - -## Code References - -### Core Implementation Files -- [pymc_marketing/mmm/plot.py](pymc_marketing/mmm/plot.py) - New MMMPlotSuite (1272 lines) -- [pymc_marketing/mmm/old_plot.py](pymc_marketing/mmm/old_plot.py) - Legacy implementation (1936 lines) - **TO BE RENAMED to legacy_plot.py** -- [pymc_marketing/mmm/config.py:21-66](pymc_marketing/mmm/config.py#L21-L66) - Backend configuration -- [pymc_marketing/mmm/multidimensional.py:602-607](pymc_marketing/mmm/multidimensional.py#L602-L607) - Integration point (.plot property) - -### Test Files -- [tests/mmm/test_plot.py](tests/mmm/test_plot.py) - Main plot tests (800+ lines) -- [tests/mmm/test_plot_backends.py](tests/mmm/test_plot_backends.py) - Backend tests (255 lines, incomplete) -- [tests/mmm/test_plotting.py](tests/mmm/test_plotting.py) - Legacy plotting tests - -### Deprecation Patterns -- [pymc_marketing/model_builder.py:60-77](pymc_marketing/model_builder.py#L60-L77) - Parameter deprecation helper -- [pymc_marketing/mmm/plot.py:737-771](pymc_marketing/mmm/plot.py#L737-L771) - Method deprecation example -- [pymc_marketing/clv/models/basic.py:49-59](pymc_marketing/clv/models/basic.py#L49-L59) - Config key deprecation -- [tests/test_model_builder.py:530-554](tests/test_model_builder.py#L530-L554) - Deprecation test pattern - -## Architecture Insights - -1. **Config-Based Feature Flags**: The codebase uses dict-based configuration (`mmm_config`) for runtime behavior control, similar to matplotlib's `rcParams` or arviz's config system. - -2. **Property-Based API**: Plot methods are accessed via `.plot` property that creates instances on-demand, enabling clean version switching at the access point. - -3. **Backend Abstraction**: The new implementation achieves backend independence through arviz_plots' `PlotCollection`, which handles backend-specific rendering internally. - -4. **Test Fixture Patterns**: All test fixtures use deterministic random seeds and module scope for performance, following pytest best practices. - -5. **Deprecation Philosophy**: The codebase uses `DeprecationWarning` for library developers and `FutureWarning` for end users, with clear migration paths in all warnings. - -6. **Incremental Migration**: Multiple patterns show support for gradual API transitions over several releases before removing old code. - -## Data Parameter Standardization ⚠️ **CRITICAL - MUST IMPLEMENT** - -### Summary - -**Goal**: All plotting methods should accept data as input parameters for consistency, testability, and flexibility. - -**Status**: Currently **inconsistent** - some methods accept data, others hard-code `self.idata` access. - -**Impact**: Must be fixed BEFORE writing tests, as tests need to be written against the correct API. - -**Time Estimate**: 4 hours - -**Key Changes**: -- 7 methods need updates -- `_sensitivity_analysis_plot()` must accept `data` as REQUIRED parameter (no fallback) -- All other methods can have fallback to `self.idata` -- Removes monkey-patching in `uplift_curve()` and `marginal_curve()` - -### Current State Analysis - -The new MMMPlotSuite methods currently have **inconsistent data parameter patterns**: - -**✅ Methods that already accept data as input:** -- `posterior_predictive(idata: xr.Dataset | None)` - With fallback to `self.idata.posterior_predictive` -- `budget_allocation_roas(samples: xr.Dataset)` - No fallback -- `allocated_contribution_by_channel_over_time(samples: xr.Dataset)` - No fallback - -**❌ Methods that need data parameters added:** -- `contributions_over_time()` - Currently uses `self.idata.posterior` directly -- `saturation_scatterplot()` - Currently uses `self.idata.constant_data` and `self.idata.posterior` -- `saturation_curves()` - Accepts `curve` but still uses `self.idata.constant_data` and `self.idata.posterior` for scatter -- `_sensitivity_analysis_plot()` - Currently uses `self.idata.sensitivity_analysis` (**must accept data, NO fallback**) -- `sensitivity_analysis()` - Needs to accept and pass data to `_sensitivity_analysis_plot()` -- `uplift_curve()` - Needs to accept and pass data to `_sensitivity_analysis_plot()` -- `marginal_curve()` - Needs to accept and pass data to `_sensitivity_analysis_plot()` - -### Required API Changes - -#### 1. contributions_over_time() - Add data parameter with fallback - -**Current signature** (line 387): -```python -def contributions_over_time( - self, - var: list[str], - hdi_prob: float = 0.85, - dims: dict[str, str | int | list] | None = None, - backend: str | None = None, -) -> PlotCollection: -``` - -**New signature**: -```python -def contributions_over_time( - self, - var: list[str], - data: xr.Dataset | None = None, # ← ADD THIS - hdi_prob: float = 0.85, - dims: dict[str, str | int | list] | None = None, - backend: str | None = None, -) -> PlotCollection: - """Plot the time-series contributions for each variable in `var`. - - Parameters - ---------- - var : list of str - A list of variable names to plot from the posterior. - data : xr.Dataset, optional - Dataset containing posterior data. If None, uses self.idata.posterior. - ... - """ -``` - -**Implementation changes** (lines 426-437): -```python -# OLD: -if not hasattr(self.idata, "posterior"): - raise ValueError(...) -da = self.idata.posterior[var] - -# NEW: -if data is None: - if not hasattr(self.idata, "posterior"): - raise ValueError( - "No posterior data found in 'self.idata' and no 'data' argument provided. " - "Please ensure 'self.idata' contains a 'posterior' group or provide 'data'." - ) - data = self.idata.posterior -da = data[var] -``` - -#### 2. saturation_scatterplot() - Add data parameters with fallback - -**Current signature** (line 493): -```python -def saturation_scatterplot( - self, - original_scale: bool = False, - dims: dict[str, str | int | list] | None = None, - backend: str | None = None, -) -> PlotCollection: -``` - -**New signature**: -```python -def saturation_scatterplot( - self, - original_scale: bool = False, - constant_data: xr.Dataset | None = None, # ← ADD THIS - posterior_data: xr.Dataset | None = None, # ← ADD THIS - dims: dict[str, str | int | list] | None = None, - backend: str | None = None, -) -> PlotCollection: - """Plot the saturation curves for each channel. - - Parameters - ---------- - original_scale : bool, optional - Whether to plot the original scale contributions. Default is False. - constant_data : xr.Dataset, optional - Dataset containing constant_data group with 'channel_data' variable. - If None, uses self.idata.constant_data. - posterior_data : xr.Dataset, optional - Dataset containing posterior group with channel contribution variables. - If None, uses self.idata.posterior. - ... - """ -``` - -**Implementation changes** (lines 524-562): -```python -# OLD: -if not hasattr(self.idata, "constant_data"): - raise ValueError(...) -cdims = self.idata.constant_data.channel_data.dims -channel_data = self.idata.constant_data.channel_data -channel_contrib = self.idata.posterior[channel_contribution] - -# NEW: -if constant_data is None: - if not hasattr(self.idata, "constant_data"): - raise ValueError( - "No 'constant_data' found in 'self.idata' and no 'constant_data' argument provided. " - "Please ensure 'self.idata' contains the constant_data group or provide 'constant_data'." - ) - constant_data = self.idata.constant_data - -if posterior_data is None: - if not hasattr(self.idata, "posterior"): - raise ValueError( - "No 'posterior' found in 'self.idata' and no 'posterior_data' argument provided. " - "Please ensure 'self.idata' contains the posterior group or provide 'posterior_data'." - ) - posterior_data = self.idata.posterior - -cdims = constant_data.channel_data.dims -channel_data = constant_data.channel_data -channel_contrib = posterior_data[channel_contribution] -``` - -#### 3. saturation_curves() - Update to use data parameters from saturation_scatterplot - -**Current signature** (line 597): -```python -def saturation_curves( - self, - curve: xr.DataArray, - original_scale: bool = False, - n_samples: int = 10, - hdi_probs: float | list[float] | None = None, - random_seed: np.random.Generator | None = None, - dims: dict[str, str | int | list] | None = None, - backend: str | None = None, -) -> PlotCollection: -``` - -**New signature**: -```python -def saturation_curves( - self, - curve: xr.DataArray, - original_scale: bool = False, - constant_data: xr.Dataset | None = None, # ← ADD THIS - posterior_data: xr.Dataset | None = None, # ← ADD THIS - n_samples: int = 10, - hdi_probs: float | list[float] | None = None, - random_seed: np.random.Generator | None = None, - dims: dict[str, str | int | list] | None = None, - backend: str | None = None, -) -> PlotCollection: - """Overlay saturation‑curve scatter‑plots with posterior‑predictive sample curves. - - Parameters - ---------- - curve : xr.DataArray - Posterior‑predictive curves (e.g. dims `("chain","draw","x","channel","geo")`). - original_scale : bool, default=False - Plot `channel_contribution_original_scale` if True, else `channel_contribution`. - constant_data : xr.Dataset, optional - Dataset containing constant_data group. If None, uses self.idata.constant_data. - posterior_data : xr.Dataset, optional - Dataset containing posterior group. If None, uses self.idata.posterior. - ... - """ -``` - -**Implementation changes** (lines 645-696): -```python -# OLD: -if not hasattr(self.idata, "constant_data"): - raise ValueError(...) -if original_scale: - curve_data = curve * self.idata.constant_data.target_scale - curve_data["x"] = curve_data["x"] * self.idata.constant_data.channel_scale -cdims = self.idata.constant_data.channel_data.dims -pc = self.saturation_scatterplot(original_scale=original_scale, dims=dims, backend=backend) - -# NEW: -if constant_data is None: - if not hasattr(self.idata, "constant_data"): - raise ValueError( - "No 'constant_data' found in 'self.idata' and no 'constant_data' argument provided." - ) - constant_data = self.idata.constant_data - -if posterior_data is None: - if not hasattr(self.idata, "posterior"): - raise ValueError( - "No 'posterior' found in 'self.idata' and no 'posterior_data' argument provided." - ) - posterior_data = self.idata.posterior - -if original_scale: - curve_data = curve * constant_data.target_scale - curve_data["x"] = curve_data["x"] * constant_data.channel_scale -cdims = constant_data.channel_data.dims -pc = self.saturation_scatterplot( - original_scale=original_scale, - constant_data=constant_data, - posterior_data=posterior_data, - dims=dims, - backend=backend -) -``` - -#### 4. _sensitivity_analysis_plot() - Accept data parameter WITHOUT fallback ⚠️ **CRITICAL** - -**Current signature** (line 979): -```python -def _sensitivity_analysis_plot( - self, - hdi_prob: float = 0.94, - aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - backend: str | None = None, -) -> PlotCollection: -``` - -**New signature**: -```python -def _sensitivity_analysis_plot( - self, - data: xr.DataArray | xr.Dataset, # ← ADD THIS (REQUIRED, NO DEFAULT) - hdi_prob: float = 0.94, - aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - backend: str | None = None, -) -> PlotCollection: - """Plot helper for sensitivity analysis results. - - Parameters - ---------- - data : xr.DataArray or xr.Dataset - Sensitivity analysis data to plot. Must have 'sample' and 'sweep' dimensions. - If Dataset, should contain 'x' variable. NO fallback to self.idata. - ... - """ -``` - -**Implementation changes** (lines 1002-1007): -```python -# OLD: -if not hasattr(self.idata, "sensitivity_analysis"): - raise ValueError("No sensitivity analysis results found. Run run_sweep() first.") -sa = self.idata.sensitivity_analysis -x = sa["x"] if isinstance(sa, xr.Dataset) else sa - -# NEW: -# Validate input data -if data is None: - raise ValueError( - "data parameter is required for _sensitivity_analysis_plot. " - "This is a helper method that should receive data explicitly." - ) - -# Handle Dataset or DataArray -x = data["x"] if isinstance(data, xr.Dataset) else data -``` - -**Rationale for NO fallback:** -- This is a **private helper method** (prefixed with `_`) -- It should be a pure plotting function that operates on provided data -- The public methods (`sensitivity_analysis()`, `uplift_curve()`, `marginal_curve()`) handle data retrieval from `self.idata` -- This separation of concerns makes the code more testable and maintainable - -#### 5. sensitivity_analysis() - Update to pass data - -**Current implementation** (lines 1071-1116): -```python -def sensitivity_analysis( - self, - hdi_prob: float = 0.94, - aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - backend: str | None = None, -) -> PlotCollection: - pc = self._sensitivity_analysis_plot( - hdi_prob=hdi_prob, aggregation=aggregation, backend=backend - ) - pc.map(azp.visuals.labelled_y, text="Contribution") - return pc -``` - -**New implementation**: -```python -def sensitivity_analysis( - self, - data: xr.DataArray | xr.Dataset | None = None, # ← ADD THIS - hdi_prob: float = 0.94, - aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - backend: str | None = None, -) -> PlotCollection: - """Plot sensitivity analysis results. - - Parameters - ---------- - data : xr.DataArray or xr.Dataset, optional - Sensitivity analysis data to plot. If None, uses self.idata.sensitivity_analysis. - ... - """ - # Retrieve data if not provided - if data is None: - if not hasattr(self.idata, "sensitivity_analysis"): - raise ValueError( - "No sensitivity analysis results found in 'self.idata' and no 'data' argument provided. " - "Run 'mmm.sensitivity.run_sweep()' first or provide 'data'." - ) - data = self.idata.sensitivity_analysis # type: ignore - - pc = self._sensitivity_analysis_plot( - data=data, # ← PASS DATA - hdi_prob=hdi_prob, - aggregation=aggregation, - backend=backend, - ) - pc.map(azp.visuals.labelled_y, text="Contribution") - return pc -``` - -#### 6. uplift_curve() - Update to pass data - -**Current implementation** (lines 1158-1193): -```python -def uplift_curve( - self, - hdi_prob: float = 0.94, - aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - backend: str | None = None, -) -> PlotCollection: - if not hasattr(self.idata, "sensitivity_analysis"): - raise ValueError(...) - - sa_group = self.idata.sensitivity_analysis - if isinstance(sa_group, xr.Dataset): - if "uplift_curve" not in sa_group: - raise ValueError(...) - data_var = sa_group["uplift_curve"] - else: - raise ValueError(...) - - # Monkey-patch approach with temporary swap - tmp_idata = xr.Dataset({"x": data_var}) - original_group = self.idata.sensitivity_analysis - try: - self.idata.sensitivity_analysis = tmp_idata - pc = self._sensitivity_analysis_plot(...) - ... - finally: - self.idata.sensitivity_analysis = original_group -``` - -**New implementation**: -```python -def uplift_curve( - self, - data: xr.DataArray | xr.Dataset | None = None, # ← ADD THIS - hdi_prob: float = 0.94, - aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - backend: str | None = None, -) -> PlotCollection: - """Plot precomputed uplift curves. - - Parameters - ---------- - data : xr.DataArray or xr.Dataset, optional - Uplift curve data to plot. If Dataset, should contain 'uplift_curve' variable. - If None, uses self.idata.sensitivity_analysis['uplift_curve']. - ... - """ - # Retrieve data if not provided - if data is None: - if not hasattr(self.idata, "sensitivity_analysis"): - raise ValueError( - "No sensitivity analysis results found in 'self.idata' and no 'data' argument provided. " - "Run 'mmm.sensitivity.run_sweep()' first or provide 'data'." - ) - - sa_group = self.idata.sensitivity_analysis # type: ignore - if isinstance(sa_group, xr.Dataset): - if "uplift_curve" not in sa_group: - raise ValueError( - "Expected 'uplift_curve' in idata.sensitivity_analysis. " - "Use SensitivityAnalysis.compute_uplift_curve_respect_to_base(..., extend_idata=True)." - ) - data = sa_group["uplift_curve"] - else: - raise ValueError( - "sensitivity_analysis does not contain 'uplift_curve'. Did you persist it to idata?" - ) - - # Handle Dataset input - if isinstance(data, xr.Dataset): - if "uplift_curve" in data: - data = data["uplift_curve"] - elif "x" in data: - data = data["x"] - else: - raise ValueError("Dataset must contain 'uplift_curve' or 'x' variable.") - - # Call helper with data (no more monkey-patching!) - pc = self._sensitivity_analysis_plot( - data=data, # ← PASS DATA DIRECTLY - hdi_prob=hdi_prob, - aggregation=aggregation, - backend=backend, - ) - pc.map(azp.visuals.labelled_y, text="Uplift (%)") - return pc -``` - -#### 7. marginal_curve() - Update to pass data - -**Current implementation** (lines 1237-1271): -```python -def marginal_curve( - self, - hdi_prob: float = 0.94, - aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - backend: str | None = None, -) -> PlotCollection: - if not hasattr(self.idata, "sensitivity_analysis"): - raise ValueError(...) - - sa_group = self.idata.sensitivity_analysis - # Similar monkey-patching as uplift_curve -``` - -**New implementation**: -```python -def marginal_curve( - self, - data: xr.DataArray | xr.Dataset | None = None, # ← ADD THIS - hdi_prob: float = 0.94, - aggregation: dict[str, tuple[str, ...] | list[str]] | None = None, - backend: str | None = None, -) -> PlotCollection: - """Plot precomputed marginal effects. - - Parameters - ---------- - data : xr.DataArray or xr.Dataset, optional - Marginal effects data to plot. If Dataset, should contain 'marginal_effects' variable. - If None, uses self.idata.sensitivity_analysis['marginal_effects']. - ... - """ - # Retrieve data if not provided - if data is None: - if not hasattr(self.idata, "sensitivity_analysis"): - raise ValueError( - "No sensitivity analysis results found in 'self.idata' and no 'data' argument provided. " - "Run 'mmm.sensitivity.run_sweep()' first or provide 'data'." - ) - - sa_group = self.idata.sensitivity_analysis # type: ignore - if isinstance(sa_group, xr.Dataset): - if "marginal_effects" not in sa_group: - raise ValueError( - "Expected 'marginal_effects' in idata.sensitivity_analysis. " - "Use SensitivityAnalysis.compute_marginal_effects(..., extend_idata=True)." - ) - data = sa_group["marginal_effects"] - else: - raise ValueError( - "sensitivity_analysis does not contain 'marginal_effects'. Did you persist it to idata?" - ) - - # Handle Dataset input - if isinstance(data, xr.Dataset): - if "marginal_effects" in data: - data = data["marginal_effects"] - elif "x" in data: - data = data["x"] - else: - raise ValueError("Dataset must contain 'marginal_effects' or 'x' variable.") - - # Call helper with data (no more monkey-patching!) - pc = self._sensitivity_analysis_plot( - data=data, # ← PASS DATA DIRECTLY - hdi_prob=hdi_prob, - aggregation=aggregation, - backend=backend, - ) - pc.map(azp.visuals.labelled_y, text="Marginal Effect") - return pc -``` - -### Benefits of This Standardization - -1. **Consistency**: All methods follow the same pattern for data handling -2. **Flexibility**: Users can pass external data or use data from self.idata -3. **Testability**: Methods can be tested with mock data without needing full MMM setup -4. **Separation of Concerns**: `_sensitivity_analysis_plot()` is a pure plotting function -5. **No More Monkey-Patching**: The uplift_curve() and marginal_curve() methods no longer need to temporarily swap self.idata -6. **Better Error Messages**: Clear messages when data is missing - -### Implementation Priority - -This is **CRITICAL** and should be completed as **Priority 0** (along with file renaming) because: -- It's a fundamental API design issue -- It affects multiple methods -- It's easier to fix before the migration is complete -- Tests need to be written against the correct API - -## Recommendations - -### Priority 0: File Renaming and Data Parameter Standardization (Must Complete First) - -0. **Rename files and classes** ✅ 30 minutes - - Rename `pymc_marketing/mmm/old_plot.py` to `legacy_plot.py` - - Rename class `OldMMMPlotSuite` to `LegacyMMMPlotSuite` throughout the file - - Update any imports in existing code/tests - - **This must be done BEFORE implementing other changes** - -0b. **Data Parameter Standardization** ✅ 4 hours - - Update `contributions_over_time()` to accept `data` parameter with fallback - - Update `saturation_scatterplot()` to accept `constant_data` and `posterior_data` parameters with fallback - - Update `saturation_curves()` to accept and pass `constant_data` and `posterior_data` parameters - - Update `_sensitivity_analysis_plot()` to accept `data` parameter WITHOUT fallback (REQUIRED parameter) - - Update `sensitivity_analysis()` to accept and pass `data` parameter with fallback - - Update `uplift_curve()` to accept and pass `data` parameter with fallback (removes monkey-patching) - - Update `marginal_curve()` to accept and pass `data` parameter with fallback (removes monkey-patching) - - **This is critical for API consistency and must be done BEFORE writing tests** - -### Priority 1: Critical (Must Complete for PR) - -1. **Remove deprecated method from new suite** ✅ 15 minutes - - Delete `saturation_curves_scatter()` from [pymc_marketing/mmm/plot.py:737-771](pymc_marketing/mmm/plot.py#L737-L771) - - Keep it in LegacyMMMPlotSuite (will be in legacy_plot.py after renaming) - - Document in migration guide that deprecated methods are not carried forward to v2 - -2. **Add backward compatibility flag** ✅ 2 hours - - Modify `config.py` to add `"plot.use_v2": False` - - Implement version switching in `multidimensional.py:602-607` - - Import from `legacy_plot` module - - Add deprecation warning with migration guide link - - Test manual switching works - -3. **Create comprehensive backend testing for new suite** ✅ 6 hours - - Rename existing test_plot.py to test_legacy_plot.py - - Update imports in legacy test file to use legacy_plot module - - CREATE NEW test_plot.py for the new MMMPlotSuite - - Write ~8 methods × 3 backends = ~24 parametrized tests (note: saturation_curves_scatter removed) - - Remove experimental test_plot_backends.py file - - Add backend override and invalid backend tests - - Verify all new tests pass - -4. **Create compatibility test suite** ✅ 3 hours - - Create `test_plot_compatibility.py` - - Test version switching (5 tests) - - Test deprecation warnings (4 tests) - - Test return types (3 tests) - - Test missing methods (4 tests) - - Test parameter compatibility (2 tests) - -### Priority 2: Important (Before Merge) - -5. **Update documentation** ⏱️ 4 hours - - Update method docstrings with PlotCollection info - - Add version directives (.. versionadded::) - - Document backend parameter - - Add usage examples - -6. **Write migration guide** ⏱️ 6 hours - - Create `docs/source/guides/mmm_plotting_migration.rst` - - Document all breaking changes (including parameter type changes) - - Provide side-by-side examples - - List missing features and workarounds - - Explain that parameter changes require code adaptation when switching to v2 - -### Priority 3: Nice to Have (Can Defer) - -8. **Add usage examples to docstrings** ⏱️ 2 hours - - Add Examples section to all methods - - Show basic usage, saving, backend switching - -9. **Create visual test notebook** ⏱️ 3 hours - - Notebook comparing old vs new outputs - - Demonstrates all backends - - Helps verify visual equivalence - -10. **Performance testing** ⏱️ 2 hours - - Compare old vs new rendering times - - Test with large datasets - - Document any performance changes - -## Open Questions - -### Q1: When should default switch from old to new? - -**Options**: -- A. v0.18.0 - Aggressive, breaks existing code immediately -- B. v0.19.0 - Conservative, gives 1 release for users to adapt -- C. v0.20.0 - Very conservative, 2 releases to adapt - -**Recommendation**: Option B (v0.19.0) -- v0.18.0: Introduce with legacy default + warning -- v0.19.0: Switch to new default, keep legacy available -- v0.20.0: Remove legacy completely - -### Q2: Should LegacyMMMPlotSuite be importable directly? - -**Current**: Only via `.plot` property -**Alternative**: Export in `mmm/__init__.py` - -**Recommendation**: Keep internal-only -- Encourages proper migration -- Reduces maintenance burden -- Users can still access via `use_v2=False` - -### Q3: How to handle `budget_allocation()` removal? - -**Options**: -- A. Keep in legacy suite, remove from new (current approach) -- B. Add adapter in new suite that approximates behavior -- C. Port to new suite with PlotCollection return type - -**Recommendation**: Option A with stub that raises -- Clear error message guides users -- Avoids maintaining duplicate functionality -- Allows temporary use of legacy suite - -### Q4: Should warnings be shown every time or once per session? - -**Current Pattern**: Every call -**Alternative**: Once per session using warning filters - -**Recommendation**: Every call (current) -- More visible, harder to ignore -- Consistent with other deprecation warnings -- Users can suppress globally if desired - -### Q5: What about projects pinned to specific versions? - -**Scenario**: User pins to v0.18.0, doesn't update - -**Solution**: -- `use_v2=False` default in v0.18.0 ensures no breakage -- Warning provides clear timeline -- Projects can update at their own pace -- No forced migration until they upgrade to v0.20.0+ - -## Implementation Checklist - -### Phase 1: Code Changes -- [ ] **Rename `old_plot.py` to `legacy_plot.py` and `OldMMMPlotSuite` to `LegacyMMMPlotSuite`** -- [ ] **Remove deprecated method from new suite:** - - [ ] Delete `saturation_curves_scatter()` from pymc_marketing/mmm/plot.py (lines 737-771) - - [ ] Keep it in LegacyMMMPlotSuite (legacy_plot.py) for backward compatibility - - [ ] Add note in migration guide about deprecated methods not carried forward to v2 -- [ ] **Data Parameter Standardization (CRITICAL - do before tests):** - - [ ] Update `contributions_over_time()` - add `data` parameter with fallback - - [ ] Update `saturation_scatterplot()` - add `constant_data` and `posterior_data` parameters with fallback - - [ ] Update `saturation_curves()` - add and pass `constant_data` and `posterior_data` parameters - - [ ] Update `_sensitivity_analysis_plot()` - add `data` parameter WITHOUT fallback (REQUIRED) - - [ ] Update `sensitivity_analysis()` - add and pass `data` parameter with fallback - - [ ] Update `uplift_curve()` - add and pass `data` parameter with fallback - - [ ] Update `marginal_curve()` - add and pass `data` parameter with fallback -- [ ] Add `"plot.use_v2": False` to config.py defaults -- [ ] Modify multidimensional.py `.plot` property with version switching -- [ ] Add FutureWarning for legacy suite usage -- [ ] Update all docstrings to document PlotCollection return type and new data parameters - -### Phase 2: Testing -- [ ] **Rename `tests/mmm/test_plot.py` to `test_legacy_plot.py` (tests for legacy suite)** -- [ ] **Update imports in renamed test file to use `legacy_plot.LegacyMMMPlotSuite`** -- [ ] **Create NEW `tests/mmm/test_plot.py` for new MMMPlotSuite** -- [ ] **Write ~8 methods × 3 backends = ~24 parametrized tests for new suite** (note: saturation_curves_scatter removed) -- [ ] Remove experimental `tests/mmm/test_plot_backends.py` file -- [ ] Remove deprecated `saturation_curves_scatter()` from pymc_marketing/mmm/plot.py -- [ ] Add backend override and invalid backend tests -- [ ] Create `tests/mmm/test_plot_compatibility.py` (15+ tests) -- [ ] Add mock_mmm fixture -- [ ] Add mock_allocation_samples fixture -- [ ] Verify all ~24 new suite backend tests pass -- [ ] Verify all 15 compatibility tests pass -- [ ] Test warning suppression works -- [ ] Test both suites produce valid output - -### Phase 3: Documentation -- [ ] Create migration guide (docs/source/guides/mmm_plotting_migration.rst) -- [ ] Document breaking changes table -- [ ] Provide code examples for migration -- [ ] Update API reference -- [ ] Add versionadded directives -- [ ] Document backend selection -- [ ] List missing features and workarounds - -### Phase 4: Review -- [ ] Code review for new implementation -- [ ] Test coverage review (aim for >95%) -- [ ] Documentation review -- [ ] Migration guide validation with sample code -- [ ] Timeline communication (v0.18.0 → v0.20.0) - -## Related Research - -- [CLAUDE.md](../../CLAUDE.md) - Project development guidelines -- [CONTRIBUTING.md](../../CONTRIBUTING.md) - Code style and testing requirements -- [pyproject.toml](../../pyproject.toml) - Test configuration and linting rules - -## Appendix: Complete Implementation Template - -### A. Config File Modification - -**File**: `pymc_marketing/mmm/config.py` - -```python -_defaults = { - "plot.backend": "matplotlib", - "plot.show_warnings": True, - "plot.use_v2": False, # ← ADD THIS LINE -} -``` - -### B. Property Modification - -**File**: `pymc_marketing/mmm/multidimensional.py` - -```python -@property -def plot(self) -> MMMPlotSuite | LegacyMMMPlotSuite: - """Use the MMMPlotSuite to plot the results. - - The plot suite version is controlled by mmm_config["plot.use_v2"]: - - False (default): Uses legacy matplotlib-based suite (will be deprecated) - - True: Uses new arviz_plots-based suite with multi-backend support - - .. versionchanged:: 0.18.0 - Added version control via mmm_config["plot.use_v2"]. - The legacy suite will be removed in v0.20.0. - - Examples - -------- - Use new plot suite: - - >>> from pymc_marketing.mmm import mmm_config - >>> mmm_config["plot.use_v2"] = True - >>> pc = mmm.plot.posterior_predictive() - >>> pc.show() - - Returns - ------- - MMMPlotSuite or LegacyMMMPlotSuite - Plot suite instance for creating MMM visualizations. - """ - from pymc_marketing.mmm.config import mmm_config - from pymc_marketing.mmm.plot import MMMPlotSuite - from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite - import warnings - - self._validate_model_was_built() - self._validate_idata_exists() - - # Check version flag - if mmm_config.get("plot.use_v2", False): - return MMMPlotSuite(idata=self.idata) - else: - # Show deprecation warning for legacy suite - if mmm_config.get("plot.show_warnings", True): - warnings.warn( - "The current MMMPlotSuite will be deprecated in v0.20.0. " - "The new version uses arviz_plots and supports multiple backends (matplotlib, plotly, bokeh). " - "To use the new version: mmm_config['plot.use_v2'] = True\n" - "To suppress this warning: mmm_config['plot.show_warnings'] = False\n" - "See migration guide: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html", - FutureWarning, - stacklevel=2, - ) - return LegacyMMMPlotSuite(idata=self.idata) -``` - -### C. Missing Method Stub - -**File**: `pymc_marketing/mmm/plot.py` (add to MMMPlotSuite class) - -```python -def budget_allocation(self, *args, **kwargs): - """ - Create bar chart comparing allocated spend and channel contributions. - - .. deprecated:: 0.18.0 - This method was removed in MMMPlotSuite v2. The arviz_plots library - used in v2 doesn't support this specific chart type. See alternatives below. - - Raises - ------ - NotImplementedError - This method is not available in MMMPlotSuite v2. - - Notes - ----- - Alternatives: - - 1. **For ROI distributions**: Use :meth:`budget_allocation_roas` - (different purpose but related to budget allocation) - - 2. **To use the old method**: Switch to legacy suite: - - >>> from pymc_marketing.mmm import mmm_config - >>> mmm_config["plot.use_v2"] = False - >>> mmm.plot.budget_allocation(samples) - - 3. **Custom implementation**: Create bar chart using samples data: - - >>> import matplotlib.pyplot as plt - >>> channel_contrib = samples["channel_contribution"].mean(...) - >>> allocated_spend = samples["allocation"] - >>> # Create custom bar chart with matplotlib - - See Also - -------- - budget_allocation_roas : Plot ROI distributions by channel - - Examples - -------- - Use legacy suite temporarily: - - >>> from pymc_marketing.mmm import mmm_config - >>> original = mmm_config.get("plot.use_v2") - >>> try: - ... mmm_config["plot.use_v2"] = False - ... fig, ax = mmm.plot.budget_allocation(samples) - ... fig.savefig("budget.png") - ... finally: - ... mmm_config["plot.use_v2"] = original - """ - raise NotImplementedError( - "budget_allocation() was removed in MMMPlotSuite v2.\n\n" - "The new arviz_plots-based implementation doesn't support this chart type.\n\n" - "Alternatives:\n" - " 1. For ROI distributions: use budget_allocation_roas()\n" - " 2. To use old method: set mmm_config['plot.use_v2'] = False\n" - " 3. Implement custom bar chart using the samples data\n\n" - "See documentation: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html#budget-allocation" - ) -``` - ---- - -**End of Research Document** From cdb48774b0a4d513e80a467fce1ce4e9bbadff4c Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Sat, 22 Nov 2025 09:12:01 -0500 Subject: [PATCH 19/29] update Claude.md --- CLAUDE.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CLAUDE.md b/CLAUDE.md index b044d95ea..ebab1a835 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -25,7 +25,9 @@ make init ### Testing and Quality To use pytest you first need to activate the enviroment: ```bash -source ~/miniconda3/etc/profile.d/conda.sh && conda activate pymc-marketing-dev +# Try to initialize conda (works if conda is in PATH or common locations) +eval "$(conda shell.bash hook 2>/dev/null)" && conda activate pymc-marketing-dev || \ +source "$(conda info --base 2>/dev/null)/etc/profile.d/conda.sh" && conda activate pymc-marketing-dev ``` Running tests: From 3741bb233bee16cfd65ccdee22c7b954468e6b55 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Sat, 22 Nov 2025 14:25:22 -0500 Subject: [PATCH 20/29] update env --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 829f07d6e..b2e69b551 100644 --- a/environment.yml +++ b/environment.yml @@ -62,3 +62,4 @@ dependencies: - pip: - roadmapper - labs-sphinx-theme + - arviz-plots[matplotlib] From 8cb9d07185b2b3121167e18d309e679dc0dbfcda Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Sat, 22 Nov 2025 14:31:38 -0500 Subject: [PATCH 21/29] update pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 92e3aadb4..e2cd7dcea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "pyprojroot", "pymc-extras>=0.4.0", "preliz>=0.20.0", + "arviz_plots[matplotlib]>0.7.0" ] [project.optional-dependencies] From cb9dccccdd5791860cd32fa8c8a43d88fa2eeb28 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Sat, 22 Nov 2025 14:33:41 -0500 Subject: [PATCH 22/29] update pyproject.toml --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e2cd7dcea..379bb90e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "pyprojroot", "pymc-extras>=0.4.0", "preliz>=0.20.0", - "arviz_plots[matplotlib]>0.7.0" + "arviz_plots[matplotlib]>=0.7.0" ] [project.optional-dependencies] @@ -93,6 +93,7 @@ test = [ "osqp<1.0.0,>=0.6.2", "pygraphviz", "preliz>=0.20.0", + "arviz_plots[plotly,bokeh]>=0.7.0" ] [tool.hatch.build.targets.sdist] From 97281280580b27e00c9f97a09a4c778ccfdf85f1 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Sat, 22 Nov 2025 15:13:18 -0500 Subject: [PATCH 23/29] increase coverage of config.py to 100 --- tests/mmm/test_plot_compatibility.py | 87 ++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/tests/mmm/test_plot_compatibility.py b/tests/mmm/test_plot_compatibility.py index f3a81f6b4..2f2da779c 100644 --- a/tests/mmm/test_plot_compatibility.py +++ b/tests/mmm/test_plot_compatibility.py @@ -473,3 +473,90 @@ def test_valid_keys_do_not_warn(self): mmm_config["plot.backend"] = original_backend mmm_config["plot.use_v2"] = original_use_v2 mmm_config["plot.show_warnings"] = original_warnings + + def test_reset_restores_defaults(self): + """Test that reset() restores all configuration to default values.""" + from pymc_marketing.mmm import mmm_config + + # Store original state + original_backend = mmm_config.get("plot.backend", "matplotlib") + original_use_v2 = mmm_config.get("plot.use_v2", False) + original_warnings = mmm_config.get("plot.show_warnings", True) + + try: + # Change all config values + mmm_config["plot.backend"] = "plotly" + mmm_config["plot.use_v2"] = True + mmm_config["plot.show_warnings"] = False + + # Verify they were changed + assert mmm_config["plot.backend"] == "plotly" + assert mmm_config["plot.use_v2"] is True + assert mmm_config["plot.show_warnings"] is False + + # Reset to defaults + mmm_config.reset() + + # Verify all values are back to defaults + assert mmm_config["plot.backend"] == "matplotlib" + assert mmm_config["plot.use_v2"] is False + assert mmm_config["plot.show_warnings"] is True + + # Verify reset clears any invalid keys that were set + mmm_config["invalid.key"] = "test" + assert "invalid.key" in mmm_config + mmm_config.reset() + assert "invalid.key" not in mmm_config + finally: + # Restore original state + mmm_config["plot.backend"] = original_backend + mmm_config["plot.use_v2"] = original_use_v2 + mmm_config["plot.show_warnings"] = original_warnings + + def test_invalid_backend_warns_but_allows_setting(self): + """Test that setting an invalid backend warns but still sets the value.""" + from pymc_marketing.mmm import mmm_config + + original_backend = mmm_config.get("plot.backend", "matplotlib") + + try: + # Try to set an invalid backend + with pytest.warns(UserWarning, match="Invalid backend"): + mmm_config["plot.backend"] = "invalid_backend" + + # Verify the warning message contains valid backends + with pytest.warns(UserWarning) as warning_list: + mmm_config["plot.backend"] = "another_invalid" + + warning_msg = str(warning_list[0].message) + assert "Invalid backend" in warning_msg + assert "another_invalid" in warning_msg + assert ( + "matplotlib" in warning_msg + or "plotly" in warning_msg + or "bokeh" in warning_msg + ) + + # Verify the invalid backend was still set (allows setting but warns) + assert mmm_config["plot.backend"] == "another_invalid" + finally: + mmm_config["plot.backend"] = original_backend + + def test_valid_backends_do_not_warn(self): + """Test that setting valid backend values does not warn.""" + from pymc_marketing.mmm import mmm_config + + original_backend = mmm_config.get("plot.backend", "matplotlib") + + try: + # Setting valid backends should not warn + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + mmm_config["plot.backend"] = "matplotlib" + mmm_config["plot.backend"] = "plotly" + mmm_config["plot.backend"] = "bokeh" + + # Verify values were set + assert mmm_config["plot.backend"] == "bokeh" + finally: + mmm_config["plot.backend"] = original_backend From 93eeb63cd93a164245f6b7539a4943cb45404c9b Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Thu, 27 Nov 2025 12:01:46 -0500 Subject: [PATCH 24/29] addressed 2 comments --- pymc_marketing/mmm/plot.py | 314 ++++++++++++++++++++++--------------- 1 file changed, 191 insertions(+), 123 deletions(-) diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 12609ada6..a97bf1109 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -178,6 +178,8 @@ from arviz_base.labels import DimCoordLabeller, NoVarLabeller, mix_labellers from arviz_plots import PlotCollection +from pymc_marketing.mmm.config import mmm_config + __all__ = ["MMMPlotSuite"] WIDTH_PER_COL: float = 10.0 @@ -287,8 +289,6 @@ def _dim_list_handler( def _resolve_backend(self, backend: str | None) -> str: """Resolve backend parameter to actual backend string.""" - from pymc_marketing.mmm.config import mmm_config - return backend or mmm_config["plot.backend"] def _get_data_or_fallback( @@ -397,37 +397,50 @@ def posterior_predictive( -------- Basic usage: - >>> mmm.sample_posterior_predictive(X) - >>> pc = mmm.plot.posterior_predictive() - >>> pc.show() + .. code-block:: python + + mmm.sample_posterior_predictive(X) + pc = mmm.plot.posterior_predictive() + pc.show() Plot with different HDI probability: - >>> pc = mmm.plot.posterior_predictive(hdi_prob=0.94) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.posterior_predictive(hdi_prob=0.94) + pc.show() Save to file: - >>> pc = mmm.plot.posterior_predictive() - >>> pc.save("posterior_predictive.png") + .. code-block:: python + + pc = mmm.plot.posterior_predictive() + pc.save("posterior_predictive.png") Use different backend: - >>> pc = mmm.plot.posterior_predictive(backend="plotly") - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.posterior_predictive(backend="plotly") + pc.show() Provide explicit data: - >>> external_pp = xr.Dataset(...) # Custom posterior predictive - >>> pc = mmm.plot.posterior_predictive(idata=external_pp) - >>> pc.show() + .. code-block:: python + + external_pp = xr.Dataset(...) # Custom posterior predictive + pc = mmm.plot.posterior_predictive(idata=external_pp) + pc.show() Direct instantiation pattern: - >>> from pymc_marketing.mmm.plot import MMMPlotSuite - >>> mps = MMMPlotSuite(custom_idata) - >>> pc = mps.posterior_predictive() - >>> pc.show() + .. code-block:: python + + from pymc_marketing.mmm.plot import MMMPlotSuite + + mps = MMMPlotSuite(custom_idata) + pc = mps.posterior_predictive() + pc.show() """ if not 0 < hdi_prob < 1: raise ValueError("HDI probability must be between 0 and 1.") @@ -569,52 +582,67 @@ def contributions_over_time( -------- Basic usage - plot channel contributions: - >>> mmm.fit(X, y) - >>> pc = mmm.plot.contributions_over_time(var=["channel_contribution"]) - >>> pc.show() + .. code-block:: python + + mmm.fit(X, y) + pc = mmm.plot.contributions_over_time(var=["channel_contribution"]) + pc.show() Plot multiple variables together: - >>> pc = mmm.plot.contributions_over_time( - ... var=["channel_contribution", "intercept"] - ... ) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.contributions_over_time( + var=["channel_contribution", "intercept"] + ) + pc.show() Filter by dimension: - >>> pc = mmm.plot.contributions_over_time( - ... var=["channel_contribution"], dims={"geo": "US"} - ... ) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.contributions_over_time( + var=["channel_contribution"], dims={"geo": "US"} + ) + pc.show() Filter with multiple dimension values: - >>> pc = mmm.plot.contributions_over_time( - ... var=["channel_contribution"], dims={"geo": ["US", "UK"]} - ... ) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.contributions_over_time( + var=["channel_contribution"], dims={"geo": ["US", "UK"]} + ) + pc.show() Use different backend: - >>> pc = mmm.plot.contributions_over_time( - ... var=["channel_contribution"], backend="plotly" - ... ) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.contributions_over_time( + var=["channel_contribution"], backend="plotly" + ) + pc.show() Provide explicit data (option 1 - via data parameter): - >>> custom_posterior = xr.Dataset(...) - >>> pc = mmm.plot.contributions_over_time( - ... var=["my_contribution"], data=custom_posterior - ... ) - >>> pc.show() + .. code-block:: python + + custom_posterior = xr.Dataset(...) + pc = mmm.plot.contributions_over_time( + var=["my_contribution"], data=custom_posterior + ) + pc.show() Provide explicit data (option 2 - direct instantiation): - >>> from pymc_marketing.mmm.plot import MMMPlotSuite - >>> mps = MMMPlotSuite(custom_idata) - >>> pc = mps.contributions_over_time(var=["my_contribution"]) - >>> pc.show() + .. code-block:: python + + from pymc_marketing.mmm.plot import MMMPlotSuite + + mps = MMMPlotSuite(custom_idata) + pc = mps.contributions_over_time(var=["my_contribution"]) + pc.show() """ if not 0 < hdi_prob < 1: raise ValueError("HDI probability must be between 0 and 1.") @@ -786,34 +814,44 @@ def saturation_scatterplot( -------- Basic usage (scaled space): - >>> mmm.fit(X, y) - >>> pc = mmm.plot.saturation_scatterplot() - >>> pc.show() + .. code-block:: python + + mmm.fit(X, y) + pc = mmm.plot.saturation_scatterplot() + pc.show() Plot in original scale: - >>> mmm.add_original_scale_contribution_variable(var=["channel_contribution"]) - >>> pc = mmm.plot.saturation_scatterplot(original_scale=True) - >>> pc.show() + .. code-block:: python + + mmm.add_original_scale_contribution_variable(var=["channel_contribution"]) + pc = mmm.plot.saturation_scatterplot(original_scale=True) + pc.show() Filter by dimension: - >>> pc = mmm.plot.saturation_scatterplot(dims={"geo": "US"}) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.saturation_scatterplot(dims={"geo": "US"}) + pc.show() Use different backend: - >>> pc = mmm.plot.saturation_scatterplot(backend="plotly") - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.saturation_scatterplot(backend="plotly") + pc.show() Provide explicit data: - >>> custom_constant = xr.Dataset(...) - >>> custom_posterior = xr.Dataset(...) - >>> pc = mmm.plot.saturation_scatterplot( - ... constant_data=custom_constant, posterior_data=custom_posterior - ... ) - >>> pc.show() + .. code-block:: python + + custom_constant = xr.Dataset(...) + custom_posterior = xr.Dataset(...) + pc = mmm.plot.saturation_scatterplot( + constant_data=custom_constant, posterior_data=custom_posterior + ) + pc.show() """ # Resolve backend backend = self._resolve_backend(backend) @@ -1000,33 +1038,42 @@ def saturation_curves( -------- Generate and plot saturation curves: - >>> # Generate curves using saturation transformation - >>> curve = mmm.saturation.sample_curve( - ... idata=mmm.idata.posterior[["saturation_beta", "saturation_lam"]], - ... max_value=2.0, - ... ) - >>> pc = mmm.plot.saturation_curves(curve) - >>> pc.show() + .. code-block:: python + + # Generate curves using saturation transformation + curve = mmm.saturation.sample_curve( + idata=mmm.idata.posterior[["saturation_beta", "saturation_lam"]], + max_value=2.0, + ) + pc = mmm.plot.saturation_curves(curve) + pc.show() Add HDI bands: - >>> pc = mmm.plot.saturation_curves(curve, hdi_probs=[0.5, 0.94]) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.saturation_curves(curve, hdi_probs=[0.5, 0.94]) + pc.show() Original scale with custom seed: - >>> import numpy as np - >>> rng = np.random.default_rng(42) - >>> mmm.add_original_scale_contribution_variable(var=["channel_contribution"]) - >>> pc = mmm.plot.saturation_curves( - ... curve, original_scale=True, n_samples=15, random_seed=rng - ... ) - >>> pc.show() + .. code-block:: python + + import numpy as np + + rng = np.random.default_rng(42) + mmm.add_original_scale_contribution_variable(var=["channel_contribution"]) + pc = mmm.plot.saturation_curves( + curve, original_scale=True, n_samples=15, random_seed=rng + ) + pc.show() Filter by dimension: - >>> pc = mmm.plot.saturation_curves(curve, dims={"geo": "US"}) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.saturation_curves(curve, dims={"geo": "US"}) + pc.show() """ # Get constant_data and posterior_data with fallback constant_data = self._get_data_or_fallback( @@ -1201,25 +1248,31 @@ def budget_allocation_roas( -------- Basic usage with budget optimization results: - >>> allocation_results = mmm.allocate_budget_to_maximize_response( - ... total_budget=100_000, budget_bounds={"lower": 0.5, "upper": 2.0} - ... ) - >>> pc = mmm.plot.budget_allocation_roas(allocation_results) - >>> pc.show() + .. code-block:: python + + allocation_results = mmm.allocate_budget_to_maximize_response( + total_budget=100_000, budget_bounds={"lower": 0.5, "upper": 2.0} + ) + pc = mmm.plot.budget_allocation_roas(allocation_results) + pc.show() Group by geography to compare ROI across regions: - >>> pc = mmm.plot.budget_allocation_roas( - ... allocation_results, dims_to_group_by="geo" - ... ) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.budget_allocation_roas( + allocation_results, dims_to_group_by="geo" + ) + pc.show() Filter and group: - >>> pc = mmm.plot.budget_allocation_roas( - ... allocation_results, dims={"segment": "premium"}, dims_to_group_by="geo" - ... ) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.budget_allocation_roas( + allocation_results, dims={"segment": "premium"}, dims_to_group_by="geo" + ) + pc.show() """ # Get the channels from samples if "channel" not in samples.dims: @@ -1361,27 +1414,33 @@ def allocated_contribution_by_channel_over_time( -------- Basic usage with budget optimization results: - >>> allocation_results = mmm.allocate_budget_to_maximize_response( - ... total_budget=100_000, budget_bounds={"lower": 0.5, "upper": 2.0} - ... ) - >>> pc = mmm.plot.allocated_contribution_by_channel_over_time( - ... allocation_results - ... ) - >>> pc.show() + .. code-block:: python + + allocation_results = mmm.allocate_budget_to_maximize_response( + total_budget=100_000, budget_bounds={"lower": 0.5, "upper": 2.0} + ) + pc = mmm.plot.allocated_contribution_by_channel_over_time( + allocation_results + ) + pc.show() Custom HDI probability: - >>> pc = mmm.plot.allocated_contribution_by_channel_over_time( - ... allocation_results, hdi_prob=0.94 - ... ) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.allocated_contribution_by_channel_over_time( + allocation_results, hdi_prob=0.94 + ) + pc.show() Use different backend: - >>> pc = mmm.plot.allocated_contribution_by_channel_over_time( - ... allocation_results, backend="plotly" - ... ) - >>> pc.show() + .. code-block:: python + + pc = mmm.plot.allocated_contribution_by_channel_over_time( + allocation_results, backend="plotly" + ) + pc.show() """ # Check for expected dimensions and variables if "channel" not in samples.dims: @@ -2080,16 +2139,22 @@ def budget_allocation(self, *args, **kwargs): 2. **To use the old method**: Switch to legacy suite: - >>> from pymc_marketing.mmm import mmm_config - >>> mmm_config["plot.use_v2"] = False - >>> mmm.plot.budget_allocation(samples) + .. code-block:: python + + from pymc_marketing.mmm import mmm_config + + mmm_config["plot.use_v2"] = False + mmm.plot.budget_allocation(samples) 3. **Custom implementation**: Create bar chart using samples data: - >>> import matplotlib.pyplot as plt - >>> channel_contrib = samples["channel_contribution"].mean(...) - >>> allocated_spend = samples["allocation"] - >>> # Create custom bar chart with matplotlib + .. code-block:: python + + import matplotlib.pyplot as plt + + channel_contrib = samples["channel_contribution"].mean(...) + allocated_spend = samples["allocation"] + # Create custom bar chart with matplotlib See Also -------- @@ -2099,14 +2164,17 @@ def budget_allocation(self, *args, **kwargs): -------- Use legacy suite temporarily: - >>> from pymc_marketing.mmm import mmm_config - >>> original = mmm_config.get("plot.use_v2") - >>> try: - ... mmm_config["plot.use_v2"] = False - ... fig, ax = mmm.plot.budget_allocation(samples) - ... fig.savefig("budget.png") - ... finally: - ... mmm_config["plot.use_v2"] = original + .. code-block:: python + + from pymc_marketing.mmm import mmm_config + + original = mmm_config.get("plot.use_v2") + try: + mmm_config["plot.use_v2"] = False + fig, ax = mmm.plot.budget_allocation(samples) + fig.savefig("budget.png") + finally: + mmm_config["plot.use_v2"] = original """ raise NotImplementedError( "budget_allocation() was removed in MMMPlotSuite v2.\n\n" From cfea8e5cf89c4abde58bba39b963e12492b8e29c Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Thu, 27 Nov 2025 12:14:59 -0500 Subject: [PATCH 25/29] remove versionadded --- pymc_marketing/mmm/plot.py | 63 -------------------------------------- 1 file changed, 63 deletions(-) diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index a97bf1109..34d5239d4 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -344,9 +344,6 @@ def posterior_predictive( line and highest density interval (HDI) bands. Useful for checking model fit and understanding prediction uncertainty. - .. versionadded:: 0.18.0 - New arviz_plots-based implementation supporting multiple backends. - Parameters ---------- var : str, optional @@ -516,9 +513,6 @@ def contributions_over_time( HDI bands. Useful for understanding channel contributions, intercepts, or other time-varying effects in your model. - .. versionadded:: 0.18.0 - New arviz_plots-based implementation supporting multiple backends. - Parameters ---------- var : list of str @@ -527,10 +521,6 @@ def contributions_over_time( data : xr.Dataset, optional Dataset containing posterior data with variables in `var`. If None, uses self.idata.posterior. - - .. versionadded:: 0.18.0 - Added data parameter for explicit data passing. - This parameter allows: - Testing with mock data without modifying self.idata - Plotting external results not stored in self.idata @@ -733,9 +723,6 @@ def saturation_scatterplot( contributions (Y-axis), one subplot per channel. Useful for understanding the saturation behavior and diminishing returns of each marketing channel. - .. versionadded:: 0.18.0 - New arviz_plots-based implementation supporting multiple backends. - Parameters ---------- original_scale : bool, default False @@ -748,10 +735,6 @@ def saturation_scatterplot( - 'target_scale': Target scaling factor (if original_scale=True) If None, uses self.idata.constant_data. - - .. versionadded:: 0.18.0 - Added constant_data parameter for explicit data passing. - This parameter allows: - Testing with mock constant data - Plotting with alternative scaling factors @@ -760,10 +743,6 @@ def saturation_scatterplot( Dataset containing posterior group with channel contribution variables. Must contain 'channel_contribution' or 'channel_contribution_original_scale'. If None, uses self.idata.posterior. - - .. versionadded:: 0.18.0 - Added posterior_data parameter for explicit data passing. - This parameter allows: - Testing with mock posterior samples - Plotting external inference results @@ -957,9 +936,6 @@ def saturation_curves( - HDI bands showing uncertainty - Smooth saturation curves over the scatter plot - .. versionadded:: 0.18.0 - New arviz_plots-based implementation supporting multiple backends. - Parameters ---------- curve : xr.DataArray @@ -974,17 +950,9 @@ def saturation_curves( If True, requires channel_contribution_original_scale in posterior. constant_data : xr.Dataset, optional Dataset containing constant_data group. If None, uses self.idata.constant_data. - - .. versionadded:: 0.18.0 - Added constant_data parameter for explicit data passing. - This parameter allows testing with mock data and plotting alternative scenarios. posterior_data : xr.Dataset, optional Dataset containing posterior group. If None, uses self.idata.posterior. - - .. versionadded:: 0.18.0 - Added posterior_data parameter for explicit data passing. - This parameter allows testing with mock posterior samples and comparing model fits. n_samples : int, default 10 Number of sample curves to draw per subplot. @@ -1186,10 +1154,6 @@ def budget_allocation_roas( allocation. Useful for comparing ROI across channels and understanding optimization trade-offs. - .. versionadded:: 0.18.0 - New method in MMMPlotSuite v2. This is different from the legacy - budget_allocation() method which showed bar charts. - Parameters ---------- samples : xr.Dataset @@ -1358,9 +1322,6 @@ def allocated_contribution_by_channel_over_time( optimized budget allocation. Shows mean contribution lines per channel with HDI uncertainty bands. - .. versionadded:: 0.18.0 - New arviz_plots-based implementation supporting multiple backends. - Parameters ---------- samples : xr.Dataset @@ -1530,9 +1491,6 @@ def _sensitivity_analysis_plot( sensitivity analysis visualizations. Public methods (sensitivity_analysis, uplift_curve, marginal_curve) handle data retrieval and call this helper. - .. versionadded:: 0.18.0 - New arviz_plots-based implementation supporting multiple backends. - Parameters ---------- data : xr.DataArray or xr.Dataset @@ -1671,9 +1629,6 @@ def sensitivity_analysis( (e.g., channel spend) are varied. Shows mean response line and HDI bands across sweep values. - .. versionadded:: 0.18.0 - New arviz_plots-based implementation supporting multiple backends. - Parameters ---------- data : xr.DataArray or xr.Dataset, optional @@ -1683,10 +1638,6 @@ def sensitivity_analysis( If Dataset, should contain 'x' variable. If None, uses self.idata.sensitivity_analysis. - - .. versionadded:: 0.18.0 - Added data parameter for explicit data passing. - This parameter allows: - Testing with mock sensitivity analysis results - Plotting external sweep results @@ -1802,9 +1753,6 @@ def uplift_curve( contributions) as inputs are varied, compared to a reference point. Shows mean uplift line and HDI bands. - .. versionadded:: 0.18.0 - New arviz_plots-based implementation supporting multiple backends. - Parameters ---------- data : xr.DataArray or xr.Dataset, optional @@ -1814,10 +1762,6 @@ def uplift_curve( Must be precomputed using: ``SensitivityAnalysis.compute_uplift_curve_respect_to_base(...)`` - - .. versionadded:: 0.18.0 - Added data parameter for explicit data passing. - This parameter allows: - Testing with mock uplift curve data - Plotting externally computed uplift curves @@ -1964,9 +1908,6 @@ def marginal_curve( with respect to inputs. Shows how much output changes per unit change in input at each sweep value. - .. versionadded:: 0.18.0 - New arviz_plots-based implementation supporting multiple backends. - Parameters ---------- data : xr.DataArray or xr.Dataset, optional @@ -1976,10 +1917,6 @@ def marginal_curve( Must be precomputed using: ``SensitivityAnalysis.compute_marginal_effects(...)`` - - .. versionadded:: 0.18.0 - Added data parameter for explicit data passing. - This parameter allows: - Testing with mock marginal effects data - Plotting externally computed marginal effects From 3c82941d93a1210e90c78e746a51a4dd22acb7bf Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Thu, 27 Nov 2025 17:37:51 -0500 Subject: [PATCH 26/29] remove tests --- tests/mmm/test_legacy_plot_imports.py | 33 --------------------------- 1 file changed, 33 deletions(-) delete mode 100644 tests/mmm/test_legacy_plot_imports.py diff --git a/tests/mmm/test_legacy_plot_imports.py b/tests/mmm/test_legacy_plot_imports.py deleted file mode 100644 index 7b80556ff..000000000 --- a/tests/mmm/test_legacy_plot_imports.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2022 - 2025 The PyMC Labs Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for legacy plot module renaming.""" - -import pytest - - -def test_legacy_plot_module_exists(): - """Test that legacy_plot module exists and can be imported.""" - try: - from pymc_marketing.mmm import legacy_plot - - assert hasattr(legacy_plot, "LegacyMMMPlotSuite") - except ImportError as e: - pytest.fail(f"Failed to import legacy_plot: {e}") - - -def test_legacy_class_name(): - """Test that legacy suite class is named LegacyMMMPlotSuite.""" - from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite - - assert LegacyMMMPlotSuite.__name__ == "LegacyMMMPlotSuite" From 153583e8174d1038334c5ca88ae1f6b8bba90721 Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Thu, 27 Nov 2025 17:45:01 -0500 Subject: [PATCH 27/29] update docstrings --- pymc_marketing/mmm/config.py | 61 +++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/pymc_marketing/mmm/config.py b/pymc_marketing/mmm/config.py index 487545f89..375dbe997 100644 --- a/pymc_marketing/mmm/config.py +++ b/pymc_marketing/mmm/config.py @@ -63,43 +63,54 @@ class MMMConfig(dict): -------- Set plotting backend globally: - >>> from pymc_marketing.mmm import mmm_config - >>> mmm_config["plot.backend"] = "plotly" - >>> # All plots now use plotly by default - >>> mmm = MMM(...) - >>> mmm.fit(X, y) - >>> pc = mmm.plot.posterior_predictive() # Uses plotly - >>> pc.show() + .. code-block:: python + + from pymc_marketing.mmm import mmm_config + + mmm_config["plot.backend"] = "plotly" + # All plots now use plotly by default + mmm = MMM(...) + mmm.fit(X, y) + pc = mmm.plot.posterior_predictive() # Uses plotly + pc.show() Enable new plotting suite (v2): - >>> mmm_config["plot.use_v2"] = True - >>> # Now using arviz_plots-based multi-backend suite - >>> mmm = MMM(...) - >>> mmm.fit(X, y) - >>> pc = mmm.plot.contributions_over_time(var=["intercept"]) - >>> pc.show() + .. code-block:: python + + mmm_config["plot.use_v2"] = True + # Now using arviz_plots-based multi-backend suite + mmm = MMM(...) + mmm.fit(X, y) + pc = mmm.plot.contributions_over_time(var=["intercept"]) + pc.show() Suppress warnings: - >>> mmm_config["plot.show_warnings"] = False + .. code-block:: python + + mmm_config["plot.show_warnings"] = False Reset to defaults: - >>> mmm_config.reset() - >>> mmm_config["plot.backend"] - 'matplotlib' + .. code-block:: python + + mmm_config.reset() + mmm_config["plot.backend"] + # 'matplotlib' Context manager pattern for temporary config changes: - >>> original = mmm_config["plot.backend"] - >>> try: - ... mmm_config["plot.backend"] = "plotly" - ... # Use plotly for this section - ... pc = mmm.plot.posterior_predictive() - ... pc.show() - ... finally: - ... mmm_config["plot.backend"] = original + .. code-block:: python + + original = mmm_config["plot.backend"] + try: + mmm_config["plot.backend"] = "plotly" + # Use plotly for this section + pc = mmm.plot.posterior_predictive() + pc.show() + finally: + mmm_config["plot.backend"] = original See Also -------- From 00c269d3ee2f2b50b92ec55840789f3c2112e19c Mon Sep 17 00:00:00 2001 From: Imri Sofer Date: Thu, 27 Nov 2025 17:53:12 -0500 Subject: [PATCH 28/29] changes name of MMMConfig to MMMPlotConfig --- pymc_marketing/mmm/__init__.py | 4 +- pymc_marketing/mmm/config.py | 22 +- pymc_marketing/mmm/legacy_plot.py | 6 +- pymc_marketing/mmm/multidimensional.py | 22 +- pymc_marketing/mmm/plot.py | 28 +-- tests/mmm/test_plot.py | 16 +- tests/mmm/test_plot_compatibility.py | 280 ++++++++++++------------- 7 files changed, 189 insertions(+), 189 deletions(-) diff --git a/pymc_marketing/mmm/__init__.py b/pymc_marketing/mmm/__init__.py index 98a22212e..b43abd980 100644 --- a/pymc_marketing/mmm/__init__.py +++ b/pymc_marketing/mmm/__init__.py @@ -38,7 +38,7 @@ TanhSaturationBaselined, saturation_from_dict, ) -from pymc_marketing.mmm.config import mmm_config +from pymc_marketing.mmm.config import mmm_plot_config from pymc_marketing.mmm.fourier import MonthlyFourier, WeeklyFourier, YearlyFourier from pymc_marketing.mmm.hsgp import ( HSGP, @@ -110,7 +110,7 @@ "create_eta_prior", "create_m_and_L_recommendations", "mmm", - "mmm_config", + "mmm_plot_config", "preprocessing", "preprocessing_method_X", "preprocessing_method_y", diff --git a/pymc_marketing/mmm/config.py b/pymc_marketing/mmm/config.py index 375dbe997..b67a16dcc 100644 --- a/pymc_marketing/mmm/config.py +++ b/pymc_marketing/mmm/config.py @@ -18,7 +18,7 @@ VALID_BACKENDS = {"matplotlib", "plotly", "bokeh"} -class MMMConfig(dict): +class MMMPlotConfig(dict): """Configuration dictionary for MMM plotting settings. Global configuration object that controls MMM plotting behavior including @@ -65,9 +65,9 @@ class MMMConfig(dict): .. code-block:: python - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - mmm_config["plot.backend"] = "plotly" + mmm_plot_config["plot.backend"] = "plotly" # All plots now use plotly by default mmm = MMM(...) mmm.fit(X, y) @@ -78,7 +78,7 @@ class MMMConfig(dict): .. code-block:: python - mmm_config["plot.use_v2"] = True + mmm_plot_config["plot.use_v2"] = True # Now using arviz_plots-based multi-backend suite mmm = MMM(...) mmm.fit(X, y) @@ -89,28 +89,28 @@ class MMMConfig(dict): .. code-block:: python - mmm_config["plot.show_warnings"] = False + mmm_plot_config["plot.show_warnings"] = False Reset to defaults: .. code-block:: python - mmm_config.reset() - mmm_config["plot.backend"] + mmm_plot_config.reset() + mmm_plot_config["plot.backend"] # 'matplotlib' Context manager pattern for temporary config changes: .. code-block:: python - original = mmm_config["plot.backend"] + original = mmm_plot_config["plot.backend"] try: - mmm_config["plot.backend"] = "plotly" + mmm_plot_config["plot.backend"] = "plotly" # Use plotly for this section pc = mmm.plot.posterior_predictive() pc.show() finally: - mmm_config["plot.backend"] = original + mmm_plot_config["plot.backend"] = original See Also -------- @@ -164,4 +164,4 @@ def reset(self): # Global config instance -mmm_config = MMMConfig() +mmm_plot_config = MMMPlotConfig() diff --git a/pymc_marketing/mmm/legacy_plot.py b/pymc_marketing/mmm/legacy_plot.py index 590d71a03..bcf6689d8 100644 --- a/pymc_marketing/mmm/legacy_plot.py +++ b/pymc_marketing/mmm/legacy_plot.py @@ -190,7 +190,7 @@ class LegacyMMMPlotSuite: .. deprecated:: 0.18.0 This class will be removed in v0.20.0. Use MMMPlotSuite with - mmm_config["plot.use_v2"] = True for the new arviz_plots-based suite. + mmm_plot_config["plot.use_v2"] = True for the new arviz_plots-based suite. This class is maintained for backward compatibility but will be removed in a future release. The new MMMPlotSuite supports multiple backends @@ -379,9 +379,9 @@ def _dim_list_handler( def _resolve_backend(self, backend: str | None) -> str: """Resolve backend parameter to actual backend string.""" - from pymc_marketing.mmm.config import mmm_config + from pymc_marketing.mmm.config import mmm_plot_config - return backend or mmm_config["plot.backend"] + return backend or mmm_plot_config["plot.backend"] # ------------------------------------------------------------------------ # Main Plotting Methods diff --git a/pymc_marketing/mmm/multidimensional.py b/pymc_marketing/mmm/multidimensional.py index 5e7cb4d5d..faf06d239 100644 --- a/pymc_marketing/mmm/multidimensional.py +++ b/pymc_marketing/mmm/multidimensional.py @@ -183,7 +183,7 @@ SaturationTransformation, saturation_from_dict, ) -from pymc_marketing.mmm.config import mmm_config +from pymc_marketing.mmm.config import mmm_plot_config from pymc_marketing.mmm.events import EventEffect from pymc_marketing.mmm.fourier import YearlyFourier from pymc_marketing.mmm.hsgp import HSGPBase, hsgp_from_dict @@ -621,26 +621,26 @@ def attrs_to_init_kwargs(cls, attrs: dict[str, str]) -> dict[str, Any]: def plot(self): """Use the MMMPlotSuite to plot the results. - The plot suite version is controlled by mmm_config["plot.use_v2"]: + The plot suite version is controlled by mmm_plot_config["plot.use_v2"]: - False (default): Uses legacy matplotlib-based suite (will be deprecated) - True: Uses new arviz_plots-based suite with multi-backend support .. versionchanged:: 0.18.0 - Added version control via mmm_config["plot.use_v2"]. + Added version control via mmm_plot_config["plot.use_v2"]. The legacy suite will be removed in v0.20.0. Examples -------- Use new plot suite: - >>> from pymc_marketing.mmm import mmm_config - >>> mmm_config["plot.use_v2"] = True + >>> from pymc_marketing.mmm import mmm_plot_config + >>> mmm_plot_config["plot.use_v2"] = True >>> pc = mmm.plot.posterior_predictive() >>> pc.show() Use legacy plot suite: - >>> mmm_config["plot.use_v2"] = False + >>> mmm_plot_config["plot.use_v2"] = False >>> fig, ax = mmm.plot.posterior_predictive() >>> fig.savefig("plot.png") @@ -653,19 +653,19 @@ def plot(self): self._validate_idata_exists() # Check version flag - if mmm_config.get("plot.use_v2", False): + if mmm_plot_config.get("plot.use_v2", False): return MMMPlotSuite(idata=self.idata) else: # Show deprecation warning for legacy suite - if mmm_config.get("plot.show_warnings", True): + if mmm_plot_config.get("plot.show_warnings", True): warnings.warn( "The current MMMPlotSuite will be deprecated in v0.20.0. " "The new version uses arviz_plots and supports multiple backends " "(matplotlib, plotly, bokeh). " "To use the new version:\n" - " from pymc_marketing.mmm import mmm_config\n" - " mmm_config['plot.use_v2'] = True\n" - "To suppress this warning: mmm_config['plot.show_warnings'] = False\n" + " from pymc_marketing.mmm import mmm_plot_config\n" + " mmm_plot_config['plot.use_v2'] = True\n" + "To suppress this warning: mmm_plot_config['plot.show_warnings'] = False\n" "See migration guide: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html", FutureWarning, stacklevel=2, diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 34d5239d4..c8fbf8098 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -178,7 +178,7 @@ from arviz_base.labels import DimCoordLabeller, NoVarLabeller, mix_labellers from arviz_plots import PlotCollection -from pymc_marketing.mmm.config import mmm_config +from pymc_marketing.mmm.config import mmm_plot_config __all__ = ["MMMPlotSuite"] @@ -289,7 +289,7 @@ def _dim_list_handler( def _resolve_backend(self, backend: str | None) -> str: """Resolve backend parameter to actual backend string.""" - return backend or mmm_config["plot.backend"] + return backend or mmm_plot_config["plot.backend"] def _get_data_or_fallback( self, @@ -360,7 +360,7 @@ def posterior_predictive( Probability mass for HDI interval (between 0 and 1). backend : str, optional Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". - If None, uses global config via mmm_config["plot.backend"]. + If None, uses global config via mmm_plot_config["plot.backend"]. Default is "matplotlib". Returns @@ -536,7 +536,7 @@ def contributions_over_time( If provided, only the selected slice(s) will be plotted. backend : str, optional Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". - If None, uses global config via mmm_config["plot.backend"]. + If None, uses global config via mmm_plot_config["plot.backend"]. Default is "matplotlib". Returns @@ -755,7 +755,7 @@ def saturation_scatterplot( If provided, only the selected slice(s) will be plotted. backend : str, optional Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". - If None, uses global config via mmm_config["plot.backend"]. + If None, uses global config via mmm_plot_config["plot.backend"]. Default is "matplotlib". Returns @@ -972,7 +972,7 @@ def saturation_curves( If provided, only the selected slice(s) will be plotted. backend : str, optional Plotting backend to use. Options: "matplotlib", "plotly", "bokeh". - If None, uses global config via mmm_config["plot.backend"]. + If None, uses global config via mmm_plot_config["plot.backend"]. Default is "matplotlib". Returns @@ -1206,7 +1206,7 @@ def budget_allocation_roas( - **New method** (this): Shows ROI distributions (KDE plots) - **Legacy method**: Shows bar charts comparing spend vs contributions - To use the legacy method, set: ``mmm_config["plot.use_v2"] = False`` + To use the legacy method, set: ``mmm_plot_config["plot.use_v2"] = False`` Examples -------- @@ -2078,9 +2078,9 @@ def budget_allocation(self, *args, **kwargs): .. code-block:: python - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - mmm_config["plot.use_v2"] = False + mmm_plot_config["plot.use_v2"] = False mmm.plot.budget_allocation(samples) 3. **Custom implementation**: Create bar chart using samples data: @@ -2103,22 +2103,22 @@ def budget_allocation(self, *args, **kwargs): .. code-block:: python - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.use_v2") + original = mmm_plot_config.get("plot.use_v2") try: - mmm_config["plot.use_v2"] = False + mmm_plot_config["plot.use_v2"] = False fig, ax = mmm.plot.budget_allocation(samples) fig.savefig("budget.png") finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original """ raise NotImplementedError( "budget_allocation() was removed in MMMPlotSuite v2.\n\n" "The new arviz_plots-based implementation doesn't support this chart type.\n\n" "Alternatives:\n" " 1. For ROI distributions: use budget_allocation_roas()\n" - " 2. To use old method: set mmm_config['plot.use_v2'] = False\n" + " 2. To use old method: set mmm_plot_config['plot.use_v2'] = False\n" " 3. Implement custom bar chart using the samples data\n\n" "See documentation: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html#budget-allocation" ) diff --git a/tests/mmm/test_plot.py b/tests/mmm/test_plot.py index b0a94e48f..b0c8c885b 100644 --- a/tests/mmm/test_plot.py +++ b/tests/mmm/test_plot.py @@ -482,13 +482,13 @@ def test_backend_overrides_global_config(self, mock_suite): """Test that method backend parameter overrides global config.""" from arviz_plots import PlotCollection - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.backend", "matplotlib") + original = mmm_plot_config.get("plot.backend", "matplotlib") try: # Set global to matplotlib - mmm_config["plot.backend"] = "matplotlib" + mmm_plot_config["plot.backend"] = "matplotlib" # Override with plotly pc_plotly = mock_suite.contributions_over_time( @@ -501,19 +501,19 @@ def test_backend_overrides_global_config(self, mock_suite): assert isinstance(pc_default, PlotCollection) finally: - mmm_config["plot.backend"] = original + mmm_plot_config["plot.backend"] = original @pytest.mark.parametrize("config_backend", ["matplotlib", "plotly", "bokeh"]) def test_backend_parameter_none_uses_config(self, mock_suite, config_backend): """Test that backend=None uses global config.""" from arviz_plots import PlotCollection - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.backend", "matplotlib") + original = mmm_plot_config.get("plot.backend", "matplotlib") try: - mmm_config["plot.backend"] = config_backend + mmm_plot_config["plot.backend"] = config_backend pc = mock_suite.contributions_over_time( var=["intercept"], @@ -523,7 +523,7 @@ def test_backend_parameter_none_uses_config(self, mock_suite, config_backend): assert isinstance(pc, PlotCollection) finally: - mmm_config["plot.backend"] = original + mmm_plot_config["plot.backend"] = original def test_invalid_backend_raises_error(self, mock_suite): """Test that invalid backend raises an appropriate error.""" diff --git a/tests/mmm/test_plot_compatibility.py b/tests/mmm/test_plot_compatibility.py index 2f2da779c..1dbdaceeb 100644 --- a/tests/mmm/test_plot_compatibility.py +++ b/tests/mmm/test_plot_compatibility.py @@ -23,16 +23,16 @@ class TestVersionSwitching: - """Test that mmm_config['plot.use_v2'] controls which suite is returned.""" + """Test that mmm_plot_config['plot.use_v2'] controls which suite is returned.""" def test_use_v2_false_returns_legacy_suite(self, mock_mmm): """Test that use_v2=False returns LegacyMMMPlotSuite.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = False + mmm_plot_config["plot.use_v2"] = False with pytest.warns(FutureWarning, match="deprecated in v0.20.0"): plot_suite = mock_mmm.plot @@ -40,17 +40,17 @@ def test_use_v2_false_returns_legacy_suite(self, mock_mmm): assert isinstance(plot_suite, LegacyMMMPlotSuite) assert plot_suite.__class__.__name__ == "LegacyMMMPlotSuite" finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original def test_use_v2_true_returns_new_suite(self, mock_mmm): """Test that use_v2=True returns MMMPlotSuite.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite from pymc_marketing.mmm.plot import MMMPlotSuite - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = True + mmm_plot_config["plot.use_v2"] = True # Should not warn with warnings.catch_warnings(): @@ -61,16 +61,16 @@ def test_use_v2_true_returns_new_suite(self, mock_mmm): assert not isinstance(plot_suite, LegacyMMMPlotSuite) assert plot_suite.__class__.__name__ == "MMMPlotSuite" finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original def test_default_is_legacy_suite(self, mock_mmm): """Test that default behavior uses legacy suite (backward compatible).""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite # Ensure default state - if "plot.use_v2" in mmm_config: - del mmm_config["plot.use_v2"] + if "plot.use_v2" in mmm_plot_config: + del mmm_plot_config["plot.use_v2"] with pytest.warns(FutureWarning): plot_suite = mock_mmm.plot @@ -79,13 +79,13 @@ def test_default_is_legacy_suite(self, mock_mmm): def test_config_flag_persists_across_calls(self, mock_mmm): """Test that setting config flag affects all subsequent calls.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config from pymc_marketing.mmm.plot import MMMPlotSuite - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: # Set once - mmm_config["plot.use_v2"] = True + mmm_plot_config["plot.use_v2"] = True # Multiple calls should all use new suite plot_suite1 = mock_mmm.plot @@ -96,18 +96,18 @@ def test_config_flag_persists_across_calls(self, mock_mmm): assert isinstance(plot_suite2, MMMPlotSuite) assert isinstance(plot_suite3, MMMPlotSuite) finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original def test_switching_between_v2_true_and_false(self, mock_mmm): """Test that switching from use_v2=True to False and back works correctly.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config from pymc_marketing.mmm.legacy_plot import LegacyMMMPlotSuite from pymc_marketing.mmm.plot import MMMPlotSuite - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: # Start with use_v2 = True - mmm_config["plot.use_v2"] = True + mmm_plot_config["plot.use_v2"] = True # Should return new suite without warnings with warnings.catch_warnings(): @@ -117,7 +117,7 @@ def test_switching_between_v2_true_and_false(self, mock_mmm): assert isinstance(plot_suite_v2, MMMPlotSuite) # Switch to use_v2 = False - mmm_config["plot.use_v2"] = False + mmm_plot_config["plot.use_v2"] = False # Should return legacy suite with deprecation warning with pytest.warns(FutureWarning, match="deprecated in v0.20.0"): @@ -126,7 +126,7 @@ def test_switching_between_v2_true_and_false(self, mock_mmm): assert isinstance(plot_suite_legacy, LegacyMMMPlotSuite) # Switch back to use_v2 = True - mmm_config["plot.use_v2"] = True + mmm_plot_config["plot.use_v2"] = True # Should return new suite again without warnings with warnings.catch_warnings(): @@ -135,7 +135,7 @@ def test_switching_between_v2_true_and_false(self, mock_mmm): assert isinstance(plot_suite_v2_again, MMMPlotSuite) finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original class TestDeprecationWarnings: @@ -143,33 +143,33 @@ class TestDeprecationWarnings: def test_deprecation_warning_shown_by_default(self, mock_mmm): """Test that deprecation warning is shown when using legacy suite.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original_use_v2 = mmm_config.get("plot.use_v2", False) - original_warnings = mmm_config.get("plot.show_warnings", True) + original_use_v2 = mmm_plot_config.get("plot.use_v2", False) + original_warnings = mmm_plot_config.get("plot.show_warnings", True) try: - mmm_config["plot.use_v2"] = False - mmm_config["plot.show_warnings"] = True + mmm_plot_config["plot.use_v2"] = False + mmm_plot_config["plot.show_warnings"] = True with pytest.warns(FutureWarning, match=r"deprecated in v0\.20\.0"): plot_suite = mock_mmm.plot assert plot_suite is not None finally: - mmm_config["plot.use_v2"] = original_use_v2 - mmm_config["plot.show_warnings"] = original_warnings + mmm_plot_config["plot.use_v2"] = original_use_v2 + mmm_plot_config["plot.show_warnings"] = original_warnings def test_deprecation_warning_suppressible(self, mock_mmm): """Test that deprecation warning can be suppressed.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original_use_v2 = mmm_config.get("plot.use_v2", False) - original_warnings = mmm_config.get("plot.show_warnings", True) + original_use_v2 = mmm_plot_config.get("plot.use_v2", False) + original_warnings = mmm_plot_config.get("plot.show_warnings", True) try: - mmm_config["plot.use_v2"] = False - mmm_config["plot.show_warnings"] = False + mmm_plot_config["plot.use_v2"] = False + mmm_plot_config["plot.show_warnings"] = False # Should not warn with warnings.catch_warnings(): @@ -178,17 +178,17 @@ def test_deprecation_warning_suppressible(self, mock_mmm): assert plot_suite is not None finally: - mmm_config["plot.use_v2"] = original_use_v2 - mmm_config["plot.show_warnings"] = original_warnings + mmm_plot_config["plot.use_v2"] = original_use_v2 + mmm_plot_config["plot.show_warnings"] = original_warnings def test_warning_message_includes_migration_info(self, mock_mmm): """Test that warning provides clear migration instructions.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original_use_v2 = mmm_config.get("plot.use_v2", False) + original_use_v2 = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = False + mmm_plot_config["plot.use_v2"] = False with pytest.warns(FutureWarning) as warning_list: _ = mock_mmm.plot @@ -204,16 +204,16 @@ def test_warning_message_includes_migration_info(self, mock_mmm): for word in ["migration", "guide", "documentation", "docs"] ), "Should reference migration guide" finally: - mmm_config["plot.use_v2"] = original_use_v2 + mmm_plot_config["plot.use_v2"] = original_use_v2 def test_no_warning_when_using_new_suite(self, mock_mmm): """Test that no warning shown when using new suite.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = True + mmm_plot_config["plot.use_v2"] = True with warnings.catch_warnings(): warnings.simplefilter("error") @@ -221,7 +221,7 @@ def test_no_warning_when_using_new_suite(self, mock_mmm): assert plot_suite is not None finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original class TestReturnTypeCompatibility: @@ -229,12 +229,12 @@ class TestReturnTypeCompatibility: def test_legacy_suite_returns_tuple(self, mock_mmm_fitted): """Test legacy suite returns (Figure, Axes) tuple.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = False + mmm_plot_config["plot.use_v2"] = False with pytest.warns(FutureWarning): plot_suite = mock_mmm_fitted.plot @@ -252,16 +252,16 @@ def test_legacy_suite_returns_tuple(self, mock_mmm_fitted): else: assert isinstance(result[1], Axes) finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original def test_new_suite_returns_plot_collection(self, mock_mmm_fitted): """Test new suite returns PlotCollection.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = True + mmm_plot_config["plot.use_v2"] = True plot_suite = mock_mmm_fitted.plot result = plot_suite.posterior_predictive() @@ -274,17 +274,17 @@ def test_new_suite_returns_plot_collection(self, mock_mmm_fitted): ) assert hasattr(result, "show"), "PlotCollection should have show method" finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original def test_both_suites_produce_valid_plots(self, mock_mmm_fitted): """Test that both suites can successfully create plots.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: # Legacy suite - mmm_config["plot.use_v2"] = False + mmm_plot_config["plot.use_v2"] = False with pytest.warns(FutureWarning): legacy_result = mock_mmm_fitted.plot.contributions_over_time( var=["intercept"] @@ -292,11 +292,11 @@ def test_both_suites_produce_valid_plots(self, mock_mmm_fitted): assert legacy_result is not None # New suite - mmm_config["plot.use_v2"] = True + mmm_plot_config["plot.use_v2"] = True new_result = mock_mmm_fitted.plot.contributions_over_time(var=["intercept"]) assert new_result is not None finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original class TestDeprecatedMethodRemoval: @@ -304,28 +304,28 @@ class TestDeprecatedMethodRemoval: def test_saturation_curves_scatter_removed_from_new_suite(self, mock_mmm_fitted): """Test saturation_curves_scatter removed from new MMMPlotSuite.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = True + mmm_plot_config["plot.use_v2"] = True plot_suite = mock_mmm_fitted.plot assert not hasattr(plot_suite, "saturation_curves_scatter"), ( "saturation_curves_scatter should not exist in new MMMPlotSuite" ) finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original def test_saturation_curves_scatter_exists_in_legacy_suite(self, mock_mmm_fitted): """Test saturation_curves_scatter still exists in LegacyMMMPlotSuite.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = False + mmm_plot_config["plot.use_v2"] = False with pytest.warns(FutureWarning): plot_suite = mock_mmm_fitted.plot @@ -334,7 +334,7 @@ def test_saturation_curves_scatter_exists_in_legacy_suite(self, mock_mmm_fitted) "saturation_curves_scatter should exist in LegacyMMMPlotSuite" ) finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original class TestMissingMethods: @@ -344,12 +344,12 @@ def test_budget_allocation_exists_in_legacy_suite( self, mock_mmm_fitted, mock_allocation_samples ): """Test that budget_allocation() works in legacy suite.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = False + mmm_plot_config["plot.use_v2"] = False with pytest.warns(FutureWarning): plot_suite = mock_mmm_fitted.plot @@ -359,31 +359,31 @@ def test_budget_allocation_exists_in_legacy_suite( assert isinstance(result, tuple) assert len(result) == 2 finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original def test_budget_allocation_raises_in_new_suite(self, mock_mmm_fitted): """Test that budget_allocation() raises helpful error in new suite.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = True + mmm_plot_config["plot.use_v2"] = True plot_suite = mock_mmm_fitted.plot with pytest.raises(NotImplementedError, match="removed in MMMPlotSuite v2"): plot_suite.budget_allocation(samples=None) finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original def test_budget_allocation_roas_exists_in_new_suite(self, mock_mmm_fitted): """Test that budget_allocation_roas() exists in new suite.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = True + mmm_plot_config["plot.use_v2"] = True plot_suite = mock_mmm_fitted.plot # Just check that the method exists (not AttributeError) @@ -394,16 +394,16 @@ def test_budget_allocation_roas_exists_in_new_suite(self, mock_mmm_fitted): "budget_allocation_roas should be callable" ) finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original def test_budget_allocation_roas_missing_in_legacy_suite(self, mock_mmm_fitted): """Test that budget_allocation_roas() doesn't exist in legacy suite.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original = mmm_config.get("plot.use_v2", False) + original = mmm_plot_config.get("plot.use_v2", False) try: - mmm_config["plot.use_v2"] = False + mmm_plot_config["plot.use_v2"] = False with pytest.warns(FutureWarning): plot_suite = mock_mmm_fitted.plot @@ -411,26 +411,26 @@ def test_budget_allocation_roas_missing_in_legacy_suite(self, mock_mmm_fitted): with pytest.raises(AttributeError): plot_suite.budget_allocation_roas(samples=None) finally: - mmm_config["plot.use_v2"] = original + mmm_plot_config["plot.use_v2"] = original class TestConfigValidation: - """Test MMMConfig key validation.""" + """Test MMMPlotConfig key validation.""" def test_invalid_key_warns_but_allows_setting(self): """Test that setting an invalid config key warns but still sets the value.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config # Store original state - original_invalid = mmm_config.get("invalid.key", None) + original_invalid = mmm_plot_config.get("invalid.key", None) try: # Try to set an invalid key with pytest.warns(UserWarning, match="Invalid config key"): - mmm_config["invalid.key"] = "some_value" + mmm_plot_config["invalid.key"] = "some_value" # Verify the warning message contains valid keys with pytest.warns(UserWarning) as warning_list: - mmm_config["another.invalid.key"] = "another_value" + mmm_plot_config["another.invalid.key"] = "another_value" warning_msg = str(warning_list[0].message) assert "Invalid config key" in warning_msg @@ -438,95 +438,95 @@ def test_invalid_key_warns_but_allows_setting(self): assert "plot.backend" in warning_msg or "plot.show_warnings" in warning_msg # Verify the invalid key was still set (allows setting but warns) - assert mmm_config["invalid.key"] == "some_value" - assert mmm_config["another.invalid.key"] == "another_value" + assert mmm_plot_config["invalid.key"] == "some_value" + assert mmm_plot_config["another.invalid.key"] == "another_value" finally: # Clean up invalid keys - if "invalid.key" in mmm_config: - del mmm_config["invalid.key"] - if "another.invalid.key" in mmm_config: - del mmm_config["another.invalid.key"] + if "invalid.key" in mmm_plot_config: + del mmm_plot_config["invalid.key"] + if "another.invalid.key" in mmm_plot_config: + del mmm_plot_config["another.invalid.key"] if original_invalid is not None: - mmm_config["invalid.key"] = original_invalid + mmm_plot_config["invalid.key"] = original_invalid def test_valid_keys_do_not_warn(self): """Test that setting valid config keys does not warn.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original_backend = mmm_config.get("plot.backend", "matplotlib") - original_use_v2 = mmm_config.get("plot.use_v2", False) - original_warnings = mmm_config.get("plot.show_warnings", True) + original_backend = mmm_plot_config.get("plot.backend", "matplotlib") + original_use_v2 = mmm_plot_config.get("plot.use_v2", False) + original_warnings = mmm_plot_config.get("plot.show_warnings", True) try: # Setting valid keys should not warn with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) - mmm_config["plot.backend"] = "plotly" - mmm_config["plot.use_v2"] = True - mmm_config["plot.show_warnings"] = False + mmm_plot_config["plot.backend"] = "plotly" + mmm_plot_config["plot.use_v2"] = True + mmm_plot_config["plot.show_warnings"] = False # Verify values were set - assert mmm_config["plot.backend"] == "plotly" - assert mmm_config["plot.use_v2"] is True - assert mmm_config["plot.show_warnings"] is False + assert mmm_plot_config["plot.backend"] == "plotly" + assert mmm_plot_config["plot.use_v2"] is True + assert mmm_plot_config["plot.show_warnings"] is False finally: - mmm_config["plot.backend"] = original_backend - mmm_config["plot.use_v2"] = original_use_v2 - mmm_config["plot.show_warnings"] = original_warnings + mmm_plot_config["plot.backend"] = original_backend + mmm_plot_config["plot.use_v2"] = original_use_v2 + mmm_plot_config["plot.show_warnings"] = original_warnings def test_reset_restores_defaults(self): """Test that reset() restores all configuration to default values.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config # Store original state - original_backend = mmm_config.get("plot.backend", "matplotlib") - original_use_v2 = mmm_config.get("plot.use_v2", False) - original_warnings = mmm_config.get("plot.show_warnings", True) + original_backend = mmm_plot_config.get("plot.backend", "matplotlib") + original_use_v2 = mmm_plot_config.get("plot.use_v2", False) + original_warnings = mmm_plot_config.get("plot.show_warnings", True) try: # Change all config values - mmm_config["plot.backend"] = "plotly" - mmm_config["plot.use_v2"] = True - mmm_config["plot.show_warnings"] = False + mmm_plot_config["plot.backend"] = "plotly" + mmm_plot_config["plot.use_v2"] = True + mmm_plot_config["plot.show_warnings"] = False # Verify they were changed - assert mmm_config["plot.backend"] == "plotly" - assert mmm_config["plot.use_v2"] is True - assert mmm_config["plot.show_warnings"] is False + assert mmm_plot_config["plot.backend"] == "plotly" + assert mmm_plot_config["plot.use_v2"] is True + assert mmm_plot_config["plot.show_warnings"] is False # Reset to defaults - mmm_config.reset() + mmm_plot_config.reset() # Verify all values are back to defaults - assert mmm_config["plot.backend"] == "matplotlib" - assert mmm_config["plot.use_v2"] is False - assert mmm_config["plot.show_warnings"] is True + assert mmm_plot_config["plot.backend"] == "matplotlib" + assert mmm_plot_config["plot.use_v2"] is False + assert mmm_plot_config["plot.show_warnings"] is True # Verify reset clears any invalid keys that were set - mmm_config["invalid.key"] = "test" - assert "invalid.key" in mmm_config - mmm_config.reset() - assert "invalid.key" not in mmm_config + mmm_plot_config["invalid.key"] = "test" + assert "invalid.key" in mmm_plot_config + mmm_plot_config.reset() + assert "invalid.key" not in mmm_plot_config finally: # Restore original state - mmm_config["plot.backend"] = original_backend - mmm_config["plot.use_v2"] = original_use_v2 - mmm_config["plot.show_warnings"] = original_warnings + mmm_plot_config["plot.backend"] = original_backend + mmm_plot_config["plot.use_v2"] = original_use_v2 + mmm_plot_config["plot.show_warnings"] = original_warnings def test_invalid_backend_warns_but_allows_setting(self): """Test that setting an invalid backend warns but still sets the value.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original_backend = mmm_config.get("plot.backend", "matplotlib") + original_backend = mmm_plot_config.get("plot.backend", "matplotlib") try: # Try to set an invalid backend with pytest.warns(UserWarning, match="Invalid backend"): - mmm_config["plot.backend"] = "invalid_backend" + mmm_plot_config["plot.backend"] = "invalid_backend" # Verify the warning message contains valid backends with pytest.warns(UserWarning) as warning_list: - mmm_config["plot.backend"] = "another_invalid" + mmm_plot_config["plot.backend"] = "another_invalid" warning_msg = str(warning_list[0].message) assert "Invalid backend" in warning_msg @@ -538,25 +538,25 @@ def test_invalid_backend_warns_but_allows_setting(self): ) # Verify the invalid backend was still set (allows setting but warns) - assert mmm_config["plot.backend"] == "another_invalid" + assert mmm_plot_config["plot.backend"] == "another_invalid" finally: - mmm_config["plot.backend"] = original_backend + mmm_plot_config["plot.backend"] = original_backend def test_valid_backends_do_not_warn(self): """Test that setting valid backend values does not warn.""" - from pymc_marketing.mmm import mmm_config + from pymc_marketing.mmm import mmm_plot_config - original_backend = mmm_config.get("plot.backend", "matplotlib") + original_backend = mmm_plot_config.get("plot.backend", "matplotlib") try: # Setting valid backends should not warn with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) - mmm_config["plot.backend"] = "matplotlib" - mmm_config["plot.backend"] = "plotly" - mmm_config["plot.backend"] = "bokeh" + mmm_plot_config["plot.backend"] = "matplotlib" + mmm_plot_config["plot.backend"] = "plotly" + mmm_plot_config["plot.backend"] = "bokeh" # Verify values were set - assert mmm_config["plot.backend"] == "bokeh" + assert mmm_plot_config["plot.backend"] == "bokeh" finally: - mmm_config["plot.backend"] = original_backend + mmm_plot_config["plot.backend"] = original_backend From e9a05babc094d93db681c9121861436cb6ea978b Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Fri, 28 Nov 2025 10:54:24 +0100 Subject: [PATCH 29/29] Address feedback: Combine tests using pytest.mark.parametrize (#2103) * Initial plan * Combine contributions_over_time tests using pytest.mark.parametrize Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: williambdean <57733339+williambdean@users.noreply.github.com> Co-authored-by: Juan Orduz --- tests/mmm/test_plot_data_parameters.py | 34 +++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/mmm/test_plot_data_parameters.py b/tests/mmm/test_plot_data_parameters.py index 7ea0c5f61..033feec0a 100644 --- a/tests/mmm/test_plot_data_parameters.py +++ b/tests/mmm/test_plot_data_parameters.py @@ -20,23 +20,23 @@ from pymc_marketing.mmm.plot import MMMPlotSuite -def test_contributions_over_time_accepts_data_parameter(mock_posterior_data): - """Test that contributions_over_time accepts data parameter.""" - # Create suite without idata - suite = MMMPlotSuite(idata=None) - - # Should work with explicit data parameter - pc = suite.contributions_over_time(var=["intercept"], data=mock_posterior_data) - - assert isinstance(pc, arviz_plots.PlotCollection) - - -def test_contributions_over_time_data_parameter_fallback(mock_idata_with_posterior): - """Test that contributions_over_time falls back to self.idata.posterior.""" - suite = MMMPlotSuite(idata=mock_idata_with_posterior) - - # Should work without data parameter (fallback) - pc = suite.contributions_over_time(var=["intercept"]) +@pytest.mark.parametrize( + "use_explicit_data", + [ + pytest.param(True, id="explicit_data_parameter"), + pytest.param(False, id="fallback_to_idata"), + ], +) +def test_contributions_over_time_data_parameter( + use_explicit_data, mock_posterior_data, mock_idata_with_posterior +): + """Test contributions_over_time with explicit data or fallback to idata.posterior.""" + if use_explicit_data: + suite = MMMPlotSuite(idata=None) + pc = suite.contributions_over_time(var=["intercept"], data=mock_posterior_data) + else: + suite = MMMPlotSuite(idata=mock_idata_with_posterior) + pc = suite.contributions_over_time(var=["intercept"]) assert isinstance(pc, arviz_plots.PlotCollection)