Skip to content

Commit 18350ce

Browse files
BUG: Fix issue with Montage.plot (#13494)
Co-authored-by: Christian O'Reilly <christian.oreilly@gmail.com>
1 parent 7ed5e27 commit 18350ce

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

doc/changes/dev/13494.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug where :meth:`mne.channels.DigMontage.plot` would error when ``axes`` was passed by `Christian O'Reilly`_.

mne/viz/montage.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,19 @@ def plot_montage(
101101
)
102102

103103
if scale != 1.0:
104+
axes = axes if axes else fig.axes[0]
105+
104106
# scale points
105-
collection = fig.axes[0].collections[0]
107+
collection = axes.collections[0]
106108
collection.set_sizes([scale * 10])
107109

108110
# scale labels
109-
labels = fig.findobj(match=plt.Text)
110-
x_label, y_label = fig.axes[0].xaxis.label, fig.axes[0].yaxis.label
111-
z_label = fig.axes[0].zaxis.label if kind == "3d" else None
112-
tick_labels = fig.axes[0].get_xticklabels() + fig.axes[0].get_yticklabels()
111+
labels = axes.findobj(match=plt.Text)
112+
x_label, y_label = axes.xaxis.label, axes.yaxis.label
113+
z_label = axes.zaxis.label if kind == "3d" else None
114+
tick_labels = axes.get_xticklabels() + axes.get_yticklabels()
113115
if kind == "3d":
114-
tick_labels += fig.axes[0].get_zticklabels()
116+
tick_labels += axes.get_zticklabels()
115117
for label in labels:
116118
if label not in [x_label, y_label, z_label] + tick_labels:
117119
label.set_fontsize(label.get_fontsize() * scale)

mne/viz/tests/test_montage.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import numpy as np
1111
import pytest
1212

13+
from mne import create_info
1314
from mne.channels import make_dig_montage, make_standard_montage, read_dig_fif
15+
from mne.io import RawArray
1416

1517
p_dir = Path(__file__).parents[2] / "io" / "kit" / "tests" / "data"
1618
elp = p_dir / "test_elp.txt"
@@ -86,3 +88,16 @@ def test_plot_digmontage():
8688
)
8789
montage.plot()
8890
plt.close("all")
91+
92+
93+
def test_plot_montage_scale():
94+
"""Test montage.plot with non-default scale using subplot axes."""
95+
montage = make_standard_montage("GSN-HydroCel-129")
96+
ax = plt.subplots(2, 1)[1][1]
97+
picks = montage.ch_names
98+
info = create_info(montage.ch_names, sfreq=256, ch_types="eeg")
99+
raw = RawArray(
100+
np.zeros((len(montage.ch_names), 1)), info, copy=None, verbose=False
101+
).set_montage(montage)
102+
# test for gh-13438
103+
raw.pick(picks).get_montage().plot(axes=ax, show_names=False, scale=0.1)

0 commit comments

Comments
 (0)