diff --git a/doc/api/events.rst b/doc/api/events.rst index 3f7159a22d5..88479eb8f3e 100644 --- a/doc/api/events.rst +++ b/doc/api/events.rst @@ -9,6 +9,7 @@ Events Annotations AcqParserFIF + HEDAnnotations concatenate_events count_events find_events diff --git a/mne/__init__.pyi b/mne/__init__.pyi index d50b5209346..6560854402e 100644 --- a/mne/__init__.pyi +++ b/mne/__init__.pyi @@ -11,6 +11,7 @@ __all__ = [ "Evoked", "EvokedArray", "Forward", + "HEDAnnotations", "Info", "Label", "MixedSourceEstimate", @@ -260,6 +261,7 @@ from ._freesurfer import ( ) from .annotations import ( Annotations, + HEDAnnotations, annotations_from_events, count_annotations, events_from_annotations, diff --git a/mne/annotations.py b/mne/annotations.py index 629ee7b20cb..71fa0267d65 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -29,6 +29,7 @@ write_name_list_sanitized, write_string, ) +from .fixes import _compare_version from .utils import ( _check_dict_keys, _check_dt, @@ -52,6 +53,7 @@ verbose, warn, ) +from .utils.check import _soft_import # For testing windows_like_datetime, we monkeypatch "datetime" in this module. # Keep the true datetime object around for _validate_type use. @@ -151,6 +153,7 @@ class Annotations: -------- mne.annotations_from_events mne.events_from_annotations + mne.HEDAnnotations Notes ----- @@ -288,7 +291,7 @@ def orig_time(self): def __eq__(self, other): """Compare to another Annotations instance.""" - if not isinstance(other, Annotations): + if not isinstance(other, type(self)): return False return ( np.array_equal(self.onset, other.onset) @@ -567,6 +570,12 @@ def _sort(self): self.duration = self.duration[order] self.description = self.description[order] self.ch_names = self.ch_names[order] + if hasattr(self, "hed_string"): + self.hed_string._objs = [self.hed_string._objs[i] for i in order] + for i in order: + self.hed_string.__setitem__( + i, self.hed_string._objs[i].get_original_hed_string() + ) @verbose def crop( @@ -758,6 +767,241 @@ def rename(self, mapping, verbose=None): return self +class _HEDStrings(list): + """Subclass of list that will validate before __setitem__.""" + + def __init__(self, *args, hed_version, **kwargs): + self._hed = _soft_import("hed", "validation of HED tags in annotations") + self._schema = self._hed.load_schema_version(hed_version) + super().__init__(*args, **kwargs) + self._objs = [self._validate_hed_string(item, self._schema) for item in self] + + def __setitem__(self, key, value): + """Validate value first, before assigning.""" + hs = self._validate_hed_string(value, self._schema) + super().__setitem__(key, hs.get_original_hed_string()) + self._objs[key] = hs + + def _validate_hed_string(self, value, schema): + # create HedString object and validate it + hs = self._hed.HedString(value, schema) + # handle any errors + error_handler = self._hed.errors.ErrorHandler(check_for_warnings=False) + issues = hs.validate(allow_placeholders=False, error_handler=error_handler) + error_string = self._hed.get_printable_issue_string(issues) + if len(error_string): + raise ValueError(f"A HED string failed to validate:\n {error_string}") + hs.sort() + return hs + + def append(self, item): + """Append an item to the end of the HEDString list.""" + hs = self._validate_hed_string(item, self._schema) + super().append(hs.get_original_hed_string()) + self._objs.append(hs) + + +@fill_doc +class HEDAnnotations(Annotations): + """Annotations object for annotating segments of raw data with HED tags. + + Parameters + ---------- + onset : array of float, shape (n_annotations,) + The starting time of annotations in seconds after ``orig_time``. + duration : array of float, shape (n_annotations,) | float + Durations of the annotations in seconds. If a float, all the + annotations are given the same duration. + description : array of str, shape (n_annotations,) | str + Array of strings containing description for each annotation. If a + string, all the annotations are given the same description. To reject + epochs, use description starting with keyword 'bad'. See example above. + hed_string : array of str, shape (n_annotations,) | str + Sequence of strings containing a HED tag (or comma-separated list of HED tags) + for each annotation. If a single string is provided, all annotations are + assigned the same HED string. + hed_version : str + The HED schema version against which to validate the HED strings. + orig_time : float | str | datetime | tuple of int | None + A POSIX Timestamp, datetime or a tuple containing the timestamp as the + first element and microseconds as the second element. Determines the + starting time of annotation acquisition. If None (default), + starting time is determined from beginning of raw data acquisition. + In general, ``raw.info['meas_date']`` (or None) can be used for syncing + the annotations with raw data if their acquisition is started at the + same time. If it is a string, it should conform to the ISO8601 format. + More precisely to this '%%Y-%%m-%%d %%H:%%M:%%S.%%f' particular case of + the ISO8601 format where the delimiter between date and time is ' '. + %(ch_names_annot)s + + See Also + -------- + mne.Annotations + + Notes + ----- + + .. versionadded:: 1.10 + """ + + def __init__( + self, + onset, + duration, + description, + hed_string, + hed_version="8.3.0", + orig_time=None, + ch_names=None, + ): + self._hed_version = hed_version + self.hed_string = _HEDStrings(hed_string, hed_version=self._hed_version) + super().__init__( + onset=onset, + duration=duration, + description=description, + orig_time=orig_time, + ch_names=ch_names, + ) + + def __eq__(self, other): + """Compare to another HEDAnnotations instance.""" + _slf = self.hed_string + _oth = other.hed_string + + if _compare_version(self._hed_version, "<", other._hed_version): + _slf = [_slf._validate_hed_string(v, _oth._schema) for v in _slf._objs] + elif _compare_version(self._hed_version, ">", other._hed_version): + _oth = [_oth._validate_hed_string(v, _slf._schema) for v in _oth._objs] + return super().__eq__(other) and _slf == _oth + + def __repr__(self): + """Show a textual summary of the object.""" + counter = Counter([hs.get_as_short() for hs in self.hed_string._objs]) + + # textwrap.shorten won't work: we remove all spaces and shouldn't split on `-` + def _shorten(text, width=74, placeholder=" ..."): + parts = text.split(",") + out = parts[0] + for part in parts[1:]: + # +1 for the comma ↓↓↓ + if width < len(out) + 1 + len(part) + len(placeholder): + break + out = f"{out},{part}" + return out + placeholder + + kinds = [ + f"{_shorten(k, width=74):<74} ({v})" for k, v in sorted(counter.items()) + ] + if len(kinds) > 5: + kinds = [*kinds[:5], f"... and {len(kinds) - 5} more"] + kinds = "\n ".join(kinds) + if len(kinds): + kinds = f":\n {kinds}\n" + ch_specific = ", channel-specific" if self._any_ch_names() else "" + s = ( + f"HEDAnnotations | {len(self.onset)} segment" + f"{_pl(len(self.onset))}{ch_specific}{kinds}" + ) + return f"<{s}>" + + def __getitem__(self, key, *, with_ch_names=None): + """Propagate indexing and slicing to the underlying structure.""" + result = super().__getitem__(key, with_ch_names=with_ch_names) + if isinstance(result, OrderedDict): + result["hed_string"] = self.hed_string[key] + return result + else: + key = list(key) if isinstance(key, tuple) else key + hed_string = [self.hed_string[key]] + return HEDAnnotations( + result.onset, + result.duration, + result.description, + hed_string=hed_string, + hed_version=self._hed_version, + orig_time=self.orig_time, + ch_names=result.ch_names, + ) + + def __getstate__(self): + """Make serialization work, by removing module reference.""" + return dict( + _orig_time=self._orig_time, + onset=self.onset, + duration=self.duration, + description=self.description, + ch_names=self.ch_names, + hed_string=list(self.hed_string), + _hed_version=self._hed_version, + ) + + def __setstate__(self, state): + """Unpack from serialized format.""" + self._orig_time = state["_orig_time"] + self.onset = state["onset"] + self.duration = state["duration"] + self.description = state["description"] + self.ch_names = state["ch_names"] + self._hed_version = state["_hed_version"] + self.hed_string = _HEDStrings( + state["hed_string"], hed_version=self._hed_version + ) + + @fill_doc + def append(self, *, onset, duration, description, hed_string, ch_names=None): + """Add an annotated segment. Operates inplace. + + Parameters + ---------- + onset : float | array-like + Annotation time onset from the beginning of the recording in + seconds. + duration : float | array-like + Duration of the annotation in seconds. + description : str | array-like + Description for the annotation. To reject epochs, use description + starting with keyword 'bad'. + hed_string : array of str, shape (n_annotations,) | str + Sequence of strings containing a HED tag (or comma-separated list of HED + tags) for each annotation. If a single string is provided, all annotations + are assigned the same HED string. + %(ch_names_annot)s + + Returns + ------- + self : mne.HEDAnnotations + The modified HEDAnnotations object. + """ + self.hed_string.append(hed_string) + super().append( + onset=onset, duration=duration, description=description, ch_names=ch_names + ) + + def crop( + self, tmin=None, tmax=None, emit_warning=False, use_orig_time=True, verbose=None + ): + """TODO.""" + pass + + def delete(self, idx): + """Remove an annotation. Operates inplace. + + Parameters + ---------- + idx : int | array-like of int + Index of the annotation to remove. Can be array-like to remove multiple + indices. + """ + _ = self.hed_string._objs.pop(idx) + _ = self.hed_string.pop(idx) + super().delete(idx) + + def to_data_frame(self, time_format="datetime"): + """TODO.""" + pass + + class EpochAnnotationsMixin: """Mixin class for Annotations in Epochs.""" @@ -1732,5 +1976,6 @@ def count_annotations(annotations): >>> count_annotations(annotations) {'T0': 2, 'T1': 1} """ - types, counts = np.unique(annotations.description, return_counts=True) + field = "hed_string" if isinstance(annotations, HEDAnnotations) else "description" + types, counts = np.unique(getattr(annotations, field), return_counts=True) return {str(t): int(count) for t, count in zip(types, counts)} diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 4d0db170e2a..823aed20556 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -22,6 +22,7 @@ from mne import ( Annotations, Epochs, + HEDAnnotations, annotations_from_events, count_annotations, create_info, @@ -1825,3 +1826,78 @@ def test_append_splits_boundary(tmp_path, split_size): assert len(raw.annotations) == 2 assert raw.annotations.description[0] == "BAD boundary" assert_allclose(raw.annotations.onset, [onset] * 2) + + +def test_hed_annotations(): + """Test hed_strings validation.""" + pytest.importorskip("hed") + # test initting with bad value + validation_fail_msg = "A HED string failed to validate" + with pytest.raises(ValueError, match=validation_fail_msg): + _ = HEDAnnotations( + onset=[1], + duration=[0.1], + description=["a"], + hed_string=["foo"], + ) + # test initting with good values + good_values = dict( + square="Sensory-event, Experimental-stimulus, Visual-presentation, (Square, " + "DarkBlue, (Center-of, Computer-screen))", # extra spaces intentional + tone="Sensory-event, Experimental-stimulus, Auditory-presentation, (Tone, " + "Frequency/550 Hz)", + press="Agent-action, (Experiment-participant, (Press, Mouse-button))", + word="Sensory-event, (Word, Label/Word-look), Auditory-presentation, " + "Visual-presentation", + ) + ann = HEDAnnotations( + onset=[3, 2, 1], + duration=[0.1, 0.0, 0.3], + description=["d", "c", "a"], + hed_string=[good_values["square"], good_values["tone"], good_values["press"]], + ) + # make sure sorting by onset worked correctly + assert ann.hed_string[0] == good_values["press"] + assert ann.hed_string._objs[0].get_original_hed_string() == good_values["press"] + # test appending + foo = ann.copy() + ons_dur_desc = dict(onset=1.5, duration=0.2, description="b") + with pytest.raises(ValueError, match=validation_fail_msg): + foo.append(**ons_dur_desc, hed_string="foo") + foo.append(**ons_dur_desc, hed_string=good_values["word"]) + # make sure sorting by onset also works for .append() + assert list(foo.hed_string) == [ + x.get_original_hed_string() for x in foo.hed_string._objs + ] + # make sure we didn't mess up the type of the HEDStrings + assert isinstance(foo.hed_string, mne.annotations._HEDStrings) + # test modifying with bad value + with pytest.raises(ValueError, match=validation_fail_msg): + ann.hed_string[0] = "foo" + # test modifying, __eq__, and delete() + foo = ann.copy() + assert ann == foo + foo.hed_string[0] = good_values["word"] + assert ann != foo + ann.hed_string[0] = good_values["word"] + assert ann == foo + foo.delete(0) + assert ann != foo + assert foo.hed_string[0] == ann.hed_string[1] + # test .count() + want_counts = { + good_values["word"]: 1, + good_values["tone"]: 1, + good_values["square"]: 1, + } + assert ann.count() == want_counts + # test __getitem__ + first = ann[0] + assert first["hed_string"] == good_values["word"] + # setting bad value on extracted OrderedDict won't try to validate: + first["hed_string"] = "foo" + # ...and won't affect the original object + assert ann.hed_string[0] == good_values["word"] + # test __repr__ + _repr = repr(ann) + assert "Auditory-presentation,Experimental-stimulus,Sensory-event ..." in _repr