11use std:: iter;
22
3+ use hir:: HasSource ;
34use ide_db:: {
45 famous_defs:: FamousDefs ,
56 syntax_helpers:: node_ext:: { for_each_tail_expr, walk_expr} ,
67} ;
8+ use itertools:: Itertools ;
79use syntax:: {
8- ast:: { self , make, Expr } ,
9- match_ast, ted, AstNode ,
10+ ast:: { self , make, Expr , HasGenericParams } ,
11+ match_ast, ted, AstNode , ToSmolStr ,
1012} ;
1113
1214use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -39,25 +41,22 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext<
3941 } ;
4042
4143 let type_ref = & ret_type. ty ( ) ?;
42- let ty = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) ;
43- let result_enum =
44+ let core_result =
4445 FamousDefs ( & ctx. sema , ctx. sema . scope ( type_ref. syntax ( ) ) ?. krate ( ) ) . core_result_Result ( ) ?;
4546
46- if matches ! ( ty, Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == result_enum) {
47+ let ty = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) ;
48+ if matches ! ( ty, Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == core_result) {
49+ // The return type is already wrapped in a Result
4750 cov_mark:: hit!( wrap_return_type_in_result_simple_return_type_already_result) ;
4851 return None ;
4952 }
5053
51- let new_result_ty =
52- make:: ext:: ty_result ( type_ref. clone ( ) , make:: ty_placeholder ( ) ) . clone_for_update ( ) ;
53- let generic_args = new_result_ty. syntax ( ) . descendants ( ) . find_map ( ast:: GenericArgList :: cast) ?;
54- let last_genarg = generic_args. generic_args ( ) . last ( ) ?;
55-
5654 acc. add (
5755 AssistId ( "wrap_return_type_in_result" , AssistKind :: RefactorRewrite ) ,
5856 "Wrap return type in Result" ,
5957 type_ref. syntax ( ) . text_range ( ) ,
6058 |edit| {
59+ let new_result_ty = result_type ( ctx, & core_result, type_ref) . clone_for_update ( ) ;
6160 let body = edit. make_mut ( ast:: Expr :: BlockExpr ( body) ) ;
6261
6362 let mut exprs_to_wrap = Vec :: new ( ) ;
@@ -81,16 +80,72 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext<
8180 }
8281
8382 let old_result_ty = edit. make_mut ( type_ref. clone ( ) ) ;
84-
8583 ted:: replace ( old_result_ty. syntax ( ) , new_result_ty. syntax ( ) ) ;
8684
87- if let Some ( cap) = ctx. config . snippet_cap {
88- edit. add_placeholder_snippet ( cap, last_genarg) ;
85+ // Add a placeholder snippet at the first generic argument that doesn't equal the return type.
86+ // This is normally the error type, but that may not be the case when we inserted a type alias.
87+ let args = new_result_ty. syntax ( ) . descendants ( ) . find_map ( ast:: GenericArgList :: cast) ;
88+ let error_type_arg = args. and_then ( |list| {
89+ list. generic_args ( ) . find ( |arg| match arg {
90+ ast:: GenericArg :: TypeArg ( _) => arg. syntax ( ) . text ( ) != type_ref. syntax ( ) . text ( ) ,
91+ ast:: GenericArg :: LifetimeArg ( _) => false ,
92+ _ => true ,
93+ } )
94+ } ) ;
95+ if let Some ( error_type_arg) = error_type_arg {
96+ if let Some ( cap) = ctx. config . snippet_cap {
97+ edit. add_placeholder_snippet ( cap, error_type_arg) ;
98+ }
8999 }
90100 } ,
91101 )
92102}
93103
104+ fn result_type (
105+ ctx : & AssistContext < ' _ > ,
106+ core_result : & hir:: Enum ,
107+ ret_type : & ast:: Type ,
108+ ) -> ast:: Type {
109+ // Try to find a Result<T, ...> type alias in the current scope (shadowing the default).
110+ let result_path = hir:: ModPath :: from_segments (
111+ hir:: PathKind :: Plain ,
112+ iter:: once ( hir:: Name :: new_symbol_root ( hir:: sym:: Result . clone ( ) ) ) ,
113+ ) ;
114+ let alias = ctx. sema . resolve_mod_path ( ret_type. syntax ( ) , & result_path) . and_then ( |def| {
115+ def. filter_map ( |def| match def. as_module_def ( ) ? {
116+ hir:: ModuleDef :: TypeAlias ( alias) => {
117+ let enum_ty = alias. ty ( ctx. db ( ) ) . as_adt ( ) ?. as_enum ( ) ?;
118+ ( & enum_ty == core_result) . then_some ( alias)
119+ }
120+ _ => None ,
121+ } )
122+ . find_map ( |alias| {
123+ let mut inserted_ret_type = false ;
124+ let generic_params = alias
125+ . source ( ctx. db ( ) ) ?
126+ . value
127+ . generic_param_list ( ) ?
128+ . generic_params ( )
129+ . map ( |param| match param {
130+ // Replace the very first type parameter with the functions return type.
131+ ast:: GenericParam :: TypeParam ( _) if !inserted_ret_type => {
132+ inserted_ret_type = true ;
133+ ret_type. to_smolstr ( )
134+ }
135+ ast:: GenericParam :: LifetimeParam ( _) => make:: lifetime ( "'_" ) . to_smolstr ( ) ,
136+ _ => make:: ty_placeholder ( ) . to_smolstr ( ) ,
137+ } )
138+ . join ( ", " ) ;
139+
140+ let name = alias. name ( ctx. db ( ) ) ;
141+ let name = name. as_str ( ) ;
142+ Some ( make:: ty ( & format ! ( "{name}<{generic_params}>" ) ) )
143+ } )
144+ } ) ;
145+ // If there is no applicable alias in scope use the default Result type.
146+ alias. unwrap_or_else ( || make:: ext:: ty_result ( ret_type. clone ( ) , make:: ty_placeholder ( ) ) )
147+ }
148+
94149fn tail_cb_impl ( acc : & mut Vec < ast:: Expr > , e : & ast:: Expr ) {
95150 match e {
96151 Expr :: BreakExpr ( break_expr) => {
@@ -998,4 +1053,216 @@ fn foo(the_field: u32) -> Result<u32, ${0:_}> {
9981053"# ,
9991054 ) ;
10001055 }
1056+
1057+ #[ test]
1058+ fn wrap_return_type_in_local_result_type ( ) {
1059+ check_assist (
1060+ wrap_return_type_in_result,
1061+ r#"
1062+ //- minicore: result
1063+ type Result<T> = core::result::Result<T, ()>;
1064+
1065+ fn foo() -> i3$02 {
1066+ return 42i32;
1067+ }
1068+ "# ,
1069+ r#"
1070+ type Result<T> = core::result::Result<T, ()>;
1071+
1072+ fn foo() -> Result<i32> {
1073+ return Ok(42i32);
1074+ }
1075+ "# ,
1076+ ) ;
1077+
1078+ check_assist (
1079+ wrap_return_type_in_result,
1080+ r#"
1081+ //- minicore: result
1082+ type Result2<T> = core::result::Result<T, ()>;
1083+
1084+ fn foo() -> i3$02 {
1085+ return 42i32;
1086+ }
1087+ "# ,
1088+ r#"
1089+ type Result2<T> = core::result::Result<T, ()>;
1090+
1091+ fn foo() -> Result<i32, ${0:_}> {
1092+ return Ok(42i32);
1093+ }
1094+ "# ,
1095+ ) ;
1096+ }
1097+
1098+ #[ test]
1099+ fn wrap_return_type_in_imported_local_result_type ( ) {
1100+ check_assist (
1101+ wrap_return_type_in_result,
1102+ r#"
1103+ //- minicore: result
1104+ mod some_module {
1105+ pub type Result<T> = core::result::Result<T, ()>;
1106+ }
1107+
1108+ use some_module::Result;
1109+
1110+ fn foo() -> i3$02 {
1111+ return 42i32;
1112+ }
1113+ "# ,
1114+ r#"
1115+ mod some_module {
1116+ pub type Result<T> = core::result::Result<T, ()>;
1117+ }
1118+
1119+ use some_module::Result;
1120+
1121+ fn foo() -> Result<i32> {
1122+ return Ok(42i32);
1123+ }
1124+ "# ,
1125+ ) ;
1126+
1127+ check_assist (
1128+ wrap_return_type_in_result,
1129+ r#"
1130+ //- minicore: result
1131+ mod some_module {
1132+ pub type Result<T> = core::result::Result<T, ()>;
1133+ }
1134+
1135+ use some_module::*;
1136+
1137+ fn foo() -> i3$02 {
1138+ return 42i32;
1139+ }
1140+ "# ,
1141+ r#"
1142+ mod some_module {
1143+ pub type Result<T> = core::result::Result<T, ()>;
1144+ }
1145+
1146+ use some_module::*;
1147+
1148+ fn foo() -> Result<i32> {
1149+ return Ok(42i32);
1150+ }
1151+ "# ,
1152+ ) ;
1153+ }
1154+
1155+ #[ test]
1156+ fn wrap_return_type_in_local_result_type_from_function_body ( ) {
1157+ check_assist (
1158+ wrap_return_type_in_result,
1159+ r#"
1160+ //- minicore: result
1161+ fn foo() -> i3$02 {
1162+ type Result<T> = core::result::Result<T, ()>;
1163+ 0
1164+ }
1165+ "# ,
1166+ r#"
1167+ fn foo() -> Result<i32, ${0:_}> {
1168+ type Result<T> = core::result::Result<T, ()>;
1169+ Ok(0)
1170+ }
1171+ "# ,
1172+ ) ;
1173+ }
1174+
1175+ #[ test]
1176+ fn wrap_return_type_in_local_result_type_already_using_alias ( ) {
1177+ check_assist_not_applicable (
1178+ wrap_return_type_in_result,
1179+ r#"
1180+ //- minicore: result
1181+ pub type Result<T> = core::result::Result<T, ()>;
1182+
1183+ fn foo() -> Result<i3$02> {
1184+ return Ok(42i32);
1185+ }
1186+ "# ,
1187+ ) ;
1188+ }
1189+
1190+ #[ test]
1191+ fn wrap_return_type_in_local_result_type_multiple_generics ( ) {
1192+ check_assist (
1193+ wrap_return_type_in_result,
1194+ r#"
1195+ //- minicore: result
1196+ type Result<T, E> = core::result::Result<T, E>;
1197+
1198+ fn foo() -> i3$02 {
1199+ 0
1200+ }
1201+ "# ,
1202+ r#"
1203+ type Result<T, E> = core::result::Result<T, E>;
1204+
1205+ fn foo() -> Result<i32, ${0:_}> {
1206+ Ok(0)
1207+ }
1208+ "# ,
1209+ ) ;
1210+
1211+ check_assist (
1212+ wrap_return_type_in_result,
1213+ r#"
1214+ //- minicore: result
1215+ type Result<T, E> = core::result::Result<Foo<T, E>, ()>;
1216+
1217+ fn foo() -> i3$02 {
1218+ 0
1219+ }
1220+ "# ,
1221+ r#"
1222+ type Result<T, E> = core::result::Result<Foo<T, E>, ()>;
1223+
1224+ fn foo() -> Result<i32, ${0:_}> {
1225+ Ok(0)
1226+ }
1227+ "# ,
1228+ ) ;
1229+
1230+ check_assist (
1231+ wrap_return_type_in_result,
1232+ r#"
1233+ //- minicore: result
1234+ type Result<'a, T, E> = core::result::Result<Foo<T, E>, &'a ()>;
1235+
1236+ fn foo() -> i3$02 {
1237+ 0
1238+ }
1239+ "# ,
1240+ r#"
1241+ type Result<'a, T, E> = core::result::Result<Foo<T, E>, &'a ()>;
1242+
1243+ fn foo() -> Result<'_, i32, ${0:_}> {
1244+ Ok(0)
1245+ }
1246+ "# ,
1247+ ) ;
1248+
1249+ check_assist (
1250+ wrap_return_type_in_result,
1251+ r#"
1252+ //- minicore: result
1253+ type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;
1254+
1255+ fn foo() -> i3$02 {
1256+ 0
1257+ }
1258+ "# ,
1259+ r#"
1260+ type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;
1261+
1262+ fn foo() -> Result<i32, ${0:_}> {
1263+ Ok(0)
1264+ }
1265+ "# ,
1266+ ) ;
1267+ }
10011268}
0 commit comments