Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit cc09ebb

Browse files
committed
fix problem with tests
1 parent 718a114 commit cc09ebb

File tree

2 files changed

+19
-20
lines changed

2 files changed

+19
-20
lines changed

numba_typing/tests/test_type_annotations.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
from typing import Union, Dict, List, TypeVar
44

55

6+
def check_equal(result, expected):
7+
if len(result) != len(expected):
8+
return False
9+
for sig in result:
10+
if sig not in expected:
11+
return False
12+
return True
13+
14+
615
class TestTypeAnnotations(unittest.TestCase):
716

817
def test_get_func_annotations_exceptions(self):
@@ -46,17 +55,6 @@ def func_three(a: Dict[int, str], b: str = "string", c: int = 1):
4655
with self.subTest(func=f.__name__):
4756
self.assertEqual(type_annotations.get_func_annotations(f), expected)
4857

49-
''' def test_product_annotations(self):
50-
51-
S = TypeVar('S', float, str)
52-
annotations = ({'a': [int], 'b': [int, float], 'c': [S]}, {})
53-
result = type_annotations.product_annotations(annotations)
54-
expected = [[{'a': int, 'b': int, 'c': float}, {}],
55-
[{'a': int, 'b': int, 'c': str}, {}],
56-
[{'a': int, 'b': float, 'c': float}, {}],
57-
[{'a': int, 'b': float, 'c': str}, {}]]
58-
self.assertEqual(result, expected)'''
59-
6058
def test_convert_to_sig_list(self):
6159
T = TypeVar('T', int, str)
6260
S = TypeVar('S', float, str)
@@ -133,7 +131,8 @@ def test_get_internal_typevars(self):
133131
{'a': str, 'b': Dict[str, bool]}]
134132

135133
result = type_annotations.get_internal_typevars(signature)
136-
self.assertEqual(result, expected)
134+
135+
self.assertTrue(check_equal(result, expected))
137136

138137
def test_update_sig(self):
139138
T = TypeVar('T', int, str)
@@ -158,7 +157,8 @@ def test_expand_typevars(self):
158157
{'a': str, 'b': Dict[str, bool], 'c': int}]
159158

160159
result = type_annotations.expand_typevars(sig, unique_typevars)
161-
self.assertEqual(result, expected)
160+
161+
self.assertTrue(check_equal(result, expected))
162162

163163
def test_product_annotations(self):
164164

@@ -176,12 +176,10 @@ def test_product_annotations(self):
176176
[{'a': int, 'b': Dict[int, bool], 'c': bool, 'd': int}, {'d': 3}],
177177
[{'a': str, 'b': Dict[str, float], 'c': bool, 'd': int}, {'d': 3}],
178178
[{'a': str, 'b': Dict[str, bool], 'c': bool, 'd': int}, {'d': 3}]]
179-
180179

181180
result = type_annotations.product_annotations(annotations)
182-
#print(result)
183-
184-
self.assertEqual(result, expected)
181+
182+
self.assertTrue(check_equal(result, expected))
185183

186184

187185
if __name__ == '__main__':

numba_typing/type_annotations.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from inspect import signature
2-
from typing import get_type_hints, Union, TypeVar,_GenericAlias
2+
from typing import get_type_hints, Union, TypeVar, _GenericAlias
33
from itertools import product
44
from copy import deepcopy
55

@@ -43,15 +43,16 @@ def get_annotation_types(annotation):
4343
def product_annotations(annotations):
4444
'''Get all variants of annotations.'''
4545
types, vals = annotations
46-
list_of_sig = convert_to_sig_list(types)
46+
list_of_sig = convert_to_sig_list(types)
4747
signature = []
4848
#unique_typevars = get_internal_typevars(list_of_sig)
4949

5050
for sig in list_of_sig:
5151
signature.extend(get_internal_typevars(sig))
52-
52+
5353
return add_vals_to_signature(signature, vals)
5454

55+
5556
def add_vals_to_signature(signature, vals):
5657
'''Add default values ​​to all signatures'''
5758
result = []

0 commit comments

Comments
 (0)