104104 TupleExpr ,
105105 TypeInfo ,
106106 UnaryExpr ,
107- is_StrExpr_list ,
108107)
109108from mypy .options import Options as MypyOptions
110109from mypy .stubdoc import Sig , find_unique_signatures , parse_all_signatures
129128from mypy .types import (
130129 OVERLOAD_NAMES ,
131130 TPDICT_NAMES ,
131+ TYPED_NAMEDTUPLE_NAMES ,
132132 AnyType ,
133133 CallableType ,
134134 Instance ,
@@ -400,10 +400,12 @@ def visit_str_expr(self, node: StrExpr) -> str:
400400 def visit_index_expr (self , node : IndexExpr ) -> str :
401401 base = node .base .accept (self )
402402 index = node .index .accept (self )
403+ if len (index ) > 2 and index .startswith ("(" ) and index .endswith (")" ):
404+ index = index [1 :- 1 ]
403405 return f"{ base } [{ index } ]"
404406
405407 def visit_tuple_expr (self , node : TupleExpr ) -> str :
406- return ", " .join (n .accept (self ) for n in node .items )
408+ return f"( { ', ' .join (n .accept (self ) for n in node .items )} )"
407409
408410 def visit_list_expr (self , node : ListExpr ) -> str :
409411 return f"[{ ', ' .join (n .accept (self ) for n in node .items )} ]"
@@ -1010,6 +1012,37 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
10101012 elif isinstance (base , IndexExpr ):
10111013 p = AliasPrinter (self )
10121014 base_types .append (base .accept (p ))
1015+ elif isinstance (base , CallExpr ):
1016+ # namedtuple(typename, fields), NamedTuple(typename, fields) calls can
1017+ # be used as a base class. The first argument is a string literal that
1018+ # is usually the same as the class name.
1019+ #
1020+ # Note:
1021+ # A call-based named tuple as a base class cannot be safely converted to
1022+ # a class-based NamedTuple definition because class attributes defined
1023+ # in the body of the class inheriting from the named tuple call are not
1024+ # namedtuple fields at runtime.
1025+ if self .is_namedtuple (base ):
1026+ nt_fields = self ._get_namedtuple_fields (base )
1027+ assert isinstance (base .args [0 ], StrExpr )
1028+ typename = base .args [0 ].value
1029+ if nt_fields is not None :
1030+ # A valid namedtuple() call, use NamedTuple() instead with
1031+ # Incomplete as field types
1032+ fields_str = ", " .join (f"({ f !r} , { t } )" for f , t in nt_fields )
1033+ base_types .append (f"NamedTuple({ typename !r} , [{ fields_str } ])" )
1034+ self .add_typing_import ("NamedTuple" )
1035+ else :
1036+ # Invalid namedtuple() call, cannot determine fields
1037+ base_types .append ("Incomplete" )
1038+ elif self .is_typed_namedtuple (base ):
1039+ p = AliasPrinter (self )
1040+ base_types .append (base .accept (p ))
1041+ else :
1042+ # At this point, we don't know what the base class is, so we
1043+ # just use Incomplete as the base class.
1044+ base_types .append ("Incomplete" )
1045+ self .import_tracker .require_name ("Incomplete" )
10131046 return base_types
10141047
10151048 def visit_block (self , o : Block ) -> None :
@@ -1022,8 +1055,11 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
10221055 foundl = []
10231056
10241057 for lvalue in o .lvalues :
1025- if isinstance (lvalue , NameExpr ) and self .is_namedtuple (o .rvalue ):
1026- assert isinstance (o .rvalue , CallExpr )
1058+ if (
1059+ isinstance (lvalue , NameExpr )
1060+ and isinstance (o .rvalue , CallExpr )
1061+ and (self .is_namedtuple (o .rvalue ) or self .is_typed_namedtuple (o .rvalue ))
1062+ ):
10271063 self .process_namedtuple (lvalue , o .rvalue )
10281064 continue
10291065 if (
@@ -1069,37 +1105,79 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
10691105 if all (foundl ):
10701106 self ._state = VAR
10711107
1072- def is_namedtuple (self , expr : Expression ) -> bool :
1073- if not isinstance (expr , CallExpr ):
1074- return False
1108+ def is_namedtuple (self , expr : CallExpr ) -> bool :
10751109 callee = expr .callee
1076- return (isinstance (callee , NameExpr ) and callee .name .endswith ("namedtuple" )) or (
1077- isinstance (callee , MemberExpr ) and callee .name == "namedtuple"
1110+ return (
1111+ isinstance (callee , NameExpr )
1112+ and (self .refers_to_fullname (callee .name , "collections.namedtuple" ))
1113+ ) or (
1114+ isinstance (callee , MemberExpr )
1115+ and isinstance (callee .expr , NameExpr )
1116+ and f"{ callee .expr .name } .{ callee .name } " == "collections.namedtuple"
10781117 )
10791118
1119+ def is_typed_namedtuple (self , expr : CallExpr ) -> bool :
1120+ callee = expr .callee
1121+ return (
1122+ isinstance (callee , NameExpr )
1123+ and self .refers_to_fullname (callee .name , TYPED_NAMEDTUPLE_NAMES )
1124+ ) or (
1125+ isinstance (callee , MemberExpr )
1126+ and isinstance (callee .expr , NameExpr )
1127+ and f"{ callee .expr .name } .{ callee .name } " in TYPED_NAMEDTUPLE_NAMES
1128+ )
1129+
1130+ def _get_namedtuple_fields (self , call : CallExpr ) -> list [tuple [str , str ]] | None :
1131+ if self .is_namedtuple (call ):
1132+ fields_arg = call .args [1 ]
1133+ if isinstance (fields_arg , StrExpr ):
1134+ field_names = fields_arg .value .replace ("," , " " ).split ()
1135+ elif isinstance (fields_arg , (ListExpr , TupleExpr )):
1136+ field_names = []
1137+ for field in fields_arg .items :
1138+ if not isinstance (field , StrExpr ):
1139+ return None
1140+ field_names .append (field .value )
1141+ else :
1142+ return None # Invalid namedtuple fields type
1143+ if field_names :
1144+ self .import_tracker .require_name ("Incomplete" )
1145+ return [(field_name , "Incomplete" ) for field_name in field_names ]
1146+ elif self .is_typed_namedtuple (call ):
1147+ fields_arg = call .args [1 ]
1148+ if not isinstance (fields_arg , (ListExpr , TupleExpr )):
1149+ return None
1150+ fields : list [tuple [str , str ]] = []
1151+ b = AliasPrinter (self )
1152+ for field in fields_arg .items :
1153+ if not (isinstance (field , TupleExpr ) and len (field .items ) == 2 ):
1154+ return None
1155+ field_name , field_type = field .items
1156+ if not isinstance (field_name , StrExpr ):
1157+ return None
1158+ fields .append ((field_name .value , field_type .accept (b )))
1159+ return fields
1160+ else :
1161+ return None # Not a named tuple call
1162+
10801163 def process_namedtuple (self , lvalue : NameExpr , rvalue : CallExpr ) -> None :
10811164 if self ._state != EMPTY :
10821165 self .add ("\n " )
1083- if isinstance (rvalue .args [1 ], StrExpr ):
1084- items = rvalue .args [1 ].value .replace ("," , " " ).split ()
1085- elif isinstance (rvalue .args [1 ], (ListExpr , TupleExpr )):
1086- list_items = rvalue .args [1 ].items
1087- assert is_StrExpr_list (list_items )
1088- items = [item .value for item in list_items ]
1089- else :
1166+ fields = self ._get_namedtuple_fields (rvalue )
1167+ if fields is None :
10901168 self .add (f"{ self ._indent } { lvalue .name } : Incomplete" )
10911169 self .import_tracker .require_name ("Incomplete" )
10921170 return
10931171 self .import_tracker .require_name ("NamedTuple" )
10941172 self .add (f"{ self ._indent } class { lvalue .name } (NamedTuple):" )
1095- if not items :
1173+ if len ( fields ) == 0 :
10961174 self .add (" ...\n " )
1175+ self ._state = EMPTY_CLASS
10971176 else :
1098- self .import_tracker .require_name ("Incomplete" )
10991177 self .add ("\n " )
1100- for item in items :
1101- self .add (f"{ self ._indent } { item } : Incomplete \n " )
1102- self ._state = CLASS
1178+ for f_name , f_type in fields :
1179+ self .add (f"{ self ._indent } { f_name } : { f_type } \n " )
1180+ self ._state = CLASS
11031181
11041182 def is_typeddict (self , expr : CallExpr ) -> bool :
11051183 callee = expr .callee
0 commit comments