@@ -274,6 +274,56 @@ def test_arraysequence_getitem(self):
274274 check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
275275 check_arr_seq (seq_view , [d [:, 2 ] for d in SEQ_DATA ['data' ][::- 2 ]])
276276
277+ def test_arraysequence_setitem (self ):
278+ # Set one item
279+ seq = SEQ_DATA ['seq' ] * 0
280+ for i , e in enumerate (SEQ_DATA ['seq' ]):
281+ seq [i ] = e
282+
283+ check_arr_seq (seq , SEQ_DATA ['seq' ])
284+
285+ # Setitem with a scalar.
286+ seq = SEQ_DATA ['seq' ].copy ()
287+ seq [:] = 0
288+ assert_true (seq ._data .sum () == 0 )
289+
290+ # Setitem with a list of ndarray.
291+ seq = SEQ_DATA ['seq' ] * 0
292+ seq [:] = SEQ_DATA ['data' ]
293+ check_arr_seq (seq , SEQ_DATA ['data' ])
294+
295+ # Setitem using tuple indexing.
296+ seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
297+ seq [:, 0 ] = 0
298+ assert_true (seq ._data [:, 0 ].sum () == 0 )
299+
300+ # Setitem using tuple indexing.
301+ seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
302+ seq [range (len (seq ))] = 0
303+ assert_true (seq ._data .sum () == 0 )
304+
305+ # Setitem of a slice using another slice.
306+ seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
307+ seq [0 :4 ] = seq [5 :9 ]
308+ check_arr_seq (seq [0 :4 ], seq [5 :9 ])
309+
310+ # Setitem between array sequences with different number of sequences.
311+ seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
312+ assert_raises (ValueError , seq .__setitem__ , slice (0 , 4 ), seq [5 :10 ])
313+
314+ # Setitem between array sequences with different amount of points.
315+ seq1 = ArraySequence (np .arange (10 ).reshape (5 , 2 ))
316+ seq2 = ArraySequence (np .arange (15 ).reshape (5 , 3 ))
317+ assert_raises (ValueError , seq1 .__setitem__ , slice (0 , 5 ), seq2 )
318+
319+ # Setitem between array sequences with different common shape.
320+ seq1 = ArraySequence (np .arange (12 ).reshape (2 , 2 , 3 ))
321+ seq2 = ArraySequence (np .arange (8 ).reshape (2 , 2 , 2 ))
322+ assert_raises (ValueError , seq1 .__setitem__ , slice (0 , 2 ), seq2 )
323+
324+ # Invalid index.
325+ assert_raises (TypeError , seq .__setitem__ , object (), None )
326+
277327 def test_arraysequence_operators (self ):
278328 # Disable division per zero warnings.
279329 flags = np .seterr (divide = 'ignore' , invalid = 'ignore' )
@@ -375,61 +425,6 @@ def _test_binary(op, arrseq, scalars, seqs, inplace=False):
375425 # Restore flags.
376426 np .seterr (** flags )
377427
378-
379- def test_arraysequence_setitem (self ):
380- # Set one item
381- seq = SEQ_DATA ['seq' ] * 0
382- for i , e in enumerate (SEQ_DATA ['seq' ]):
383- seq [i ] = e
384-
385- check_arr_seq (seq , SEQ_DATA ['seq' ])
386-
387- # Get all items using indexing (creates a view).
388- indices = list (range (len (SEQ_DATA ['seq' ])))
389- seq_view = SEQ_DATA ['seq' ][indices ]
390- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
391- # We took all elements so the view should match the original.
392- check_arr_seq (seq_view , SEQ_DATA ['seq' ])
393-
394- # Get multiple items using ndarray of dtype integer.
395- for dtype in [np .int8 , np .int16 , np .int32 , np .int64 ]:
396- seq_view = SEQ_DATA ['seq' ][np .array (indices , dtype = dtype )]
397- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
398- # We took all elements so the view should match the original.
399- check_arr_seq (seq_view , SEQ_DATA ['seq' ])
400-
401- # Get multiple items out of order (creates a view).
402- SEQ_DATA ['rng' ].shuffle (indices )
403- seq_view = SEQ_DATA ['seq' ][indices ]
404- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
405- check_arr_seq (seq_view , [SEQ_DATA ['data' ][i ] for i in indices ])
406-
407- # Get slice (this will create a view).
408- seq_view = SEQ_DATA ['seq' ][::2 ]
409- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
410- check_arr_seq (seq_view , SEQ_DATA ['data' ][::2 ])
411-
412- # Use advanced indexing with ndarray of data type bool.
413- selection = np .array ([False , True , True , False , True ])
414- seq_view = SEQ_DATA ['seq' ][selection ]
415- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
416- check_arr_seq (seq_view ,
417- [SEQ_DATA ['data' ][i ]
418- for i , keep in enumerate (selection ) if keep ])
419-
420- # Test invalid indexing
421- assert_raises (TypeError , SEQ_DATA ['seq' ].__getitem__ , 'abc' )
422-
423- # Get specific columns.
424- seq_view = SEQ_DATA ['seq' ][:, 2 ]
425- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
426- check_arr_seq (seq_view , [d [:, 2 ] for d in SEQ_DATA ['data' ]])
427-
428- # Combining multiple slicing and indexing operations.
429- seq_view = SEQ_DATA ['seq' ][::- 2 ][:, 2 ]
430- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
431- check_arr_seq (seq_view , [d [:, 2 ] for d in SEQ_DATA ['data' ][::- 2 ]])
432-
433428 def test_arraysequence_repr (self ):
434429 # Test that calling repr on a ArraySequence object is not falling.
435430 repr (SEQ_DATA ['seq' ])
0 commit comments