|
1 | 1 | use hir::db::AstDatabase; |
2 | | -use ide_db::{assists::Assist, source_change::SourceChange}; |
| 2 | +use ide_db::{assists::Assist, helpers::for_each_tail_expr, source_change::SourceChange}; |
3 | 3 | use syntax::AstNode; |
4 | 4 | use text_edit::TextEdit; |
5 | 5 |
|
@@ -33,10 +33,15 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingOkOrSomeInTailExpr) -> Op |
33 | 33 | let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?; |
34 | 34 | let tail_expr = d.expr.value.to_node(&root); |
35 | 35 | let tail_expr_range = tail_expr.syntax().text_range(); |
36 | | - let replacement = format!("{}({})", d.required, tail_expr.syntax()); |
37 | | - let edit = TextEdit::replace(tail_expr_range, replacement); |
| 36 | + let mut builder = TextEdit::builder(); |
| 37 | + for_each_tail_expr(&tail_expr, &mut |expr| { |
| 38 | + if ctx.sema.type_of_expr(expr).as_ref() != Some(&d.expected) { |
| 39 | + builder.insert(expr.syntax().text_range().start(), format!("{}(", d.required)); |
| 40 | + builder.insert(expr.syntax().text_range().end(), ")".to_string()); |
| 41 | + } |
| 42 | + }); |
38 | 43 | let source_change = |
39 | | - SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), edit); |
| 44 | + SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), builder.finish()); |
40 | 45 | let name = if d.required == "Ok" { "Wrap with Ok" } else { "Wrap with Some" }; |
41 | 46 | Some(vec![fix("wrap_tail_expr", name, source_change, tail_expr_range)]) |
42 | 47 | } |
@@ -68,6 +73,35 @@ fn div(x: i32, y: i32) -> Option<i32> { |
68 | 73 | ); |
69 | 74 | } |
70 | 75 |
|
| 76 | + #[test] |
| 77 | + fn test_wrap_return_type_option_tails() { |
| 78 | + check_fix( |
| 79 | + r#" |
| 80 | +//- minicore: option, result |
| 81 | +fn div(x: i32, y: i32) -> Option<i32> { |
| 82 | + if y == 0 { |
| 83 | + 0 |
| 84 | + } else if true { |
| 85 | + 100 |
| 86 | + } else { |
| 87 | + None |
| 88 | + }$0 |
| 89 | +} |
| 90 | +"#, |
| 91 | + r#" |
| 92 | +fn div(x: i32, y: i32) -> Option<i32> { |
| 93 | + if y == 0 { |
| 94 | + Some(0) |
| 95 | + } else if true { |
| 96 | + Some(100) |
| 97 | + } else { |
| 98 | + None |
| 99 | + } |
| 100 | +} |
| 101 | +"#, |
| 102 | + ); |
| 103 | + } |
| 104 | + |
71 | 105 | #[test] |
72 | 106 | fn test_wrap_return_type() { |
73 | 107 | check_fix( |
|
0 commit comments