|
2 | 2 | import unittest |
3 | 3 | import numpy as np |
4 | 4 | import warnings |
| 5 | +import operator |
5 | 6 |
|
6 | 7 | from nibabel.testing import assert_arrays_equal |
7 | 8 | from nibabel.testing import clear_and_catch_warnings |
@@ -130,6 +131,11 @@ def assert_tractogram_equal(t1, t2): |
130 | 131 | t2.data_per_streamline, t2.data_per_point) |
131 | 132 |
|
132 | 133 |
|
| 134 | +def extender(a, b): |
| 135 | + a.extend(b) |
| 136 | + return a |
| 137 | + |
| 138 | + |
133 | 139 | class TestPerArrayDict(unittest.TestCase): |
134 | 140 |
|
135 | 141 | def test_per_array_dict_creation(self): |
@@ -184,18 +190,40 @@ def test_getitem(self): |
184 | 190 | def test_extend(self): |
185 | 191 | sdict = PerArrayDict(len(DATA['tractogram']), |
186 | 192 | DATA['data_per_streamline']) |
| 193 | + |
| 194 | + new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']), |
| 195 | + 'mean_torsion': 3 * np.array(DATA['mean_torsion']), |
| 196 | + 'mean_colors': 4 * np.array(DATA['mean_colors'])} |
187 | 197 | sdict2 = PerArrayDict(len(DATA['tractogram']), |
188 | | - DATA['data_per_streamline']) |
| 198 | + new_data) |
189 | 199 |
|
190 | 200 | sdict.extend(sdict2) |
191 | 201 | assert_equal(len(sdict), len(sdict2)) |
192 | | - for k, v in DATA['tractogram'].data_per_streamline.items(): |
193 | | - assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], v) |
194 | | - assert_arrays_equal(sdict[k][len(DATA['tractogram']):], v) |
| 202 | + for k in DATA['tractogram'].data_per_streamline: |
| 203 | + assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], |
| 204 | + DATA['tractogram'].data_per_streamline[k]) |
| 205 | + assert_arrays_equal(sdict[k][len(DATA['tractogram']):], |
| 206 | + new_data[k]) |
195 | 207 |
|
196 | 208 | # Test incompatible PerArrayDicts. |
| 209 | + # Other dict is missing entries. |
197 | 210 | assert_raises(ValueError, sdict.extend, PerArrayDict()) |
198 | 211 |
|
| 212 | + # Other dict has more entries. |
| 213 | + new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']), |
| 214 | + 'mean_torsion': 3 * np.array(DATA['mean_torsion']), |
| 215 | + 'mean_colors': 4 * np.array(DATA['mean_colors']), |
| 216 | + 'other': 5 * np.array(DATA['mean_colors'])} |
| 217 | + sdict2 = PerArrayDict(len(DATA['tractogram']), new_data) |
| 218 | + assert_raises(ValueError, sdict.extend, sdict2) |
| 219 | + |
| 220 | + # Other dict has the right number of entries but wrong shape. |
| 221 | + new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']), |
| 222 | + 'mean_torsion': 3 * np.array(DATA['mean_torsion']), |
| 223 | + 'other': 4 * np.array(DATA['mean_torsion'])} |
| 224 | + sdict2 = PerArrayDict(len(DATA['tractogram']), new_data) |
| 225 | + assert_raises(ValueError, sdict.extend, sdict2) |
| 226 | + |
199 | 227 |
|
200 | 228 | class TestPerArraySequenceDict(unittest.TestCase): |
201 | 229 |
|
@@ -251,17 +279,36 @@ def test_getitem(self): |
251 | 279 | def test_extend(self): |
252 | 280 | total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows |
253 | 281 | sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point']) |
254 | | - sdict2 = PerArraySequenceDict(total_nb_rows, DATA['data_per_point']) |
| 282 | + |
| 283 | + new_data = {'colors': 2 * np.array(DATA['colors']), |
| 284 | + 'fa': 3 * np.array(DATA['fa'])} |
| 285 | + sdict2 = PerArraySequenceDict(total_nb_rows, new_data) |
255 | 286 |
|
256 | 287 | sdict.extend(sdict2) |
257 | 288 | assert_equal(len(sdict), len(sdict2)) |
258 | | - for k, v in DATA['tractogram'].data_per_point.items(): |
259 | | - assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], v) |
260 | | - assert_arrays_equal(sdict[k][len(DATA['tractogram']):], v) |
| 289 | + for k in DATA['tractogram'].data_per_point: |
| 290 | + assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], |
| 291 | + DATA['tractogram'].data_per_point[k]) |
| 292 | + assert_arrays_equal(sdict[k][len(DATA['tractogram']):], |
| 293 | + new_data[k]) |
261 | 294 |
|
262 | 295 | # Test incompatible PerArrayDicts. |
| 296 | + # Other dict is missing entries. |
263 | 297 | assert_raises(ValueError, sdict.extend, PerArraySequenceDict()) |
264 | 298 |
|
| 299 | + # Other dict has more entries. |
| 300 | + new_data = {'colors': 2 * np.array(DATA['colors']), |
| 301 | + 'fa': 3 * np.array(DATA['fa']), |
| 302 | + 'other': 4 * np.array(DATA['fa'])} |
| 303 | + sdict2 = PerArraySequenceDict(total_nb_rows, new_data) |
| 304 | + assert_raises(ValueError, sdict.extend, sdict2) |
| 305 | + |
| 306 | + # Other dict has the right number of entries but wrong shape. |
| 307 | + new_data = {'colors': 2 * np.array(DATA['colors']), |
| 308 | + 'other': 2 * np.array(DATA['colors']),} |
| 309 | + sdict2 = PerArraySequenceDict(total_nb_rows, new_data) |
| 310 | + assert_raises(ValueError, sdict.extend, sdict2) |
| 311 | + |
265 | 312 |
|
266 | 313 | class TestLazyDict(unittest.TestCase): |
267 | 314 |
|
@@ -603,19 +650,12 @@ def test_tractogram_extend(self): |
603 | 650 | # Load tractogram that contains some metadata. |
604 | 651 | t = DATA['tractogram'].copy() |
605 | 652 |
|
606 | | - # Double the tractogram. |
607 | | - new_t = t + t |
608 | | - assert_equal(len(new_t), 2*len(t)) |
609 | | - assert_tractogram_equal(new_t[:len(t)], DATA['tractogram']) |
610 | | - assert_tractogram_equal(new_t[len(t):], DATA['tractogram']) |
611 | | - |
612 | | - # Double the tractogram inplace. |
613 | | - new_t = DATA['tractogram'].copy() |
614 | | - new_t += t |
615 | | - assert_equal(len(new_t), 2*len(t)) |
616 | | - assert_tractogram_equal(new_t[:len(t)], DATA['tractogram']) |
617 | | - assert_tractogram_equal(new_t[len(t):], DATA['tractogram']) |
618 | | - |
| 653 | + for op, in_place in ((operator.add, False), (operator.iadd, True), (extender, True)): |
| 654 | + first_arg = t.copy() |
| 655 | + new_t = op(first_arg, t) |
| 656 | + assert_equal(new_t is first_arg, in_place) |
| 657 | + assert_tractogram_equal(new_t[:len(t)], DATA['tractogram']) |
| 658 | + assert_tractogram_equal(new_t[len(t):], DATA['tractogram']) |
619 | 659 |
|
620 | 660 | class TestLazyTractogram(unittest.TestCase): |
621 | 661 |
|
@@ -690,7 +730,9 @@ def test_lazy_tractogram_getitem(self): |
690 | 730 | def test_lazy_tractogram_extend(self): |
691 | 731 | t = DATA['lazy_tractogram'].copy() |
692 | 732 | new_t = DATA['lazy_tractogram'].copy() |
693 | | - assert_raises(NotImplementedError, new_t.__iadd__, t) |
| 733 | + |
| 734 | + for op in (operator.add, operator.iadd, extender): |
| 735 | + assert_raises(NotImplementedError, op, new_t, t) |
694 | 736 |
|
695 | 737 | def test_lazy_tractogram_len(self): |
696 | 738 | modules = [module_tractogram] # Modules for which to catch warnings. |
|
0 commit comments