|
| 1 | +import ast |
| 2 | +from collections import namedtuple |
| 3 | +from copy import deepcopy |
| 4 | +from typing import Dict, FrozenSet, List, Set, Tuple, Type, Union |
| 5 | + |
| 6 | +var_map = namedtuple('var_map', 'name value') |
| 7 | +none_map = var_map(name='NoneType', value=type(None)) |
| 8 | +union_map = var_map(name='Union', value=Union) |
| 9 | +pep585_map = { |
| 10 | + 'dict': var_map(name='Dict', value=Dict), |
| 11 | + 'frozenset': var_map(name='FrozenSet', value=FrozenSet), |
| 12 | + 'list': var_map(name='List', value=List), |
| 13 | + 'set': var_map(name='Set', value=Set), |
| 14 | + 'tuple': var_map(name='Tuple', value=Tuple), |
| 15 | + 'type': var_map(name='Type', value=Type), |
| 16 | +} |
| 17 | + |
| 18 | + |
| 19 | +class BackportTypeHints(ast.NodeTransformer): |
| 20 | + |
| 21 | + _typing = __import__('typing') |
| 22 | + |
| 23 | + def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: |
| 24 | + if isinstance(node.value, ast.Name) and node.value.id in pep585_map: |
| 25 | + value = self.new_name_load(pep585_map[node.value.id]) |
| 26 | + else: |
| 27 | + value = node.value # type: ignore |
| 28 | + return ast.Subscript( |
| 29 | + value=value, |
| 30 | + slice=self.visit(node.slice), |
| 31 | + ctx=ast.Load(), |
| 32 | + ) |
| 33 | + |
| 34 | + def visit_Constant(self, node: ast.Constant) -> Union[ast.Constant, ast.Name]: |
| 35 | + if node.value is None: |
| 36 | + return self.new_name_load(none_map) |
| 37 | + return node |
| 38 | + |
| 39 | + def visit_BinOp(self, node: ast.BinOp) -> Union[ast.BinOp, ast.Subscript]: |
| 40 | + out_node: Union[ast.BinOp, ast.Subscript] = node |
| 41 | + if isinstance(node.op, ast.BitOr): |
| 42 | + elts: list = [] |
| 43 | + self.append_union_elts(node.left, elts) |
| 44 | + self.append_union_elts(node.right, elts) |
| 45 | + out_node = ast.Subscript( |
| 46 | + value=self.new_name_load(union_map), |
| 47 | + slice=ast.Index( |
| 48 | + value=ast.Tuple(elts=elts, ctx=ast.Load()), |
| 49 | + ctx=ast.Load(), |
| 50 | + ), |
| 51 | + ctx=ast.Load(), |
| 52 | + ) |
| 53 | + return out_node |
| 54 | + |
| 55 | + def append_union_elts(self, node: ast.AST, elts: list) -> None: |
| 56 | + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): |
| 57 | + self.append_union_elts(node.left, elts) |
| 58 | + self.append_union_elts(node.right, elts) |
| 59 | + else: |
| 60 | + elts.append(self.visit(node)) |
| 61 | + |
| 62 | + def new_name_load(self, var: var_map) -> ast.Name: |
| 63 | + name = f'_{self.__class__.__name__}_{var.name}' |
| 64 | + self.exec_vars[name] = var.value |
| 65 | + return ast.Name(id=name, ctx=ast.Load()) |
| 66 | + |
| 67 | + def backport(self, input_ast: ast.AST, exec_vars: dict) -> ast.AST: |
| 68 | + for key, value in exec_vars.items(): |
| 69 | + if getattr(value, '__module__', '') == 'collections.abc': |
| 70 | + if hasattr(self._typing, key): |
| 71 | + exec_vars[key] = getattr(self._typing, key) |
| 72 | + self.exec_vars = exec_vars |
| 73 | + backport_ast = self.visit(deepcopy(input_ast)) |
| 74 | + return ast.fix_missing_locations(backport_ast) |
0 commit comments