@@ -53,6 +53,35 @@ def update_seq(self, arr_seq):
5353 arr_seq ._lengths = np .array (self .lengths )
5454
5555
56+ def _define_operators (cls ):
57+ """ Decorator which adds support for some Python operators. """
58+ def _wrap (cls , op , name = None , inplace = False , unary = False ):
59+ name = name or op
60+ if unary :
61+ setattr (cls , name , lambda self : self ._op (op ))
62+ else :
63+ setattr (cls , name ,
64+ lambda self , value : self ._op (op , value , inplace = inplace ))
65+
66+ for op in ["__iadd__" , "__isub__" , "__imul__" , "__idiv__" ,
67+ "__ifloordiv__" , "__itruediv__" , "__ior__" ]:
68+ _wrap (cls , op , inplace = True )
69+
70+ for op in ["__add__" , "__sub__" , "__mul__" , "__div__" ,
71+ "__floordiv__" , "__truediv__" , "__or__" ]:
72+ op_ = "__i{}__" .format (op .strip ("_" ))
73+ _wrap (cls , op_ , name = op )
74+
75+ for op in ["__eq__" , "__ne__" , "__lt__" , "__le__" , "__gt__" , "__ge__" ]:
76+ _wrap (cls , op )
77+
78+ for op in ["__neg__" ]:
79+ _wrap (cls , op , unary = True )
80+
81+ return cls
82+
83+
84+ @_define_operators
5685class ArraySequence (object ):
5786 """ Sequence of ndarrays having variable first dimension sizes.
5887
@@ -120,6 +149,23 @@ def data(self):
120149 """ Elements in this array sequence. """
121150 return self ._data
122151
152+ def _check_shape (self , arrseq ):
153+ """ Check whether this array sequence is compatible with another. """
154+ msg = "cannot perform operation - array sequences have different"
155+ if len (self ._lengths ) != len (arrseq ._lengths ):
156+ msg += " lengths: {} vs. {}."
157+ raise ValueError (msg .format (len (self ._lengths ), len (arrseq ._lengths )))
158+
159+ if self .total_nb_rows != arrseq .total_nb_rows :
160+ msg += " amount of data: {} vs. {}."
161+ raise ValueError (msg .format (self .total_nb_rows , arrseq .total_nb_rows ))
162+
163+ if self .common_shape != arrseq .common_shape :
164+ msg += " common shape: {} vs. {}."
165+ raise ValueError (msg .format (self .common_shape , arrseq .common_shape ))
166+
167+ return True
168+
123169 def _get_next_offset (self ):
124170 """ Offset in ``self._data`` at which to write next rowelement """
125171 if len (self ._offsets ) == 0 :
@@ -377,6 +423,37 @@ def __setitem__(self, idx, elements):
377423 for o1 , l1 , o2 , l2 in zip (offsets , lengths , elements ._offsets , elements ._lengths ):
378424 data [o1 :o1 + l1 ] = elements ._data [o2 :o2 + l2 ]
379425
426+ def _op (self , op , value = None , inplace = False ):
427+ """ Applies some operator to this arraysequence.
428+
429+ This handles both unary and binary operators with a scalar or another
430+ array sequence. Operations are performed directly on the underlying
431+ data, or a copy of it, which depends on the value of `inplace`.
432+
433+ Parameters
434+ ----------
435+ op : str
436+ Name of the Python operator (e.g., `"__add__"`).
437+ value : scalar or :class:`ArraySequence`, optional
438+ If None, the operator is assumed to be unary.
439+ Otherwise, that value is used in the binary operation.
440+ inplace: bool, optional
441+ If False, the operation is done on a copy of this array sequence.
442+ Otherwise, this array sequence gets modified directly.
443+ """
444+ seq = self if inplace else self .copy ()
445+
446+ if is_array_sequence (value ) and seq ._check_shape (value ):
447+ for o1 , l1 , o2 , l2 in zip (seq ._offsets , seq ._lengths , value ._offsets , value ._lengths ):
448+ seq ._data [o1 :o1 + l1 ] = getattr (seq ._data [o1 :o1 + l1 ], op )(value ._data [o2 :o2 + l2 ])
449+
450+ else :
451+ args = [] if value is None else [value ] # Dealing with unary and binary ops.
452+ for o1 , l1 in zip (seq ._offsets , seq ._lengths ):
453+ seq ._data [o1 :o1 + l1 ] = getattr (seq ._data [o1 :o1 + l1 ], op )(* args )
454+
455+ return seq
456+
380457 def __iter__ (self ):
381458 if len (self ._lengths ) != len (self ._offsets ):
382459 raise ValueError ("ArraySequence object corrupted:"
0 commit comments