103103 Int64 ,
104104 OptParType ,
105105 TensorType ,
106+ OptTensorType ,
106107)
107108from onnx_array_api .npx .npx_var import Input , Var
108109
@@ -125,35 +126,62 @@ def test_shape_inference(self):
125126 self .assertEqual (output .type .tensor_type .elem_type , TensorProto .FLOAT )
126127
127128 def test_tensor (self ):
128- dt = TensorType ["float32" ]
129+ dt = TensorType ["float32" , "F32" ]
129130 self .assertEqual (len (dt .dtypes ), 1 )
130131 self .assertEqual (dt .dtypes [0 ].dtype , ElemType .float32 )
131132 self .assertEmpty (dt .shape )
132- self .assertEqual (dt .type_name (), "TensorType['float32']" )
133+ self .assertEqual (dt .type_name (), "TensorType['float32', 'F32' ]" )
133134
134- dt = TensorType ["float32" ]
135+ dt = TensorType ["float32" , "F32" ]
135136 self .assertEqual (len (dt .dtypes ), 1 )
136137 self .assertEqual (dt .dtypes [0 ].dtype , ElemType .float32 )
137- self .assertEqual (dt .type_name (), "TensorType['float32']" )
138+ self .assertEqual (dt .type_name (), "TensorType['float32', 'F32' ]" )
138139
139- dt = TensorType [np .float32 ]
140+ dt = TensorType [np .float32 , "F32" ]
140141 self .assertEqual (len (dt .dtypes ), 1 )
141142 self .assertEqual (dt .dtypes [0 ].dtype , ElemType .float32 )
142- self .assertEqual (dt .type_name (), "TensorType['float32']" )
143+ self .assertEqual (dt .type_name (), "TensorType['float32', 'F32' ]" )
143144 self .assertEmpty (dt .shape )
144145
145- dt = TensorType [np .str_ ]
146+ dt = TensorType [np .str_ , "TEXT" ]
146147 self .assertEqual (len (dt .dtypes ), 1 )
147148 self .assertEqual (dt .dtypes [0 ].dtype , ElemType .str_ )
148- self .assertEqual (dt .type_name (), "TensorType[strings]" )
149+ self .assertEqual (dt .type_name (), "TensorType[strings, 'TEXT']" )
150+ self .assertEmpty (dt .shape )
151+
152+ self .assertRaise (lambda : TensorType [None ], TypeError )
153+ self .assertRaise (lambda : TensorType [{np .float32 , np .str_ }], TypeError )
154+
155+ def test_opt_tensor (self ):
156+ dt = OptTensorType ["float32" , "F32" ]
157+ self .assertEqual (len (dt .dtypes ), 1 )
158+ self .assertEqual (dt .dtypes [0 ].dtype , ElemType .float32 )
159+ self .assertEmpty (dt .shape )
160+ self .assertEqual (dt .type_name (), "OptTensorType['float32', 'F32']" )
161+
162+ dt = OptTensorType ["float32" , "F32" ]
163+ self .assertEqual (len (dt .dtypes ), 1 )
164+ self .assertEqual (dt .dtypes [0 ].dtype , ElemType .float32 )
165+ self .assertEqual (dt .type_name (), "OptTensorType['float32', 'F32']" )
166+
167+ dt = OptTensorType [np .float32 , "F32" ]
168+ self .assertEqual (len (dt .dtypes ), 1 )
169+ self .assertEqual (dt .dtypes [0 ].dtype , ElemType .float32 )
170+ self .assertEqual (dt .type_name (), "OptTensorType['float32', 'F32']" )
171+ self .assertEmpty (dt .shape )
172+
173+ dt = OptTensorType [np .str_ , "TEXT" ]
174+ self .assertEqual (len (dt .dtypes ), 1 )
175+ self .assertEqual (dt .dtypes [0 ].dtype , ElemType .str_ )
176+ self .assertEqual (dt .type_name (), "OptTensorType[strings, 'TEXT']" )
149177 self .assertEmpty (dt .shape )
150178
151179 self .assertRaise (lambda : TensorType [None ], TypeError )
152180 self .assertRaise (lambda : TensorType [{np .float32 , np .str_ }], TypeError )
153181
154182 def test_superset (self ):
155- t1 = TensorType [ElemType .numerics ]
156- t2 = TensorType [ElemType .float64 ]
183+ t1 = TensorType [ElemType .numerics , "T" ]
184+ t2 = TensorType [ElemType .float64 , "F64" ]
157185 self .assertTrue (t1 .issuperset (t2 ))
158186 t1 = Float32 [None ]
159187 t2 = Float32 [None ]
@@ -167,14 +195,14 @@ def test_superset(self):
167195 t1 = Float32 ["N" ]
168196 t2 = Float32 [5 ]
169197 self .assertTrue (t1 .issuperset (t2 ))
170- t1 = TensorType [ElemType .int64 ]
198+ t1 = TensorType [ElemType .int64 , "I" ]
171199 t2 = Int64 [1 ]
172200 self .assertTrue (t1 .issuperset (t2 ))
173201
174202 def test_sig (self ):
175203 def local1 (
176- x : TensorType [ElemType .floats ],
177- ) -> TensorType [ElemType .floats ]:
204+ x : TensorType [ElemType .floats , "T" ],
205+ ) -> TensorType [ElemType .floats , "T" ]:
178206 return x
179207
180208 def local2 (
@@ -2536,13 +2564,17 @@ def test_numpy_all_empty_axis_1(self):
25362564 got = ref .run (None , {"A" : data })
25372565 self .assertEqualArray (y , got [0 ])
25382566
2539- @unittest .skipIf (True , reason = "Fails to follow Array API" )
2540- def test_get_item (self ):
2567+ def test_get_item_b (self ):
25412568 a = EagerNumpyTensor (np .array ([True ], dtype = np .bool_ ))
25422569 i = a [0 ]
25432570 self .assertEqualArray (i .numpy (), a .numpy ()[0 ])
25442571
2572+ def test_get_item_i8 (self ):
2573+ a = EagerNumpyTensor (np .array ([5 , 6 ], dtype = np .int8 ))
2574+ i = a [0 ]
2575+ self .assertEqualArray (i .numpy (), a .numpy ()[0 ])
2576+
25452577
25462578if __name__ == "__main__" :
2547- # TestNpx().test_get_item ()
2579+ TestNpx ().test_filter ()
25482580 unittest .main (verbosity = 2 )
0 commit comments