Skip to content

Commit e800d9f

Browse files
committed
Backport type stubs in python<=3.9 to support PEP 585 and 604 syntax.
1 parent 7da8b82 commit e800d9f

File tree

6 files changed

+92
-8
lines changed

6 files changed

+92
-8
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ Added
2121
- Support for pydantic models and attr defines similar to dataclasses.
2222
- Support for `pydantic types
2323
<https://docs.pydantic.dev/usage/types/#pydantic-types>`__.
24+
- Backport type stubs in python<=3.9 to support PEP `585
25+
<https://peps.python.org/pep-0585/>`__ and `604
26+
<https://peps.python.org/pep-0604/>`__ syntax.
2427

2528
Fixed
2629
^^^^^
@@ -39,6 +42,8 @@ Changed
3942
- Include enum members in error when invalid value is given
4043
(`pytorch-lightning#17247
4144
<https://github.com/Lightning-AI/lightning/issues/17247>`__).
45+
- The ``signatures`` extras now installs the ``typing-extensions`` package on
46+
python<=3.9.
4247

4348
Deprecated
4449
^^^^^^^^^^

README.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,10 +1689,9 @@ Many of the types defined in stub files use the latest syntax for type hints,
16891689
that is, bitwise or operator ``|`` for unions and generics, e.g.
16901690
``list[<type>]`` instead of ``typing.List[<type>]``, see PEPs `604
16911691
<https://peps.python.org/pep-0604>`__ and `585
1692-
<https://peps.python.org/pep-0585>`__. The types with this new syntax can't be
1693-
evaluated at runtime in Python versions older than ``3.10``. Since jsonargparse
1694-
needs to interpret the types at runtime, these will only be resolved in newer
1695-
versions of Python.
1692+
<https://peps.python.org/pep-0585>`__. On python>=3.10 these are fully
1693+
supported. On python<=3.9 backporting these types is attempted and in some cases
1694+
it can fail. On failure the type annotation is set to ``Any``.
16961695

16971696
Most of the types in the Python standard library have their types in stubs. An
16981697
example from the standard library would be:

jsonargparse/_backports.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import ast
2+
from collections import namedtuple
3+
from copy import deepcopy
4+
from typing import Dict, FrozenSet, List, Set, Tuple, Type, Union
5+
6+
var_map = namedtuple('var_map', 'name value')
7+
none_map = var_map(name='NoneType', value=type(None))
8+
union_map = var_map(name='Union', value=Union)
9+
pep585_map = {
10+
'dict': var_map(name='Dict', value=Dict),
11+
'frozenset': var_map(name='FrozenSet', value=FrozenSet),
12+
'list': var_map(name='List', value=List),
13+
'set': var_map(name='Set', value=Set),
14+
'tuple': var_map(name='Tuple', value=Tuple),
15+
'type': var_map(name='Type', value=Type),
16+
}
17+
18+
19+
class BackportTypeHints(ast.NodeTransformer):
20+
21+
_typing = __import__('typing')
22+
23+
def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
24+
if isinstance(node.value, ast.Name) and node.value.id in pep585_map:
25+
value = self.new_name_load(pep585_map[node.value.id])
26+
else:
27+
value = node.value # type: ignore
28+
return ast.Subscript(
29+
value=value,
30+
slice=self.visit(node.slice),
31+
ctx=ast.Load(),
32+
)
33+
34+
def visit_Constant(self, node: ast.Constant) -> Union[ast.Constant, ast.Name]:
35+
if node.value is None:
36+
return self.new_name_load(none_map)
37+
return node
38+
39+
def visit_BinOp(self, node: ast.BinOp) -> Union[ast.BinOp, ast.Subscript]:
40+
out_node: Union[ast.BinOp, ast.Subscript] = node
41+
if isinstance(node.op, ast.BitOr):
42+
elts: list = []
43+
self.append_union_elts(node.left, elts)
44+
self.append_union_elts(node.right, elts)
45+
out_node = ast.Subscript(
46+
value=self.new_name_load(union_map),
47+
slice=ast.Index(
48+
value=ast.Tuple(elts=elts, ctx=ast.Load()),
49+
ctx=ast.Load(),
50+
),
51+
ctx=ast.Load(),
52+
)
53+
return out_node
54+
55+
def append_union_elts(self, node: ast.AST, elts: list) -> None:
56+
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
57+
self.append_union_elts(node.left, elts)
58+
self.append_union_elts(node.right, elts)
59+
else:
60+
elts.append(self.visit(node))
61+
62+
def new_name_load(self, var: var_map) -> ast.Name:
63+
name = f'_{self.__class__.__name__}_{var.name}'
64+
self.exec_vars[name] = var.value
65+
return ast.Name(id=name, ctx=ast.Load())
66+
67+
def backport(self, input_ast: ast.AST, exec_vars: dict) -> ast.AST:
68+
for key, value in exec_vars.items():
69+
if getattr(value, '__module__', '') == 'collections.abc':
70+
if hasattr(self._typing, key):
71+
exec_vars[key] = getattr(self._typing, key)
72+
self.exec_vars = exec_vars
73+
backport_ast = self.visit(deepcopy(input_ast))
74+
return ast.fix_missing_locations(backport_ast)

jsonargparse/_stubs_resolver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,10 @@ def get_arg_type(arg_ast, aliases):
293293
type_alias = typing_extensions_import('TypeAlias')
294294
if type_alias:
295295
exec_vars['TypeAlias'] = type_alias
296+
if sys.version_info < (3, 10):
297+
from ._backports import BackportTypeHints
298+
backporter = BackportTypeHints()
299+
type_ast = backporter.backport(type_ast, exec_vars)
296300
try:
297301
exec(compile(type_ast, filename="<ast>", mode="exec"), exec_vars, exec_vars)
298302
except NameError as ex:

jsonargparse_tests/test_stubs_resolver.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ def test_get_params_conditional_python_version(self):
9595
self.assertEqual(['a', 'version'], get_param_names(params))
9696
if sys.version_info >= (3, 10):
9797
self.assertEqual('int | float | str | bytes | bytearray | None', str(params[0].annotation))
98+
elif sys.version_info[:2] == (3, 9):
99+
self.assertEqual('typing.Union[int, float, str, bytes, bytearray, NoneType]', str(params[0].annotation))
98100
else:
99-
expected = Any if sys.version_info < (3, 9) else inspect._empty
100-
self.assertEqual(expected, params[0].annotation)
101+
self.assertEqual(Any, params[0].annotation)
101102
self.assertEqual(int, params[1].annotation)
102103
with mock_typeshed_client_unavailable():
103104
params = get_params(Random, 'seed')
@@ -140,7 +141,7 @@ def test_get_params_function(self):
140141
def test_get_param_relative_import_from_init(self):
141142
params = get_params(yaml.safe_load)
142143
self.assertEqual(['stream'], get_param_names(params))
143-
if sys.version_info >= (3, 10):
144+
if sys.version_info >= (3, 8):
144145
self.assertNotEqual(params[0].annotation, inspect._empty)
145146
else:
146147
self.assertEqual(params[0].annotation, inspect._empty)

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ all = [
4949
"jsonargparse[reconplogger]",
5050
]
5151
signatures = [
52+
"jsonargparse[typing-extensions]",
5253
"docstring-parser>=0.15",
5354
"typeshed-client>=2.1.0",
5455
]
@@ -75,7 +76,7 @@ omegaconf = [
7576
"omegaconf>=2.1.1",
7677
]
7778
typing-extensions = [
78-
"typing-extensions>=3.10.0.0; python_version < '3.8'",
79+
"typing-extensions>=3.10.0.0; python_version < '3.10'",
7980
]
8081
reconplogger = [
8182
"reconplogger>=4.4.0",

0 commit comments

Comments
 (0)