Skip to content

Commit 196761a

Browse files
committed
Added function to create multiple ArraySequences from a generator.
Also, added an option to specify the buffer size used when creating an ArraySequence from a generator.
1 parent 713ef7b commit 196761a

File tree

2 files changed

+57
-22
lines changed

2 files changed

+57
-22
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ def is_array_sequence(obj):
1010
return False
1111

1212

13+
def is_ndarray_of_int_or_bool(obj):
14+
return (isinstance(obj, np.ndarray) and
15+
(np.issubdtype(obj.dtype, np.integer) or
16+
np.issubdtype(obj.dtype, np.bool)))
17+
18+
1319
class ArraySequence(object):
1420
""" Sequence of ndarrays having variable first dimension sizes.
1521
@@ -23,9 +29,7 @@ class ArraySequence(object):
2329
same for every ndarray.
2430
"""
2531

26-
BUFFER_SIZE = 87382 * 4 # About 4 Mb if item shape is 3 (e.g. 3D points).
27-
28-
def __init__(self, iterable=None):
32+
def __init__(self, iterable=None, buffer_size=4):
2933
""" Initialize array sequence instance
3034
3135
Parameters
@@ -36,6 +40,8 @@ def __init__(self, iterable=None):
3640
from array-like objects yielded by the iterable.
3741
If :class:`ArraySequence`, create a view (no memory is allocated).
3842
For an actual copy use :meth:`.copy` instead.
43+
buffer_size : float, optional
44+
Size (in Mb) for memory allocation when `iterable` is a generator.
3945
"""
4046
# Create new empty `ArraySequence` object.
4147
self._is_view = False
@@ -62,14 +68,23 @@ def __init__(self, iterable=None):
6268
for i, e in enumerate(iterable):
6369
e = np.asarray(e)
6470
if i == 0:
65-
new_shape = (ArraySequence.BUFFER_SIZE,) + e.shape[1:]
71+
try:
72+
n_elements = np.sum([len(iterable[i])
73+
for i in range(len(iterable))])
74+
new_shape = (n_elements,) + e.shape[1:]
75+
except TypeError:
76+
# Can't get the number of elements in iterable. So,
77+
# we use a memory buffer while building the ArraySequence.
78+
n_rows_buffer = buffer_size*1024**2 // e.nbytes
79+
new_shape = (n_rows_buffer,) + e.shape[1:]
80+
6681
self._data = np.empty(new_shape, dtype=e.dtype)
6782

6883
end = offset + len(e)
69-
if end >= len(self._data):
84+
if end > len(self._data):
7085
# Resize needed, adding `len(e)` items plus some buffer.
7186
nb_points = len(self._data)
72-
nb_points += len(e) + ArraySequence.BUFFER_SIZE
87+
nb_points += len(e) + n_rows_buffer
7388
self._data.resize((nb_points,) + self.common_shape)
7489

7590
offsets.append(offset)
@@ -230,24 +245,29 @@ def __getitem__(self, idx):
230245
start = self._offsets[idx]
231246
return self._data[start:start + self._lengths[idx]]
232247

233-
elif isinstance(idx, (slice, list)):
248+
elif isinstance(idx, (slice, list)) or is_ndarray_of_int_or_bool(idx):
234249
seq = self.__class__()
235250
seq._data = self._data
236251
seq._offsets = self._offsets[idx]
237252
seq._lengths = self._lengths[idx]
238253
seq._is_view = True
239254
return seq
240255

241-
elif (isinstance(idx, np.ndarray) and
242-
(np.issubdtype(idx.dtype, np.integer) or
243-
np.issubdtype(idx.dtype, np.bool))):
256+
elif isinstance(idx, tuple):
244257
seq = self.__class__()
245-
seq._data = self._data
246-
seq._offsets = self._offsets[idx]
247-
seq._lengths = self._lengths[idx]
258+
seq._data = self._data.__getitem__((slice(None),) + idx[1:])
259+
seq._offsets = self._offsets[idx[0]]
260+
seq._lengths = self._lengths[idx[0]]
248261
seq._is_view = True
249262
return seq
250263

264+
# for name, slice_ in data_per_point_slice.items():
265+
# seq = ArraySequence()
266+
# seq._data = scalars._data[:, slice_]
267+
# seq._offsets = scalars._offsets
268+
# seq._lengths = scalars._lengths
269+
# tractogram.data_per_point[name] = seq
270+
251271
raise TypeError("Index must be either an int, a slice, a list of int"
252272
" or a ndarray of bool! Not " + str(type(idx)))
253273

@@ -283,11 +303,22 @@ def save(self, filename):
283303
lengths=self._lengths)
284304

285305
@classmethod
286-
def from_filename(cls, filename):
306+
def load(cls, filename):
287307
""" Loads a :class:`ArraySequence` object from a .npz file. """
288308
content = np.load(filename)
289309
seq = cls()
290310
seq._data = content["data"]
291311
seq._offsets = content["offsets"]
292312
seq._lengths = content["lengths"]
293313
return seq
314+
315+
316+
def create_arraysequences_from_generator(gen, n):
317+
""" Creates :class:`ArraySequence` objects from a generator yielding tuples
318+
"""
319+
seqs = [ArraySequence() for _ in range(n)]
320+
for data in gen:
321+
for i, seq in enumerate(seqs):
322+
seq.append(data[i])
323+
324+
return seqs

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def check_arr_seq(seq, arrays):
6161
def check_arr_seq_view(seq_view, seq):
6262
assert_true(seq_view._is_view)
6363
assert_true(seq_view is not seq)
64-
assert_true(seq_view._data is seq._data)
64+
assert_true(np.may_share_memory(seq_view._data, seq._data))
6565
assert_true(seq_view._offsets is not seq._offsets)
6666
assert_true(seq_view._lengths is not seq._lengths)
6767

@@ -77,18 +77,17 @@ def test_creating_arraysequence_from_list(self):
7777

7878
# List of ndarrays.
7979
N = 5
80-
for ndim in range(0, N+1):
80+
for ndim in range(1, N+1):
8181
common_shape = tuple([SEQ_DATA['rng'].randint(1, 10)
8282
for _ in range(ndim-1)])
8383
data = generate_data(nb_arrays=5, common_shape=common_shape,
8484
rng=SEQ_DATA['rng'])
8585
check_arr_seq(ArraySequence(data), data)
8686

8787
# Force ArraySequence constructor to use buffering.
88-
old_buffer_size = ArraySequence.BUFFER_SIZE
89-
ArraySequence.BUFFER_SIZE = 1
90-
check_arr_seq(ArraySequence(SEQ_DATA['data']), SEQ_DATA['data'])
91-
ArraySequence.BUFFER_SIZE = old_buffer_size
88+
buffer_size = 1. / 1024**2 # 1 bytes
89+
check_arr_seq(ArraySequence(iter(SEQ_DATA['data']), buffer_size),
90+
SEQ_DATA['data'])
9291

9392
def test_creating_arraysequence_from_generator(self):
9493
gen = (e for e in SEQ_DATA['data'])
@@ -245,6 +244,11 @@ def test_arraysequence_getitem(self):
245244
# Test invalid indexing
246245
assert_raises(TypeError, SEQ_DATA['seq'].__getitem__, 'abc')
247246

247+
# Get specific columns.
248+
seq_view = SEQ_DATA['seq'][:, 2]
249+
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
250+
check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data']])
251+
248252
def test_arraysequence_repr(self):
249253
# Test that calling repr on a ArraySequence object is not falling.
250254
repr(SEQ_DATA['seq'])
@@ -269,7 +273,7 @@ def test_save_and_load_arraysequence(self):
269273
seq = ArraySequence()
270274
seq.save(f)
271275
f.seek(0, os.SEEK_SET)
272-
loaded_seq = ArraySequence.from_filename(f)
276+
loaded_seq = ArraySequence.load(f)
273277
assert_array_equal(loaded_seq._data, seq._data)
274278
assert_array_equal(loaded_seq._offsets, seq._offsets)
275279
assert_array_equal(loaded_seq._lengths, seq._lengths)
@@ -279,7 +283,7 @@ def test_save_and_load_arraysequence(self):
279283
seq = SEQ_DATA['seq']
280284
seq.save(f)
281285
f.seek(0, os.SEEK_SET)
282-
loaded_seq = ArraySequence.from_filename(f)
286+
loaded_seq = ArraySequence.load(f)
283287
assert_array_equal(loaded_seq._data, seq._data)
284288
assert_array_equal(loaded_seq._offsets, seq._offsets)
285289
assert_array_equal(loaded_seq._lengths, seq._lengths)

0 commit comments

Comments
 (0)