1111
1212from .. import tractogram as module_tractogram
1313from ..tractogram import TractogramItem , Tractogram , LazyTractogram
14- from ..tractogram import DataPerStreamlineDict , DataPerPointDict , LazyDict
14+ from ..tractogram import PerArrayDict , PerArraySequenceDict , LazyDict
1515
1616DATA = {}
1717
@@ -126,64 +126,35 @@ def assert_tractogram_equal(t1, t2):
126126 t2 .data_per_streamline , t2 .data_per_point )
127127
128128
129- class TestTractogramItem (unittest .TestCase ):
129+ class TestPerArrayDict (unittest .TestCase ):
130130
131- def test_creating_tractogram_item (self ):
132- rng = np .random .RandomState (42 )
133- streamline = rng .rand (rng .randint (10 , 50 ), 3 )
134- colors = rng .rand (len (streamline ), 3 )
135- mean_curvature = 1.11
136- mean_color = np .array ([0 , 1 , 0 ], dtype = "f4" )
137-
138- data_for_streamline = {"mean_curvature" : mean_curvature ,
139- "mean_color" : mean_color }
140-
141- data_for_points = {"colors" : colors }
142-
143- # Create a tractogram item with a streamline, data.
144- t = TractogramItem (streamline , data_for_streamline , data_for_points )
145- assert_equal (len (t ), len (streamline ))
146- assert_array_equal (t .streamline , streamline )
147- assert_array_equal (list (t ), streamline )
148- assert_array_equal (t .data_for_streamline ['mean_curvature' ],
149- mean_curvature )
150- assert_array_equal (t .data_for_streamline ['mean_color' ],
151- mean_color )
152- assert_array_equal (t .data_for_points ['colors' ],
153- colors )
154-
155-
156- class TestTractogramDataDict (unittest .TestCase ):
157-
158- def test_datadict_creation (self ):
159- # Create a DataPerStreamlineDict object using another
160- # DataPerStreamlineDict object.
131+ def test_per_array_dict_creation (self ):
132+ # Create a PerArrayDict object using another
133+ # PerArrayDict object.
134+ nb_streamlines = len (DATA ['tractogram' ])
161135 data_per_streamline = DATA ['tractogram' ].data_per_streamline
162- data_dict = DataPerStreamlineDict (DATA ['tractogram' ],
163- data_per_streamline )
136+ data_dict = PerArrayDict (nb_streamlines , data_per_streamline )
164137 assert_equal (data_dict .keys (), data_per_streamline .keys ())
165138 for k in data_dict .keys ():
166139 assert_array_equal (data_dict [k ], data_per_streamline [k ])
167140
168141 del data_dict ['mean_curvature' ]
169142 assert_equal (len (data_dict ),
170- len (DATA [ 'tractogram' ]. data_per_streamline )- 1 )
143+ len (data_per_streamline )- 1 )
171144
172- # Create a DataPerStreamlineDict object using an existing dict object.
173- data_per_streamline = DATA ['tractogram' ].data_per_streamline .store
174- data_dict = DataPerStreamlineDict (DATA ['tractogram' ],
175- data_per_streamline )
145+ # Create a PerArrayDict object using an existing dict object.
146+ data_per_streamline = DATA ['data_per_streamline' ]
147+ data_dict = PerArrayDict (nb_streamlines , data_per_streamline )
176148 assert_equal (data_dict .keys (), data_per_streamline .keys ())
177149 for k in data_dict .keys ():
178150 assert_array_equal (data_dict [k ], data_per_streamline [k ])
179151
180152 del data_dict ['mean_curvature' ]
181153 assert_equal (len (data_dict ), len (data_per_streamline )- 1 )
182154
183- # Create a DataPerStreamlineDict object using keyword arguments.
184- data_per_streamline = DATA ['tractogram' ].data_per_streamline .store
185- data_dict = DataPerStreamlineDict (DATA ['tractogram' ],
186- ** data_per_streamline )
155+ # Create a PerArrayDict object using keyword arguments.
156+ data_per_streamline = DATA ['data_per_streamline' ]
157+ data_dict = PerArrayDict (nb_streamlines , ** data_per_streamline )
187158 assert_equal (data_dict .keys (), data_per_streamline .keys ())
188159 for k in data_dict .keys ():
189160 assert_array_equal (data_dict [k ], data_per_streamline [k ])
@@ -192,21 +163,77 @@ def test_datadict_creation(self):
192163 assert_equal (len (data_dict ), len (data_per_streamline )- 1 )
193164
194165 def test_getitem (self ):
195- data_dict = DataPerPointDict (DATA ['tractogram' ],
196- DATA ['data_per_point' ])
166+ sdict = PerArrayDict (len (DATA ['tractogram' ]),
167+ DATA ['data_per_streamline' ])
168+
169+ assert_raises (KeyError , sdict .__getitem__ , 'invalid' )
170+
171+ # Test slicing and advanced indexing.
172+ for k , v in DATA ['tractogram' ].data_per_streamline .items ():
173+ assert_true (k in sdict )
174+ assert_arrays_equal (sdict [k ], v )
175+ assert_arrays_equal (sdict [::2 ][k ], v [::2 ])
176+ assert_arrays_equal (sdict [::- 1 ][k ], v [::- 1 ])
177+ assert_arrays_equal (sdict [- 1 ][k ], v [- 1 ])
178+ assert_arrays_equal (sdict [[0 , - 1 ]][k ], v [[0 , - 1 ]])
179+
180+
181+ class TestPerArraySequenceDict (unittest .TestCase ):
182+
183+ def test_per_array_sequence_dict_creation (self ):
184+ # Create a PerArraySequenceDict object using another
185+ # PerArraySequenceDict object.
186+ nb_elements = DATA ['tractogram' ].streamlines .nb_elements
187+ data_per_point = DATA ['tractogram' ].data_per_point
188+ data_dict = PerArraySequenceDict (nb_elements , data_per_point )
189+ assert_equal (data_dict .keys (), data_per_point .keys ())
190+ for k in data_dict .keys ():
191+ assert_arrays_equal (data_dict [k ], data_per_point [k ])
192+
193+ del data_dict ['fa' ]
194+ assert_equal (len (data_dict ),
195+ len (data_per_point )- 1 )
196+
197+ # Create a PerArraySequenceDict object using an existing dict object.
198+ data_per_point = DATA ['data_per_point' ]
199+ data_dict = PerArraySequenceDict (nb_elements , data_per_point )
200+ assert_equal (data_dict .keys (), data_per_point .keys ())
201+ for k in data_dict .keys ():
202+ assert_arrays_equal (data_dict [k ], data_per_point [k ])
203+
204+ del data_dict ['fa' ]
205+ assert_equal (len (data_dict ), len (data_per_point )- 1 )
206+
207+ # Create a PerArraySequenceDict object using keyword arguments.
208+ data_per_point = DATA ['data_per_point' ]
209+ data_dict = PerArraySequenceDict (nb_elements , ** data_per_point )
210+ assert_equal (data_dict .keys (), data_per_point .keys ())
211+ for k in data_dict .keys ():
212+ assert_arrays_equal (data_dict [k ], data_per_point [k ])
213+
214+ del data_dict ['fa' ]
215+ assert_equal (len (data_dict ), len (data_per_point )- 1 )
216+
217+ def test_getitem (self ):
218+ nb_elements = DATA ['tractogram' ].streamlines .nb_elements
219+ sdict = PerArraySequenceDict (nb_elements , DATA ['data_per_point' ])
197220
198- assert_true ('fa' in data_dict )
199- assert_arrays_equal (data_dict ['fa' ], DATA ['fa' ])
200- assert_arrays_equal (data_dict [::2 ]['fa' ], DATA ['fa' ][::2 ])
201- assert_arrays_equal (data_dict [::- 1 ]['fa' ], DATA ['fa' ][::- 1 ])
202- assert_arrays_equal (data_dict [- 1 ]['fa' ], DATA ['fa' ][- 1 ])
203- assert_raises (KeyError , data_dict .__getitem__ , 'invalid' )
221+ assert_raises (KeyError , sdict .__getitem__ , 'invalid' )
204222
223+ # Test slicing and advanced indexing.
224+ for k , v in DATA ['tractogram' ].data_per_point .items ():
225+ assert_true (k in sdict )
226+ assert_arrays_equal (sdict [k ], v )
227+ assert_arrays_equal (sdict [::2 ][k ], v [::2 ])
228+ assert_arrays_equal (sdict [::- 1 ][k ], v [::- 1 ])
229+ assert_arrays_equal (sdict [- 1 ][k ], v [- 1 ])
230+ assert_arrays_equal (sdict [[0 , - 1 ]][k ], v [[0 , - 1 ]])
205231
206- class TestTractogramLazyDict (unittest .TestCase ):
232+
233+ class TestLazyDict (unittest .TestCase ):
207234
208235 def test_lazydict_creation (self ):
209- data_dict = LazyDict (None , DATA ['data_per_streamline_func' ])
236+ data_dict = LazyDict (DATA ['data_per_streamline_func' ])
210237 assert_equal (data_dict .keys (), DATA ['data_per_streamline_func' ].keys ())
211238 for k in data_dict .keys ():
212239 assert_array_equal (list (data_dict [k ]),
@@ -216,6 +243,33 @@ def test_lazydict_creation(self):
216243 len (DATA ['data_per_streamline_func' ]))
217244
218245
246+ class TestTractogramItem (unittest .TestCase ):
247+
248+ def test_creating_tractogram_item (self ):
249+ rng = np .random .RandomState (42 )
250+ streamline = rng .rand (rng .randint (10 , 50 ), 3 )
251+ colors = rng .rand (len (streamline ), 3 )
252+ mean_curvature = 1.11
253+ mean_color = np .array ([0 , 1 , 0 ], dtype = "f4" )
254+
255+ data_for_streamline = {"mean_curvature" : mean_curvature ,
256+ "mean_color" : mean_color }
257+
258+ data_for_points = {"colors" : colors }
259+
260+ # Create a tractogram item with a streamline, data.
261+ t = TractogramItem (streamline , data_for_streamline , data_for_points )
262+ assert_equal (len (t ), len (streamline ))
263+ assert_array_equal (t .streamline , streamline )
264+ assert_array_equal (list (t ), streamline )
265+ assert_array_equal (t .data_for_streamline ['mean_curvature' ],
266+ mean_curvature )
267+ assert_array_equal (t .data_for_streamline ['mean_color' ],
268+ mean_color )
269+ assert_array_equal (t .data_for_points ['colors' ],
270+ colors )
271+
272+
219273class TestTractogram (unittest .TestCase ):
220274
221275 def test_tractogram_creation (self ):
0 commit comments