Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 92a8a8c

Browse files
committed
add unit tests and handler for union
1 parent f437095 commit 92a8a8c

File tree

3 files changed

+69
-11
lines changed

3 files changed

+69
-11
lines changed

numba_typing/tests/__init__.py

Whitespace-only changes.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
import type_annotations
3+
from typing import Union, Dict, List
4+
5+
6+
class TestTypeAnnotations(unittest.TestCase):
7+
8+
def test_get_func_annotations_exception(self):
9+
10+
def foo(a: int, b, c: str = "string"):
11+
pass
12+
self.assertRaises(SyntaxError, type_annotations.get_func_annotations, foo)
13+
14+
def test_get_cls_annotations(self):
15+
class MyClass(object):
16+
x: int = 3
17+
y: str = "string"
18+
19+
def __init__(self, x: str, y: int):
20+
self.x = x
21+
self.y = y
22+
23+
self.assertEqual(type_annotations.get_cls_annotations(MyClass), ({'x': int, 'y': str}, {}))
24+
25+
def test_get_func_annotations(self):
26+
def func_one(a: int, b: Union[int, float], c: str):
27+
pass
28+
with self.subTest("annotations"):
29+
self.assertEqual(type_annotations.get_func_annotations(func_one),
30+
({'a': int, 'b': [int, float], 'c': str}, {}))
31+
32+
def func_two(a: int = 2, b: str = "string", c: List[int] = [1, 2, 3]):
33+
pass
34+
with self.subTest("annotations and all default values"):
35+
self.assertEqual(type_annotations.get_func_annotations(func_two),
36+
({'a': int, 'b': str, 'c': List[int]}, {'a': 2, 'b': 'string', 'c': [1, 2, 3]}))
37+
38+
def func_three(a: Dict[int, str], b: str = "string", c: int = 1):
39+
pass
40+
with self.subTest("annotations and not all default values"):
41+
self.assertEqual(type_annotations.get_func_annotations(func_three),
42+
({'a': Dict[int, str], 'b': str, 'c': int}, {'b': 'string', 'c': 1}))
43+
44+
45+
if __name__ == '__main__':
46+
unittest.main()

numba_typing/type_annotations.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from inspect import signature
2-
from typing import get_type_hints
2+
from typing import get_type_hints, Union
33

44

55
def get_func_annotations(func):
@@ -11,8 +11,11 @@ def get_func_annotations(func):
1111
for name, param in sig.parameters.items():
1212
if param.annotation == sig.empty:
1313
raise SyntaxError(f'Not found annotation for parameter {name}')
14-
15-
annotations[name] = param.annotation
14+
annot = param.annotation
15+
if get_args_union(annot):
16+
annotations[name] = get_args_union(annot)
17+
else:
18+
annotations[name] = annot
1619
if param.default != sig.empty:
1720
defaults[name] = param.default
1821

@@ -21,11 +24,20 @@ def get_func_annotations(func):
2124

2225
def get_cls_annotations(cls):
2326
"""Get annotations of class attributes."""
24-
return get_type_hints(cls)
25-
26-
27-
if __name__ == '__main__':
28-
def foo(a: int, b: int = 3):
29-
pass
30-
31-
print(get_func_annotations(foo))
27+
annotations = get_type_hints(cls)
28+
for x in annotations:
29+
if get_args_union(annotations[x]):
30+
annotations[x] = get_args_union(annotations[x])
31+
return annotations, {}
32+
33+
34+
def get_args_union(annot):
35+
try:
36+
annot.__origin__
37+
except:
38+
return None
39+
else:
40+
if annot.__origin__ is Union:
41+
return list(annot.__args__)
42+
else:
43+
return None

0 commit comments

Comments
 (0)