Skip to content

Commit f2f49bd

Browse files
committed
Applied recommended code changes & bug fix in _scalp_coupling_index iteration
1 parent 4bae23e commit f2f49bd

File tree

3 files changed

+25
-41
lines changed

3 files changed

+25
-41
lines changed

mne/preprocessing/nirs/_beer_lambert_law.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def beer_lambert_law(raw, ppf=6.0):
119119

120120

121121
def _load_absorption(freqs):
122-
"""Load molar extinction coefficients"""
122+
"""Load molar extinction coefficients."""
123123
# Data from https://omlc.org/spectra/hemoglobin/summary.html
124124
# The text was copied to a text file. The text before and
125125
# after the table was deleted. The the following was run in

mne/preprocessing/nirs/_scalp_coupling_index.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -65,42 +65,25 @@ def scalp_coupling_index(
6565

6666
sci = np.zeros(picks.shape)
6767

68-
if n_wavelengths == 2:
69-
# Use pairwise correlation for 2 wavelengths (backward compatibility)
70-
for ii in range(0, len(picks), 2):
68+
# Calculate all pairwise correlations within each group and use the minimum as SCI
69+
pair_indices = np.triu_indices(n_wavelengths, k=1)
70+
for gg in range(0, len(picks), n_wavelengths):
71+
group_data = filtered_data[gg : gg + n_wavelengths]
72+
73+
# Calculate pairwise correlations within the group
74+
correlations = np.zeros(pair_indices[0].shape[0])
75+
76+
for n, (ii, jj) in enumerate(zip(*pair_indices)):
7177
with np.errstate(invalid="ignore"):
72-
c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1]
73-
if not np.isfinite(c): # someone had std=0
74-
c = 0
75-
sci[ii] = c
76-
sci[ii + 1] = c
77-
else:
78-
# For multiple wavelengths: calculate all pairwise correlations within each group
79-
# and use the minimum as the quality metric
80-
81-
# Group picks by number of wavelengths
82-
# Drops last incomplete group, but we're assuming valid data
83-
pick_iter = iter(picks)
84-
pick_groups = zip(*[pick_iter] * n_wavelengths)
85-
86-
for group_picks in pick_groups:
87-
group_data = filtered_data[group_picks]
88-
89-
# Calculate pairwise correlations within the group
90-
pair_indices = np.triu_indices(len(group_picks), k=1)
91-
correlations = np.zeros(pair_indices[0].shape[0])
92-
93-
for n, (ii, jj) in enumerate(zip(*pair_indices)):
94-
with np.errstate(invalid="ignore"):
95-
c = np.corrcoef(group_data[ii], group_data[jj])[0][1]
96-
if np.isfinite(c):
97-
correlations[n] = c
98-
99-
# Use minimum correlation as the quality metric
100-
group_sci = correlations.min()
101-
102-
# Assign the same SCI value to all channels in the group
103-
sci[group_picks] = group_sci
78+
c = np.corrcoef(group_data[ii], group_data[jj])[0][1]
79+
if np.isfinite(c):
80+
correlations[n] = c
81+
82+
# Use minimum correlation as SCI
83+
group_sci = correlations.min()
84+
85+
# Assign the same SCI value to all channels in the group
86+
sci[gg : gg + n_wavelengths] = group_sci
10487

10588
sci[zero_mask] = 0
10689
sci = sci[np.argsort(picks)] # restore original order

mne/preprocessing/nirs/nirs.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,12 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr
149149
# (e.g., for 2 wavelengths, we need even number of channels)
150150
if len(picks) % len(pair_vals) != 0:
151151
picks = _throw_or_return_empty(
152-
f"NIRS channels not ordered correctly. The number of channels "
152+
"NIRS channels not ordered correctly. The number of channels "
153153
f"must be a multiple of {len(pair_vals)} values, but "
154154
f"{len(picks)} channels were provided.",
155155
throw_errors,
156156
)
157+
157158
# Ensure wavelength info exists for waveform data
158159
all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave]
159160
if len(pair_vals) != len(set(all_freqs)):
@@ -197,9 +198,9 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr
197198
picks = picks[np.argsort([info["ch_names"][pick] for pick in picks])]
198199

199200
# Validate channel grouping (same source-detector pairs, all pair_vals match)
200-
for i in range(0, len(picks), len(pair_vals)):
201+
for ii in range(0, len(picks), len(pair_vals)):
201202
# Extract a group of channels (e.g., all wavelengths for one S-D pair)
202-
group_picks = picks[i : i + len(pair_vals)]
203+
group_picks = picks[ii : ii + len(pair_vals)]
203204

204205
# Parse channel names using regex to extract source, detector, and value info
205206
group_info = [
@@ -228,8 +229,8 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr
228229
break
229230

230231
if check_bads:
231-
for i in range(0, len(picks), len(pair_vals)):
232-
group_picks = picks[i : i + len(pair_vals)]
232+
for ii in range(0, len(picks), len(pair_vals)):
233+
group_picks = picks[ii : ii + len(pair_vals)]
233234

234235
want = [info.ch_names[pick] for pick in group_picks]
235236
got = list(set(info["bads"]).intersection(want))

0 commit comments

Comments
 (0)