@@ -353,6 +353,7 @@ def test_its():
353353 2. causalpy.InterruptedTimeSeries returns correct type
354354 3. the correct number of MCMC chains exists in the posterior inference data
355355 4. the correct number of MCMC draws exists in the posterior inference data
356+ 5. the method get_plot_data returns a DataFrame with expected columns
356357 """
357358 df = (
358359 cp .load_data ("its" )
@@ -378,9 +379,21 @@ def test_its():
378379 isinstance (item , plt .Axes ) for item in ax
379380 ), "ax must be a numpy.ndarray of plt.Axes"
380381 plot_data = result .get_plot_data ()
381- assert isinstance (plot_data , pd .DataFrame ), "The returned object is not a pandas DataFrame"
382- expected_columns = ['prediction' , 'pred_hdi_lower' , 'pred_hdi_upper' , 'impact' , 'impact_hdi_lower' , 'impact_hdi_upper' ]
383- assert set (expected_columns ).issubset (set (plot_data .columns )), f"DataFrame is missing expected columns { expected_columns } "
382+ assert isinstance (plot_data , pd .DataFrame ), (
383+ "The returned object is not a pandas DataFrame"
384+ )
385+ expected_columns = [
386+ "prediction" ,
387+ "pred_hdi_lower_94" ,
388+ "pred_hdi_upper_94" ,
389+ "impact" ,
390+ "impact_hdi_lower_94" ,
391+ "impact_hdi_upper_94" ,
392+ ]
393+ assert set (expected_columns ).issubset (set (plot_data .columns )), (
394+ f"DataFrame is missing expected columns { expected_columns } "
395+ )
396+
384397
385398@pytest .mark .integration
386399def test_its_covid ():
@@ -392,6 +405,7 @@ def test_its_covid():
392405 2. causalpy.InterruptedtimeSeries returns correct type
393406 3. the correct number of MCMC chains exists in the posterior inference data
394407 4. the correct number of MCMC draws exists in the posterior inference data
408+ 5. the method get_plot_data returns a DataFrame with expected columns
395409 """
396410
397411 df = (
@@ -418,9 +432,20 @@ def test_its_covid():
418432 isinstance (item , plt .Axes ) for item in ax
419433 ), "ax must be a numpy.ndarray of plt.Axes"
420434 plot_data = result .get_plot_data ()
421- assert isinstance (plot_data , pd .DataFrame ), "The returned object is not a pandas DataFrame"
422- expected_columns = ['prediction' , 'pred_hdi_lower' , 'pred_hdi_upper' , 'impact' , 'impact_hdi_lower' , 'impact_hdi_upper' ]
423- assert set (expected_columns ).issubset (set (plot_data .columns )), f"DataFrame is missing expected columns { expected_columns } "
435+ assert isinstance (plot_data , pd .DataFrame ), (
436+ "The returned object is not a pandas DataFrame"
437+ )
438+ expected_columns = [
439+ "prediction" ,
440+ "pred_hdi_lower_94" ,
441+ "pred_hdi_upper_94" ,
442+ "impact" ,
443+ "impact_hdi_lower_94" ,
444+ "impact_hdi_upper_94" ,
445+ ]
446+ assert set (expected_columns ).issubset (set (plot_data .columns )), (
447+ f"DataFrame is missing expected columns { expected_columns } "
448+ )
424449
425450
426451@pytest .mark .integration
@@ -433,6 +458,7 @@ def test_sc():
433458 2. causalpy.SyntheticControl returns correct type
434459 3. the correct number of MCMC chains exists in the posterior inference data
435460 4. the correct number of MCMC draws exists in the posterior inference data
461+ 5. the method get_plot_data returns a DataFrame with expected columns
436462 """
437463
438464 df = cp .load_data ("sc" )
@@ -463,9 +489,21 @@ def test_sc():
463489 isinstance (item , plt .Axes ) for item in ax
464490 ), "ax must be a numpy.ndarray of plt.Axes"
465491 plot_data = result .get_plot_data ()
466- assert isinstance (plot_data , pd .DataFrame ), "The returned object is not a pandas DataFrame"
467- expected_columns = ['prediction' , 'pred_hdi_lower' , 'pred_hdi_upper' , 'impact' , 'impact_hdi_lower' , 'impact_hdi_upper' ]
468- assert set (expected_columns ).issubset (set (plot_data .columns )), f"DataFrame is missing expected columns { expected_columns } "
492+ assert isinstance (plot_data , pd .DataFrame ), (
493+ "The returned object is not a pandas DataFrame"
494+ )
495+ expected_columns = [
496+ "prediction" ,
497+ "pred_hdi_lower_94" ,
498+ "pred_hdi_upper_94" ,
499+ "impact" ,
500+ "impact_hdi_lower_94" ,
501+ "impact_hdi_upper_94" ,
502+ ]
503+ assert set (expected_columns ).issubset (set (plot_data .columns )), (
504+ f"DataFrame is missing expected columns { expected_columns } "
505+ )
506+
469507
470508@pytest .mark .integration
471509def test_sc_brexit ():
@@ -477,6 +515,7 @@ def test_sc_brexit():
477515 2. causalpy.SyntheticControl returns correct type
478516 3. the correct number of MCMC chains exists in the posterior inference data
479517 4. the correct number of MCMC draws exists in the posterior inference data
518+ 5. the method get_plot_data returns a DataFrame with expected columns
480519 """
481520
482521 df = (
@@ -512,9 +551,20 @@ def test_sc_brexit():
512551 isinstance (item , plt .Axes ) for item in ax
513552 ), "ax must be a numpy.ndarray of plt.Axes"
514553 plot_data = result .get_plot_data ()
515- assert isinstance (plot_data , pd .DataFrame ), "The returned object is not a pandas DataFrame"
516- expected_columns = ['prediction' , 'pred_hdi_lower' , 'pred_hdi_upper' , 'impact' , 'impact_hdi_lower' , 'impact_hdi_upper' ]
517- assert set (expected_columns ).issubset (set (plot_data .columns )), f"DataFrame is missing expected columns { expected_columns } "
554+ assert isinstance (plot_data , pd .DataFrame ), (
555+ "The returned object is not a pandas DataFrame"
556+ )
557+ expected_columns = [
558+ "prediction" ,
559+ "pred_hdi_lower_94" ,
560+ "pred_hdi_upper_94" ,
561+ "impact" ,
562+ "impact_hdi_lower_94" ,
563+ "impact_hdi_upper_94" ,
564+ ]
565+ assert set (expected_columns ).issubset (set (plot_data .columns )), (
566+ f"DataFrame is missing expected columns { expected_columns } "
567+ )
518568
519569
520570@pytest .mark .integration
0 commit comments