@@ -108,6 +108,7 @@ def infer_constraints_for_callable(
108108 callee : CallableType ,
109109 arg_types : Sequence [Type | None ],
110110 arg_kinds : list [ArgKind ],
111+ arg_names : Sequence [str | None ] | None ,
111112 formal_to_actual : list [list [int ]],
112113 context : ArgumentInferContext ,
113114) -> list [Constraint ]:
@@ -118,6 +119,20 @@ def infer_constraints_for_callable(
118119 constraints : list [Constraint ] = []
119120 mapper = ArgTypeExpander (context )
120121
122+ param_spec = callee .param_spec ()
123+ param_spec_arg_types = []
124+ param_spec_arg_names = []
125+ param_spec_arg_kinds = []
126+
127+ incomplete_star_mapping = False
128+ for i , actuals in enumerate (formal_to_actual ):
129+ for actual in actuals :
130+ if actual is None and callee .arg_kinds [i ] in (ARG_STAR , ARG_STAR2 ):
131+ # We can't use arguments to infer ParamSpec constraint, if only some
132+ # are present in the current inference pass.
133+ incomplete_star_mapping = True
134+ break
135+
121136 for i , actuals in enumerate (formal_to_actual ):
122137 if isinstance (callee .arg_types [i ], UnpackType ):
123138 unpack_type = callee .arg_types [i ]
@@ -194,11 +209,47 @@ def infer_constraints_for_callable(
194209 actual_type = mapper .expand_actual_type (
195210 actual_arg_type , arg_kinds [actual ], callee .arg_names [i ], callee .arg_kinds [i ]
196211 )
197- # TODO: if callee has ParamSpec, we need to collect all actuals that map to star
198- # args and create single constraint between P and resulting Parameters instead.
199- c = infer_constraints (callee .arg_types [i ], actual_type , SUPERTYPE_OF )
200- constraints .extend (c )
201-
212+ if (
213+ param_spec
214+ and callee .arg_kinds [i ] in (ARG_STAR , ARG_STAR2 )
215+ and not incomplete_star_mapping
216+ ):
217+ # If actual arguments are mapped to ParamSpec type, we can't infer individual
218+ # constraints, instead store them and infer single constraint at the end.
219+ # It is impossible to map actual kind to formal kind, so use some heuristic.
220+ # This inference is used as a fallback, so relying on heuristic should be OK.
221+ param_spec_arg_types .append (
222+ mapper .expand_actual_type (
223+ actual_arg_type , arg_kinds [actual ], None , arg_kinds [actual ]
224+ )
225+ )
226+ actual_kind = arg_kinds [actual ]
227+ param_spec_arg_kinds .append (
228+ ARG_POS if actual_kind not in (ARG_STAR , ARG_STAR2 ) else actual_kind
229+ )
230+ param_spec_arg_names .append (arg_names [actual ] if arg_names else None )
231+ else :
232+ c = infer_constraints (callee .arg_types [i ], actual_type , SUPERTYPE_OF )
233+ constraints .extend (c )
234+ if (
235+ param_spec
236+ and not any (c .type_var == param_spec .id for c in constraints )
237+ and not incomplete_star_mapping
238+ ):
239+ # Use ParamSpec constraint from arguments only if there are no other constraints,
240+ # since as explained above it is quite ad-hoc.
241+ constraints .append (
242+ Constraint (
243+ param_spec ,
244+ SUPERTYPE_OF ,
245+ Parameters (
246+ arg_types = param_spec_arg_types ,
247+ arg_kinds = param_spec_arg_kinds ,
248+ arg_names = param_spec_arg_names ,
249+ imprecise_arg_kinds = True ,
250+ ),
251+ )
252+ )
202253 return constraints
203254
204255
@@ -949,6 +1000,14 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
9491000 res : list [Constraint ] = []
9501001 cactual = self .actual .with_unpacked_kwargs ()
9511002 param_spec = template .param_spec ()
1003+
1004+ template_ret_type , cactual_ret_type = template .ret_type , cactual .ret_type
1005+ if template .type_guard is not None :
1006+ template_ret_type = template .type_guard
1007+ if cactual .type_guard is not None :
1008+ cactual_ret_type = cactual .type_guard
1009+ res .extend (infer_constraints (template_ret_type , cactual_ret_type , self .direction ))
1010+
9521011 if param_spec is None :
9531012 # TODO: Erase template variables if it is generic?
9541013 if (
@@ -1008,51 +1067,50 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
10081067 )
10091068 extra_tvars = True
10101069
1070+ # Compare prefixes as well
1071+ cactual_prefix = cactual .copy_modified (
1072+ arg_types = cactual .arg_types [:prefix_len ],
1073+ arg_kinds = cactual .arg_kinds [:prefix_len ],
1074+ arg_names = cactual .arg_names [:prefix_len ],
1075+ )
1076+ res .extend (
1077+ infer_callable_arguments_constraints (prefix , cactual_prefix , self .direction )
1078+ )
1079+
1080+ param_spec_target : Type | None = None
1081+ skip_imprecise = (
1082+ any (c .type_var == param_spec .id for c in res ) and cactual .imprecise_arg_kinds
1083+ )
10111084 if not cactual_ps :
10121085 max_prefix_len = len ([k for k in cactual .arg_kinds if k in (ARG_POS , ARG_OPT )])
10131086 prefix_len = min (prefix_len , max_prefix_len )
1014- res .append (
1015- Constraint (
1016- param_spec ,
1017- neg_op (self .direction ),
1018- Parameters (
1019- arg_types = cactual .arg_types [prefix_len :],
1020- arg_kinds = cactual .arg_kinds [prefix_len :],
1021- arg_names = cactual .arg_names [prefix_len :],
1022- variables = cactual .variables
1023- if not type_state .infer_polymorphic
1024- else [],
1025- ),
1087+ # This logic matches top-level callable constraint exception, if we managed
1088+ # to get other constraints for ParamSpec, don't infer one with imprecise kinds
1089+ if not skip_imprecise :
1090+ param_spec_target = Parameters (
1091+ arg_types = cactual .arg_types [prefix_len :],
1092+ arg_kinds = cactual .arg_kinds [prefix_len :],
1093+ arg_names = cactual .arg_names [prefix_len :],
1094+ variables = cactual .variables
1095+ if not type_state .infer_polymorphic
1096+ else [],
1097+ imprecise_arg_kinds = cactual .imprecise_arg_kinds ,
10261098 )
1027- )
10281099 else :
1029- if len (param_spec .prefix .arg_types ) <= len (cactual_ps .prefix .arg_types ):
1030- cactual_ps = cactual_ps .copy_modified (
1100+ if (
1101+ len (param_spec .prefix .arg_types ) <= len (cactual_ps .prefix .arg_types )
1102+ and not skip_imprecise
1103+ ):
1104+ param_spec_target = cactual_ps .copy_modified (
10311105 prefix = Parameters (
10321106 arg_types = cactual_ps .prefix .arg_types [prefix_len :],
10331107 arg_kinds = cactual_ps .prefix .arg_kinds [prefix_len :],
10341108 arg_names = cactual_ps .prefix .arg_names [prefix_len :],
1109+ imprecise_arg_kinds = cactual_ps .prefix .imprecise_arg_kinds ,
10351110 )
10361111 )
1037- res .append (Constraint (param_spec , neg_op (self .direction ), cactual_ps ))
1038-
1039- # Compare prefixes as well
1040- cactual_prefix = cactual .copy_modified (
1041- arg_types = cactual .arg_types [:prefix_len ],
1042- arg_kinds = cactual .arg_kinds [:prefix_len ],
1043- arg_names = cactual .arg_names [:prefix_len ],
1044- )
1045- res .extend (
1046- infer_callable_arguments_constraints (prefix , cactual_prefix , self .direction )
1047- )
1048-
1049- template_ret_type , cactual_ret_type = template .ret_type , cactual .ret_type
1050- if template .type_guard is not None :
1051- template_ret_type = template .type_guard
1052- if cactual .type_guard is not None :
1053- cactual_ret_type = cactual .type_guard
1054-
1055- res .extend (infer_constraints (template_ret_type , cactual_ret_type , self .direction ))
1112+ if param_spec_target is not None :
1113+ res .append (Constraint (param_spec , neg_op (self .direction ), param_spec_target ))
10561114 if extra_tvars :
10571115 for c in res :
10581116 c .extra_tvars += cactual .variables
0 commit comments