1+ import array as pyarray
2+
13import pytest
24
35from arrayfire .array_api import Array , float32 , int16
46from arrayfire .array_api ._dtypes import supported_dtypes
57
68# TODO change separated methods with setup and teardown to avoid code duplication
9+ # TODO add tests for array arguments: device, offset, strides
710
811
9- def test_empty_array () -> None :
12+ def test_create_empty_array () -> None :
1013 array = Array ()
1114
1215 assert array .dtype == float32
@@ -16,7 +19,7 @@ def test_empty_array() -> None:
1619 assert len (array ) == 0
1720
1821
19- def test_empty_array_with_nonempty_dtype () -> None :
22+ def test_create_empty_array_with_nonempty_dtype () -> None :
2023 array = Array (dtype = int16 )
2124
2225 assert array .dtype == int16
@@ -26,7 +29,32 @@ def test_empty_array_with_nonempty_dtype() -> None:
2629 assert len (array ) == 0
2730
2831
29- def test_empty_array_with_nonempty_shape () -> None :
32+ def test_create_empty_array_with_str_dtype () -> None :
33+ array = Array (dtype = "short int" )
34+
35+ assert array .dtype == int16
36+ assert array .ndim == 0
37+ assert array .size == 0
38+ assert array .shape == ()
39+ assert len (array ) == 0
40+
41+
42+ def test_create_empty_array_with_literal_dtype () -> None :
43+ array = Array (dtype = "h" )
44+
45+ assert array .dtype == int16
46+ assert array .ndim == 0
47+ assert array .size == 0
48+ assert array .shape == ()
49+ assert len (array ) == 0
50+
51+
52+ def test_create_empty_array_with_not_matching_str_dtype () -> None :
53+ with pytest .raises (TypeError ):
54+ Array (dtype = "hello world" )
55+
56+
57+ def test_create_empty_array_with_nonempty_shape () -> None :
3058 array = Array (shape = (2 , 3 ))
3159
3260 assert array .dtype == float32
@@ -36,7 +64,7 @@ def test_empty_array_with_nonempty_shape() -> None:
3664 assert len (array ) == 2
3765
3866
39- def test_array_from_1d_list () -> None :
67+ def test_create_array_from_1d_list () -> None :
4068 array = Array ([1 , 2 , 3 ])
4169
4270 assert array .dtype == float32
@@ -46,11 +74,22 @@ def test_array_from_1d_list() -> None:
4674 assert len (array ) == 3
4775
4876
49- def test_array_from_2d_list () -> None :
77+ def test_create_array_from_2d_list () -> None :
5078 with pytest .raises (TypeError ):
5179 Array ([[1 , 2 , 3 ], [1 , 2 , 3 ]])
5280
5381
82+ def test_create_array_from_pyarray () -> None :
83+ py_array = pyarray .array ("f" , [1 , 2 , 3 ])
84+ array = Array (py_array )
85+
86+ assert array .dtype == float32
87+ assert array .ndim == 1
88+ assert array .size == 3
89+ assert array .shape == (3 ,)
90+ assert len (array ) == 3
91+
92+
5493def test_array_from_list_with_unsupported_dtype () -> None :
5594 for dtype in supported_dtypes :
5695 if dtype == float32 :
0 commit comments