|
43 | 43 |
|
44 | 44 | import argparse |
45 | 45 | import glob |
| 46 | +import keyword |
46 | 47 | import os |
47 | 48 | import os.path |
48 | 49 | import sys |
|
80 | 81 | ClassDef, |
81 | 82 | ComparisonExpr, |
82 | 83 | Decorator, |
| 84 | + DictExpr, |
83 | 85 | EllipsisExpr, |
84 | 86 | Expression, |
85 | 87 | FloatExpr, |
|
126 | 128 | from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression |
127 | 129 | from mypy.types import ( |
128 | 130 | OVERLOAD_NAMES, |
| 131 | + TPDICT_NAMES, |
129 | 132 | AnyType, |
130 | 133 | CallableType, |
131 | 134 | Instance, |
@@ -405,6 +408,14 @@ def visit_tuple_expr(self, node: TupleExpr) -> str: |
405 | 408 | def visit_list_expr(self, node: ListExpr) -> str: |
406 | 409 | return f"[{', '.join(n.accept(self) for n in node.items)}]" |
407 | 410 |
|
| 411 | + def visit_dict_expr(self, o: DictExpr) -> str: |
| 412 | + dict_items = [] |
| 413 | + for key, value in o.items: |
| 414 | + # This is currently only used for TypedDict where all keys are strings. |
| 415 | + assert isinstance(key, StrExpr) |
| 416 | + dict_items.append(f"{key.accept(self)}: {value.accept(self)}") |
| 417 | + return f"{{{', '.join(dict_items)}}}" |
| 418 | + |
408 | 419 | def visit_ellipsis(self, node: EllipsisExpr) -> str: |
409 | 420 | return "..." |
410 | 421 |
|
@@ -641,6 +652,7 @@ def visit_mypy_file(self, o: MypyFile) -> None: |
641 | 652 | "_typeshed": ["Incomplete"], |
642 | 653 | "typing": ["Any", "TypeVar"], |
643 | 654 | "collections.abc": ["Generator"], |
| 655 | + "typing_extensions": ["TypedDict"], |
644 | 656 | } |
645 | 657 | for pkg, imports in known_imports.items(): |
646 | 658 | for t in imports: |
@@ -1014,6 +1026,13 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: |
1014 | 1026 | assert isinstance(o.rvalue, CallExpr) |
1015 | 1027 | self.process_namedtuple(lvalue, o.rvalue) |
1016 | 1028 | continue |
| 1029 | + if ( |
| 1030 | + isinstance(lvalue, NameExpr) |
| 1031 | + and isinstance(o.rvalue, CallExpr) |
| 1032 | + and self.is_typeddict(o.rvalue) |
| 1033 | + ): |
| 1034 | + self.process_typeddict(lvalue, o.rvalue) |
| 1035 | + continue |
1017 | 1036 | if ( |
1018 | 1037 | isinstance(lvalue, NameExpr) |
1019 | 1038 | and not self.is_private_name(lvalue.name) |
@@ -1082,6 +1101,75 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: |
1082 | 1101 | self.add(f"{self._indent} {item}: Incomplete\n") |
1083 | 1102 | self._state = CLASS |
1084 | 1103 |
|
| 1104 | + def is_typeddict(self, expr: CallExpr) -> bool: |
| 1105 | + callee = expr.callee |
| 1106 | + return ( |
| 1107 | + isinstance(callee, NameExpr) and self.refers_to_fullname(callee.name, TPDICT_NAMES) |
| 1108 | + ) or ( |
| 1109 | + isinstance(callee, MemberExpr) |
| 1110 | + and isinstance(callee.expr, NameExpr) |
| 1111 | + and f"{callee.expr.name}.{callee.name}" in TPDICT_NAMES |
| 1112 | + ) |
| 1113 | + |
| 1114 | + def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None: |
| 1115 | + if self._state != EMPTY: |
| 1116 | + self.add("\n") |
| 1117 | + |
| 1118 | + if not isinstance(rvalue.args[0], StrExpr): |
| 1119 | + self.add(f"{self._indent}{lvalue.name}: Incomplete") |
| 1120 | + self.import_tracker.require_name("Incomplete") |
| 1121 | + return |
| 1122 | + |
| 1123 | + items: list[tuple[str, Expression]] = [] |
| 1124 | + total: Expression | None = None |
| 1125 | + if len(rvalue.args) > 1 and rvalue.arg_kinds[1] == ARG_POS: |
| 1126 | + if not isinstance(rvalue.args[1], DictExpr): |
| 1127 | + self.add(f"{self._indent}{lvalue.name}: Incomplete") |
| 1128 | + self.import_tracker.require_name("Incomplete") |
| 1129 | + return |
| 1130 | + for attr_name, attr_type in rvalue.args[1].items: |
| 1131 | + if not isinstance(attr_name, StrExpr): |
| 1132 | + self.add(f"{self._indent}{lvalue.name}: Incomplete") |
| 1133 | + self.import_tracker.require_name("Incomplete") |
| 1134 | + return |
| 1135 | + items.append((attr_name.value, attr_type)) |
| 1136 | + if len(rvalue.args) > 2: |
| 1137 | + if rvalue.arg_kinds[2] != ARG_NAMED or rvalue.arg_names[2] != "total": |
| 1138 | + self.add(f"{self._indent}{lvalue.name}: Incomplete") |
| 1139 | + self.import_tracker.require_name("Incomplete") |
| 1140 | + return |
| 1141 | + total = rvalue.args[2] |
| 1142 | + else: |
| 1143 | + for arg_name, arg in zip(rvalue.arg_names[1:], rvalue.args[1:]): |
| 1144 | + if not isinstance(arg_name, str): |
| 1145 | + self.add(f"{self._indent}{lvalue.name}: Incomplete") |
| 1146 | + self.import_tracker.require_name("Incomplete") |
| 1147 | + return |
| 1148 | + if arg_name == "total": |
| 1149 | + total = arg |
| 1150 | + else: |
| 1151 | + items.append((arg_name, arg)) |
| 1152 | + self.import_tracker.require_name("TypedDict") |
| 1153 | + p = AliasPrinter(self) |
| 1154 | + if any(not key.isidentifier() or keyword.iskeyword(key) for key, _ in items): |
| 1155 | + # Keep the call syntax if there are non-identifier or keyword keys. |
| 1156 | + self.add(f"{self._indent}{lvalue.name} = {rvalue.accept(p)}\n") |
| 1157 | + self._state = VAR |
| 1158 | + else: |
| 1159 | + bases = "TypedDict" |
| 1160 | + # TODO: Add support for generic TypedDicts. Requires `Generic` as base class. |
| 1161 | + if total is not None: |
| 1162 | + bases += f", total={total.accept(p)}" |
| 1163 | + self.add(f"{self._indent}class {lvalue.name}({bases}):") |
| 1164 | + if len(items) == 0: |
| 1165 | + self.add(" ...\n") |
| 1166 | + self._state = EMPTY_CLASS |
| 1167 | + else: |
| 1168 | + self.add("\n") |
| 1169 | + for key, key_type in items: |
| 1170 | + self.add(f"{self._indent} {key}: {key_type.accept(p)}\n") |
| 1171 | + self._state = CLASS |
| 1172 | + |
1085 | 1173 | def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool: |
1086 | 1174 | """Return True for things that look like target for an alias. |
1087 | 1175 |
|
|
0 commit comments