|
8 | 8 | import random |
9 | 9 | import shutil |
10 | 10 | import tempfile |
| 11 | +import time |
11 | 12 |
|
12 | 13 | import flaky |
13 | 14 | import numpy as np |
@@ -697,13 +698,25 @@ def test_learner_subdomain(learner_type, f, learner_kwargs): |
697 | 698 | raise NotImplementedError() |
698 | 699 |
|
699 | 700 |
|
| 701 | +def add_time(f): |
| 702 | + @ft.wraps(f) |
| 703 | + def wrapper(*args, **kwargs): |
| 704 | + t0 = time.time() |
| 705 | + result = f(*args, **kwargs) |
| 706 | + return {"result": result, "time": time.time() - t0} |
| 707 | + |
| 708 | + return wrapper |
| 709 | + |
| 710 | + |
700 | 711 | @run_with( |
701 | 712 | Learner1D, |
702 | 713 | Learner2D, |
703 | 714 | LearnerND, |
704 | 715 | AverageLearner, |
705 | 716 | AverageLearner1D, |
706 | 717 | SequenceLearner, |
| 718 | + IntegratorLearner, |
| 719 | + with_all_loss_functions=False, |
707 | 720 | ) |
708 | 721 | def test_to_dataframe(learner_type, f, learner_kwargs): |
709 | 722 | if learner_type is LearnerND: |
@@ -752,3 +765,21 @@ def test_to_dataframe(learner_type, f, learner_kwargs): |
752 | 765 | bal_learner2 = BalancingLearner(learners2) |
753 | 766 | bal_learner2.load_dataframe(df_bal, **kw) |
754 | 767 | assert bal_learner2.npoints == bal_learner.npoints |
| 768 | + |
| 769 | + if learner_type is SequenceLearner: |
| 770 | + # We do not test the DataSaver with the SequenceLearner |
| 771 | + # because the DataSaver is not compatible with the SequenceLearner. |
| 772 | + return |
| 773 | + |
| 774 | + # Test with DataSaver |
| 775 | + learner = learner_type( |
| 776 | + add_time(generate_random_parametrization(f)), **learner_kwargs |
| 777 | + ) |
| 778 | + data_saver = DataSaver(learner, operator.itemgetter("result")) |
| 779 | + df = data_saver.to_dataframe(**kw) # test if empty dataframe works |
| 780 | + simple_run(data_saver, 100) |
| 781 | + df = data_saver.to_dataframe(**kw) |
| 782 | + if learner_type is AverageLearner1D: |
| 783 | + assert len(df) == data_saver.nsamples |
| 784 | + else: |
| 785 | + assert len(df) == data_saver.npoints |
0 commit comments