@@ -10,6 +10,12 @@ def is_array_sequence(obj):
1010 return False
1111
1212
13+ def is_ndarray_of_int_or_bool (obj ):
14+ return (isinstance (obj , np .ndarray ) and
15+ (np .issubdtype (obj .dtype , np .integer ) or
16+ np .issubdtype (obj .dtype , np .bool )))
17+
18+
1319class ArraySequence (object ):
1420 """ Sequence of ndarrays having variable first dimension sizes.
1521
@@ -23,9 +29,7 @@ class ArraySequence(object):
2329 same for every ndarray.
2430 """
2531
26- BUFFER_SIZE = 87382 * 4 # About 4 Mb if item shape is 3 (e.g. 3D points).
27-
28- def __init__ (self , iterable = None ):
32+ def __init__ (self , iterable = None , buffer_size = 4 ):
2933 """ Initialize array sequence instance
3034
3135 Parameters
@@ -36,6 +40,8 @@ def __init__(self, iterable=None):
3640 from array-like objects yielded by the iterable.
3741 If :class:`ArraySequence`, create a view (no memory is allocated).
3842 For an actual copy use :meth:`.copy` instead.
43+ buffer_size : float, optional
44+ Size (in Mb) for memory allocation when `iterable` is a generator.
3945 """
4046 # Create new empty `ArraySequence` object.
4147 self ._is_view = False
@@ -62,14 +68,23 @@ def __init__(self, iterable=None):
6268 for i , e in enumerate (iterable ):
6369 e = np .asarray (e )
6470 if i == 0 :
65- new_shape = (ArraySequence .BUFFER_SIZE ,) + e .shape [1 :]
71+ try :
72+ n_elements = np .sum ([len (iterable [i ])
73+ for i in range (len (iterable ))])
74+ new_shape = (n_elements ,) + e .shape [1 :]
75+ except TypeError :
76+ # Can't get the number of elements in iterable. So,
77+ # we use a memory buffer while building the ArraySequence.
78+ n_rows_buffer = buffer_size * 1024 ** 2 // e .nbytes
79+ new_shape = (n_rows_buffer ,) + e .shape [1 :]
80+
6681 self ._data = np .empty (new_shape , dtype = e .dtype )
6782
6883 end = offset + len (e )
69- if end >= len (self ._data ):
84+ if end > len (self ._data ):
7085 # Resize needed, adding `len(e)` items plus some buffer.
7186 nb_points = len (self ._data )
72- nb_points += len (e ) + ArraySequence . BUFFER_SIZE
87+ nb_points += len (e ) + n_rows_buffer
7388 self ._data .resize ((nb_points ,) + self .common_shape )
7489
7590 offsets .append (offset )
@@ -230,24 +245,29 @@ def __getitem__(self, idx):
230245 start = self ._offsets [idx ]
231246 return self ._data [start :start + self ._lengths [idx ]]
232247
233- elif isinstance (idx , (slice , list )):
248+ elif isinstance (idx , (slice , list )) or is_ndarray_of_int_or_bool ( idx ) :
234249 seq = self .__class__ ()
235250 seq ._data = self ._data
236251 seq ._offsets = self ._offsets [idx ]
237252 seq ._lengths = self ._lengths [idx ]
238253 seq ._is_view = True
239254 return seq
240255
241- elif (isinstance (idx , np .ndarray ) and
242- (np .issubdtype (idx .dtype , np .integer ) or
243- np .issubdtype (idx .dtype , np .bool ))):
256+ elif isinstance (idx , tuple ):
244257 seq = self .__class__ ()
245- seq ._data = self ._data
246- seq ._offsets = self ._offsets [idx ]
247- seq ._lengths = self ._lengths [idx ]
258+ seq ._data = self ._data . __getitem__ (( slice ( None ),) + idx [ 1 :])
259+ seq ._offsets = self ._offsets [idx [ 0 ] ]
260+ seq ._lengths = self ._lengths [idx [ 0 ] ]
248261 seq ._is_view = True
249262 return seq
250263
264+ # for name, slice_ in data_per_point_slice.items():
265+ # seq = ArraySequence()
266+ # seq._data = scalars._data[:, slice_]
267+ # seq._offsets = scalars._offsets
268+ # seq._lengths = scalars._lengths
269+ # tractogram.data_per_point[name] = seq
270+
251271 raise TypeError ("Index must be either an int, a slice, a list of int"
252272 " or a ndarray of bool! Not " + str (type (idx )))
253273
@@ -283,11 +303,22 @@ def save(self, filename):
283303 lengths = self ._lengths )
284304
285305 @classmethod
286- def from_filename (cls , filename ):
306+ def load (cls , filename ):
287307 """ Loads a :class:`ArraySequence` object from a .npz file. """
288308 content = np .load (filename )
289309 seq = cls ()
290310 seq ._data = content ["data" ]
291311 seq ._offsets = content ["offsets" ]
292312 seq ._lengths = content ["lengths" ]
293313 return seq
314+
315+
316+ def create_arraysequences_from_generator (gen , n ):
317+ """ Creates :class:`ArraySequence` objects from a generator yielding tuples
318+ """
319+ seqs = [ArraySequence () for _ in range (n )]
320+ for data in gen :
321+ for i , seq in enumerate (seqs ):
322+ seq .append (data [i ])
323+
324+ return seqs
0 commit comments