Skip to content

Commit 1d01c0a

Browse files
committed
Refactored DataDict following @matthew-brett's advices.
1 parent 233af1a commit 1d01c0a

File tree

3 files changed

+191
-87
lines changed

3 files changed

+191
-87
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ def common_shape(self):
9292
""" Matching shape of the elements in this array sequence. """
9393
return self._data.shape[1:]
9494

95+
@property
96+
def nb_elements(self):
97+
""" Total number of elements in this array sequence. """
98+
return self._data.shape[0]
99+
95100
def append(self, element):
96101
""" Appends `element` to this array sequence.
97102

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 107 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from .. import tractogram as module_tractogram
1313
from ..tractogram import TractogramItem, Tractogram, LazyTractogram
14-
from ..tractogram import DataPerStreamlineDict, DataPerPointDict, LazyDict
14+
from ..tractogram import PerArrayDict, PerArraySequenceDict, LazyDict
1515

1616
DATA = {}
1717

@@ -126,64 +126,35 @@ def assert_tractogram_equal(t1, t2):
126126
t2.data_per_streamline, t2.data_per_point)
127127

128128

129-
class TestTractogramItem(unittest.TestCase):
129+
class TestPerArrayDict(unittest.TestCase):
130130

131-
def test_creating_tractogram_item(self):
132-
rng = np.random.RandomState(42)
133-
streamline = rng.rand(rng.randint(10, 50), 3)
134-
colors = rng.rand(len(streamline), 3)
135-
mean_curvature = 1.11
136-
mean_color = np.array([0, 1, 0], dtype="f4")
137-
138-
data_for_streamline = {"mean_curvature": mean_curvature,
139-
"mean_color": mean_color}
140-
141-
data_for_points = {"colors": colors}
142-
143-
# Create a tractogram item with a streamline, data.
144-
t = TractogramItem(streamline, data_for_streamline, data_for_points)
145-
assert_equal(len(t), len(streamline))
146-
assert_array_equal(t.streamline, streamline)
147-
assert_array_equal(list(t), streamline)
148-
assert_array_equal(t.data_for_streamline['mean_curvature'],
149-
mean_curvature)
150-
assert_array_equal(t.data_for_streamline['mean_color'],
151-
mean_color)
152-
assert_array_equal(t.data_for_points['colors'],
153-
colors)
154-
155-
156-
class TestTractogramDataDict(unittest.TestCase):
157-
158-
def test_datadict_creation(self):
159-
# Create a DataPerStreamlineDict object using another
160-
# DataPerStreamlineDict object.
131+
def test_per_array_dict_creation(self):
132+
# Create a PerArrayDict object using another
133+
# PerArrayDict object.
134+
nb_streamlines = len(DATA['tractogram'])
161135
data_per_streamline = DATA['tractogram'].data_per_streamline
162-
data_dict = DataPerStreamlineDict(DATA['tractogram'],
163-
data_per_streamline)
136+
data_dict = PerArrayDict(nb_streamlines, data_per_streamline)
164137
assert_equal(data_dict.keys(), data_per_streamline.keys())
165138
for k in data_dict.keys():
166139
assert_array_equal(data_dict[k], data_per_streamline[k])
167140

168141
del data_dict['mean_curvature']
169142
assert_equal(len(data_dict),
170-
len(DATA['tractogram'].data_per_streamline)-1)
143+
len(data_per_streamline)-1)
171144

172-
# Create a DataPerStreamlineDict object using an existing dict object.
173-
data_per_streamline = DATA['tractogram'].data_per_streamline.store
174-
data_dict = DataPerStreamlineDict(DATA['tractogram'],
175-
data_per_streamline)
145+
# Create a PerArrayDict object using an existing dict object.
146+
data_per_streamline = DATA['data_per_streamline']
147+
data_dict = PerArrayDict(nb_streamlines, data_per_streamline)
176148
assert_equal(data_dict.keys(), data_per_streamline.keys())
177149
for k in data_dict.keys():
178150
assert_array_equal(data_dict[k], data_per_streamline[k])
179151

180152
del data_dict['mean_curvature']
181153
assert_equal(len(data_dict), len(data_per_streamline)-1)
182154

183-
# Create a DataPerStreamlineDict object using keyword arguments.
184-
data_per_streamline = DATA['tractogram'].data_per_streamline.store
185-
data_dict = DataPerStreamlineDict(DATA['tractogram'],
186-
**data_per_streamline)
155+
# Create a PerArrayDict object using keyword arguments.
156+
data_per_streamline = DATA['data_per_streamline']
157+
data_dict = PerArrayDict(nb_streamlines, **data_per_streamline)
187158
assert_equal(data_dict.keys(), data_per_streamline.keys())
188159
for k in data_dict.keys():
189160
assert_array_equal(data_dict[k], data_per_streamline[k])
@@ -192,21 +163,77 @@ def test_datadict_creation(self):
192163
assert_equal(len(data_dict), len(data_per_streamline)-1)
193164

194165
def test_getitem(self):
195-
data_dict = DataPerPointDict(DATA['tractogram'],
196-
DATA['data_per_point'])
166+
sdict = PerArrayDict(len(DATA['tractogram']),
167+
DATA['data_per_streamline'])
168+
169+
assert_raises(KeyError, sdict.__getitem__, 'invalid')
170+
171+
# Test slicing and advanced indexing.
172+
for k, v in DATA['tractogram'].data_per_streamline.items():
173+
assert_true(k in sdict)
174+
assert_arrays_equal(sdict[k], v)
175+
assert_arrays_equal(sdict[::2][k], v[::2])
176+
assert_arrays_equal(sdict[::-1][k], v[::-1])
177+
assert_arrays_equal(sdict[-1][k], v[-1])
178+
assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]])
179+
180+
181+
class TestPerArraySequenceDict(unittest.TestCase):
182+
183+
def test_per_array_sequence_dict_creation(self):
184+
# Create a PerArraySequenceDict object using another
185+
# PerArraySequenceDict object.
186+
nb_elements = DATA['tractogram'].streamlines.nb_elements
187+
data_per_point = DATA['tractogram'].data_per_point
188+
data_dict = PerArraySequenceDict(nb_elements, data_per_point)
189+
assert_equal(data_dict.keys(), data_per_point.keys())
190+
for k in data_dict.keys():
191+
assert_arrays_equal(data_dict[k], data_per_point[k])
192+
193+
del data_dict['fa']
194+
assert_equal(len(data_dict),
195+
len(data_per_point)-1)
196+
197+
# Create a PerArraySequenceDict object using an existing dict object.
198+
data_per_point = DATA['data_per_point']
199+
data_dict = PerArraySequenceDict(nb_elements, data_per_point)
200+
assert_equal(data_dict.keys(), data_per_point.keys())
201+
for k in data_dict.keys():
202+
assert_arrays_equal(data_dict[k], data_per_point[k])
203+
204+
del data_dict['fa']
205+
assert_equal(len(data_dict), len(data_per_point)-1)
206+
207+
# Create a PerArraySequenceDict object using keyword arguments.
208+
data_per_point = DATA['data_per_point']
209+
data_dict = PerArraySequenceDict(nb_elements, **data_per_point)
210+
assert_equal(data_dict.keys(), data_per_point.keys())
211+
for k in data_dict.keys():
212+
assert_arrays_equal(data_dict[k], data_per_point[k])
213+
214+
del data_dict['fa']
215+
assert_equal(len(data_dict), len(data_per_point)-1)
216+
217+
def test_getitem(self):
218+
nb_elements = DATA['tractogram'].streamlines.nb_elements
219+
sdict = PerArraySequenceDict(nb_elements, DATA['data_per_point'])
197220

198-
assert_true('fa' in data_dict)
199-
assert_arrays_equal(data_dict['fa'], DATA['fa'])
200-
assert_arrays_equal(data_dict[::2]['fa'], DATA['fa'][::2])
201-
assert_arrays_equal(data_dict[::-1]['fa'], DATA['fa'][::-1])
202-
assert_arrays_equal(data_dict[-1]['fa'], DATA['fa'][-1])
203-
assert_raises(KeyError, data_dict.__getitem__, 'invalid')
221+
assert_raises(KeyError, sdict.__getitem__, 'invalid')
204222

223+
# Test slicing and advanced indexing.
224+
for k, v in DATA['tractogram'].data_per_point.items():
225+
assert_true(k in sdict)
226+
assert_arrays_equal(sdict[k], v)
227+
assert_arrays_equal(sdict[::2][k], v[::2])
228+
assert_arrays_equal(sdict[::-1][k], v[::-1])
229+
assert_arrays_equal(sdict[-1][k], v[-1])
230+
assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]])
205231

206-
class TestTractogramLazyDict(unittest.TestCase):
232+
233+
class TestLazyDict(unittest.TestCase):
207234

208235
def test_lazydict_creation(self):
209-
data_dict = LazyDict(None, DATA['data_per_streamline_func'])
236+
data_dict = LazyDict(DATA['data_per_streamline_func'])
210237
assert_equal(data_dict.keys(), DATA['data_per_streamline_func'].keys())
211238
for k in data_dict.keys():
212239
assert_array_equal(list(data_dict[k]),
@@ -216,6 +243,33 @@ def test_lazydict_creation(self):
216243
len(DATA['data_per_streamline_func']))
217244

218245

246+
class TestTractogramItem(unittest.TestCase):
247+
248+
def test_creating_tractogram_item(self):
249+
rng = np.random.RandomState(42)
250+
streamline = rng.rand(rng.randint(10, 50), 3)
251+
colors = rng.rand(len(streamline), 3)
252+
mean_curvature = 1.11
253+
mean_color = np.array([0, 1, 0], dtype="f4")
254+
255+
data_for_streamline = {"mean_curvature": mean_curvature,
256+
"mean_color": mean_color}
257+
258+
data_for_points = {"colors": colors}
259+
260+
# Create a tractogram item with a streamline, data.
261+
t = TractogramItem(streamline, data_for_streamline, data_for_points)
262+
assert_equal(len(t), len(streamline))
263+
assert_array_equal(t.streamline, streamline)
264+
assert_array_equal(list(t), streamline)
265+
assert_array_equal(t.data_for_streamline['mean_curvature'],
266+
mean_curvature)
267+
assert_array_equal(t.data_for_streamline['mean_color'],
268+
mean_color)
269+
assert_array_equal(t.data_for_points['colors'],
270+
colors)
271+
272+
219273
class TestTractogram(unittest.TestCase):
220274

221275
def test_tractogram_creation(self):

0 commit comments

Comments
 (0)