diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index 2872a7621e2..49c679175be 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -450,7 +450,8 @@ def _update_data(self): def _get_epoch_num_from_time(self, time): epoch_nums = self.mne.inst.selection - return epoch_nums[np.searchsorted(self.mne.boundary_times[1:], time)] + epoch_ix = np.searchsorted(self.mne.boundary_times[1:-1], time) + return epoch_nums[epoch_ix] def _redraw(self, update_data=True, annotations=False): """Redraws backend if necessary.""" diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index f3563b454f0..3e5ef54658b 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -58,6 +58,15 @@ from ..fixes import _close_event from ..utils import Bunch, _click_ch_name, check_version, logger from ._figure import BrowserBase +from .ui_events import ( + ChannelsSelect, + TimeBrowse, + TimeChange, + disable_ui_events, + publish, + subscribe, + unsubscribe, +) from .utils import ( DraggableLine, _events_off, @@ -188,6 +197,7 @@ class MNEAnnotationFigure(MNEFigure): def _close(self, event=None): """Handle close events (via keypress or window [x]).""" + unsubscribe(self, ["time_change", "time_browse", "channels_select"]) parent = self.mne.parent_fig # disable span selector parent.mne.ax_main.selector.active = False @@ -566,6 +576,15 @@ def __init__(self, inst, figsize, ica=None, xlabel="Time (s)", **kwargs): vline_text=vline_text, ) + # Start listening to incoming TimeChange UI events + # subscribe(self, "time_change", self._on_time_change_event) + + # Start listening to incoming TimeBrowse UI events + subscribe(self, "time_browse", self._on_time_browse_event) + + # Start listening to incoming ChannelsSelect UI events + subscribe(self, "channels_select", self._on_channels_select_event) + def _get_size(self): return self.get_size_inches() @@ -641,7 +660,7 @@ def _keypress(self, event): key = event.key n_channels = self.mne.n_channels if self.mne.is_epochs: - last_time = self.mne.n_times / self.mne.info["sfreq"] + last_time = self.mne.boundary_times[-2] else: last_time = self.mne.inst.times[-1] # scroll up/down @@ -674,24 +693,21 @@ def _keypress(self, event): else: ceiling = len(self.mne.ch_order) - n_channels ch_start = self.mne.ch_start + direction * n_channels - self.mne.ch_start = np.clip(ch_start, 0, ceiling) - self._update_picks() - self._update_vscroll() - self._redraw() + ch_start = np.clip(ch_start, 0, ceiling) + channels = self.mne.ch_names[ + self.mne.ch_order[ch_start : ch_start + n_channels] + ] + publish(self, ChannelsSelect(ch_names=channels)) # scroll left/right elif key in ("right", "left", "shift+right", "shift+left"): - old_t_start = self.mne.t_start direction = 1 if key.endswith("right") else -1 if self.mne.is_epochs: denom = 1 if key.startswith("shift") else self.mne.n_epochs else: denom = 1 if key.startswith("shift") else 4 - t_max = last_time - self.mne.duration t_start = self.mne.t_start + direction * self.mne.duration / denom - self.mne.t_start = np.clip(t_start, self.mne.first_time, t_max) - if self.mne.t_start != old_t_start: - self._update_hscroll() - self._redraw(annotations=True) + t_start = np.clip(t_start, 0, last_time - self.mne.duration) + self._publish_time_browse_event(t_start) # scale traces elif key in ("=", "+", "-"): scaler = 1 / 1.1 if key == "-" else 1.1 @@ -704,41 +720,31 @@ def _keypress(self, event): and not self.mne.butterfly ): new_n_ch = n_channels + (1 if key == "pageup" else -1) - self.mne.n_channels = np.clip(new_n_ch, 1, len(self.mne.ch_order)) + n_channels = np.clip(new_n_ch, 1, len(self.mne.ch_order)) + ch_start = self.mne.ch_start # add new chs from above if we're at the bottom of the scrollbar - ch_end = self.mne.ch_start + self.mne.n_channels - if ch_end > len(self.mne.ch_order) and self.mne.ch_start > 0: - self.mne.ch_start -= 1 - self._update_vscroll() - # redraw only if changed - if self.mne.n_channels != n_channels: - self._update_picks() - self._update_trace_offsets() - self._redraw(annotations=True) + ch_end = ch_start + n_channels + if ch_end > len(self.mne.ch_order) and ch_start > 0: + ch_start -= 1 + channels = self.mne.ch_names[ + self.mne.ch_order[ch_start : ch_start + n_channels] + ] + publish(self, ChannelsSelect(ch_names=channels)) + # change duration elif key in ("home", "end"): - old_dur = self.mne.duration dur_delta = 1 if key == "end" else -1 if self.mne.is_epochs: - # prevent from showing zero epochs, or more epochs than we have - self.mne.n_epochs = np.clip( - self.mne.n_epochs + dur_delta, 1, len(self.mne.inst) - ) # use the length of one epoch as duration change min_dur = len(self.mne.inst.times) / self.mne.info["sfreq"] new_dur = self.mne.duration + dur_delta * min_dur else: - # never show fewer than 3 samples - min_dur = 3 * np.diff(self.mne.inst.times[:2])[0] # use multiplicative dur_delta dur_delta = 5 / 4 if dur_delta > 0 else 4 / 5 new_dur = self.mne.duration * dur_delta - self.mne.duration = np.clip(new_dur, min_dur, last_time) - if self.mne.duration != old_dur: - if self.mne.t_start + self.mne.duration > last_time: - self.mne.t_start = last_time - self.mne.duration - self._update_hscroll() - self._redraw(annotations=True) + self._publish_time_browse_event( + self.mne.t_start, self.mne.t_start + new_dur + ) elif key == "?": # help window self._toggle_help_fig(event) elif key == "a": # annotation mode @@ -797,7 +803,13 @@ def _buttonpress(self, event): idx = self.mne.traces.index(line) self._toggle_bad_channel(idx) return - self._show_vline(event.xdata) # butterfly / not on data trace + time = event.xdata + if self.mne.is_epochs: + width = self.mne.boundary_times[1] - self.mne.boundary_times[0] + time = (time % width) + self.mne.inst.tmin + else: + time += self.mne.inst.first_time + publish(self, TimeChange(time=time)) self._redraw(update_data=False, annotations=False) return # click in vertical scrollbar @@ -1752,8 +1764,7 @@ def _check_update_hscroll_clicked(self, event): ix = np.searchsorted(self.mne.boundary_times[1:], time) time = self.mne.boundary_times[ix] if self.mne.t_start != time: - self.mne.t_start = time - self._update_hscroll() + self._publish_time_browse_event(time) return True return False @@ -1765,9 +1776,10 @@ def _check_update_vscroll_clicked(self, event): len(self.mne.ch_order) - self.mne.n_channels, ) if self.mne.ch_start != new_ch_start: - self.mne.ch_start = new_ch_start - self._update_picks() - self._update_vscroll() + channels = self.mne.ch_names[ + self.mne.ch_order[new_ch_start : new_ch_start + self.mne.n_channels] + ] + publish(self, ChannelsSelect(ch_names=channels)) return True return False @@ -2318,6 +2330,150 @@ def _get_scale_bar_texts(self): return texts + def _publish_time_browse_event(self, t_start=None, t_end=None): + """Publish a TimeBrowse event with meaningful time_start and time_end values.""" + # Figure out proper t_start and t_end that doesn't exceed the data boundaries. + if t_start is None: + t_start = self.mne.t_start + else: + if self.mne.is_epochs: + last_time = self.mne.n_times / self.mne.info["sfreq"] + else: + last_time = self.mne.inst.times[-1] + t_max = last_time - self.mne.duration + t_start = np.clip(t_start, self.mne.first_time, t_max) + + if t_end is None: + t_end = t_start + self.mne.duration + else: + t_end = min(t_end, self.mne.n_times / self.mne.info["sfreq"]) + + # Don't publish an event if nothing changed. + if ( + self.mne.t_start == t_start + and self.mne.t_start + self.mne.duration == t_end + ): + return + + if self.mne.is_epochs: + # Translate the time-coordinate in the browser window to the actual + # start/end times of the epochs in the raw file. + epoch_num_start = self._get_epoch_num_from_time( + t_start + self.mne.sampling_period + ) + epoch_num_end = self._get_epoch_num_from_time( + t_end + self.mne.sampling_period + ) + onsets = self.mne.inst.events[:, 0] / self.mne.info["sfreq"] + t_start = onsets[epoch_num_start] + self.mne.inst.tmin + t_end = onsets[epoch_num_end - 1] + self.mne.inst.tmax + else: + # For raw data, we need to take `first_time` into account. + t_start += self.mne.inst.first_time + t_end += self.mne.inst.first_time + + publish(self, TimeBrowse(time_start=t_start, time_end=t_end)) + + def _on_time_browse_event(self, event): + """Respond to the TimeBrowse UI event, update horizontal scrolling.""" + time_start = event.time_start + time_end = event.time_end + + if self.mne.is_epochs: + # Translate the start/end times from the original raw to the indices of the + # epochs being shown, and then to the appropriate start/end times in the + # browser. + events = self.mne.inst.events[self.mne.inst.selection] + onsets = events[:, 0] / self.mne.info["sfreq"] + # Subtract/add one sample to make sure we end on the right side. This is + # needed because of small floating point errors. + epoch_ix_start = np.searchsorted( + onsets, time_start - self.mne.inst.tmin - self.mne.sampling_period + ) + epoch_ix_end = np.searchsorted( + onsets, time_end - self.mne.inst.tmax + self.mne.sampling_period + ) + # Always show at least one epoch. + epoch_ix_start = min(epoch_ix_start, len(self.mne.inst) - 1) + epoch_ix_end = max(epoch_ix_start + 1, epoch_ix_end) + self.mne.n_epochs = epoch_ix_end - epoch_ix_start + + # Compute the browser time period to match the selected epochs. + time_start = self.mne.boundary_times[epoch_ix_start] + width = self.mne.boundary_times[1] - self.mne.boundary_times[0] + time_end = self.mne.boundary_times[epoch_ix_end - 1] + width + else: + # For raw data, we need to take `first_time` into account. + time_start -= self.mne.inst.first_time + time_end -= self.mne.inst.first_time + + # Never show fewer than 3 samples. + min_dur = 3 * np.diff(self.mne.inst.times[:2])[0] + time_end = np.clip(time_end, time_start + min_dur, self.mne.inst.times[-1]) + + # Update browser window. + with disable_ui_events(self): + self.mne.t_start = time_start + self.mne.duration = time_end - time_start + self._update_hscroll() + self._redraw(annotations=True) + + # def _on_channels_select_event(self, event): + # """Respond to the ChannelsSelect UI event.""" + # old_n_channels = self.mne.n_channels + # picks = np.flatnonzero( + # np.isin(self.mne.ch_names[self.mne.ch_order], event.ch_names) + # ) + # if len(picks) == 0: + # return # can't handle the event + # if picks.min() == self.mne.ch_start and len(picks) == self.mne.n_channels: + # return # no change + + # with disable_ui_events(self): + # self.mne.ch_start = picks.min() + # self.mne.n_channels = len(picks) + # self._update_vscroll() + # self._update_picks() + # if self.mne.n_channels != old_n_channels: + # self._update_trace_offsets() + # self._redraw(annotations=True) + + def _on_channels_select_event(self, event): + """Respond to the ChannelsSelect UI event.""" + old_n_channels = self.mne.n_channels + all_channels = self.mne.ch_names[self.mne.ch_order] + ch_indices = np.where(np.isin(all_channels, event.ch_names))[0] + # picks = np.flatnonzero( + # np.isin(self.mne.ch_names[self.mne.ch_order], event.ch_names) + # ) + if len(ch_indices) == 0: + return # can't handle the event + if ( + ch_indices.min() == self.mne.ch_start + and len(ch_indices) == self.mne.n_channels + ): + return # no change + + # with disable_ui_events(self): + self.mne.ch_start = ch_indices.min() + self.mne.n_channels = len(ch_indices) + self._update_vscroll() + self._update_picks() + if self.mne.n_channels != old_n_channels: + self._update_trace_offsets() + self._redraw(annotations=True) + + def _on_time_change_event(self, event): + """Respond to the TimeChange UI event.""" + if self.mne.is_epochs: + time = np.clip(event.time, self.mne.inst.tmin, self.mne.inst.tmax) + time -= self.mne.inst.tmin + else: + time = event.time - self.mne.inst.first_time + time = np.clip(time, self.mne.inst.times[0], self.mne.inst.times[-1]) + with disable_ui_events(self): + self._show_vline(time) + class MNELineFigure(MNEFigure): """Interactive figure for non-scrolling line plots.""" diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index b8b3fe29a4d..515ae7a192b 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -92,6 +92,31 @@ class TimeChange(UIEvent): time: float +@dataclass +@fill_doc +class TimeBrowse(UIEvent): + """Indicates that the user has browsed to a new time range. + + Parameters + ---------- + time_start : float + The new start time in seconds. + time_end : float + The new end time in seconds. + + Attributes + ---------- + %(ui_event_name_source)s + time_start : float + The new start time in seconds. + time_end : float + The new end time in seconds. + """ + + time_start: float + time_end: float + + @dataclass @fill_doc class PlaybackSpeed(UIEvent): @@ -253,6 +278,7 @@ def _get_event_channel(fig): import matplotlib from ._brain import Brain + from ._figure import BrowserBase from .evoked_field import EvokedField # Create the event channel if it doesn't exist yet @@ -283,9 +309,13 @@ def delete_event_channel(event=None, *, weakfig=weakfig): # Hook up the above callback function to the close event of the figure # window. How this is done exactly depends on the various figure types # MNE-Python has. - _validate_type(fig, (matplotlib.figure.Figure, Brain, EvokedField), "fig") + _validate_type( + fig, (matplotlib.figure.Figure, Brain, EvokedField, BrowserBase), "fig" + ) if isinstance(fig, matplotlib.figure.Figure): fig.canvas.mpl_connect("close_event", delete_event_channel) + elif isinstance(fig, BrowserBase): + fig.mne.viewbox.destroyed.connect(delete_event_channel) else: assert hasattr(fig, "_renderer") # figures like Brain, EvokedField, etc. fig._renderer._window_close_connect(delete_event_channel, after=False)