Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit ef7e313

Browse files
committed
add function expend_annotations
1 parent 1700c4c commit ef7e313

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

numba_typing/type_annotations.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from inspect import signature
22
from typing import get_type_hints, Union
3+
from itertools import product
4+
from copy import deepcopy
35

46

57
def get_func_annotations(func):
@@ -10,7 +12,7 @@ def get_func_annotations(func):
1012

1113
for name, param in sig.parameters.items():
1214
if param.annotation == sig.empty:
13-
raise SyntaxError(f'Not found annotation for parameter {name}')
15+
raise SyntaxError(f'No annotation for parameter {name}')
1416

1517
annotations[name] = get_annotation_types(param.annotation)
1618
if param.default != sig.empty:
@@ -36,3 +38,52 @@ def get_annotation_types(annotation):
3638
pass
3739

3840
return [annotation, ]
41+
42+
43+
def expend_annotations(info):
44+
"""Get all variants of annotations."""
45+
annot = info[0]
46+
vals = info[1]
47+
list_of_annot = list(product(*annot.values()))
48+
tvs = {}
49+
tvs_unique = {}
50+
count = 1
51+
for x in annot:
52+
for y in annot[x]:
53+
if isinstance(y, TypeVar) and y.__constraints__ != ():
54+
if x in tvs:
55+
tvs[x].append(y)
56+
else:
57+
tvs[x] = [y, ]
58+
if y not in tvs_unique.keys():
59+
tvs_unique[y] = y.__constraints__
60+
count *= len(y.__constraints__)
61+
62+
prod = list(product(*tvs_unique.values()))
63+
temp_res = []
64+
65+
for i in range(len(list_of_annot)):
66+
temp = []
67+
temp_dict = {}
68+
num = 0
69+
for attr in annot:
70+
temp_dict[attr] = list_of_annot[i][num]
71+
num += 1
72+
temp.append(temp_dict)
73+
temp.append(vals)
74+
temp_res.append(temp)
75+
76+
result = []
77+
for examp in temp_res:
78+
for i in range(count):
79+
result.append(deepcopy(examp))
80+
81+
types = list(tvs_unique.keys())
82+
for k in range(len(result)):
83+
pos = k % count
84+
for x in result[k][0]:
85+
for i in range(len(prod[pos])):
86+
if result[k][0][x] == types[i]:
87+
result[k][0][x] = prod[pos][i]
88+
89+
return result

0 commit comments

Comments
 (0)