Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 402cae1

Browse files
committed
fix comments - improve product_annotations function
1 parent aba8d0b commit 402cae1

File tree

1 file changed

+18
-22
lines changed

1 file changed

+18
-22
lines changed

numba_typing/type_annotations.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,32 +42,28 @@ def get_annotation_types(annotation):
4242

4343
def product_annotations(annotations):
4444
"""Get all variants of annotations."""
45-
annot = annotations[0]
46-
vals = annotations[1]
47-
list_of_annot = list(product(*annot.values()))
48-
tvs = {}
49-
tvs_unique = {}
45+
types, vals = annotations
46+
types_product = list(product(*types.values()))
47+
typevars_unique = {}
5048
count = 1
51-
for x in annot:
52-
for y in annot[x]:
53-
if isinstance(y, TypeVar) and y.__constraints__ != ():
54-
if x in tvs:
55-
tvs[x].append(y)
56-
else:
57-
tvs[x] = [y, ]
58-
if y not in tvs_unique.keys():
59-
tvs_unique[y] = y.__constraints__
60-
count *= len(y.__constraints__)
61-
62-
prod = list(product(*tvs_unique.values()))
49+
for name, typs in types.items():
50+
for typ in typs:
51+
if not isinstance(typ, TypeVar) or not typ.__constraints__:
52+
continue
53+
54+
if typ not in typevars_unique:
55+
typevars_unique[typ] = typ.__constraints__
56+
count *= len(typ.__constraints__)
57+
58+
prod = list(product(*typevars_unique.values()))
6359
temp_res = []
6460

65-
for i in range(len(list_of_annot)):
61+
for typs in types_product:
6662
temp = []
6763
temp_dict = {}
6864
num = 0
69-
for attr in annot:
70-
temp_dict[attr] = list_of_annot[i][num]
65+
for attr in types:
66+
temp_dict[attr] = typs[num]
7167
num += 1
7268
temp.append(temp_dict)
7369
temp.append(vals)
@@ -78,12 +74,12 @@ def product_annotations(annotations):
7874
for i in range(count):
7975
result.append(deepcopy(examp))
8076

81-
types = list(tvs_unique.keys())
77+
name_of_typevars = list(typevars_unique.keys())
8278
for k in range(len(result)):
8379
pos = k % count
8480
for x in result[k][0]:
8581
for i in range(len(prod[pos])):
86-
if result[k][0][x] == types[i]:
82+
if result[k][0][x] == name_of_typevars[i]:
8783
result[k][0][x] = prod[pos][i]
8884

8985
return result

0 commit comments

Comments
 (0)