Skip to content

Commit 2dfe63f

Browse files
committed
update savefig default to save when given file_name
1 parent 43c0304 commit 2dfe63f

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

fooof/plts/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,14 @@ def savefig(func):
181181
@wraps(func)
182182
def decorated(*args, **kwargs):
183183

184-
save_fig = kwargs.pop('save_fig', False)
184+
# Grab file name and path arguments, if they are in kwargs
185185
file_name = kwargs.pop('file_name', None)
186186
file_path = kwargs.pop('file_path', None)
187187

188+
# Check for an explicit argument for whether to save figure or not
189+
# Defaults to saving when file name given (since bool(str)->True; bool(None)->False)
190+
save_fig = kwargs.pop('save_fig', bool(file_name))
191+
188192
func(*args, **kwargs)
189193

190194
if save_fig:

fooof/tests/plts/test_utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""Tests for fooof.plts.utils."""
22

33
import os
4-
import tempfile
5-
6-
from fooof.tests.tutils import plot_test
74

85
from fooof.core.modutils import safe_import
96

7+
from fooof.tests.tutils import plot_test
8+
from fooof.tests.settings import TEST_PLOTS_PATH
9+
1010
from fooof.plts.utils import *
1111

1212
mpl = safe_import('matplotlib')
@@ -79,6 +79,14 @@ def test_savefig():
7979
def example_plot():
8080
plt.plot([1, 2], [3, 4])
8181

82-
with tempfile.NamedTemporaryFile(mode='w+') as file:
83-
example_plot(save_fig=True, file_name=file.name)
84-
assert os.path.exists(file.name)
82+
# Test defaults to saving given file path & name
83+
example_plot(file_path=TEST_PLOTS_PATH, file_name='test_savefig1.pdf')
84+
assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig1.pdf'))
85+
86+
# Test works the same when explicitly given `save_fig`
87+
example_plot(save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_savefig2.pdf')
88+
assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig2.pdf'))
89+
90+
# Test does not save when `save_fig` set to False
91+
example_plot(save_fig=False, file_path=TEST_PLOTS_PATH, file_name='test_savefig3.pdf')
92+
assert not os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_savefig3.pdf'))

0 commit comments

Comments
 (0)