11import math
22from itertools import product
3- from typing import Sequence , Union
3+ from typing import Sequence , Union , get_args
44
55import pytest
66from hypothesis import assume , given , note
@@ -33,11 +33,9 @@ def test_getitem(shape, data):
3333 size = math .prod (shape )
3434 dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
3535 obj = data .draw (
36- st .lists (
37- xps .from_dtype (dtype ),
38- min_size = size ,
39- max_size = size ,
40- ).map (lambda l : reshape (l , shape )),
36+ st .lists (xps .from_dtype (dtype ), min_size = size , max_size = size ).map (
37+ lambda l : reshape (l , shape )
38+ ),
4139 label = "obj" ,
4240 )
4341 x = xp .asarray (obj , dtype = dtype )
@@ -47,7 +45,6 @@ def test_getitem(shape, data):
4745 out = x [key ]
4846
4947 ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
50-
5148 _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
5249 if Ellipsis in _key :
5350 start_a = _key .index (Ellipsis )
@@ -78,7 +75,39 @@ def test_getitem(shape, data):
7875 ph .assert_array ("__getitem__" , out , expected )
7976
8077
81- # TODO: test_setitem
78+ @given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
79+ def test_setitem (shape , data ):
80+ size = math .prod (shape )
81+ dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
82+ obj = data .draw (
83+ st .lists (xps .from_dtype (dtype ), min_size = size , max_size = size ).map (
84+ lambda l : reshape (l , shape )
85+ ),
86+ label = "obj" ,
87+ )
88+ x = xp .asarray (obj , dtype = dtype )
89+ note (f"{ x = } " )
90+ key = data .draw (xps .indices (shape = shape , max_dims = 0 ), label = "key" )
91+ value = data .draw (
92+ xps .from_dtype (dtype ) | xps .arrays (dtype = dtype , shape = ()), label = "value"
93+ )
94+
95+ res = xp .asarray (x , copy = True )
96+ res [key ] = value
97+
98+ ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
99+ ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.shape" )
100+ if isinstance (value , get_args (Scalar )):
101+ msg = f"x[{ key } ]={ res [key ]!r} , but should be { value = } [__setitem__()]"
102+ if math .isnan (value ):
103+ assert xp .isnan (res [key ]), msg
104+ else :
105+ assert res [key ] == value , msg
106+ else :
107+ ph .assert_0d_equals ("__setitem__" , "value" , value , f"x[{ key } ]" , res [key ])
108+
109+
110+ # TODO: test boolean indexing
82111
83112
84113def make_param (method_name : str , dtype : DataType , stype : ScalarType ) -> Param :
0 commit comments