@@ -15,7 +15,7 @@ use ide_db::{
1515} ;
1616use itertools:: { izip, Itertools } ;
1717use syntax:: {
18- ast:: { self , edit_in_place:: Indent , HasArgList , PathExpr } ,
18+ ast:: { self , edit :: IndentLevel , edit_in_place:: Indent , HasArgList , PathExpr } ,
1919 ted, AstNode , NodeOrToken , SyntaxKind ,
2020} ;
2121
@@ -306,7 +306,7 @@ fn inline(
306306 params : & [ ( ast:: Pat , Option < ast:: Type > , hir:: Param ) ] ,
307307 CallInfo { node, arguments, generic_arg_list } : & CallInfo ,
308308) -> ast:: Expr {
309- let body = if sema. hir_file_for ( fn_body. syntax ( ) ) . is_macro ( ) {
309+ let mut body = if sema. hir_file_for ( fn_body. syntax ( ) ) . is_macro ( ) {
310310 cov_mark:: hit!( inline_call_defined_in_macro) ;
311311 if let Some ( body) = ast:: BlockExpr :: cast ( insert_ws_into ( fn_body. syntax ( ) . clone ( ) ) ) {
312312 body
@@ -391,19 +391,19 @@ fn inline(
391391 }
392392 }
393393
394+ let mut let_stmts = Vec :: new ( ) ;
395+
394396 // Inline parameter expressions or generate `let` statements depending on whether inlining works or not.
395- for ( ( pat, param_ty, _) , usages, expr) in izip ! ( params, param_use_nodes, arguments) . rev ( ) {
397+ for ( ( pat, param_ty, _) , usages, expr) in izip ! ( params, param_use_nodes, arguments) {
396398 // izip confuses RA due to our lack of hygiene info currently losing us type info causing incorrect errors
397399 let usages: & [ ast:: PathExpr ] = & usages;
398400 let expr: & ast:: Expr = expr;
399401
400- let insert_let_stmt = || {
402+ let mut insert_let_stmt = || {
401403 let ty = sema. type_of_expr ( expr) . filter ( TypeInfo :: has_adjustment) . and ( param_ty. clone ( ) ) ;
402- if let Some ( stmt_list) = body. stmt_list ( ) {
403- stmt_list. push_front (
404- make:: let_stmt ( pat. clone ( ) , ty, Some ( expr. clone ( ) ) ) . clone_for_update ( ) . into ( ) ,
405- )
406- }
404+ let_stmts. push (
405+ make:: let_stmt ( pat. clone ( ) , ty, Some ( expr. clone ( ) ) ) . clone_for_update ( ) . into ( ) ,
406+ ) ;
407407 } ;
408408
409409 // check if there is a local var in the function that conflicts with parameter
@@ -457,14 +457,32 @@ fn inline(
457457 }
458458 }
459459
460+ let is_async_fn = function. is_async ( sema. db ) ;
461+ if is_async_fn {
462+ cov_mark:: hit!( inline_call_async_fn) ;
463+ body = make:: async_move_block_expr ( body. statements ( ) , body. tail_expr ( ) ) . clone_for_update ( ) ;
464+
465+ // Arguments should be evaluated outside the async block, and then moved into it.
466+ if !let_stmts. is_empty ( ) {
467+ cov_mark:: hit!( inline_call_async_fn_with_let_stmts) ;
468+ body. indent ( IndentLevel ( 1 ) ) ;
469+ body = make:: block_expr ( let_stmts, Some ( body. into ( ) ) ) . clone_for_update ( ) ;
470+ }
471+ } else if let Some ( stmt_list) = body. stmt_list ( ) {
472+ ted:: insert_all (
473+ ted:: Position :: after ( stmt_list. l_curly_token ( ) . unwrap ( ) ) ,
474+ let_stmts. into_iter ( ) . map ( |stmt| stmt. syntax ( ) . clone ( ) . into ( ) ) . collect ( ) ,
475+ ) ;
476+ }
477+
460478 let original_indentation = match node {
461479 ast:: CallableExpr :: Call ( it) => it. indent_level ( ) ,
462480 ast:: CallableExpr :: MethodCall ( it) => it. indent_level ( ) ,
463481 } ;
464482 body. reindent_to ( original_indentation) ;
465483
466484 match body. tail_expr ( ) {
467- Some ( expr) if body. statements ( ) . next ( ) . is_none ( ) => expr,
485+ Some ( expr) if !is_async_fn && body. statements ( ) . next ( ) . is_none ( ) => expr,
468486 _ => match node
469487 . syntax ( )
470488 . parent ( )
@@ -1351,6 +1369,109 @@ fn main() {
13511369 bar * b * a * 6
13521370 };
13531371}
1372+ "# ,
1373+ ) ;
1374+ }
1375+
1376+ #[ test]
1377+ fn async_fn_single_expression ( ) {
1378+ cov_mark:: check!( inline_call_async_fn) ;
1379+ check_assist (
1380+ inline_call,
1381+ r#"
1382+ async fn bar(x: u32) -> u32 { x + 1 }
1383+ async fn foo(arg: u32) -> u32 {
1384+ bar(arg).await * 2
1385+ }
1386+ fn spawn<T>(_: T) {}
1387+ fn main() {
1388+ spawn(foo$0(42));
1389+ }
1390+ "# ,
1391+ r#"
1392+ async fn bar(x: u32) -> u32 { x + 1 }
1393+ async fn foo(arg: u32) -> u32 {
1394+ bar(arg).await * 2
1395+ }
1396+ fn spawn<T>(_: T) {}
1397+ fn main() {
1398+ spawn(async move {
1399+ bar(42).await * 2
1400+ });
1401+ }
1402+ "# ,
1403+ ) ;
1404+ }
1405+
1406+ #[ test]
1407+ fn async_fn_multiple_statements ( ) {
1408+ cov_mark:: check!( inline_call_async_fn) ;
1409+ check_assist (
1410+ inline_call,
1411+ r#"
1412+ async fn bar(x: u32) -> u32 { x + 1 }
1413+ async fn foo(arg: u32) -> u32 {
1414+ bar(arg).await;
1415+ 42
1416+ }
1417+ fn spawn<T>(_: T) {}
1418+ fn main() {
1419+ spawn(foo$0(42));
1420+ }
1421+ "# ,
1422+ r#"
1423+ async fn bar(x: u32) -> u32 { x + 1 }
1424+ async fn foo(arg: u32) -> u32 {
1425+ bar(arg).await;
1426+ 42
1427+ }
1428+ fn spawn<T>(_: T) {}
1429+ fn main() {
1430+ spawn(async move {
1431+ bar(42).await;
1432+ 42
1433+ });
1434+ }
1435+ "# ,
1436+ ) ;
1437+ }
1438+
1439+ #[ test]
1440+ fn async_fn_with_let_statements ( ) {
1441+ cov_mark:: check!( inline_call_async_fn) ;
1442+ cov_mark:: check!( inline_call_async_fn_with_let_stmts) ;
1443+ check_assist (
1444+ inline_call,
1445+ r#"
1446+ async fn bar(x: u32) -> u32 { x + 1 }
1447+ async fn foo(x: u32, y: u32, z: &u32) -> u32 {
1448+ bar(x).await;
1449+ y + y + *z
1450+ }
1451+ fn spawn<T>(_: T) {}
1452+ fn main() {
1453+ let var = 42;
1454+ spawn(foo$0(var, var + 1, &var));
1455+ }
1456+ "# ,
1457+ r#"
1458+ async fn bar(x: u32) -> u32 { x + 1 }
1459+ async fn foo(x: u32, y: u32, z: &u32) -> u32 {
1460+ bar(x).await;
1461+ y + y + *z
1462+ }
1463+ fn spawn<T>(_: T) {}
1464+ fn main() {
1465+ let var = 42;
1466+ spawn({
1467+ let y = var + 1;
1468+ let z: &u32 = &var;
1469+ async move {
1470+ bar(var).await;
1471+ y + y + *z
1472+ }
1473+ });
1474+ }
13541475"# ,
13551476 ) ;
13561477 }
0 commit comments