|
46 | 46 | Var, |
47 | 47 | is_class_var, |
48 | 48 | ) |
49 | | -from mypy.plugin import FunctionContext, SemanticAnalyzerPluginInterface |
| 49 | +from mypy.plugin import SemanticAnalyzerPluginInterface |
50 | 50 | from mypy.plugins.common import ( |
51 | 51 | _get_argument, |
52 | 52 | _get_bool_argument, |
@@ -1062,27 +1062,41 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl |
1062 | 1062 | ) |
1063 | 1063 |
|
1064 | 1064 |
|
1065 | | -def _get_cls_from_init(t: Type) -> TypeInfo | None: |
1066 | | - proper_type = get_proper_type(t) |
1067 | | - if isinstance(proper_type, CallableType): |
1068 | | - return proper_type.type_object() |
1069 | | - return None |
| 1065 | +def fields_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType: |
| 1066 | + """Provide the proper signature for `attrs.fields`.""" |
| 1067 | + if ctx.args and len(ctx.args) == 1 and ctx.args[0] and ctx.args[0][0]: |
| 1068 | + # <hack> |
| 1069 | + assert isinstance(ctx.api, TypeChecker) |
| 1070 | + inst_type = ctx.api.expr_checker.accept(ctx.args[0][0]) |
| 1071 | + # </hack> |
| 1072 | + proper_type = get_proper_type(inst_type) |
1070 | 1073 |
|
| 1074 | + if isinstance(proper_type, AnyType): # fields(Any) -> Any |
| 1075 | + return ctx.default_signature |
| 1076 | + |
| 1077 | + cls = None |
| 1078 | + arg_types = ctx.default_signature.arg_types |
| 1079 | + |
| 1080 | + if isinstance(proper_type, TypeVarType): |
| 1081 | + inner = get_proper_type(proper_type.upper_bound) |
| 1082 | + if isinstance(inner, Instance): |
| 1083 | + # We need to work arg_types to compensate for the attrs stubs. |
| 1084 | + arg_types = [inst_type] |
| 1085 | + cls = inner.type |
| 1086 | + elif isinstance(proper_type, CallableType): |
| 1087 | + cls = proper_type.type_object() |
1071 | 1088 |
|
1072 | | -def fields_function_callback(ctx: FunctionContext) -> Type: |
1073 | | - """Provide the proper return value for `attrs.fields`.""" |
1074 | | - if ctx.arg_types and ctx.arg_types[0] and ctx.arg_types[0][0]: |
1075 | | - first_arg_type = ctx.arg_types[0][0] |
1076 | | - cls = _get_cls_from_init(first_arg_type) |
1077 | 1089 | if cls is not None: |
1078 | 1090 | if MAGIC_ATTR_NAME in cls.names: |
1079 | 1091 | # This is a proper attrs class. |
1080 | 1092 | ret_type = cls.names[MAGIC_ATTR_NAME].type |
1081 | 1093 | if ret_type is not None: |
1082 | | - return ret_type |
1083 | | - else: |
1084 | | - ctx.api.fail( |
1085 | | - f'Argument 1 to "fields" has incompatible type "{format_type_bare(first_arg_type)}"; expected an attrs class', |
1086 | | - ctx.context, |
1087 | | - ) |
1088 | | - return ctx.default_return_type |
| 1094 | + return ctx.default_signature.copy_modified( |
| 1095 | + arg_types=arg_types, ret_type=ret_type |
| 1096 | + ) |
| 1097 | + |
| 1098 | + ctx.api.fail( |
| 1099 | + f'Argument 1 to "fields" has incompatible type "{format_type_bare(proper_type)}"; expected an attrs class', |
| 1100 | + ctx.context, |
| 1101 | + ) |
| 1102 | + return ctx.default_signature |
0 commit comments