1- use hir:: db:: AstDatabase ;
1+ use hir:: db:: ExpandDatabase ;
22use ide_db:: { assists:: Assist , source_change:: SourceChange } ;
3- use syntax:: AstNode ;
43use syntax:: { ast, SyntaxNode } ;
4+ use syntax:: { match_ast, AstNode } ;
55use text_edit:: TextEdit ;
66
77use crate :: { fix, Diagnostic , DiagnosticsContext } ;
@@ -19,10 +19,15 @@ pub(crate) fn missing_unsafe(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsaf
1919}
2020
2121fn fixes ( ctx : & DiagnosticsContext < ' _ > , d : & hir:: MissingUnsafe ) -> Option < Vec < Assist > > {
22+ // The fixit will not work correctly for macro expansions, so we don't offer it in that case.
23+ if d. expr . file_id . is_macro ( ) {
24+ return None ;
25+ }
26+
2227 let root = ctx. sema . db . parse_or_expand ( d. expr . file_id ) ?;
2328 let expr = d. expr . value . to_node ( & root) ;
2429
25- let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block ( & expr) ;
30+ let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block ( & expr) ? ;
2631
2732 let replacement = format ! ( "unsafe {{ {} }}" , node_to_add_unsafe_block. text( ) ) ;
2833 let edit = TextEdit :: replace ( node_to_add_unsafe_block. text_range ( ) , replacement) ;
@@ -42,72 +47,51 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
4247// - `unsafe_expr += 1` -> `unsafe { unsafe_expr += 1 }`
4348// - `&unsafe_expr` -> `unsafe { &unsafe_expr }`
4449// - `&&unsafe_expr` -> `unsafe { &&unsafe_expr }`
45- fn pick_best_node_to_add_unsafe_block ( unsafe_expr : & ast:: Expr ) -> SyntaxNode {
50+ fn pick_best_node_to_add_unsafe_block ( unsafe_expr : & ast:: Expr ) -> Option < SyntaxNode > {
4651 // The `unsafe_expr` might be:
4752 // - `ast::CallExpr`: call an unsafe function
4853 // - `ast::MethodCallExpr`: call an unsafe method
4954 // - `ast::PrefixExpr`: dereference a raw pointer
5055 // - `ast::PathExpr`: access a static mut variable
51- for node in unsafe_expr. syntax ( ) . ancestors ( ) {
52- let Some ( parent) = node. parent ( ) else {
53- return node;
54- } ;
55- match parent. kind ( ) {
56- syntax:: SyntaxKind :: METHOD_CALL_EXPR => {
57- // Check if the `node` is the receiver of the method call
58- let method_call_expr = ast:: MethodCallExpr :: cast ( parent. clone ( ) ) . unwrap ( ) ;
59- if method_call_expr
60- . receiver ( )
61- . map ( |receiver| {
62- receiver. syntax ( ) . text_range ( ) . contains_range ( node. text_range ( ) )
63- } )
64- . unwrap_or ( false )
65- {
66- // Actually, I think it's not necessary to check whether the
67- // text range of the `node` (which is the ancestor of the
68- // `unsafe_expr`) is contained in the text range of the
69- // receiver. The `node` could potentially be the receiver, the
70- // method name, or the argument list. Since the `node` is the
71- // ancestor of the unsafe_expr, it cannot be the method name.
72- // Additionally, if the `node` is the argument list, the loop
73- // would break at least when `parent` reaches the argument list.
74- //
75- // Dispite this, I still check the text range because I think it
76- // makes the code easier to understand.
77- continue ;
78- }
79- return node;
80- }
81- syntax:: SyntaxKind :: FIELD_EXPR | syntax:: SyntaxKind :: REF_EXPR => continue ,
82- syntax:: SyntaxKind :: BIN_EXPR => {
83- // Check if the `node` is the left-hand side of an assignment
84- let is_left_hand_side_of_assignment = {
85- let bin_expr = ast:: BinExpr :: cast ( parent. clone ( ) ) . unwrap ( ) ;
86- if let Some ( ast:: BinaryOp :: Assignment { .. } ) = bin_expr. op_kind ( ) {
87- let is_left_hand_side = bin_expr
88- . lhs ( )
89- . map ( |lhs| lhs. syntax ( ) . text_range ( ) . contains_range ( node. text_range ( ) ) )
90- . unwrap_or ( false ) ;
91- is_left_hand_side
92- } else {
93- false
56+ for ( node, parent) in
57+ unsafe_expr. syntax ( ) . ancestors ( ) . zip ( unsafe_expr. syntax ( ) . ancestors ( ) . skip ( 1 ) )
58+ {
59+ match_ast ! {
60+ match parent {
61+ // If the `parent` is a `MethodCallExpr`, that means the `node`
62+ // is the receiver of the method call, because only the receiver
63+ // can be a direct child of a method call. The method name
64+ // itself is not an expression but a `NameRef`, and an argument
65+ // is a direct child of an `ArgList`.
66+ ast:: MethodCallExpr ( _) => continue ,
67+ ast:: FieldExpr ( _) => continue ,
68+ ast:: RefExpr ( _) => continue ,
69+ ast:: BinExpr ( it) => {
70+ // Check if the `node` is the left-hand side of an
71+ // assignment, if so, we don't want to wrap it in an unsafe
72+ // block, e.g. `unsafe_expr += 1`
73+ let is_left_hand_side_of_assignment = {
74+ if let Some ( ast:: BinaryOp :: Assignment { .. } ) = it. op_kind( ) {
75+ it. lhs( ) . map( |lhs| lhs. syntax( ) . text_range( ) . contains_range( node. text_range( ) ) ) . unwrap_or( false )
76+ } else {
77+ false
78+ }
79+ } ;
80+ if !is_left_hand_side_of_assignment {
81+ return Some ( node) ;
9482 }
95- } ;
96- if !is_left_hand_side_of_assignment {
97- return node;
98- }
99- }
100- _ => {
101- return node;
83+ } ,
84+ _ => { return Some ( node) ; }
85+
10286 }
10387 }
10488 }
105- unsafe_expr . syntax ( ) . clone ( )
89+ None
10690}
10791
10892#[ cfg( test) ]
10993mod tests {
110- use crate :: tests:: { check_diagnostics, check_fix} ;
94+ use crate :: tests:: { check_diagnostics, check_fix, check_no_fix } ;
11195
11296 #[ test]
11397 fn missing_unsafe_diagnostic_with_raw_ptr ( ) {
@@ -467,4 +451,19 @@ fn main() {
467451"# ,
468452 )
469453 }
454+
455+ #[ test]
456+ fn unsafe_expr_in_macro_call ( ) {
457+ check_no_fix (
458+ r#"
459+ unsafe fn foo() -> u8 {
460+ 0
461+ }
462+
463+ fn main() {
464+ let x = format!("foo: {}", foo$0());
465+ }
466+ "# ,
467+ )
468+ }
470469}
0 commit comments