55
66import numpy as np
77
8+ from ..deprecated import deprecate_with_version
9+
810MEGABYTE = 1024 * 1024
911
1012
@@ -53,6 +55,37 @@ def update_seq(self, arr_seq):
5355 arr_seq ._lengths = np .array (self .lengths )
5456
5557
58+ def _define_operators (cls ):
59+ """ Decorator which adds support for some Python operators. """
60+ def _wrap (cls , op , inplace = False , unary = False ):
61+
62+ def fn_unary_op (self ):
63+ return self ._op (op )
64+
65+ def fn_binary_op (self , value ):
66+ return self ._op (op , value , inplace = inplace )
67+
68+ setattr (cls , op , fn_unary_op if unary else fn_binary_op )
69+ fn = getattr (cls , op )
70+ fn .__name__ = op
71+ fn .__doc__ = getattr (np .ndarray , op ).__doc__
72+
73+ for op in ["__add__" , "__sub__" , "__mul__" , "__mod__" , "__pow__" ,
74+ "__floordiv__" , "__truediv__" , "__lshift__" , "__rshift__" ,
75+ "__or__" , "__and__" , "__xor__" ]:
76+ _wrap (cls , op = op , inplace = False )
77+ _wrap (cls , op = "__i{}__" .format (op .strip ("_" )), inplace = True )
78+
79+ for op in ["__eq__" , "__ne__" , "__lt__" , "__le__" , "__gt__" , "__ge__" ]:
80+ _wrap (cls , op )
81+
82+ for op in ["__neg__" , "__abs__" , "__invert__" ]:
83+ _wrap (cls , op , unary = True )
84+
85+ return cls
86+
87+
88+ @_define_operators
5689class ArraySequence (object ):
5790 """ Sequence of ndarrays having variable first dimension sizes.
5891
@@ -116,9 +149,42 @@ def total_nb_rows(self):
116149 return np .sum (self ._lengths )
117150
118151 @property
152+ @deprecate_with_version ("'ArraySequence.data' property is deprecated.\n "
153+ "Please use the 'ArraySequence.get_data()' method instead" ,
154+ '3.0' , '4.0' )
119155 def data (self ):
120156 """ Elements in this array sequence. """
121- return self ._data
157+ view = self ._data .view ()
158+ view .setflags (write = False )
159+ return view
160+
161+ def get_data (self ):
162+ """ Returns a *copy* of the elements in this array sequence.
163+
164+ Notes
165+ -----
166+ To modify the data on this array sequence, one can use
167+ in-place mathematical operators (e.g., `seq += ...`) or the use
168+ assignment operator (i.e, `seq[...] = value`).
169+ """
170+ return self .copy ()._data
171+
172+ def _check_shape (self , arrseq ):
173+ """ Check whether this array sequence is compatible with another. """
174+ msg = "cannot perform operation - array sequences have different"
175+ if len (self ._lengths ) != len (arrseq ._lengths ):
176+ msg += " lengths: {} vs. {}."
177+ raise ValueError (msg .format (len (self ._lengths ), len (arrseq ._lengths )))
178+
179+ if self .total_nb_rows != arrseq .total_nb_rows :
180+ msg += " amount of data: {} vs. {}."
181+ raise ValueError (msg .format (self .total_nb_rows , arrseq .total_nb_rows ))
182+
183+ if self .common_shape != arrseq .common_shape :
184+ msg += " common shape: {} vs. {}."
185+ raise ValueError (msg .format (self .common_shape , arrseq .common_shape ))
186+
187+ return True
122188
123189 def _get_next_offset (self ):
124190 """ Offset in ``self._data`` at which to write next rowelement """
@@ -320,7 +386,7 @@ def __getitem__(self, idx):
320386 seq ._lengths = self ._lengths [off_idx ]
321387 return seq
322388
323- if isinstance (off_idx , list ) or is_ndarray_of_int_or_bool (off_idx ):
389+ if isinstance (off_idx , ( list , range ) ) or is_ndarray_of_int_or_bool (off_idx ):
324390 # Fancy indexing
325391 seq ._offsets = self ._offsets [off_idx ]
326392 seq ._lengths = self ._lengths [off_idx ]
@@ -329,6 +395,116 @@ def __getitem__(self, idx):
329395 raise TypeError ("Index must be either an int, a slice, a list of int"
330396 " or a ndarray of bool! Not " + str (type (idx )))
331397
398+ def __setitem__ (self , idx , elements ):
399+ """ Set sequence(s) through standard or advanced numpy indexing.
400+
401+ Parameters
402+ ----------
403+ idx : int or slice or list or ndarray
404+ If int, index of the element to retrieve.
405+ If slice, use slicing to retrieve elements.
406+ If list, indices of the elements to retrieve.
407+ If ndarray with dtype int, indices of the elements to retrieve.
408+ If ndarray with dtype bool, only retrieve selected elements.
409+ elements: ndarray or :class:`ArraySequence`
410+ Data that will overwrite selected sequences.
411+ If `idx` is an int, `elements` is expected to be a ndarray.
412+ Otherwise, `elements` is expected a :class:`ArraySequence` object.
413+ """
414+ if isinstance (idx , (numbers .Integral , np .integer )):
415+ start = self ._offsets [idx ]
416+ self ._data [start :start + self ._lengths [idx ]] = elements
417+ return
418+
419+ if isinstance (idx , tuple ):
420+ off_idx = idx [0 ]
421+ data = self ._data .__getitem__ ((slice (None ),) + idx [1 :])
422+ else :
423+ off_idx = idx
424+ data = self ._data
425+
426+ if isinstance (off_idx , slice ): # Standard list slicing
427+ offsets = self ._offsets [off_idx ]
428+ lengths = self ._lengths [off_idx ]
429+
430+ elif isinstance (off_idx , (list , range )) or is_ndarray_of_int_or_bool (off_idx ):
431+ # Fancy indexing
432+ offsets = self ._offsets [off_idx ]
433+ lengths = self ._lengths [off_idx ]
434+
435+ else :
436+ raise TypeError ("Index must be either an int, a slice, a list of int"
437+ " or a ndarray of bool! Not " + str (type (idx )))
438+
439+ if is_array_sequence (elements ):
440+ if len (lengths ) != len (elements ):
441+ msg = "Trying to set {} sequences with {} sequences."
442+ raise ValueError (msg .format (len (lengths ), len (elements )))
443+
444+ if sum (lengths ) != elements .total_nb_rows :
445+ msg = "Trying to set {} points with {} points."
446+ raise ValueError (msg .format (sum (lengths ), elements .total_nb_rows ))
447+
448+ for o1 , l1 , o2 , l2 in zip (offsets , lengths , elements ._offsets , elements ._lengths ):
449+ data [o1 :o1 + l1 ] = elements ._data [o2 :o2 + l2 ]
450+
451+ elif isinstance (elements , numbers .Number ):
452+ for o1 , l1 in zip (offsets , lengths ):
453+ data [o1 :o1 + l1 ] = elements
454+
455+ else : # Try to iterate over it.
456+ for o1 , l1 , element in zip (offsets , lengths , elements ):
457+ data [o1 :o1 + l1 ] = element
458+
459+ def _op (self , op , value = None , inplace = False ):
460+ """ Applies some operator to this arraysequence.
461+
462+ This handles both unary and binary operators with a scalar or another
463+ array sequence. Operations are performed directly on the underlying
464+ data, or a copy of it, which depends on the value of `inplace`.
465+
466+ Parameters
467+ ----------
468+ op : str
469+ Name of the Python operator (e.g., `"__add__"`).
470+ value : scalar or :class:`ArraySequence`, optional
471+ If None, the operator is assumed to be unary.
472+ Otherwise, that value is used in the binary operation.
473+ inplace: bool, optional
474+ If False, the operation is done on a copy of this array sequence.
475+ Otherwise, this array sequence gets modified directly.
476+ """
477+ seq = self if inplace else self .copy ()
478+
479+ if is_array_sequence (value ) and seq ._check_shape (value ):
480+ elements = zip (seq ._offsets , seq ._lengths ,
481+ self ._offsets , self ._lengths ,
482+ value ._offsets , value ._lengths )
483+
484+ # Change seq.dtype to match the operation resulting type.
485+ o0 , l0 , o1 , l1 , o2 , l2 = next (elements )
486+ tmp = getattr (self ._data [o1 :o1 + l1 ], op )(value ._data [o2 :o2 + l2 ])
487+ seq ._data = seq ._data .astype (tmp .dtype )
488+ seq ._data [o0 :o0 + l0 ] = tmp
489+
490+ for o0 , l0 , o1 , l1 , o2 , l2 in elements :
491+ seq ._data [o0 :o0 + l0 ] = getattr (self ._data [o1 :o1 + l1 ], op )(value ._data [o2 :o2 + l2 ])
492+
493+ else :
494+ args = [] if value is None else [value ] # Dealing with unary and binary ops.
495+ elements = zip (seq ._offsets , seq ._lengths , self ._offsets , self ._lengths )
496+
497+ # Change seq.dtype to match the operation resulting type.
498+ o0 , l0 , o1 , l1 = next (elements )
499+ tmp = getattr (self ._data [o1 :o1 + l1 ], op )(* args )
500+ seq ._data = seq ._data .astype (tmp .dtype )
501+ seq ._data [o0 :o0 + l0 ] = tmp
502+
503+ for o0 , l0 , o1 , l1 in elements :
504+ seq ._data [o0 :o0 + l0 ] = getattr (self ._data [o1 :o1 + l1 ], op )(* args )
505+
506+ return seq
507+
332508 def __iter__ (self ):
333509 if len (self ._lengths ) != len (self ._offsets ):
334510 raise ValueError ("ArraySequence object corrupted:"
@@ -371,7 +547,7 @@ def load(cls, filename):
371547 return seq
372548
373549
374- def create_arraysequences_from_generator (gen , n ):
550+ def create_arraysequences_from_generator (gen , n , buffer_sizes = None ):
375551 """ Creates :class:`ArraySequence` objects from a generator yielding tuples
376552
377553 Parameters
@@ -381,8 +557,13 @@ def create_arraysequences_from_generator(gen, n):
381557 array sequences.
382558 n : int
383559 Number of :class:`ArraySequences` object to create.
560+ buffer_sizes : list of float, optional
561+ Sizes (in Mb) for each ArraySequence's buffer.
384562 """
385- seqs = [ArraySequence () for _ in range (n )]
563+ if buffer_sizes is None :
564+ buffer_sizes = [4 ] * n
565+
566+ seqs = [ArraySequence (buffer_size = size ) for size in buffer_sizes ]
386567 for data in gen :
387568 for i , seq in enumerate (seqs ):
388569 if data [i ].nbytes > 0 :
0 commit comments