@@ -15,7 +15,7 @@ use ide_db::{
1515} ;
1616use itertools:: { izip, Itertools } ;
1717use syntax:: {
18- ast:: { self , edit_in_place:: Indent , HasArgList , PathExpr } ,
18+ ast:: { self , edit :: IndentLevel , edit_in_place:: Indent , HasArgList , PathExpr } ,
1919 ted, AstNode , NodeOrToken , SyntaxKind ,
2020} ;
2121
@@ -306,7 +306,7 @@ fn inline(
306306 params : & [ ( ast:: Pat , Option < ast:: Type > , hir:: Param ) ] ,
307307 CallInfo { node, arguments, generic_arg_list } : & CallInfo ,
308308) -> ast:: Expr {
309- let body = if sema. hir_file_for ( fn_body. syntax ( ) ) . is_macro ( ) {
309+ let mut body = if sema. hir_file_for ( fn_body. syntax ( ) ) . is_macro ( ) {
310310 cov_mark:: hit!( inline_call_defined_in_macro) ;
311311 if let Some ( body) = ast:: BlockExpr :: cast ( insert_ws_into ( fn_body. syntax ( ) . clone ( ) ) ) {
312312 body
@@ -391,19 +391,19 @@ fn inline(
391391 }
392392 }
393393
394+ let mut let_stmts = Vec :: new ( ) ;
395+
394396 // Inline parameter expressions or generate `let` statements depending on whether inlining works or not.
395- for ( ( pat, param_ty, _) , usages, expr) in izip ! ( params, param_use_nodes, arguments) . rev ( ) {
397+ for ( ( pat, param_ty, _) , usages, expr) in izip ! ( params, param_use_nodes, arguments) {
396398 // izip confuses RA due to our lack of hygiene info currently losing us type info causing incorrect errors
397399 let usages: & [ ast:: PathExpr ] = & usages;
398400 let expr: & ast:: Expr = expr;
399401
400- let insert_let_stmt = || {
402+ let mut insert_let_stmt = || {
401403 let ty = sema. type_of_expr ( expr) . filter ( TypeInfo :: has_adjustment) . and ( param_ty. clone ( ) ) ;
402- if let Some ( stmt_list) = body. stmt_list ( ) {
403- stmt_list. push_front (
404- make:: let_stmt ( pat. clone ( ) , ty, Some ( expr. clone ( ) ) ) . clone_for_update ( ) . into ( ) ,
405- )
406- }
404+ let_stmts. push (
405+ make:: let_stmt ( pat. clone ( ) , ty, Some ( expr. clone ( ) ) ) . clone_for_update ( ) . into ( ) ,
406+ ) ;
407407 } ;
408408
409409 // check if there is a local var in the function that conflicts with parameter
@@ -457,14 +457,32 @@ fn inline(
457457 }
458458 }
459459
460+ let is_async_fn = function. is_async ( sema. db ) ;
461+ if is_async_fn {
462+ cov_mark:: hit!( inline_call_async_fn) ;
463+ body = make:: async_move_block_expr ( body. statements ( ) , body. tail_expr ( ) ) . clone_for_update ( ) ;
464+
465+ // Arguments should be evaluated outside the async block, and then moved into it.
466+ if !let_stmts. is_empty ( ) {
467+ cov_mark:: hit!( inline_call_async_fn_with_let_stmts) ;
468+ body. indent ( IndentLevel ( 1 ) ) ;
469+ body = make:: block_expr ( let_stmts, Some ( body. into ( ) ) ) . clone_for_update ( ) ;
470+ }
471+ } else if let Some ( stmt_list) = body. stmt_list ( ) {
472+ ted:: insert_all (
473+ ted:: Position :: after ( stmt_list. l_curly_token ( ) . unwrap ( ) ) ,
474+ let_stmts. into_iter ( ) . map ( |stmt| stmt. syntax ( ) . clone ( ) . into ( ) ) . collect ( ) ,
475+ ) ;
476+ }
477+
460478 let original_indentation = match node {
461479 ast:: CallableExpr :: Call ( it) => it. indent_level ( ) ,
462480 ast:: CallableExpr :: MethodCall ( it) => it. indent_level ( ) ,
463481 } ;
464482 body. reindent_to ( original_indentation) ;
465483
466484 match body. tail_expr ( ) {
467- Some ( expr) if body. statements ( ) . next ( ) . is_none ( ) => expr,
485+ Some ( expr) if !is_async_fn && body. statements ( ) . next ( ) . is_none ( ) => expr,
468486 _ => match node
469487 . syntax ( )
470488 . parent ( )
@@ -1350,6 +1368,109 @@ fn main() {
13501368 bar * b * a * 6
13511369 };
13521370}
1371+ "# ,
1372+ ) ;
1373+ }
1374+
1375+ #[ test]
1376+ fn async_fn_single_expression ( ) {
1377+ cov_mark:: check!( inline_call_async_fn) ;
1378+ check_assist (
1379+ inline_call,
1380+ r#"
1381+ async fn bar(x: u32) -> u32 { x + 1 }
1382+ async fn foo(arg: u32) -> u32 {
1383+ bar(arg).await * 2
1384+ }
1385+ fn spawn<T>(_: T) {}
1386+ fn main() {
1387+ spawn(foo$0(42));
1388+ }
1389+ "# ,
1390+ r#"
1391+ async fn bar(x: u32) -> u32 { x + 1 }
1392+ async fn foo(arg: u32) -> u32 {
1393+ bar(arg).await * 2
1394+ }
1395+ fn spawn<T>(_: T) {}
1396+ fn main() {
1397+ spawn(async move {
1398+ bar(42).await * 2
1399+ });
1400+ }
1401+ "# ,
1402+ ) ;
1403+ }
1404+
1405+ #[ test]
1406+ fn async_fn_multiple_statements ( ) {
1407+ cov_mark:: check!( inline_call_async_fn) ;
1408+ check_assist (
1409+ inline_call,
1410+ r#"
1411+ async fn bar(x: u32) -> u32 { x + 1 }
1412+ async fn foo(arg: u32) -> u32 {
1413+ bar(arg).await;
1414+ 42
1415+ }
1416+ fn spawn<T>(_: T) {}
1417+ fn main() {
1418+ spawn(foo$0(42));
1419+ }
1420+ "# ,
1421+ r#"
1422+ async fn bar(x: u32) -> u32 { x + 1 }
1423+ async fn foo(arg: u32) -> u32 {
1424+ bar(arg).await;
1425+ 42
1426+ }
1427+ fn spawn<T>(_: T) {}
1428+ fn main() {
1429+ spawn(async move {
1430+ bar(42).await;
1431+ 42
1432+ });
1433+ }
1434+ "# ,
1435+ ) ;
1436+ }
1437+
1438+ #[ test]
1439+ fn async_fn_with_let_statements ( ) {
1440+ cov_mark:: check!( inline_call_async_fn) ;
1441+ cov_mark:: check!( inline_call_async_fn_with_let_stmts) ;
1442+ check_assist (
1443+ inline_call,
1444+ r#"
1445+ async fn bar(x: u32) -> u32 { x + 1 }
1446+ async fn foo(x: u32, y: u32, z: &u32) -> u32 {
1447+ bar(x).await;
1448+ y + y + *z
1449+ }
1450+ fn spawn<T>(_: T) {}
1451+ fn main() {
1452+ let var = 42;
1453+ spawn(foo$0(var, var + 1, &var));
1454+ }
1455+ "# ,
1456+ r#"
1457+ async fn bar(x: u32) -> u32 { x + 1 }
1458+ async fn foo(x: u32, y: u32, z: &u32) -> u32 {
1459+ bar(x).await;
1460+ y + y + *z
1461+ }
1462+ fn spawn<T>(_: T) {}
1463+ fn main() {
1464+ let var = 42;
1465+ spawn({
1466+ let y = var + 1;
1467+ let z: &u32 = &var;
1468+ async move {
1469+ bar(var).await;
1470+ y + y + *z
1471+ }
1472+ });
1473+ }
13531474"# ,
13541475 ) ;
13551476 }
0 commit comments