@@ -5,7 +5,7 @@ use ide_db::{
55use itertools:: Itertools ;
66use syntax:: {
77 ast:: { self , Expr } ,
8- match_ast, AstNode , TextRange , TextSize ,
8+ match_ast, AstNode , NodeOrToken , SyntaxKind , TextRange ,
99} ;
1010
1111use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -38,14 +38,15 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
3838 } ;
3939
4040 let type_ref = & ret_type. ty ( ) ?;
41- let ty = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) ;
41+ let Some ( hir :: Adt :: Enum ( ret_enum ) ) = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) else { return None ; } ;
4242 let result_enum =
4343 FamousDefs ( & ctx. sema , ctx. sema . scope ( type_ref. syntax ( ) ) ?. krate ( ) ) . core_result_Result ( ) ?;
44-
45- if !matches ! ( ty, Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == result_enum) {
44+ if ret_enum != result_enum {
4645 return None ;
4746 }
4847
48+ let Some ( ok_type) = unwrap_result_type ( type_ref) else { return None ; } ;
49+
4950 acc. add (
5051 AssistId ( "unwrap_result_return_type" , AssistKind :: RefactorRewrite ) ,
5152 "Unwrap Result return type" ,
@@ -64,26 +65,19 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
6465 } ) ;
6566 for_each_tail_expr ( & body, tail_cb) ;
6667
67- let mut is_unit_type = false ;
68- if let Some ( ( _, inner_type) ) = type_ref. to_string ( ) . split_once ( '<' ) {
69- let inner_type = match inner_type. split_once ( ',' ) {
70- Some ( ( success_inner_type, _) ) => success_inner_type,
71- None => inner_type,
72- } ;
73- let new_ret_type = inner_type. strip_suffix ( '>' ) . unwrap_or ( inner_type) ;
74- if new_ret_type == "()" {
75- is_unit_type = true ;
76- let text_range = TextRange :: new (
77- ret_type. syntax ( ) . text_range ( ) . start ( ) ,
78- ret_type. syntax ( ) . text_range ( ) . end ( ) + TextSize :: from ( 1u32 ) ,
79- ) ;
80- builder. delete ( text_range)
81- } else {
82- builder. replace (
83- type_ref. syntax ( ) . text_range ( ) ,
84- inner_type. strip_suffix ( '>' ) . unwrap_or ( inner_type) ,
85- )
68+ let is_unit_type = is_unit_type ( & ok_type) ;
69+ if is_unit_type {
70+ let mut text_range = ret_type. syntax ( ) . text_range ( ) ;
71+
72+ if let Some ( NodeOrToken :: Token ( token) ) = ret_type. syntax ( ) . next_sibling_or_token ( ) {
73+ if token. kind ( ) == SyntaxKind :: WHITESPACE {
74+ text_range = TextRange :: new ( text_range. start ( ) , token. text_range ( ) . end ( ) ) ;
75+ }
8676 }
77+
78+ builder. delete ( text_range) ;
79+ } else {
80+ builder. replace ( type_ref. syntax ( ) . text_range ( ) , ok_type. syntax ( ) . text ( ) ) ;
8781 }
8882
8983 for ret_expr_arg in exprs_to_unwrap {
@@ -134,6 +128,22 @@ fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
134128 }
135129}
136130
131+ // Tries to extract `T` from `Result<T, E>`.
132+ fn unwrap_result_type ( ty : & ast:: Type ) -> Option < ast:: Type > {
133+ let ast:: Type :: PathType ( path_ty) = ty else { return None ; } ;
134+ let path = path_ty. path ( ) ?;
135+ let segment = path. first_segment ( ) ?;
136+ let generic_arg_list = segment. generic_arg_list ( ) ?;
137+ let generic_args: Vec < _ > = generic_arg_list. generic_args ( ) . collect ( ) ;
138+ let ast:: GenericArg :: TypeArg ( ok_type) = generic_args. first ( ) ? else { return None ; } ;
139+ ok_type. ty ( )
140+ }
141+
142+ fn is_unit_type ( ty : & ast:: Type ) -> bool {
143+ let ast:: Type :: TupleType ( tuple) = ty else { return false } ;
144+ tuple. fields ( ) . next ( ) . is_none ( )
145+ }
146+
137147#[ cfg( test) ]
138148mod tests {
139149 use crate :: tests:: { check_assist, check_assist_not_applicable} ;
@@ -173,6 +183,21 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
173183 r#"
174184fn foo() {
175185}
186+ "# ,
187+ ) ;
188+
189+ // Unformatted return type
190+ check_assist (
191+ unwrap_result_return_type,
192+ r#"
193+ //- minicore: result
194+ fn foo() -> Result<(), Box<dyn Error$0>>{
195+ Ok(())
196+ }
197+ "# ,
198+ r#"
199+ fn foo() {
200+ }
176201"# ,
177202 ) ;
178203 }
@@ -1014,6 +1039,54 @@ fn foo(the_field: u32) -> u32 {
10141039 }
10151040 the_field
10161041}
1042+ "# ,
1043+ ) ;
1044+ }
1045+
1046+ #[ test]
1047+ fn unwrap_result_return_type_nested_type ( ) {
1048+ check_assist (
1049+ unwrap_result_return_type,
1050+ r#"
1051+ //- minicore: result, option
1052+ fn foo() -> Result<Option<i32$0>, ()> {
1053+ Ok(Some(42))
1054+ }
1055+ "# ,
1056+ r#"
1057+ fn foo() -> Option<i32> {
1058+ Some(42)
1059+ }
1060+ "# ,
1061+ ) ;
1062+
1063+ check_assist (
1064+ unwrap_result_return_type,
1065+ r#"
1066+ //- minicore: result, option
1067+ fn foo() -> Result<Option<Result<i32$0, ()>>, ()> {
1068+ Ok(None)
1069+ }
1070+ "# ,
1071+ r#"
1072+ fn foo() -> Option<Result<i32, ()>> {
1073+ None
1074+ }
1075+ "# ,
1076+ ) ;
1077+
1078+ check_assist (
1079+ unwrap_result_return_type,
1080+ r#"
1081+ //- minicore: result, option, iterators
1082+ fn foo() -> Result<impl Iterator<Item = i32>$0, ()> {
1083+ Ok(Some(42).into_iter())
1084+ }
1085+ "# ,
1086+ r#"
1087+ fn foo() -> impl Iterator<Item = i32> {
1088+ Some(42).into_iter()
1089+ }
10171090"# ,
10181091 ) ;
10191092 }
0 commit comments