diff --git a/doc/api/inverse.rst b/doc/api/inverse.rst index 754244c17fe..9ff9e65655f 100644 --- a/doc/api/inverse.rst +++ b/doc/api/inverse.rst @@ -85,6 +85,7 @@ Inverse Solutions Dipole DipoleFixed fit_dipole + gui.dipolefit :py:mod:`mne.dipole`: diff --git a/doc/changes/dev/13074.newfeature.rst b/doc/changes/dev/13074.newfeature.rst new file mode 100644 index 00000000000..35415ea90d7 --- /dev/null +++ b/doc/changes/dev/13074.newfeature.rst @@ -0,0 +1 @@ +Add a GUI for interactive guided dipole fitting (:func:`mne.gui.dipolefit`), by `Marijn van Vliet`_ diff --git a/doc/conf.py b/doc/conf.py index f3a50fd3518..02749d47f0a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -417,6 +417,7 @@ "default", # unlinkable "CoregistrationUI", + "DipoleFitUI", "mne_qt_browser.figure.MNEQtBrowser", # pooch, since its website is unreliable and users will rarely need the links "pooch.Unzip", diff --git a/mne/commands/mne_dipolefit.py b/mne/commands/mne_dipolefit.py new file mode 100644 index 00000000000..4e1771f3f71 --- /dev/null +++ b/mne/commands/mne_dipolefit.py @@ -0,0 +1,161 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +"""Open the dipole fitting GUI. + +Examples +-------- +.. code-block:: console + + $ mne dipolefit + +""" + +import os.path as op + +import mne + + +def run(): + """Run command.""" + from mne.commands.utils import _add_verbose_flag, get_optparser + + parser = get_optparser(__file__) + + parser.add_option( + "-e", + "--evoked", + default=None, + metavar="EVOKED_FILE", + help='The evoked file ("-ave.fif") containing the data to fit dipoles to.', + ) + parser.add_option( + "--condition", + default=0, + help="The condition to use.", + ) + parser.add_option( + "--baseline-from", + default=None, + type=float, + metavar="TIME", + help="The earliest timepoint to use as baseline.", + ) + parser.add_option( + "--baseline-to", + default=None, + type=float, + metavar="TIME", + help="The latest timepoint to use as baseline.", + ) + parser.add_option( + "-c", + "--cov", + default=None, + metavar="COV_FILE", + help='The noise covariance ("-cov.fif") to use.', + ) + parser.add_option( + "-b", + "--bem", + default=None, + metavar="BEM_FILE", + help='The BEM model ("-bem-sol.fif") to use.', + ) + parser.add_option( + "-t", + "--initial-time", + default=None, + type=float, + metavar="TIME", + help="The initial time to show", + ) + parser.add_option( + "--trans", + default=None, + metavar="TRANS_FILE", + help='Head<->MRI transform FIF file ("-trans.fif")', + ) + parser.add_option( + "--stc", + default=None, + metavar="STC_FILE", + help="An optional distributed source estimate to show during dipole fitting.", + ) + parser.add_option( + "-s", "--subject", dest="subject", default=None, help="Subject name" + ) + parser.add_option( + "-d", + "--subjects-dir", + default=None, + help="Subjects directory", + ) + parser.add_option( + "--hide-density", + action="store_true", + default=False, + help="Prevent showing the magnetic field density as blobs of color.", + ) + parser.add_option( + "--channel-type", + default=None, + help=( + 'Restrict channel types to either "meg" or "eeg". By default both are used ' + "if present." + ), + ) + parser.add_option( + "-j", "--cpus", default=-1, type=int, help="Number of CPUs to use." + ) + _add_verbose_flag(parser) + + options, args = parser.parse_args() + + # expanduser allows ~ for paths + subjects_dir = options.subjects_dir + if subjects_dir is not None: + subjects_dir = op.expanduser(subjects_dir) + bem = options.bem + if bem is not None: + bem = op.expanduser(bem) + trans = options.trans + if trans is not None: + trans = op.expanduser(trans) + stc = options.stc + if stc is not None: + stc = op.expanduser(stc) + import faulthandler + + # Condition can be specified as integer index or string comment. + if options.condition is not None: + try: + condition = int(options.condition) + except ValueError: + condition = options.condition + else: + condition = None + + faulthandler.enable() + mne.gui.dipolefit( + evoked=options.evoked, + condition=condition, + baseline=(options.baseline_from, options.baseline_to), + cov=options.cov, + bem=bem, + subject=options.subject, + subjects_dir=subjects_dir, + stc=stc, + ch_type=options.channel_type, + initial_time=options.initial_time, + trans=trans, + n_jobs=options.cpus, + show_density=not options.hide_density, + show=True, + block=True, + verbose=options.verbose, + ) + + +mne.utils.run_command_if_main() diff --git a/mne/dipole.py b/mne/dipole.py index 67ff5cf65c4..4da7fa1940f 100644 --- a/mne/dipole.py +++ b/mne/dipole.py @@ -30,7 +30,12 @@ ) from .parallel import parallel_func from .source_space._source_space import SourceSpaces, _make_volume_source_space -from .surface import _compute_nearest, _points_outside_surface, transform_surface_to +from .surface import ( + _CheckInside, + _compute_nearest, + _DistanceQuery, + transform_surface_to, +) from .transforms import _coord_frame_name, _print_coord_trans, apply_trans from .utils import ( ExtendedTimeMixin, @@ -911,7 +916,7 @@ def _write_dipole_bdip(fname, dip): fid.write(np.array(has_errors, ">i4").tobytes()) # has_errors fid.write(np.zeros(1, ">f4").tobytes()) # noise level for key in _BDIP_ERROR_KEYS: - val = dip.conf[key][ti] if key in dip.conf else 0.0 + val = dip.conf[key][ti] if key in dip.conf else np.array(0.0) assert val.shape == () fid.write(np.array(val, ">f4").tobytes()) fid.write(np.zeros(25, ">f4").tobytes()) @@ -1050,7 +1055,7 @@ def _fit_Q(*, sensors, fwd_data, whitener, B, B2, B_orig, rd, ori=None): def _fit_dipoles( fun, - min_dist_to_inner_skull, + constraint, data, times, guess_rrs, @@ -1069,7 +1074,7 @@ def _fit_dipoles( # parallel over time points res = parallel( p_fun( - min_dist_to_inner_skull, + constraint, B, t, guess_rrs, @@ -1267,26 +1272,41 @@ def _fit_confidence(*, rd, Q, ori, whitener, fwd_data, sensors): return conf -def _surface_constraint(rd, surf, min_dist_to_inner_skull): - """Surface fitting constraint.""" - dist = _compute_nearest(surf["rr"], rd[np.newaxis, :], return_dists=True)[1][0] - if _points_outside_surface(rd[np.newaxis, :], surf, 1)[0]: - dist *= -1.0 - # Once we know the dipole is below the inner skull, - # let's check if its distance to the inner skull is at least - # min_dist_to_inner_skull. This can be enforced by adding a - # constrain proportional to its distance. - dist -= min_dist_to_inner_skull - return dist +def _surface_constraint(surf, min_dist_to_inner_skull): + """Create a surface fitting constraint function.""" + distance_checker = _DistanceQuery(surf["rr"], method="BallTree") + try: + import pyvista # noqa F401 + + inside_checker = _CheckInside(surf, mode="pyvista") + except ImportError: + inside_checker = _CheckInside(surf, mode="old") + + def constraint(rd): + dist = distance_checker.query(rd[np.newaxis, :])[0][0] + if not inside_checker(rd[np.newaxis, :])[0]: + dist *= -1.0 + # Once we know the dipole is below the inner skull, + # let's check if its distance to the inner skull is at least + # min_dist_to_inner_skull. This can be enforced by adding a + # constrain proportional to its distance. + dist -= min_dist_to_inner_skull + return dist + + return constraint + + +def _sphere_constraint(r0, R_adj): + """Create a sphere fitting constraint function.""" + def constraint(rd): + return R_adj - np.sqrt(np.sum((rd - r0) ** 2)) -def _sphere_constraint(rd, r0, R_adj): - """Sphere fitting constraint.""" - return R_adj - np.sqrt(np.sum((rd - r0) ** 2)) + return constraint def _fit_dipole( - min_dist_to_inner_skull, + constraint, B_orig, t, guess_rrs, @@ -1303,22 +1323,6 @@ def _fit_dipole( """Fit a single bit of data.""" B = np.dot(whitener, B_orig) - # make constraint function to keep the solver within the inner skull - if "rr" in fwd_data["inner_skull"]: # bem - surf = fwd_data["inner_skull"] - constraint = partial( - _surface_constraint, - surf=surf, - min_dist_to_inner_skull=min_dist_to_inner_skull, - ) - else: # sphere - surf = None - constraint = partial( - _sphere_constraint, - r0=fwd_data["inner_skull"]["r0"], - R_adj=fwd_data["inner_skull"].radius - min_dist_to_inner_skull, - ) - # Find a good starting point (find_best_guess in C) B2 = np.dot(B, B) if B2 == 0: @@ -1388,9 +1392,9 @@ def _fit_dipole( ) msg = "---- Fitted : %7.1f ms" % (1000.0 * t) - if surf is not None: + if "rr" in fwd_data["inner_skull"]: # bem dist_to_inner_skull = _compute_nearest( - surf["rr"], rd_final[np.newaxis, :], return_dists=True + fwd_data["inner_skull"]["rr"], rd_final[np.newaxis, :], return_dists=True )[1][0] msg += ", distance to inner skull : %2.4f mm" % (dist_to_inner_skull * 1000.0) @@ -1399,7 +1403,7 @@ def _fit_dipole( def _fit_dipole_fixed( - min_dist_to_inner_skull, + constraint, B_orig, t, guess_rrs, @@ -1703,17 +1707,23 @@ def fit_dipole( logger.info("Go through all guess source locations...") # inner_skull goes from mri to head frame - if "rr" in inner_skull: + if not bem["is_sphere"]: transform_surface_to(inner_skull, "head", mri_head_t) + + # make constraint function to keep the solver within the inner skull + if bem["is_sphere"]: + constraint = _sphere_constraint( + r0=inner_skull["r0"], + R_adj=inner_skull.radius - min_dist_to_inner_skull, + ) + else: + constraint = _surface_constraint( + surf=inner_skull, + min_dist_to_inner_skull=min_dist_to_inner_skull, + ) + if fixed_position: - if "rr" in inner_skull: - check = _surface_constraint(pos, inner_skull, min_dist_to_inner_skull) - else: - check = _sphere_constraint( - pos, - inner_skull["r0"], - R_adj=inner_skull.radius - min_dist_to_inner_skull, - ) + check = constraint(pos) if check <= 0: raise ValueError( f"fixed position is {-1000 * check:0.1f}mm outside the inner skull " @@ -1753,7 +1763,7 @@ def fit_dipole( fun = _fit_dipole_fixed if fixed_position else _fit_dipole out = _fit_dipoles( fun, - min_dist_to_inner_skull, + constraint, data, times, guess_src["rr"], diff --git a/mne/gui/__init__.pyi b/mne/gui/__init__.pyi index 086c51a4904..a6dc001ef84 100644 --- a/mne/gui/__init__.pyi +++ b/mne/gui/__init__.pyi @@ -1,2 +1,2 @@ -__all__ = ["_GUIScraper", "coregistration"] -from ._gui import _GUIScraper, coregistration +__all__ = ["_GUIScraper", "coregistration", "dipolefit"] +from ._gui import _GUIScraper, coregistration, dipolefit diff --git a/mne/gui/_gui.py b/mne/gui/_gui.py index b8898d8b7c2..546fc0ca840 100644 --- a/mne/gui/_gui.py +++ b/mne/gui/_gui.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +from ..datasets import sample from ..utils import get_config, verbose @@ -165,6 +166,117 @@ def coregistration( ) +@verbose +def dipolefit( + evoked=None, + *, + condition=0, + baseline=(None, 0), + cov=None, + bem=None, + initial_time=None, + trans=None, + stc=None, + subject=None, + subjects_dir=None, + rank="info", + show_density=True, + ch_type=None, + n_jobs=None, + show=True, + block=False, + verbose=None, +): + """GUI for interactive dipole fitting, inspired by MEGIN's XFit program. + + Parameters + ---------- + evoked : instance of Evoked | path-like | None + Evoked data to show fieldmap of and fit dipoles to. + condition : int | str + When ``evoked`` is given as a filename, use this to select which evoked to use + in the file by either specifying the index or the string comment field of the + evoked. By default, the first evoked is used. + %(baseline_evoked)s + Defaults to ``(None, 0)``, i.e. beginning of the the data until time point zero. + cov : instance of Covariance | path-like | "baseline" | None + Noise covariance matrix. If ``None``, an ad-hoc covariance matrix is used with + default values for the diagonal elements (see Notes). If ``"baseline"``, the + diagonal elements is estimated from the baseline period of the evoked data. + bem : instance of ConductorModel | path-like | None + Boundary element model to use in forward calculations. If ``None``, a spherical + model is used. + initial_time : float | None + Initial time point to show. If ``None``, the time point of the maximum field + strength is used. + trans : instance of Transform | path-like | None + The transformation from head coordinates to MRI coordinates. If ``None``, + the identity matrix is used and everything will be done in head coordinates. + stc : instance of SourceEstimate | path-like | None + An optional distributed source estimate to show alongside the fieldmap. The time + samples need to match those of the evoked data. + subject : str | None + The subject name. If ``None``, no MRI data is shown. + %(subjects_dir)s + %(rank)s + show_density : bool + Whether to show the density of the fieldmap. + ch_type : "meg" | "eeg" | None + Type of channels to use for the dipole fitting. By default (``None``) both MEG + and EEG channels will be used. + %(n_jobs)s + show : bool + Show the GUI if True. + block : bool + Whether to halt program execution until the figure is closed. + %(verbose)s + + Returns + ------- + fitter : instance of DipoleFitUI + The dipole fitting GUI. The ``.dipoles`` attribute contains the fitted dipoles. + + Notes + ----- + When using ``cov=None`` the default noise values are 5 fT/cm, 20 fT, and 0.2 µV for + gradiometers, magnetometers, and EEG channels respectively. + """ + from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + from ._xfit import DipoleFitUI + + if MNE_3D_BACKEND_TESTING: + show = block = False + + if evoked is None: + evoked = ( + sample.data_path( + download=False, + ) + / "MEG" + / "sample" + / "sample_audvis-ave.fif" + ) + return DipoleFitUI( + evoked=evoked, + condition=condition, + baseline=baseline, + cov=cov, + bem=bem, + initial_time=initial_time, + trans=trans, + stc=stc, + subject=subject, + subjects_dir=subjects_dir, + rank=rank, + show_density=show_density, + ch_type=ch_type, + n_jobs=n_jobs, + show=show, + block=block, + verbose=verbose, + ) + + class _GUIScraper: """Scrape GUI outputs.""" diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py new file mode 100644 index 00000000000..a468994ec49 --- /dev/null +++ b/mne/gui/_xfit.py @@ -0,0 +1,897 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from copy import deepcopy +from functools import partial +from pathlib import Path + +import numpy as np +import pyvista + +from .._fiff.pick import pick_types +from ..bem import ( + ConductorModel, + _ensure_bem_surfaces, + make_sphere_model, +) +from ..cov import _ensure_cov, make_ad_hoc_cov +from ..dipole import Dipole, fit_dipole +from ..evoked import Evoked, read_evokeds +from ..forward import convert_forward_solution, make_field_map +from ..forward._make_forward import _ForwardModeler +from ..minimum_norm import apply_inverse, make_inverse_operator +from ..source_estimate import ( + SourceEstimate, + _BaseSurfaceSourceEstimate, + read_source_estimate, +) +from ..source_space import setup_volume_source_space +from ..surface import _normal_orth +from ..transforms import _get_trans, _get_transforms_to_coord_frame, apply_trans +from ..utils import _check_option, _validate_type, fill_doc, logger, verbose +from ..viz import EvokedField, create_3d_figure +from ..viz._3d import _plot_head_surface, _plot_sensors_3d +from ..viz.backends._utils import _qt_app_exec +from ..viz.ui_events import link, subscribe +from ..viz.utils import _get_color_list + + +@fill_doc +class DipoleFitUI: + """GUI for interactive dipole fitting, inspired by MEGIN's XFit program. + + Parameters + ---------- + evoked : instance of Evoked | path-like + Evoked data to show fieldmap of and fit dipoles to. + condition : int | str + When ``evoked`` is given as a filename, use this to select which evoked to use + in the file by either specifying the index or the string comment field of the + evoked. By default, the first evoked is used. + %(baseline_evoked)s + Defaults to ``(None, 0)``, i.e. beginning of the the data until time point zero. + cov : instance of Covariance | "baseline" | None + Noise covariance matrix. If ``None``, an ad-hoc covariance matrix is used with + default values for the diagonal elements (see Notes). If ``"baseline"``, the + diagonal elements is estimated from the baseline period of the evoked data. + bem : instance of ConductorModel | None + Boundary element model to use in forward calculations. If ``None``, a spherical + model is used. + initial_time : float | None + Initial time point to show. If ``None``, the time point of the maximum field + strength is used. + trans : instance of Transform | None + The transformation from head coordinates to MRI coordinates. If ``None``, + the identity matrix is used and everything will be done in head coordinates. + stc : instance of SourceEstimate | None + An optional distributed source estimate to show alongside the fieldmap. The time + samples need to match those of the evoked data. + subject : str | None + The subject name. If ``None``, no MRI data is shown. + %(subjects_dir)s + %(rank)s + show_density : bool + Whether to show the density of the fieldmap. + ch_type : "meg" | "eeg" | None + Type of channels to use for the dipole fitting. By default (``None``) both MEG + and EEG channels will be used. + %(n_jobs)s + show : bool + Show the GUI if True. + block : bool + Whether to halt program execution until the figure is closed. + %(verbose)s + + Notes + ----- + When using ``cov=None`` the default noise values are 5 fT/cm, 20 fT, and 0.2 µV for + gradiometers, magnetometers, and EEG channels respectively. + """ + + def __init__( + self, + evoked=None, + *, + condition=0, + baseline=(None, 0), + cov=None, + bem=None, + initial_time=None, + trans=None, + stc=None, + subject=None, + subjects_dir=None, + rank="info", + show_density=True, + ch_type=None, + n_jobs=None, + show=True, + block=False, + verbose=None, + ): + _validate_type(evoked, ("path-like", Evoked), "evoked") + if not isinstance(evoked, Evoked): + evoked = read_evokeds(evoked, condition=condition) + + evoked.apply_baseline(baseline) + + if cov is None: + logger.info("Using ad-hoc noise covariance.") + cov = make_ad_hoc_cov(evoked.info) + elif cov == "baseline": + logger.info( + f"Estimating noise covariance from baseline ({evoked.baseline[0]:.3f} " + f"to {evoked.baseline[1]:.3f} seconds)." + ) + std = dict() + for typ in set(evoked.get_channel_types(only_data_chs=True)): + baseline = evoked.copy().pick(typ).crop(*evoked.baseline) + std[typ] = baseline.data.std(axis=1).mean() + cov = make_ad_hoc_cov(evoked.info, std) + else: + cov = _ensure_cov(cov) + + if bem is None: + bem = make_sphere_model("auto", "auto", evoked.info) + bem = _ensure_bem_surfaces(bem, extra_allow=(ConductorModel, None)) + + if ch_type is not None: + evoked = evoked.copy().pick(ch_type) + + field_map = make_field_map( + evoked, + trans=trans, + origin=bem["r0"] if bem["is_sphere"] else "auto", + subject=subject, + subjects_dir=subjects_dir, + n_jobs=n_jobs, + verbose=verbose, + ) + + if initial_time is None: + # Set initial time to moment of maximum field power. + data = evoked.copy().pick(field_map[0]["ch_names"]).data + initial_time = evoked.times[np.argmax(np.mean(data**2, axis=0))] + + if stc is not None: + _validate_type(stc, ("path-like", _BaseSurfaceSourceEstimate), "stc") + if not isinstance(stc, _BaseSurfaceSourceEstimate): + stc = read_source_estimate(stc) + + if len(stc.times) != len(evoked.times) or not np.allclose( + stc.times, evoked.times + ): + raise ValueError( + "The time samples of the source estimate do not match those of the " + "evoked data." + ) + if trans is None: + raise ValueError( + "`trans` cannot be `None` when showing the fieldlines in " + "combination with a source estimate." + ) + + # Get transforms to convert all the various meshes to MRI space. + head_mri_t = _get_trans(trans, "head", "mri")[0] + to_cf_t = _get_transforms_to_coord_frame( + evoked.info, head_mri_t, coord_frame="mri" + ) + + self.fwd = _ForwardModeler( + info=evoked.info, + trans=trans, + bem=bem, + n_jobs=n_jobs, + verbose=verbose, + ) + + # Initialize all the private attributes. + self._actors = dict() + self._bem = bem + self._ch_type = ch_type + self._cov = cov + self._current_time = initial_time + self._dipoles = dict() + self._evoked = evoked + self._field_map = field_map + self._fig_sensors = None + self._multi_dipole_method = "Multi dipole (MNE)" + self._show_density = show_density + self._stc = stc + self._subjects_dir = subjects_dir + self._subject = subject + self._time_line = None + self._head_mri_t = head_mri_t + self._to_cf_t = to_cf_t + self._rank = rank + self._verbose = verbose + self._n_jobs = n_jobs + + # Configure the GUI. + self._renderer = self._configure_main_display(show) + self._configure_dock() + + # must be done last + if show: + self._renderer.show() + if block and self._renderer._kind != "notebook": + _qt_app_exec(self._renderer.figure.store["app"]) + + @property + def dipoles(self): + """A list of all the fitted dipoles that are enabled in the GUI.""" + return [d["dip"] for d in self._dipoles.values() if d["active"]] + + def _configure_main_display(self, show=True): + """Configure main 3D display of the GUI.""" + fig = create_3d_figure((1500, 1020), bgcolor="white", show=show) + + self._fig_stc = None + if self._stc is not None: + kwargs = dict( + subject=self._subject, + subjects_dir=self._subjects_dir, + hemi="both", + time_viewer=False, + initial_time=self._current_time, + brain_kwargs=dict(units="m"), + figure=fig, + ) + if isinstance(self._stc, SourceEstimate): + kwargs["surface"] = "white" + fig = self._stc.plot(**kwargs) # overwrite "fig" to be the STC plot + self._fig_stc = fig + self._actors["brain"] = fig._actors["data"] + + fig = EvokedField( + self._evoked, + self._field_map, + time=self._current_time, + interpolation="linear", + alpha=0, + show_density=self._show_density, + foreground="black", + background="white", + fig=fig, + ) + fig.separate_canvas = False # needed to plot the timeline later + fig.set_contour_line_width(2) + if self._stc is not None: + link(self._fig_stc, fig) + + for surf_map in fig._surf_maps: + if surf_map["map_kind"] == "meg": + helmet_mesh = surf_map["mesh"] + helmet_mesh._polydata.compute_normals() # needed later + helmet_mesh._actor.prop.culling = "back" + self._actors["helmet"] = helmet_mesh._actor + # For MEG fieldlines, we want to occlude the ones not facing us, + # otherwise it's hard to interpret them. Since the "contours" object + # does not support backface culling, we create an opaque mesh to put in + # front of the contour lines with frontface culling. + occl_surf = deepcopy(surf_map["surf"]) + occl_surf["rr"] -= 1e-3 * occl_surf["nn"] + occl_act, _ = fig._renderer.surface(occl_surf, color="white") + occl_act.prop.culling = "front" + occl_act.prop.lighting = False + self._actors["occlusion_surf"] = occl_act + elif surf_map["map_kind"] == "eeg": + head_mesh = surf_map["mesh"] + head_mesh._polydata.compute_normals() # needed later + head_mesh._actor.prop.culling = "back" + self._actors["head"] = head_mesh._actor + + show_meg = (self._ch_type is None or self._ch_type == "meg") and any( + [m["kind"] == "meg" for m in self._field_map] + ) + show_eeg = (self._ch_type is None or self._ch_type == "eeg") and any( + [m["kind"] == "eeg" for m in self._field_map] + ) + meg_picks = pick_types(self._evoked.info, meg=show_meg, ref_meg=False) + eeg_picks = pick_types(self._evoked.info, meg=False, eeg=show_eeg) + picks = np.concatenate((meg_picks, eeg_picks)) + self._ch_names = [self._evoked.ch_names[i] for i in picks] + + for m in self._field_map: + if m["kind"] == "eeg": + head_surf = m["surf"] + break + else: + self._actors["head"], _, head_surf = _plot_head_surface( + renderer=fig._renderer, + head="head", + subject=self._subject, + subjects_dir=self._subjects_dir, + bem=self._bem, + coord_frame="mri", + to_cf_t=self._to_cf_t, + alpha=0.2, + ) + self._actors["head"].prop.culling = "back" + + sensors = _plot_sensors_3d( + renderer=fig._renderer, + info=self._evoked.info, + to_cf_t=self._to_cf_t, + picks=picks, + meg=["sensors"] if show_meg else False, + eeg=["original"] if show_eeg else False, + fnirs=False, + warn_meg=False, + head_surf=head_surf, + units="m", + sensor_alpha=dict(meg=0.1, eeg=1.0), + orient_glyphs=False, + scale_by_distance=False, + project_points=False, + surf=None, + check_inside=None, + nearest=None, + sensor_colors=dict( + meg=["gray" for _ in meg_picks], + eeg=["white" for _ in eeg_picks], + ), + ) + self._actors["sensors"] = list() + for s in sensors.values(): + self._actors["sensors"].extend(s) + + # Adjust camera + fig._renderer.set_camera( + azimuth=180, elevation=90, roll=90, distance=0.55, focalpoint=[0, 0, 0.03] + ) + + subscribe(fig, "time_change", self._on_time_change) + self._fig = fig + return fig._renderer + + def _configure_dock(self): + """Configure the left and right dock areas of the GUI.""" + r = self._renderer + + # Toggle buttons for various meshes + layout = r._dock_add_group_box("Meshes") + for actor_name in self._actors.keys(): + if actor_name == "occlusion_surf": + continue + r._dock_add_check_box( + name=actor_name, + value=True, + callback=partial(self.toggle_mesh, name=actor_name), + layout=layout, + ) + + # Right dock + r._dock_initialize(name="Dipole fitting", area="right") + r._dock_add_button("Sensor data", self._on_sensor_data) + r._dock_add_button("Fit dipole", self._on_fit_dipole) + methods = ["Multi dipole (MNE)", "Single dipole"] + r._dock_add_combo_box( + "Dipole model", + value="Multi dipole (MNE)", + rng=methods, + callback=self._on_select_method, + ) + self._dipole_box = r._dock_add_group_box(name="Dipoles") + self._save_button = r._dock_add_file_button( + name="save_dipoles", + desc="Save dipoles", + save=True, + func=self.save, + tooltip="Save the dipoles to disk", + filter_="Dipole files (*.dip *.bdip)", + initial_directory=".", + ) + self._save_button.set_enabled(False) + r._dock_add_stretch() + + def toggle_mesh(self, name, show=None): + """Toggle a mesh on or off. + + Parameters + ---------- + name : str + Name of the mesh to toggle. + show : bool | None + Whether to show the mesh. If None, the visibility of the mesh is toggled. + """ + _check_option("name", name, self._actors.keys()) + actors = self._actors[name] + # self._actors[name] is sometimes a list and sometimes not. Make it + # always be a list to simplify the code. + if not isinstance(actors, list): + actors = [actors] + if show is None: + show = not actors[0].GetVisibility() + for act in actors: + act.SetVisibility(show) + self._renderer._update() + + def _on_time_change(self, event): + new_time = np.clip(event.time, self._evoked.times[0], self._evoked.times[-1]) + self._current_time = new_time + if self._time_line is not None: + self._time_line.set_xdata([new_time]) + self._renderer._mplcanvas.update_plot() + self._update_arrows() + + def _on_sensor_data(self): + """Show sensor data and allow sensor selection.""" + if self._fig_sensors is not None: + return + fig = self._evoked.plot_topo(select=True) + fig.canvas.mpl_connect("close_event", self._on_sensor_data_close) + subscribe(fig, "channels_select", self._on_channels_select) + self._fig_sensors = fig + + def _on_sensor_data_close(self, event): + """Handle closing of the sensor selection window.""" + self._fig_sensors = None + if "sensors" in self._actors: + for act in self._actors["sensors"]: + act.prop.SetColor(1, 1, 1) + self._renderer._update() + + def _on_channels_select(self, event): + """Color selected sensor meshes.""" + selected_channels = set(event.ch_names) + if "sensors" in self._actors: + for act, ch_name in zip(self._actors["sensors"], self._ch_names): + if ch_name in selected_channels: + act.prop.SetColor(0, 1, 0) + else: + act.prop.SetColor(1, 1, 1) + self._renderer._update() + + def _on_fit_dipole(self): + """Fit a single dipole.""" + evoked_picked = self._evoked.copy() + cov_picked = self._cov.copy() + if self._fig_sensors is not None: + picks = self._fig_sensors.lasso.selection + if len(picks) > 0: + evoked_picked = evoked_picked.pick(picks) + evoked_picked.info.normalize_proj() + cov_picked = cov_picked.pick_channels(picks, ordered=False) + cov_picked["projs"] = evoked_picked.info["projs"] + evoked_picked.crop(self._current_time, self._current_time) + + dip = fit_dipole( + evoked_picked, + cov_picked, + self._bem, + trans=self._head_mri_t, + rank=self._rank, + n_jobs=self._n_jobs, + verbose=False, + )[0] + + self.add_dipole(dip) + + def add_dipole(self, dipole, name=None): + """Add a dipole (or multiple dipoles) to the GUI. + + Parameters + ---------- + dipole : Dipole + The dipole to add. If the ``Dipole`` object defines multiple dipoles, they + will all be added. + name : str | list of str | None + The name of the dipole. When the ``Dipole`` object defines multiple dipoles, + this should be a list containing the name for each dipole. When ``None``, + the ``.name`` attribute of the ``Dipole`` object itself will be used. + """ + _validate_type(name, (str, list, None), "name") + if isinstance(name, str): + names = [name] + elif name is None: + # Try to obtain names from `dipole.name`. When multiple dipoles are saved, + # the names are concatenated with `;` marks. + if dipole.name is None: + names = [None] * len(dipole) + elif len(dipole.name.split(";")) == len(dipole): + names = dipole.name.split(";") + else: + names = [dipole.name] * len(dipole) + else: + names = name + if len(names) != len(dipole): + raise ValueError( + f"Number of names ({len(names)}) does not match the number of dipoles " + f"({len(dipole)})." + ) + + # Ensure orientations are unit vectors. Due to rounding issues this is sometimes + # not the case. + dipole._ori /= np.linalg.norm(dipole._ori, axis=1, keepdims=True) + + new_dipoles = list() + for dip, name in zip(dipole, names): + # Coordinates needed to draw the big arrow on the helmet. + helmet_coords, helmet_pos = self._get_helmet_coords(dip) + + # Collect all relevant information on the dipole in a dict. + colors = _get_color_list() + if len(self._dipoles) == 0: + dip_num = 0 + else: + dip_num = max(self._dipoles.keys()) + 1 + if name is None: + dip.name = f"dip{dip_num}" + else: + dip.name = name + dip_color = colors[dip_num % len(colors)] + if helmet_coords is not None: + arrow_mesh = pyvista.PolyData(*_arrow_mesh()) + else: + arrow_mesh = None + dipole_dict = dict( + active=True, + brain_arrow_actor=None, + helmet_arrow_actor=None, + arrow_mesh=arrow_mesh, + color=dip_color, + dip=dip, + fix_ori=True, + fix_position=True, + helmet_coords=helmet_coords, + helmet_pos=helmet_pos, + num=dip_num, + # fit_time=self._current_time, + ) + self._dipoles[dip_num] = dipole_dict + + # Add a row to the dipole list + r = self._renderer + hlayout = r._dock_add_layout(vertical=False) + widgets = [] + widgets.append( + r._dock_add_check_box( + name="", + value=True, + callback=partial(self._on_dipole_toggle, dip_num=dip_num), + layout=hlayout, + ) + ) + widgets.append( + r._dock_add_text( + name=dip.name, + value=dip.name, + placeholder="name", + callback=partial(self._on_dipole_set_name, dip_num=dip_num), + layout=hlayout, + ) + ) + widgets.append( + r._dock_add_check_box( + name="Fix ori", + value=True, + callback=partial( + self._on_dipole_toggle_fix_orientation, dip_num=dip_num + ), + layout=hlayout, + ) + ) + widgets.append( + r._dock_add_button( + name="", + icon="clear", + callback=partial(self._on_dipole_delete, dip_num=dip_num), + layout=hlayout, + ) + ) + dipole_dict["widgets"] = widgets + r._layout_add_widget(self._dipole_box, hlayout) + new_dipoles.append(dipole_dict) + + # Show the dipoles and arrows in the 3D view. Only do this after + # `_fit_timecourses` so that they have the correct size straight away. + self._fit_timecourses() + for dipole_dict in new_dipoles: + dip = dipole_dict["dip"] + dipole_dict["brain_arrow_actor"] = self._renderer.plotter.add_arrows( + apply_trans(self._head_mri_t, dip.pos[0]), + apply_trans(self._head_mri_t, dip.ori[0]), + color=dipole_dict["color"], + mag=0.05, + ) + if dipole_dict["arrow_mesh"] is not None: + dipole_dict["helmet_arrow_actor"] = self._renderer.plotter.add_mesh( + dipole_dict["arrow_mesh"], + color=dipole_dict["color"], + culling="front", + ) + self._update_arrows() + + def _get_helmet_coords(self, dip): + """Compute the coordinate system used for drawing the big arrows on the helmet. + + In this coordinate system, Z is normal to the helmet surface, and XY + are tangential to the helmet surface. + """ + if "helmet" not in self._actors: + return None, None + + # Get the closest vertex (=point) of the helmet mesh + dip_pos = apply_trans(self._head_mri_t, dip.pos[0]) + helmet = self._actors["helmet"].GetMapper().GetInput() + distances = ((helmet.points - dip_pos) * helmet.point_normals).sum(axis=1) + closest_point = np.argmin(distances) + + # Compute the position of the projected dipole on the helmet + norm = helmet.point_normals[closest_point] + helmet_pos = dip_pos + (distances[closest_point] + 0.003) * norm + + # Create a coordinate system where X and Y are tangential to the helmet + helmet_coords = _normal_orth(norm) + + return helmet_coords, helmet_pos + + def _fit_timecourses(self): + """Compute (or re-compute) dipole timecourses. + + Called whenever something changes to the multi-dipole situation, i.e. a dipole + is added, removed, (de-)activated or the "Fix pos" box is toggled. + """ + self._save_button.set_enabled(len(self.dipoles) > 0) + active_dips = [d for d in self._dipoles.values() if d["active"]] + if len(active_dips) == 0: + return + + if self._multi_dipole_method == "Multi dipole (MNE)": + for d in active_dips: + print(d["dip"], d["dip"].pos, d["dip"].ori) + this_src = setup_volume_source_space( + "sample", + pos=dict( + rr=apply_trans( + self._head_mri_t, + np.vstack([d["dip"].pos[0] for d in active_dips]), + ), + nn=apply_trans( + self._head_mri_t, + np.vstack([d["dip"].ori[0] for d in active_dips]), + ), + ), + ) + this_fwd = self.fwd.compute(this_src) + # this_fwd, _ = make_forward_dipole( + # [d["dip"] for d in active_dips], + # self._bem, + # self._evoked.info, + # trans=self._head_mri_t, + # n_jobs=self._n_jobs, + # ) + this_fwd = convert_forward_solution(this_fwd, surf_ori=False) + + inv = make_inverse_operator( + self._evoked.info, + # fwd, + this_fwd, + self._cov, + fixed=False, + loose=1.0, + depth=0, + rank=self._rank, + ) + stc = apply_inverse( + self._evoked, + inv, + method="MNE", + lambda2=1e-6, + pick_ori="vector", + ) + + timecourses = stc.magnitude().data + orientations = (stc.data / timecourses[:, np.newaxis, :]).transpose(0, 2, 1) + fixed_timecourses = stc.project( + np.array([dip["dip"].ori[0] for dip in active_dips]) + )[0].data + + for i, dip in enumerate(active_dips): + if dip["fix_ori"]: + dip["timecourse"] = fixed_timecourses[i] + dip["orientation"] = dip["dip"].ori.repeat(len(stc.times), axis=0) + else: + dip["timecourse"] = timecourses[i] + dip["orientation"] = orientations[i] + elif self._multi_dipole_method == "Single dipole": + for dip in active_dips: + dip_with_timecourse, _ = fit_dipole( + self._evoked, + self._cov, + self._bem, + pos=dip["dip"].pos[0], # position is always fixed + ori=dip["dip"].ori[0] if dip["fix_ori"] else None, + trans=self._head_mri_t, + rank=self._rank, + n_jobs=self._n_jobs, + verbose=True, + ) + if dip["fix_ori"]: + dip["timecourse"] = dip_with_timecourse.data[0] + dip["orientation"] = dip["dip"].ori.repeat( + len(dip_with_timecourse.times), axis=0 + ) + else: + dip["timecourse"] = dip_with_timecourse.amplitude + dip["orientation"] = dip_with_timecourse.ori + + # Update matplotlib canvas at the bottom of the window + canvas = self._setup_mplcanvas() + ymin, ymax = 0, 0 + for dip in active_dips: + if "line_artist" in dip: + dip["line_artist"].set_ydata(dip["timecourse"]) + else: + dip["line_artist"] = canvas.plot( + self._evoked.times, + dip["timecourse"], + label=dip["dip"].name, + color=dip["color"], + ) + ymin = min(ymin, 1.1 * dip["timecourse"].min()) + ymax = max(ymax, 1.1 * dip["timecourse"].max()) + canvas.axes.set_ylim(ymin, ymax) + canvas.update_plot() + self._update_arrows() + + @verbose + @fill_doc + def save(self, fname, verbose=None): + """Save the fitted dipoles to a file. + + Parameters + ---------- + fname : path-like + The name of the file. Should end in ``'.dip'`` to save in plain text format, + or in ``'.bdip'`` to save in binary format. + %(verbose)s + """ + if len(self.dipoles) == 0: + logger.info("No dipoles to save.") + return + + logger.info(f"Saving dipoles as: {fname}") + fname = Path(fname) + + # Pack the dipoles into a single mne.Dipole object. + if all(d.khi2 is not None for d in self.dipoles): + khi2 = np.array([d.khi2[0] for d in self.dipoles]) + else: + khi2 = None + + if all(d.nfree is not None for d in self.dipoles): + nfree = np.array([d.nfree[0] for d in self.dipoles]) + else: + nfree = None + + dip = Dipole( + times=np.array([d.times[0] for d in self.dipoles]), + pos=np.array([d.pos[0] for d in self.dipoles]), + amplitude=np.array([d.amplitude[0] for d in self.dipoles]), + ori=np.array([d.ori[0] for d in self.dipoles]), + gof=np.array([d.gof[0] for d in self.dipoles]), + khi2=khi2, + nfree=nfree, + conf={ + key: np.array([d.conf[key][0] for d in self.dipoles]) + for key in self.dipoles[0].conf.keys() + }, + name=";".join(d.name if hasattr(d, "name") else "" for d in self.dipoles), + ) + dip.save(fname, overwrite=True, verbose=verbose) + + def _update_arrows(self): + """Update the arrows to have the correct size and orientation.""" + active_dips = [d for d in self._dipoles.values() if d["active"]] + if len(active_dips) == 0: + return + orientations = [dip["orientation"] for dip in active_dips] + timecourses = [dip["timecourse"] for dip in active_dips] + arrow_scaling = 0.05 / np.max(np.abs(timecourses)) + for dip, ori, timecourse in zip(active_dips, orientations, timecourses): + helmet_coords = dip["helmet_coords"] + if helmet_coords is None: + continue + + dip_ori = apply_trans( + self._head_mri_t, + [np.interp(self._current_time, self._evoked.times, o) for o in ori.T], + ) + dip_moment = np.interp(self._current_time, self._evoked.times, timecourse) + arrow_size = dip_moment * arrow_scaling + arrow_mesh = dip["arrow_mesh"] + + # Project the orientation of the dipole tangential to the helmet + dip_ori_tan = helmet_coords[:2] @ dip_ori @ helmet_coords[:2] + + # Rotate the coordinate system such that Y lies along the dipole + # orientation, now we have our desired coordinate system for the + # arrows. + arrow_coords = np.array( + [np.cross(dip_ori_tan, helmet_coords[2]), dip_ori_tan, helmet_coords[2]] + ) + arrow_coords /= np.linalg.norm(arrow_coords, axis=1, keepdims=True) + + # Update the arrow mesh to point in the right directions + arrow_mesh.points = (_arrow_mesh()[0] * arrow_size) @ arrow_coords + arrow_mesh.points += dip["helmet_pos"] + self._renderer._update() + + def _on_select_method(self, method): + """Select the method to use for multi-dipole timecourse fitting.""" + self._multi_dipole_method = method + self._fit_timecourses() + + def _on_dipole_toggle(self, active, dip_num): + """Toggle a dipole on or off.""" + dipole = self._dipoles[dip_num] + active = bool(active) + dipole["active"] = active + dipole["line_artist"].set_visible(active) + # Labels starting with "_" are hidden from the legend. + dipole["line_artist"].set_label(("" if active else "_") + dipole["dip"].name) + dipole["brain_arrow_actor"].visibility = active + dipole["helmet_arrow_actor"].visibility = active + self._fit_timecourses() + self._renderer._update() + self._renderer._mplcanvas.update_plot() + + def _on_dipole_set_name(self, name, dip_num): + """Set the name of a dipole.""" + self._dipoles[dip_num]["dip"].name = name + self._dipoles[dip_num]["line_artist"].set_label(name) + self._renderer._mplcanvas.update_plot() + + def _on_dipole_toggle_fix_orientation(self, fix, dip_num): + """Fix dipole orientation when fitting timecourse.""" + self._dipoles[dip_num]["fix_ori"] = bool(fix) + self._fit_timecourses() + + def _on_dipole_delete(self, dip_num): + """Delete previously fitted dipole.""" + dipole = self._dipoles[dip_num] + dipole["line_artist"].remove() + dipole["brain_arrow_actor"].visibility = False + if dipole["helmet_arrow_actor"] is not None: # no helmet arrow for EEG + dipole["helmet_arrow_actor"].visibility = False + for widget in dipole["widgets"]: + widget.hide() + del self._dipoles[dip_num] + self._fit_timecourses() + self._renderer._update() + self._renderer._mplcanvas.update_plot() + + def _setup_mplcanvas(self): + """Configure the matplotlib canvas at the bottom of the window.""" + if self._renderer._mplcanvas is None: + self._renderer._mplcanvas = self._renderer._window_get_mplcanvas( + self._fig, 0.3, False, False + ) + self._renderer._window_adjust_mplcanvas_layout() + if self._time_line is None: + self._time_line = self._renderer._mplcanvas.plot_time_line( + self._current_time, + label="time", + color="black", + ) + return self._renderer._mplcanvas + + +def _arrow_mesh(): + """Obtain a mesh of an arrow.""" + vertices = np.array( + [ + [0.0, 1.0, 0.0], + [0.3, 0.7, 0.0], + [0.1, 0.7, 0.0], + [0.1, -1.0, 0.0], + [-0.1, -1.0, 0.0], + [-0.1, 0.7, 0.0], + [-0.3, 0.7, 0.0], + ] + ) + faces = np.array([[7, 0, 1, 2, 3, 4, 5, 6]]) + return vertices, faces diff --git a/mne/gui/tests/test_xfit.py b/mne/gui/tests/test_xfit.py new file mode 100644 index 00000000000..3a6f8e62280 --- /dev/null +++ b/mne/gui/tests/test_xfit.py @@ -0,0 +1,189 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + + +import numpy as np +import pytest +from numpy.testing import assert_allclose, assert_equal + +import mne +from mne.channels import read_vectorview_selection +from mne.datasets import testing +from mne.viz import ui_events +from mne.viz.utils import _get_color_list + +data_path = testing.data_path(download=False) +subjects_dir = data_path / "subjects" +fname_dip = data_path / "MEG" / "sample" / "sample_audvis_trunc_set1.dip" +fname_evokeds = data_path / "MEG" / "sample" / "sample_audvis_trunc-ave.fif" + + +def _gui_with_two_dipoles(): + """Create a dipolefit GUI and add two dipoles to it.""" + from mne.gui import dipolefit + + g = dipolefit(fname_evokeds) + dip = mne.read_dipole(fname_dip)[[0, 1]] + g.add_dipole(dip, name=["rh1", "rh2"]) + return g + + +@pytest.mark.slowtest +@testing.requires_testing_data +def test_dipolefit_gui_basic(renderer_interactive_pyvistaqt): + """Test basic functionality of the dipole fitting GUI.""" + from mne.gui import dipolefit + + # Test basic interface elements. + g = dipolefit(fname_evokeds) + assert g._evoked.comment == "Left Auditory" # MNE-Sample data should be loaded + assert g._current_time == g._evoked.times[84] # time of max GFP + + # Test fitting a single dipole. + assert len(g._dipoles) == len(g.dipoles) == 0 + g._on_fit_dipole() + assert len(g._dipoles) == len(g.dipoles) == 1 + dip = g.dipoles[0] + assert dip.name == "Left Auditory" + assert len(dip.times) == 1 + assert_equal(dip.times, g._current_time) + assert_allclose(dip.amplitude, 6.152221e-08, rtol=1e-4) + assert_allclose(dip.pos, [[0.04568744, 0.00753845, 0.06737837]], atol=1e-5) + assert_allclose(dip.ori, [[0.45720003, -0.72124413, -0.52036049]], atol=1e-5) + old_dip1_timecourse = g._dipoles[0]["timecourse"] + + # Test fitting a second dipole with a subset of channels at a different time. + g._on_sensor_data() # open sensor selection window + picks = read_vectorview_selection("Left", info=g._evoked.info) + ui_events.publish(g._fig_sensors, ui_events.ChannelsSelect(picks)) + assert sorted(g._fig_sensors.lasso.selection) == sorted(picks) + ui_events.publish(g._fig, ui_events.TimeChange(0.1)) # change time + assert g._current_time == 0.1 + g._on_fit_dipole() + assert len(g._dipoles) == len(g.dipoles) == 2 + dip2 = g.dipoles[1] + assert_equal(dip2.times, g._evoked.times[np.searchsorted(g._evoked.times, 0.1) - 1]) + assert_allclose(dip2.amplitude, 4.422736e-08, rtol=1e-4) + assert_allclose(dip2.pos, [[-0.05893074, -0.00202937, 0.05113064]], atol=1e-5) + assert_allclose(dip2.ori, [[0.3017588, -0.88550684, -0.35329769]], atol=1e-5) + # Adding the second dipole should have affected the timecourse of the first. + new_dip1_timecourse = g._dipoles[0]["timecourse"] + assert not np.allclose(old_dip1_timecourse, new_dip1_timecourse) + + # Test differences between the two dipoles + assert list(g._dipoles.keys()) == [0, 1] + dip1_dict, dip2_dict = g._dipoles.values() + assert dip1_dict["dip"] is dip + assert dip2_dict["dip"] is dip2 + assert dip1_dict["num"] == 0 + assert dip2_dict["num"] == 1 + assert_allclose( + dip1_dict["helmet_pos"], [0.10320071, 0.00946581, 0.07516293], atol=1e-5 + ) + assert_allclose( + dip2_dict["helmet_pos"], [-0.11462019, -0.00727073, 0.04561434], atol=1e-5 + ) + assert dip1_dict["color"] == _get_color_list()[0] + assert dip2_dict["color"] == _get_color_list()[1] + + # Test changing dipole model + assert g._multi_dipole_method == "Multi dipole (MNE)" + old_timecourses = np.vstack((dip1_dict["timecourse"], dip2_dict["timecourse"])) + g._on_select_method("Single dipole") + new_timecourses = np.vstack((dip1_dict["timecourse"], dip2_dict["timecourse"])) + assert not np.allclose(old_timecourses, new_timecourses) + g._fig._renderer.close() + + +@pytest.mark.slowtest +@testing.requires_testing_data +def test_dipolefit_gui_toggle_meshes(renderer_interactive_pyvistaqt): + """Test toggling the visibility of the meshes the dipole fitting GUI.""" + from mne.gui import dipolefit + + g = dipolefit(fname_evokeds) + assert list(g._actors.keys()) == ["helmet", "occlusion_surf", "head", "sensors"] + g.toggle_mesh("helmet", show=True) + assert g._actors["helmet"].visibility + g.toggle_mesh("helmet") + assert not g._actors["helmet"].visibility + with pytest.raises(ValueError, match="Invalid value for the 'name' parameter"): + g.toggle_mesh("non existent") + g._fig._renderer.close() + + +@pytest.mark.slowtest +@testing.requires_testing_data +def test_dipolefit_gui_dipole_controls(renderer_interactive_pyvistaqt): + """Test the controls for the dipoles in the dipole fitting GUI.""" + g = _gui_with_two_dipoles() + + dip1, dip2 = g._dipoles.values() + assert dip1["active"] and dip2["active"] + old_timecourses = np.vstack((dip1["timecourse"], dip2["timecourse"])) + + # Toggle a dipole off and on. + g._on_dipole_toggle(False, dip2["num"]) + assert not dip2["active"] + new_timecourses = np.vstack((dip1["timecourse"], dip2["timecourse"])) + assert not np.allclose(old_timecourses, new_timecourses, atol=1e-9) + g._on_dipole_toggle(True, dip2["num"]) + assert dip2["active"] + new_timecourses = np.vstack((dip1["timecourse"], dip2["timecourse"])) + assert np.allclose(old_timecourses, new_timecourses, atol=0) + + # Toggle fixed orientation off and on. + assert dip1["fix_ori"] and dip2["fix_ori"] + g._on_dipole_toggle_fix_orientation(False, dip1["num"]) + assert not dip1["fix_ori"] + new_timecourses = np.vstack((dip1["timecourse"], dip2["timecourse"])) + assert not np.allclose(old_timecourses, new_timecourses, atol=1e-9) + g._on_dipole_toggle_fix_orientation(True, dip1["num"]) + assert dip1["fix_ori"] + new_timecourses = np.vstack((dip1["timecourse"], dip2["timecourse"])) + assert np.allclose(old_timecourses, new_timecourses, atol=0) + + # Change the names. + g._on_dipole_set_name("dipole1", dip1["num"]) + g._on_dipole_set_name("dipole2", dip2["num"]) + assert dip1["dip"].name == "dipole1" + assert dip2["dip"].name == "dipole2" + assert dip1["line_artist"].get_label() == "dipole1" # legend labels + assert dip2["line_artist"].get_label() == "dipole2" + + # Remove a dipole. + g._on_dipole_delete(dip1["num"]) + assert len(g.dipoles) == 1 + assert 1 in g._dipoles # dipole number should not change + assert list(g._dipoles.keys())[0] == 1 + assert list(g._dipoles.values())[0]["num"] == 1 + g._on_fit_dipole() + assert 2 in g._dipoles + assert list(g._dipoles.keys())[1] == 2 + assert list(g._dipoles.values())[1]["num"] == 2 # new dipole number + g._fig._renderer.close() + + +@pytest.mark.slowtest +@testing.requires_testing_data +def test_dipolefit_gui_save_load(tmpdir, renderer_interactive_pyvistaqt): + """Test saving and loading dipoles in the dipole fitting GUI.""" + g = _gui_with_two_dipoles() + g.save(tmpdir / "test.dip") + g.save(tmpdir / "test.bdip") + dip_from_file = mne.read_dipole(tmpdir / "test.dip") + g.add_dipole(dip_from_file) + g.add_dipole(mne.read_dipole(tmpdir / "test.bdip")) + assert len(g.dipoles) == 6 + assert [d.name for d in g.dipoles] == ["rh1", "rh2", "rh1", "rh2", "dip4", "dip5"] + assert_allclose( + np.vstack([d.pos for d in g.dipoles[:2]]), dip_from_file.pos, atol=0 + ) + assert_allclose( + np.vstack([d.pos for d in g.dipoles[2:4]]), dip_from_file.pos, atol=0 + ) + assert_allclose( + np.vstack([d.pos for d in g.dipoles[4:]]), dip_from_file.pos, atol=0 + ) + g._fig._renderer.close() diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index f844d9b54e5..d9c5e7734f1 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -83,13 +83,7 @@ ) from ._dipole import _check_concat_dipoles, _plot_dipole_3d, _plot_dipole_mri_outlines from .evoked_field import EvokedField -from .utils import ( - _check_time_unit, - _get_cmap, - _get_color_list, - figure_nobar, - plt_show, -) +from .utils import _check_time_unit, _get_cmap, _get_color_list, figure_nobar, plt_show verbose_dec = verbose FIDUCIAL_ORDER = (FIFF.FIFFV_POINT_LPA, FIFF.FIFFV_POINT_NASION, FIFF.FIFFV_POINT_RPA) diff --git a/mne/viz/_3d_overlay.py b/mne/viz/_3d_overlay.py index 3ebc308c127..fd9bbf8e1bb 100644 --- a/mne/viz/_3d_overlay.py +++ b/mne/viz/_3d_overlay.py @@ -100,7 +100,7 @@ def _compute_over(self, B, A): C[:, :3] *= A_w C[:, :3] += B[:, :3] * B_w C[:, 3:] += B_w - C[:, :3] /= C[:, 3:] + C[:, :3] /= np.maximum(1e-20, C[:, 3:]) # avoid divide by zero return np.clip(C, 0, 1, out=C) def _compose_overlays(self): diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index 839259ee117..2b53ba09ebc 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -50,7 +50,6 @@ class EvokedField: the average peak latency (across sensor types) is used. time_label : str | None How to print info about the time instant visualized. - %(n_jobs)s fig : instance of Figure3D | None If None (default), a new figure will be created, otherwise it will plot into the given figure. @@ -69,6 +68,10 @@ class EvokedField: The number of contours. .. versionadded:: 0.21 + contour_line_width : float + The line_width of the contour lines. + + .. versionadded:: 1.6 show_density : bool Whether to draw the field density as an overlay on top of the helmet/head surface. Defaults to ``True``. @@ -91,6 +94,17 @@ class EvokedField: ``True`` if there is more than one time point and ``False`` otherwise. .. versionadded:: 1.6 + background : tuple(int, int, int) + The color definition of the background: (red, green, blue). + + .. versionadded:: 1.6 + foreground : matplotlib color + Color of the foreground (will be used for colorbars and text). + None (default) will use black or white depending on the value + of ``background``. + + .. versionadded:: 1.6 + %(n_jobs)s %(verbose)s Notes @@ -109,15 +123,18 @@ def __init__( *, time=None, time_label="t = %0.0f ms", - n_jobs=None, fig=None, vmax=None, n_contours=21, + contour_line_width=1.0, show_density=True, alpha=None, interpolation="nearest", interaction="terrain", time_viewer="auto", + background="black", + foreground=None, + n_jobs=None, verbose=None, ): from .backends.renderer import _get_3d_backend, _get_renderer @@ -134,6 +151,7 @@ def __init__( self._vmax = _validate_type(vmax, (None, "numeric", dict), "vmax") self._n_contours = _ensure_int(n_contours, "n_contours") + self._contour_line_width = contour_line_width self._time_interpolation = _check_option( "interpolation", interpolation, @@ -142,6 +160,10 @@ def __init__( self._interaction = _check_option( "interaction", interaction, ["trackball", "terrain"] ) + self._bg_color = _to_rgb(background, name="background") + if foreground is None: + foreground = "w" if sum(self._bg_color) < 2 else "k" + self._fg_color = _to_rgb(foreground, name="foreground") surf_map_kinds = [surf_map["kind"] for surf_map in surf_maps] if vmax is None: @@ -193,9 +215,7 @@ def __init__( "is currently not supported inside a notebook." ) else: - self._renderer = _get_renderer( - fig, bgcolor=(0.0, 0.0, 0.0), size=(600, 600) - ) + self._renderer = _get_renderer(fig, bgcolor=background, size=(600, 600)) self._in_brain_figure = False self._units = "m" @@ -230,14 +250,17 @@ def current_time_func(): current_time_func=current_time_func, times=evoked.times, ) - if not self._in_brain_figure or "time_slider" not in fig.widgets: + if not self._in_brain_figure: # Draw the time label self._time_label = time_label if time_label is not None: if "%" in time_label: time_label = time_label % np.round(1e3 * time) self._time_label_actor = self._renderer.text2d( - x_window=0.01, y_window=0.01, text=time_label + x_window=0.01, + y_window=0.01, + text=time_label, + color=foreground, ) self._configure_dock() @@ -359,6 +382,7 @@ def _update(self): vmin=-surf_map["map_vmax"], vmax=surf_map["map_vmax"], colormap=self._colormap_lines, + width=self._contour_line_width, ) if self._time_label is not None: if hasattr(self, "_time_label_actor"): @@ -369,7 +393,10 @@ def _update(self): if "%" in self._time_label: time_label = self._time_label % np.round(1e3 * self._current_time) self._time_label_actor = self._renderer.text2d( - x_window=0.01, y_window=0.01, text=time_label + x_window=0.01, + y_window=0.01, + text=time_label, + color=self._fg_color, ) self._renderer.plotter.update() @@ -438,6 +465,16 @@ def _callback(vmax, kind, scaling): callback=self.set_contours, layout=layout, ) + + self._widgets["contours_line_width"] = r._dock_add_slider( + name="Thickness", + value=1, + rng=[0, 10], + callback=self.set_contour_line_width, + double=True, + layout=layout, + ) + r._dock_finalize() def _on_time_change(self, event): @@ -499,9 +536,13 @@ def _on_contours(self, event): break surf_map["contours"] = event.contours self._n_contours = len(event.contours) + if event.line_width is not None: + self._contour_line_width = event.line_width with disable_ui_events(self): if "contours" in self._widgets: self._widgets["contours"].set_value(len(event.contours)) + if "contour_line_width" in self._widgets and event.line_width is not None: + self._widgets["contour_line_width"].set_value(event.line_width) self._update() def set_time(self, time): @@ -536,6 +577,7 @@ def set_contours(self, n_contours): contours=np.linspace( -surf_map["map_vmax"], surf_map["map_vmax"], n_contours ).tolist(), + line_width=self._contour_line_width, ), ) @@ -570,3 +612,14 @@ def _rescale(self): current_data = surf_map["data_interp"](self._current_time) vmax = float(np.max(current_data)) self.set_vmax(vmax, kind=surf_map["map_kind"]) + + def set_contour_line_width(self, line_width): + """Set the line_width of the contour lines. + + Parameters + ---------- + line_width : float + The desired line_width of the contour lines. + """ + self._contour_line_width = line_width + self.set_contours(self._n_contours) diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index b8b3fe29a4d..dc5f8315020 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -191,11 +191,12 @@ class Contours(UIEvent): Parameters ---------- kind : str - The kind of contours lines being changed. The Notes section of the - drawing routine publishing this event should mention the possible - kinds. + The drawing routine publishing this event should mention the possible kinds. contours : list of float The new values at which contour lines need to be drawn. + line_width : float | None + The line_width with which to draw the contour lines. Can be ``None`` to + indicate to keep using the current line_width. Attributes ---------- @@ -206,10 +207,14 @@ class Contours(UIEvent): kinds. contours : list of float The new values at which contour lines need to be drawn. + line_width : float | None + The line_width with which to draw the contour lines. Can be ``None`` to + indicate to keep using the current line_width. """ kind: str contours: list[str] + line_width: float | None @dataclass diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index 881d265d2d2..13ea2bb1e83 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -134,6 +134,8 @@ _._nearest_transformed_high_res_mri_idx_rpa _._nearest_transformed_high_res_mri_idx_nasion _._nearest_transformed_high_res_mri_idx_lpa +_.prop.culling +_.prop.lighting # Figures (prevent GC for example) _.decim_data diff --git a/tutorials/inverse/21_interactive_dipole_fit.py b/tutorials/inverse/21_interactive_dipole_fit.py new file mode 100644 index 00000000000..fd99764c112 --- /dev/null +++ b/tutorials/inverse/21_interactive_dipole_fit.py @@ -0,0 +1,159 @@ +""" +.. _tut-xfit: + +===================================================================== +Source localization by guided equivalent current dipole (ECD) fitting +===================================================================== + +This combination of manual specification and automated fitting is one of the oldest MEG +source estimation techniques :footcite:`Sarvas1987`. We will manually identify where and +when dipole source are active, upon which the fitting algorithm will find the best +location for the source. The result is a sparse source estimate of several equivalent +current dipoles (ECDs) that together explain (most of) the MEG evoked response. ECDs are +especially suited for capturing individual components of an evoked response (e.g. N100m, +N400m, etc.). Once the set of ECDs has been established, their timecourses can be +computed for multiple :class:`~mne.Evoked` objects, for example different experimental +conditions. + +This tutorial will demonstrate how to fit ECDs using the interactive GUI and also how to +do it using Python code. +""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +# %% +# Guided ECD fitting using the GUI +# -------------------------------- +# +# Starting the GUI +# ~~~~~~~~~~~~~~~~ +# We can start the GUI either from the command line by using the ``mne dipolefit`` +# program. By default it will load an evoked response from the MNE-Sample data, which +# is what we will use in this tutorial. To load your own data, use the ``-e +# EVOKED_FILE`` option to load an evoked response from a file (filename typically ends +# in ``*-ave.fif``). +# +# The GUI can also be started from an interactive python console: +import mne + +mne.gui.dipolefit() + +# %% +# Without specifying anything about the head model, the GUI shows the minimal setup that +# can be used to fit dipoles: the sensors, a spherical head model and the +# electro-magnetic field recorded by the sensors, using an ad-hoc noise covariance +# matrix. If we provide more information, we can create a more accurate head model that +# provides better ECD fits and gives us more guidance for determining sources. On the +# command line there are various options you can use to specify files containing the +# covariance matrix, BEM model and MRI<->head transformation, see the output of ``mne +# dipolefit --help``. +# +# In an interactive python console, we can provide the appropriate MNE-Python objects +# when starting the GUI: +path = mne.datasets.sample.data_path() +meg_dir = path / "MEG" / "sample" +subjects_dir = path / "subjects" + +evoked = mne.read_evokeds(meg_dir / "sample_audvis-ave.fif", condition="Left Auditory") +evoked.apply_baseline() +cov = mne.read_cov(meg_dir / "sample_audvis-cov.fif") +bem = mne.read_bem_solution( + subjects_dir / "sample" / "bem" / "sample-5120-5120-5120-bem-sol.fif" +) +trans = mne.read_trans(meg_dir / "sample_audvis_raw-trans.fif") + +# A distributed source estimate is a helpful guide for our dipole fits. +inv = mne.minimum_norm.read_inverse_operator( + meg_dir / "sample_audvis-meg-oct-6-meg-inv.fif" +) +stc = mne.minimum_norm.apply_inverse(evoked, inv) + +# Open the GUI with a better head model. +fitting_gui = mne.gui.dipolefit( + evoked, + cov=cov, + bem=bem, + trans=trans, + stc=stc, + ch_type="meg", # only use MEG sensors for this tutorial + subject="sample", + subjects_dir=subjects_dir, +) + +# %% +# During guided ECD fitting, we look for patterns in the eletro-magnetic field to +# identify when and where sources may be active. We can use the time slider to examine +# how the field changes over time. The sample data is an evoked response to an auditory +# tone being played to the left of the participant and we can see the initial auditory +# response peaking at around 85 ms on the right hemisphere. The field shows a typical +# di-polar pattern with a pair of red/blue focii on either side of the source that +# should be located in auditory cortex (the distributed source estimates shows where it +# is). +# +# By pressing the "Fit dipole" button we instruct the algorithm to fit a dipole at the +# current time. After a few seconds of computation, the resulting dipole will be +# displayed as a arrow in the brain, indicating its source, as well as an arrow on the +# MEG helmet indicating the fit between the dipole and the field pattern. The timecourse +# of the dipole is shown below. On the right are controls to name, remove, temporarily +# (de-)activate, and save the dipole to a file. You also find a toggle switch to make +# the dipole's orientation dynamic or keep it fixed at the orientation it had at the +# time when it was fitted. +# +# Selecting channels to guide the ECD modeling +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# At nearly the same time, auditory responses are occurring in both left and right +# auditory cortex. Hence, the single dipole that we fitted without any guidance will +# have some bias as the algorithm attempted to fit the ECD to the entire bi-lateral +# field pattern. +# +# To isolate portions of the field pattern that contain a single pair of red/blue focii, +# the ideal fitting target for the algorithm, we can restrict the analysis to a subset +# of sensors. To do so, first press the "Sensor data" button, which will open a new +# window showing the evoked response across all sensors. By clicking and dragging the +# mouse we can make a lasso selection around the sensors we wish to include in the +# analysis. Hold ``CTRL`` to add to the current selection and ``CTRL + SHIFT`` to remove +# from the current selection. The currently selected sensors are highlighted in green in +# the main window, showing the portion of the field pattern they cover. When you are +# happy with the selection, you can use the "Fit dipole" button as before to fit a +# dipole using the selected sensors at the current timepoint. +# +# Remove or de-activate the dipole we previously fitted to the entire field pattern and +# fit two dipoles using the left-side and right-side sensors respectively. It is helpful +# to name them. +# +# Multi/single dipole modes +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# By default, the fitting algorithm is in "Multi dipole (MNE)" mode, meaning portions of +# the signal attributed to one dipole can not be attributed to a second dipole at the +# same time. You will notice that if you have two dipoles with similar orientations +# close to each other, their timecourses become a strange mixture as each dipole will +# claim a part of the same signal. To prevent this, we can switch the algorithm over to +# "Single dipole". In this mode, the timecourse of each dipole will be computed whilst +# ignoring all other dipoles, which is useful when evaluating multiple candidate dipoles +# for the same source. +# +# Saving and loading sets of dipoles +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# We can save the fitted dipoles using the "Save dipoles" button. There are two possible +# file formats for this: plain text (``.dip``) and a binary format (``.bdip``). These +# are formats are compatible with MegIn's software, allowing interoperability between +# MNE-Python and Xfit. The fitted dipoles can also be accessed and saved through Python +# code: +fitted_dipoles = fitting_gui.dipoles # the dipoles we fitted +# save with: fitted_dipoles.save("my_file.dip") + +# %% +# Saved dipoles can be loaded with :func:`mne.read_dipole` and added to an existing +# dipole fitting GUI like so: +dips_to_add = mne.read_dipole(meg_dir / "sample_audvis_set1.dip") +dips_to_add = dips_to_add[[27, 33]] # add only two of the 34 dipoles in the file +name = ["rh", "lh"] # we can give names to the dipoles if we want +fitting_gui = mne.gui.dipolefit() +fitting_gui.add_dipole(dips_to_add, name=name) + +# %% +# References +# ---------- +# .. footbibliography::