33import numpy as np
44import warnings
55import operator
6+ from collections import defaultdict
67
78from nibabel .testing import assert_arrays_equal
89from nibabel .testing import clear_and_catch_warnings
1718DATA = {}
1819
1920
21+ def make_fake_streamline (nb_points , data_per_point_shapes = {},
22+ data_for_streamline_shapes = {}, rng = None ):
23+ """ Make a single streamline according to provided requirements. """
24+ if rng is None :
25+ rng = np .random .RandomState ()
26+
27+ streamline = rng .randn (nb_points , 3 ).astype ("f4" )
28+
29+ data_per_point = {}
30+ for k , shape in data_per_point_shapes .items ():
31+ data_per_point [k ] = rng .randn (* ((nb_points ,) + shape )).astype ("f4" )
32+
33+ data_for_streamline = {}
34+ for k , shape in data_for_streamline .items ():
35+ data_for_streamline [k ] = rng .randn (* shape ).astype ("f4" )
36+
37+ return streamline , data_per_point , data_for_streamline
38+
39+
40+ def make_fake_tractogram (list_nb_points , data_per_point_shapes = {},
41+ data_for_streamline_shapes = {}, rng = None ):
42+ """ Make multiple streamlines according to provided requirements. """
43+ all_streamlines = []
44+ all_data_per_point = defaultdict (lambda : [])
45+ all_data_per_streamline = defaultdict (lambda : [])
46+ for nb_points in list_nb_points :
47+ data = make_fake_streamline (nb_points , data_per_point_shapes ,
48+ data_for_streamline_shapes , rng )
49+ streamline , data_per_point , data_for_streamline = data
50+
51+ all_streamlines .append (streamline )
52+ for k , v in data_per_point .items ():
53+ all_data_per_point [k ].append (v )
54+
55+ for k , v in data_for_streamline .items ():
56+ all_data_per_streamline [k ].append (v )
57+
58+ return all_streamlines , all_data_per_point , all_data_per_streamline
59+
60+
61+ def make_dummy_streamline (nb_points ):
62+ """ Make the streamlines that have been used to create test data files."""
63+ if nb_points == 1 :
64+ streamline = np .arange (1 * 3 , dtype = "f4" ).reshape ((1 , 3 ))
65+ data_per_point = {"fa" : np .array ([[0.2 ]], dtype = "f4" ),
66+ "colors" : np .array ([(1 , 0 , 0 )]* 1 , dtype = "f4" )}
67+ data_for_streamline = {"mean_curvature" : np .array ([1.11 ], dtype = "f4" ),
68+ "mean_torsion" : np .array ([1.22 ], dtype = "f4" ),
69+ "mean_colors" : np .array ([1 , 0 , 0 ], dtype = "f4" )}
70+
71+ elif nb_points == 2 :
72+ streamline = np .arange (2 * 3 , dtype = "f4" ).reshape ((2 , 3 ))
73+ data_per_point = {"fa" : np .array ([[0.3 ],
74+ [0.4 ]], dtype = "f4" ),
75+ "colors" : np .array ([(0 , 1 , 0 )]* 2 , dtype = "f4" )}
76+ data_for_streamline = {"mean_curvature" : np .array ([2.11 ], dtype = "f4" ),
77+ "mean_torsion" : np .array ([2.22 ], dtype = "f4" ),
78+ "mean_colors" : np .array ([0 , 1 , 0 ], dtype = "f4" )}
79+
80+ elif nb_points == 5 :
81+ streamline = np .arange (5 * 3 , dtype = "f4" ).reshape ((5 , 3 ))
82+ data_per_point = {"fa" : np .array ([[0.5 ],
83+ [0.6 ],
84+ [0.6 ],
85+ [0.7 ],
86+ [0.8 ]], dtype = "f4" ),
87+ "colors" : np .array ([(0 , 0 , 1 )]* 5 , dtype = "f4" )}
88+ data_for_streamline = {"mean_curvature" : np .array ([3.11 ], dtype = "f4" ),
89+ "mean_torsion" : np .array ([3.22 ], dtype = "f4" ),
90+ "mean_colors" : np .array ([0 , 0 , 1 ], dtype = "f4" )}
91+
92+ return streamline , data_per_point , data_for_streamline
93+
94+
2095def setup ():
2196 global DATA
2297 DATA ['rng' ] = np .random .RandomState (1234 )
23- DATA ['streamlines' ] = [np .arange (1 * 3 , dtype = "f4" ).reshape ((1 , 3 )),
24- np .arange (2 * 3 , dtype = "f4" ).reshape ((2 , 3 )),
25- np .arange (5 * 3 , dtype = "f4" ).reshape ((5 , 3 ))]
26-
27- DATA ['fa' ] = [np .array ([[0.2 ]], dtype = "f4" ),
28- np .array ([[0.3 ],
29- [0.4 ]], dtype = "f4" ),
30- np .array ([[0.5 ],
31- [0.6 ],
32- [0.6 ],
33- [0.7 ],
34- [0.8 ]], dtype = "f4" )]
35-
36- DATA ['colors' ] = [np .array ([(1 , 0 , 0 )]* 1 , dtype = "f4" ),
37- np .array ([(0 , 1 , 0 )]* 2 , dtype = "f4" ),
38- np .array ([(0 , 0 , 1 )]* 5 , dtype = "f4" )]
39-
40- DATA ['mean_curvature' ] = [np .array ([1.11 ], dtype = "f4" ),
41- np .array ([2.11 ], dtype = "f4" ),
42- np .array ([3.11 ], dtype = "f4" )]
43-
44- DATA ['mean_torsion' ] = [np .array ([1.22 ], dtype = "f4" ),
45- np .array ([2.22 ], dtype = "f4" ),
46- np .array ([3.22 ], dtype = "f4" )]
47-
48- DATA ['mean_colors' ] = [np .array ([1 , 0 , 0 ], dtype = "f4" ),
49- np .array ([0 , 1 , 0 ], dtype = "f4" ),
50- np .array ([0 , 0 , 1 ], dtype = "f4" )]
98+
99+ DATA ['streamlines' ] = []
100+ DATA ['fa' ] = []
101+ DATA ['colors' ] = []
102+ DATA ['mean_curvature' ] = []
103+ DATA ['mean_torsion' ] = []
104+ DATA ['mean_colors' ] = []
105+ for nb_points in [1 , 2 , 5 ]:
106+ data = make_dummy_streamline (nb_points )
107+ streamline , data_per_point , data_for_streamline = data
108+ DATA ['streamlines' ].append (streamline )
109+ DATA ['fa' ].append (data_per_point ['fa' ])
110+ DATA ['colors' ].append (data_per_point ['colors' ])
111+ DATA ['mean_curvature' ].append (data_for_streamline ['mean_curvature' ])
112+ DATA ['mean_torsion' ].append (data_for_streamline ['mean_torsion' ])
113+ DATA ['mean_colors' ].append (data_for_streamline ['mean_colors' ])
51114
52115 DATA ['data_per_point' ] = {'colors' : DATA ['colors' ],
53116 'fa' : DATA ['fa' ]}
@@ -280,9 +343,14 @@ def test_extend(self):
280343 total_nb_rows = DATA ['tractogram' ].streamlines .total_nb_rows
281344 sdict = PerArraySequenceDict (total_nb_rows , DATA ['data_per_point' ])
282345
283- new_data = {'colors' : 2 * np .array (DATA ['colors' ]),
284- 'fa' : 3 * np .array (DATA ['fa' ])}
285- sdict2 = PerArraySequenceDict (total_nb_rows , new_data )
346+ # Test compatible PerArrayDicts.
347+ list_nb_points = [2 , 7 , 4 ]
348+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
349+ "fa" : DATA ['fa' ][0 ].shape [1 :]}
350+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
351+ data_per_point_shapes ,
352+ rng = DATA ['rng' ])
353+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
286354
287355 sdict .extend (sdict2 )
288356 assert_equal (len (sdict ), len (sdict2 ))
@@ -297,16 +365,22 @@ def test_extend(self):
297365 assert_raises (ValueError , sdict .extend , PerArraySequenceDict ())
298366
299367 # Other dict has more entries.
300- new_data = {'colors' : 2 * np .array (DATA ['colors' ]),
301- 'fa' : 3 * np .array (DATA ['fa' ]),
302- 'other' : 4 * np .array (DATA ['fa' ])}
303- sdict2 = PerArraySequenceDict (total_nb_rows , new_data )
368+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
369+ "fa" : DATA ['fa' ][0 ].shape [1 :],
370+ "other" : (7 ,)}
371+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
372+ data_per_point_shapes ,
373+ rng = DATA ['rng' ])
374+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
304375 assert_raises (ValueError , sdict .extend , sdict2 )
305376
306377 # Other dict has the right number of entries but wrong shape.
307- new_data = {'colors' : 2 * np .array (DATA ['colors' ]),
308- 'other' : 2 * np .array (DATA ['colors' ]),}
309- sdict2 = PerArraySequenceDict (total_nb_rows , new_data )
378+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
379+ "fa" : DATA ['fa' ][0 ].shape [1 :] + (3 ,)}
380+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
381+ data_per_point_shapes ,
382+ rng = DATA ['rng' ])
383+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
310384 assert_raises (ValueError , sdict .extend , sdict2 )
311385
312386
@@ -650,13 +724,15 @@ def test_tractogram_extend(self):
650724 # Load tractogram that contains some metadata.
651725 t = DATA ['tractogram' ].copy ()
652726
653- for op , in_place in ((operator .add , False ), (operator .iadd , True ), (extender , True )):
727+ for op , in_place in ((operator .add , False ), (operator .iadd , True ),
728+ (extender , True )):
654729 first_arg = t .copy ()
655730 new_t = op (first_arg , t )
656731 assert_equal (new_t is first_arg , in_place )
657732 assert_tractogram_equal (new_t [:len (t )], DATA ['tractogram' ])
658733 assert_tractogram_equal (new_t [len (t ):], DATA ['tractogram' ])
659734
735+
660736class TestLazyTractogram (unittest .TestCase ):
661737
662738 def test_lazy_tractogram_creation (self ):
@@ -670,7 +746,8 @@ def test_lazy_tractogram_creation(self):
670746 'mean_colors' : (x for x in DATA ['mean_colors' ])}
671747
672748 # Creating LazyTractogram with generators is not allowed as
673- # generators get exhausted and are not reusable unlike generator function.
749+ # generators get exhausted and are not reusable unlike generator
750+ # function.
674751 assert_raises (TypeError , LazyTractogram , streamlines )
675752 assert_raises (TypeError , LazyTractogram ,
676753 data_per_streamline = data_per_streamline )
@@ -701,7 +778,8 @@ def test_lazy_tractogram_from_data_func(self):
701778 tractogram = LazyTractogram .from_data_func (_empty_data_gen )
702779 check_tractogram (tractogram )
703780
704- # Create `LazyTractogram` from a generator function yielding TractogramItem.
781+ # Create `LazyTractogram` from a generator function yielding
782+ # TractogramItem.
705783 data = [DATA ['streamlines' ], DATA ['fa' ], DATA ['colors' ],
706784 DATA ['mean_curvature' ], DATA ['mean_torsion' ],
707785 DATA ['mean_colors' ]]
@@ -839,8 +917,8 @@ def test_lazy_tractogram_copy(self):
839917 # Check we copied the data and not simply created new references.
840918 assert_true (tractogram is not DATA ['lazy_tractogram' ])
841919
842- # When copying LazyTractogram, the generator function yielding streamlines
843- # should stay the same.
920+ # When copying LazyTractogram, the generator function yielding
921+ # streamlines should stay the same.
844922 assert_true (tractogram ._streamlines
845923 is DATA ['lazy_tractogram' ]._streamlines )
846924
0 commit comments