@@ -8,10 +8,13 @@ use std::{
88
99use chalk_ir:: { cast:: Cast , fold:: Shift , Mutability , TyVariableKind } ;
1010use hir_def:: {
11- expr:: { Array , BinaryOp , Expr , ExprId , Literal , MatchGuard , Statement , UnaryOp } ,
11+ expr:: {
12+ ArithOp , Array , BinaryOp , CmpOp , Expr , ExprId , Literal , MatchGuard , Ordering , Statement ,
13+ UnaryOp ,
14+ } ,
1215 path:: { GenericArg , GenericArgs } ,
1316 resolver:: resolver_for_expr,
14- AssocContainerId , FieldId , Lookup ,
17+ AssocContainerId , FieldId , FunctionId , Lookup ,
1518} ;
1619use hir_expand:: name:: { name, Name } ;
1720use stdx:: always;
@@ -23,7 +26,7 @@ use crate::{
2326 infer:: coerce:: CoerceMany ,
2427 lower:: lower_to_chalk_mutability,
2528 mapping:: from_chalk,
26- method_resolution, op ,
29+ method_resolution,
2730 primitive:: { self , UintTy } ,
2831 static_lifetime, to_chalk_trait_id,
2932 traits:: FnTrait ,
@@ -669,34 +672,21 @@ impl<'a> InferenceContext<'a> {
669672 }
670673 }
671674 Expr :: BinaryOp { lhs, rhs, op } => match op {
672- Some ( op) => {
673- let lhs_expectation = match op {
674- BinaryOp :: LogicOp ( ..) => {
675- Expectation :: has_type ( TyKind :: Scalar ( Scalar :: Bool ) . intern ( & Interner ) )
676- }
677- _ => Expectation :: none ( ) ,
678- } ;
679- let lhs_ty = self . infer_expr ( * lhs, & lhs_expectation) ;
680- let lhs_ty = self . resolve_ty_shallow ( & lhs_ty) ;
681- let rhs_expectation = op:: binary_op_rhs_expectation ( * op, lhs_ty. clone ( ) ) ;
682- let rhs_ty =
683- self . infer_expr_coerce ( * rhs, & Expectation :: has_type ( rhs_expectation) ) ;
684- let rhs_ty = self . resolve_ty_shallow ( & rhs_ty) ;
685-
686- let ret = op:: binary_op_return_ty ( * op, lhs_ty. clone ( ) , rhs_ty. clone ( ) ) ;
687-
688- if ret. is_unknown ( ) {
689- cov_mark:: hit!( infer_expr_inner_binary_operator_overload) ;
690-
691- self . resolve_associated_type_with_params (
692- lhs_ty,
693- self . resolve_binary_op_output ( op) ,
694- & [ rhs_ty] ,
695- )
696- } else {
697- ret
698- }
675+ Some ( BinaryOp :: Assignment { op : None } ) => {
676+ let lhs_ty = self . infer_expr ( * lhs, & Expectation :: none ( ) ) ;
677+ self . infer_expr_coerce ( * rhs, & Expectation :: has_type ( lhs_ty) ) ;
678+ self . result . standard_types . unit . clone ( )
679+ }
680+ Some ( BinaryOp :: LogicOp ( _) ) => {
681+ let bool_ty = self . result . standard_types . bool_ . clone ( ) ;
682+ self . infer_expr_coerce ( * lhs, & Expectation :: HasType ( bool_ty. clone ( ) ) ) ;
683+ let lhs_diverges = self . diverges ;
684+ self . infer_expr_coerce ( * rhs, & Expectation :: HasType ( bool_ty. clone ( ) ) ) ;
685+ // Depending on the LHS' value, the RHS can never execute.
686+ self . diverges = lhs_diverges;
687+ bool_ty
699688 }
689+ Some ( op) => self . infer_overloadable_binop ( * lhs, * op, * rhs, tgt_expr) ,
700690 _ => self . err_ty ( ) ,
701691 } ,
702692 Expr :: Range { lhs, rhs, range_type } => {
@@ -862,6 +852,62 @@ impl<'a> InferenceContext<'a> {
862852 ty
863853 }
864854
855+ fn infer_overloadable_binop (
856+ & mut self ,
857+ lhs : ExprId ,
858+ op : BinaryOp ,
859+ rhs : ExprId ,
860+ tgt_expr : ExprId ,
861+ ) -> Ty {
862+ let lhs_expectation = Expectation :: none ( ) ;
863+ let lhs_ty = self . infer_expr ( lhs, & lhs_expectation) ;
864+ let rhs_ty = self . table . new_type_var ( ) ;
865+
866+ let func = self . resolve_binop_method ( op) ;
867+ let func = match func {
868+ Some ( func) => func,
869+ None => {
870+ let rhs_ty = self . builtin_binary_op_rhs_expectation ( op, lhs_ty. clone ( ) ) ;
871+ let rhs_ty = self . infer_expr_coerce ( rhs, & Expectation :: from_option ( rhs_ty) ) ;
872+ return self
873+ . builtin_binary_op_return_ty ( op, lhs_ty, rhs_ty)
874+ . unwrap_or_else ( || self . err_ty ( ) ) ;
875+ }
876+ } ;
877+
878+ let subst = TyBuilder :: subst_for_def ( self . db , func)
879+ . push ( lhs_ty. clone ( ) )
880+ . push ( rhs_ty. clone ( ) )
881+ . build ( ) ;
882+ self . write_method_resolution ( tgt_expr, func, subst. clone ( ) ) ;
883+
884+ let method_ty = self . db . value_ty ( func. into ( ) ) . substitute ( & Interner , & subst) ;
885+ self . register_obligations_for_call ( & method_ty) ;
886+
887+ self . infer_expr_coerce ( rhs, & Expectation :: has_type ( rhs_ty. clone ( ) ) ) ;
888+
889+ let ret_ty = match method_ty. callable_sig ( self . db ) {
890+ Some ( sig) => sig. ret ( ) . clone ( ) ,
891+ None => self . err_ty ( ) ,
892+ } ;
893+
894+ let ret_ty = self . normalize_associated_types_in ( ret_ty) ;
895+
896+ // FIXME: record autoref adjustments
897+
898+ // use knowledge of built-in binary ops, which can sometimes help inference
899+ if let Some ( builtin_rhs) = self . builtin_binary_op_rhs_expectation ( op, lhs_ty. clone ( ) ) {
900+ self . unify ( & builtin_rhs, & rhs_ty) ;
901+ }
902+ if let Some ( builtin_ret) =
903+ self . builtin_binary_op_return_ty ( op, lhs_ty. clone ( ) , rhs_ty. clone ( ) )
904+ {
905+ self . unify ( & builtin_ret, & ret_ty) ;
906+ }
907+
908+ ret_ty
909+ }
910+
865911 fn infer_block (
866912 & mut self ,
867913 expr : ExprId ,
@@ -1136,4 +1182,141 @@ impl<'a> InferenceContext<'a> {
11361182 }
11371183 }
11381184 }
1185+
1186+ fn builtin_binary_op_return_ty ( & mut self , op : BinaryOp , lhs_ty : Ty , rhs_ty : Ty ) -> Option < Ty > {
1187+ let lhs_ty = self . resolve_ty_shallow ( & lhs_ty) ;
1188+ let rhs_ty = self . resolve_ty_shallow ( & rhs_ty) ;
1189+ match op {
1190+ BinaryOp :: LogicOp ( _) | BinaryOp :: CmpOp ( _) => {
1191+ Some ( TyKind :: Scalar ( Scalar :: Bool ) . intern ( & Interner ) )
1192+ }
1193+ BinaryOp :: Assignment { .. } => Some ( TyBuilder :: unit ( ) ) ,
1194+ BinaryOp :: ArithOp ( ArithOp :: Shl | ArithOp :: Shr ) => {
1195+ // all integer combinations are valid here
1196+ if matches ! (
1197+ lhs_ty. kind( & Interner ) ,
1198+ TyKind :: Scalar ( Scalar :: Int ( _) | Scalar :: Uint ( _) )
1199+ | TyKind :: InferenceVar ( _, TyVariableKind :: Integer )
1200+ ) && matches ! (
1201+ rhs_ty. kind( & Interner ) ,
1202+ TyKind :: Scalar ( Scalar :: Int ( _) | Scalar :: Uint ( _) )
1203+ | TyKind :: InferenceVar ( _, TyVariableKind :: Integer )
1204+ ) {
1205+ Some ( lhs_ty)
1206+ } else {
1207+ None
1208+ }
1209+ }
1210+ BinaryOp :: ArithOp ( _) => match ( lhs_ty. kind ( & Interner ) , rhs_ty. kind ( & Interner ) ) {
1211+ // (int, int) | (uint, uint) | (float, float)
1212+ ( TyKind :: Scalar ( Scalar :: Int ( _) ) , TyKind :: Scalar ( Scalar :: Int ( _) ) )
1213+ | ( TyKind :: Scalar ( Scalar :: Uint ( _) ) , TyKind :: Scalar ( Scalar :: Uint ( _) ) )
1214+ | ( TyKind :: Scalar ( Scalar :: Float ( _) ) , TyKind :: Scalar ( Scalar :: Float ( _) ) ) => {
1215+ Some ( rhs_ty)
1216+ }
1217+ // ({int}, int) | ({int}, uint)
1218+ (
1219+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer ) ,
1220+ TyKind :: Scalar ( Scalar :: Int ( _) | Scalar :: Uint ( _) ) ,
1221+ ) => Some ( rhs_ty) ,
1222+ // (int, {int}) | (uint, {int})
1223+ (
1224+ TyKind :: Scalar ( Scalar :: Int ( _) | Scalar :: Uint ( _) ) ,
1225+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer ) ,
1226+ ) => Some ( lhs_ty) ,
1227+ // ({float} | float)
1228+ (
1229+ TyKind :: InferenceVar ( _, TyVariableKind :: Float ) ,
1230+ TyKind :: Scalar ( Scalar :: Float ( _) ) ,
1231+ ) => Some ( rhs_ty) ,
1232+ // (float, {float})
1233+ (
1234+ TyKind :: Scalar ( Scalar :: Float ( _) ) ,
1235+ TyKind :: InferenceVar ( _, TyVariableKind :: Float ) ,
1236+ ) => Some ( lhs_ty) ,
1237+ // ({int}, {int}) | ({float}, {float})
1238+ (
1239+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer ) ,
1240+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer ) ,
1241+ )
1242+ | (
1243+ TyKind :: InferenceVar ( _, TyVariableKind :: Float ) ,
1244+ TyKind :: InferenceVar ( _, TyVariableKind :: Float ) ,
1245+ ) => Some ( rhs_ty) ,
1246+ _ => None ,
1247+ } ,
1248+ }
1249+ }
1250+
1251+ fn builtin_binary_op_rhs_expectation ( & mut self , op : BinaryOp , lhs_ty : Ty ) -> Option < Ty > {
1252+ Some ( match op {
1253+ BinaryOp :: LogicOp ( ..) => TyKind :: Scalar ( Scalar :: Bool ) . intern ( & Interner ) ,
1254+ BinaryOp :: Assignment { op : None } => lhs_ty,
1255+ BinaryOp :: CmpOp ( CmpOp :: Eq { .. } ) => match self
1256+ . resolve_ty_shallow ( & lhs_ty)
1257+ . kind ( & Interner )
1258+ {
1259+ TyKind :: Scalar ( _) | TyKind :: Str => lhs_ty,
1260+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer | TyVariableKind :: Float ) => lhs_ty,
1261+ _ => return None ,
1262+ } ,
1263+ BinaryOp :: ArithOp ( ArithOp :: Shl | ArithOp :: Shr ) => return None ,
1264+ BinaryOp :: CmpOp ( CmpOp :: Ord { .. } )
1265+ | BinaryOp :: Assignment { op : Some ( _) }
1266+ | BinaryOp :: ArithOp ( _) => match self . resolve_ty_shallow ( & lhs_ty) . kind ( & Interner ) {
1267+ TyKind :: Scalar ( Scalar :: Int ( _) | Scalar :: Uint ( _) | Scalar :: Float ( _) ) => lhs_ty,
1268+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer | TyVariableKind :: Float ) => lhs_ty,
1269+ _ => return None ,
1270+ } ,
1271+ } )
1272+ }
1273+
1274+ fn resolve_binop_method ( & self , op : BinaryOp ) -> Option < FunctionId > {
1275+ let ( name, lang_item) = match op {
1276+ BinaryOp :: LogicOp ( _) => return None ,
1277+ BinaryOp :: ArithOp ( aop) => match aop {
1278+ ArithOp :: Add => ( name ! ( add) , "add" ) ,
1279+ ArithOp :: Mul => ( name ! ( mul) , "mul" ) ,
1280+ ArithOp :: Sub => ( name ! ( sub) , "sub" ) ,
1281+ ArithOp :: Div => ( name ! ( div) , "div" ) ,
1282+ ArithOp :: Rem => ( name ! ( rem) , "rem" ) ,
1283+ ArithOp :: Shl => ( name ! ( shl) , "shl" ) ,
1284+ ArithOp :: Shr => ( name ! ( shr) , "shr" ) ,
1285+ ArithOp :: BitXor => ( name ! ( bitxor) , "bitxor" ) ,
1286+ ArithOp :: BitOr => ( name ! ( bitor) , "bitor" ) ,
1287+ ArithOp :: BitAnd => ( name ! ( bitand) , "bitand" ) ,
1288+ } ,
1289+ BinaryOp :: Assignment { op : Some ( aop) } => match aop {
1290+ ArithOp :: Add => ( name ! ( add_assign) , "add_assign" ) ,
1291+ ArithOp :: Mul => ( name ! ( mul_assign) , "mul_assign" ) ,
1292+ ArithOp :: Sub => ( name ! ( sub_assign) , "sub_assign" ) ,
1293+ ArithOp :: Div => ( name ! ( div_assign) , "div_assign" ) ,
1294+ ArithOp :: Rem => ( name ! ( rem_assign) , "rem_assign" ) ,
1295+ ArithOp :: Shl => ( name ! ( shl_assign) , "shl_assign" ) ,
1296+ ArithOp :: Shr => ( name ! ( shr_assign) , "shr_assign" ) ,
1297+ ArithOp :: BitXor => ( name ! ( bitxor_assign) , "bitxor_assign" ) ,
1298+ ArithOp :: BitOr => ( name ! ( bitor_assign) , "bitor_assign" ) ,
1299+ ArithOp :: BitAnd => ( name ! ( bitand_assign) , "bitand_assign" ) ,
1300+ } ,
1301+ BinaryOp :: CmpOp ( cop) => match cop {
1302+ CmpOp :: Eq { negated : false } => ( name ! ( eq) , "eq" ) ,
1303+ CmpOp :: Eq { negated : true } => ( name ! ( ne) , "eq" ) ,
1304+ CmpOp :: Ord { ordering : Ordering :: Less , strict : false } => {
1305+ ( name ! ( le) , "partial_ord" )
1306+ }
1307+ CmpOp :: Ord { ordering : Ordering :: Less , strict : true } => ( name ! ( lt) , "partial_ord" ) ,
1308+ CmpOp :: Ord { ordering : Ordering :: Greater , strict : false } => {
1309+ ( name ! ( ge) , "partial_ord" )
1310+ }
1311+ CmpOp :: Ord { ordering : Ordering :: Greater , strict : true } => {
1312+ ( name ! ( gt) , "partial_ord" )
1313+ }
1314+ } ,
1315+ BinaryOp :: Assignment { op : None } => return None ,
1316+ } ;
1317+
1318+ let trait_ = self . resolve_lang_item ( lang_item) ?. as_trait ( ) ?;
1319+
1320+ self . db . trait_data ( trait_) . method_by_name ( & name)
1321+ }
11391322}
0 commit comments