@@ -5,7 +5,7 @@ use ide_db::{
55use itertools:: Itertools ;
66use syntax:: {
77 ast:: { self , Expr } ,
8- match_ast, AstNode , TextRange , TextSize ,
8+ match_ast, AstNode , NodeOrToken , SyntaxKind , TextRange , TextSize ,
99} ;
1010
1111use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -38,14 +38,15 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
3838 } ;
3939
4040 let type_ref = & ret_type. ty ( ) ?;
41- let ty = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) ;
41+ let Some ( hir :: Adt :: Enum ( ret_enum ) ) = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) else { return None ; } ;
4242 let result_enum =
4343 FamousDefs ( & ctx. sema , ctx. sema . scope ( type_ref. syntax ( ) ) ?. krate ( ) ) . core_result_Result ( ) ?;
44-
45- if !matches ! ( ty, Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == result_enum) {
44+ if ret_enum != result_enum {
4645 return None ;
4746 }
4847
48+ let Some ( ok_type) = unwrap_result_type ( type_ref) else { return None ; } ;
49+
4950 acc. add (
5051 AssistId ( "unwrap_result_return_type" , AssistKind :: RefactorRewrite ) ,
5152 "Unwrap Result return type" ,
@@ -64,26 +65,22 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
6465 } ) ;
6566 for_each_tail_expr ( & body, tail_cb) ;
6667
67- let mut is_unit_type = false ;
68- if let Some ( ( _, inner_type) ) = type_ref. to_string ( ) . split_once ( '<' ) {
69- let inner_type = match inner_type. split_once ( ',' ) {
70- Some ( ( success_inner_type, _) ) => success_inner_type,
71- None => inner_type,
72- } ;
73- let new_ret_type = inner_type. strip_suffix ( '>' ) . unwrap_or ( inner_type) ;
74- if new_ret_type == "()" {
75- is_unit_type = true ;
76- let text_range = TextRange :: new (
77- ret_type. syntax ( ) . text_range ( ) . start ( ) ,
78- ret_type. syntax ( ) . text_range ( ) . end ( ) + TextSize :: from ( 1u32 ) ,
79- ) ;
80- builder. delete ( text_range)
81- } else {
82- builder. replace (
83- type_ref. syntax ( ) . text_range ( ) ,
84- inner_type. strip_suffix ( '>' ) . unwrap_or ( inner_type) ,
85- )
68+ let is_unit_type = is_unit_type ( & ok_type) ;
69+ if is_unit_type {
70+ let mut text_range = ret_type. syntax ( ) . text_range ( ) ;
71+
72+ if let Some ( NodeOrToken :: Token ( token) ) = ret_type. syntax ( ) . next_sibling_or_token ( ) {
73+ if token. kind ( ) == SyntaxKind :: WHITESPACE {
74+ text_range = TextRange :: new (
75+ text_range. start ( ) ,
76+ text_range. end ( ) + TextSize :: from ( 1u32 ) ,
77+ ) ;
78+ }
8679 }
80+
81+ builder. delete ( text_range) ;
82+ } else {
83+ builder. replace ( type_ref. syntax ( ) . text_range ( ) , ok_type. syntax ( ) . text ( ) ) ;
8784 }
8885
8986 for ret_expr_arg in exprs_to_unwrap {
@@ -134,6 +131,22 @@ fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
134131 }
135132}
136133
134+ // Tries to extract `T` from `Result<T, E>`.
135+ fn unwrap_result_type ( ty : & ast:: Type ) -> Option < ast:: Type > {
136+ let ast:: Type :: PathType ( path_ty) = ty else { return None ; } ;
137+ let Some ( path) = path_ty. path ( ) else { return None ; } ;
138+ let Some ( segment) = path. first_segment ( ) else { return None ; } ;
139+ let Some ( generic_arg_list) = segment. generic_arg_list ( ) else { return None ; } ;
140+ let generic_args: Vec < _ > = generic_arg_list. generic_args ( ) . collect ( ) ;
141+ let Some ( ast:: GenericArg :: TypeArg ( ok_type) ) = generic_args. first ( ) else { return None ; } ;
142+ ok_type. ty ( )
143+ }
144+
145+ fn is_unit_type ( ty : & ast:: Type ) -> bool {
146+ let ast:: Type :: TupleType ( tuple) = ty else { return false } ;
147+ tuple. fields ( ) . next ( ) . is_none ( )
148+ }
149+
137150#[ cfg( test) ]
138151mod tests {
139152 use crate :: tests:: { check_assist, check_assist_not_applicable} ;
@@ -173,6 +186,21 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
173186 r#"
174187fn foo() {
175188}
189+ "# ,
190+ ) ;
191+
192+ // Unformatted return type
193+ check_assist (
194+ unwrap_result_return_type,
195+ r#"
196+ //- minicore: result
197+ fn foo() -> Result<(), Box<dyn Error$0>>{
198+ Ok(())
199+ }
200+ "# ,
201+ r#"
202+ fn foo() {
203+ }
176204"# ,
177205 ) ;
178206 }
@@ -1014,6 +1042,54 @@ fn foo(the_field: u32) -> u32 {
10141042 }
10151043 the_field
10161044}
1045+ "# ,
1046+ ) ;
1047+ }
1048+
1049+ #[ test]
1050+ fn unwrap_result_return_type_nested_type ( ) {
1051+ check_assist (
1052+ unwrap_result_return_type,
1053+ r#"
1054+ //- minicore: result, option
1055+ fn foo() -> Result<Option<i32$0>, ()> {
1056+ Ok(Some(42))
1057+ }
1058+ "# ,
1059+ r#"
1060+ fn foo() -> Option<i32> {
1061+ Some(42)
1062+ }
1063+ "# ,
1064+ ) ;
1065+
1066+ check_assist (
1067+ unwrap_result_return_type,
1068+ r#"
1069+ //- minicore: result, option
1070+ fn foo() -> Result<Option<Result<i32$0, ()>>, ()> {
1071+ Ok(None)
1072+ }
1073+ "# ,
1074+ r#"
1075+ fn foo() -> Option<Result<i32, ()>> {
1076+ None
1077+ }
1078+ "# ,
1079+ ) ;
1080+
1081+ check_assist (
1082+ unwrap_result_return_type,
1083+ r#"
1084+ //- minicore: result, option, iterators
1085+ fn foo() -> Result<impl Iterator<Item = i32>$0, ()> {
1086+ Ok(Some(42).into_iter())
1087+ }
1088+ "# ,
1089+ r#"
1090+ fn foo() -> impl Iterator<Item = i32> {
1091+ Some(42).into_iter()
1092+ }
10171093"# ,
10181094 ) ;
10191095 }
0 commit comments