11use hir:: db:: AstDatabase ;
22use ide_db:: { assists:: Assist , source_change:: SourceChange } ;
3- use syntax:: ast:: { ExprStmt , LetStmt } ;
43use syntax:: AstNode ;
54use syntax:: { ast, SyntaxNode } ;
65use text_edit:: TextEdit ;
@@ -23,7 +22,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
2322 let root = ctx. sema . db . parse_or_expand ( d. expr . file_id ) ?;
2423 let expr = d. expr . value . to_node ( & root) ;
2524
26- let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block ( ctx , & expr) ;
25+ let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block ( & expr) ;
2726
2827 let replacement = format ! ( "unsafe {{ {} }}" , node_to_add_unsafe_block. text( ) ) ;
2928 let edit = TextEdit :: replace ( node_to_add_unsafe_block. text_range ( ) , replacement) ;
@@ -32,39 +31,78 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
3231 Some ( vec ! [ fix( "add_unsafe" , "Add unsafe block" , source_change, expr. syntax( ) . text_range( ) ) ] )
3332}
3433
35- // Find the let statement or expression statement closest to the `expr` in the
36- // ancestor chain.
37- //
38- // Why don't we just add an unsafe block around the `expr`?
39- //
40- // Consider this example:
41- // ```
42- // STATIC_MUT += 1;
43- // ```
44- // We can't add an unsafe block to the left-hand side of an assignment.
45- // ```
46- // unsafe { STATIC_MUT } += 1;
47- // ```
48- //
49- // Or this example:
50- // ```
51- // let z = STATIC_MUT.a;
52- // ```
53- // We can't add an unsafe block like this:
54- // ```
55- // let z = unsafe { STATIC_MUT } .a;
56- // ```
57- fn pick_best_node_to_add_unsafe_block (
58- ctx : & DiagnosticsContext < ' _ > ,
59- expr : & ast:: Expr ,
60- ) -> SyntaxNode {
61- let Some ( let_or_expr_stmt) = ctx. sema . ancestors_with_macros ( expr. syntax ( ) . clone ( ) ) . find ( |node| {
62- LetStmt :: can_cast ( node. kind ( ) ) || ExprStmt :: can_cast ( node. kind ( ) )
63- } ) else {
64- // Is this reachable?
65- return expr. syntax ( ) . clone ( ) ;
66- } ;
67- let_or_expr_stmt
34+ // Pick the first ancestor expression of the unsafe `expr` that is not a
35+ // receiver of a method call, a field access, the left-hand side of an
36+ // assignment, or a reference. As all of those cases would incur a forced move
37+ // if wrapped which might not be wanted. That is:
38+ // - `unsafe_expr.foo` -> `unsafe { unsafe_expr.foo }`
39+ // - `unsafe_expr.foo.bar` -> `unsafe { unsafe_expr.foo.bar }`
40+ // - `unsafe_expr.foo()` -> `unsafe { unsafe_expr.foo() }`
41+ // - `unsafe_expr.foo.bar()` -> `unsafe { unsafe_expr.foo.bar() }`
42+ // - `unsafe_expr += 1` -> `unsafe { unsafe_expr += 1 }`
43+ // - `&unsafe_expr` -> `unsafe { &unsafe_expr }`
44+ // - `&&unsafe_expr` -> `unsafe { &&unsafe_expr }`
45+ fn pick_best_node_to_add_unsafe_block ( unsafe_expr : & ast:: Expr ) -> SyntaxNode {
46+ // The `unsafe_expr` might be:
47+ // - `ast::CallExpr`: call an unsafe function
48+ // - `ast::MethodCallExpr`: call an unsafe method
49+ // - `ast::PrefixExpr`: dereference a raw pointer
50+ // - `ast::PathExpr`: access a static mut variable
51+ for node in unsafe_expr. syntax ( ) . ancestors ( ) {
52+ let Some ( parent) = node. parent ( ) else {
53+ return node;
54+ } ;
55+ match parent. kind ( ) {
56+ syntax:: SyntaxKind :: METHOD_CALL_EXPR => {
57+ // Check if the `node` is the receiver of the method call
58+ let method_call_expr = ast:: MethodCallExpr :: cast ( parent. clone ( ) ) . unwrap ( ) ;
59+ if method_call_expr
60+ . receiver ( )
61+ . map ( |receiver| {
62+ receiver. syntax ( ) . text_range ( ) . contains_range ( node. text_range ( ) )
63+ } )
64+ . unwrap_or ( false )
65+ {
66+ // Actually, I think it's not necessary to check whether the
67+ // text range of the `node` (which is the ancestor of the
68+ // `unsafe_expr`) is contained in the text range of the
69+ // receiver. The `node` could potentially be the receiver, the
70+ // method name, or the argument list. Since the `node` is the
71+ // ancestor of the unsafe_expr, it cannot be the method name.
72+ // Additionally, if the `node` is the argument list, the loop
73+ // would break at least when `parent` reaches the argument list.
74+ //
75+ // Dispite this, I still check the text range because I think it
76+ // makes the code easier to understand.
77+ continue ;
78+ }
79+ return node;
80+ }
81+ syntax:: SyntaxKind :: FIELD_EXPR | syntax:: SyntaxKind :: REF_EXPR => continue ,
82+ syntax:: SyntaxKind :: BIN_EXPR => {
83+ // Check if the `node` is the left-hand side of an assignment
84+ let is_left_hand_side_of_assignment = {
85+ let bin_expr = ast:: BinExpr :: cast ( parent. clone ( ) ) . unwrap ( ) ;
86+ if let Some ( ast:: BinaryOp :: Assignment { .. } ) = bin_expr. op_kind ( ) {
87+ let is_left_hand_side = bin_expr
88+ . lhs ( )
89+ . map ( |lhs| lhs. syntax ( ) . text_range ( ) . contains_range ( node. text_range ( ) ) )
90+ . unwrap_or ( false ) ;
91+ is_left_hand_side
92+ } else {
93+ false
94+ }
95+ } ;
96+ if !is_left_hand_side_of_assignment {
97+ return node;
98+ }
99+ }
100+ _ => {
101+ return node;
102+ }
103+ }
104+ }
105+ unsafe_expr. syntax ( ) . clone ( )
68106}
69107
70108#[ cfg( test) ]
@@ -168,7 +206,7 @@ fn main() {
168206 r#"
169207fn main() {
170208 let x = &5 as *const usize;
171- unsafe { let z = *x; }
209+ let z = unsafe { *x };
172210}
173211"# ,
174212 ) ;
@@ -192,7 +230,7 @@ unsafe fn func() {
192230 let z = *x;
193231}
194232fn main() {
195- unsafe { func(); }
233+ unsafe { func() };
196234}
197235"# ,
198236 )
@@ -224,7 +262,7 @@ impl S {
224262}
225263fn main() {
226264 let s = S(5);
227- unsafe { s.func(); }
265+ unsafe { s.func() };
228266}
229267"# ,
230268 )
@@ -252,7 +290,7 @@ struct Ty {
252290static mut STATIC_MUT: Ty = Ty { a: 0 };
253291
254292fn main() {
255- unsafe { let x = STATIC_MUT.a; }
293+ let x = unsafe { STATIC_MUT.a };
256294}
257295"# ,
258296 )
@@ -276,7 +314,155 @@ extern "rust-intrinsic" {
276314}
277315
278316fn main() {
279- unsafe { let _ = floorf32(12.0); }
317+ let _ = unsafe { floorf32(12.0) };
318+ }
319+ "# ,
320+ )
321+ }
322+
323+ #[ test]
324+ fn unsafe_expr_as_a_receiver_of_a_method_call ( ) {
325+ check_fix (
326+ r#"
327+ unsafe fn foo() -> String {
328+ "string".to_string()
329+ }
330+
331+ fn main() {
332+ foo$0().len();
333+ }
334+ "# ,
335+ r#"
336+ unsafe fn foo() -> String {
337+ "string".to_string()
338+ }
339+
340+ fn main() {
341+ unsafe { foo().len() };
342+ }
343+ "# ,
344+ )
345+ }
346+
347+ #[ test]
348+ fn unsafe_expr_as_an_argument_of_a_method_call ( ) {
349+ check_fix (
350+ r#"
351+ static mut STATIC_MUT: u8 = 0;
352+
353+ fn main() {
354+ let mut v = vec![];
355+ v.push(STATIC_MUT$0);
356+ }
357+ "# ,
358+ r#"
359+ static mut STATIC_MUT: u8 = 0;
360+
361+ fn main() {
362+ let mut v = vec![];
363+ v.push(unsafe { STATIC_MUT });
364+ }
365+ "# ,
366+ )
367+ }
368+
369+ #[ test]
370+ fn unsafe_expr_as_left_hand_side_of_assignment ( ) {
371+ check_fix (
372+ r#"
373+ static mut STATIC_MUT: u8 = 0;
374+
375+ fn main() {
376+ STATIC_MUT$0 = 1;
377+ }
378+ "# ,
379+ r#"
380+ static mut STATIC_MUT: u8 = 0;
381+
382+ fn main() {
383+ unsafe { STATIC_MUT = 1 };
384+ }
385+ "# ,
386+ )
387+ }
388+
389+ #[ test]
390+ fn unsafe_expr_as_right_hand_side_of_assignment ( ) {
391+ check_fix (
392+ r#"
393+ static mut STATIC_MUT: u8 = 0;
394+
395+ fn main() {
396+ let x;
397+ x = STATIC_MUT$0;
398+ }
399+ "# ,
400+ r#"
401+ static mut STATIC_MUT: u8 = 0;
402+
403+ fn main() {
404+ let x;
405+ x = unsafe { STATIC_MUT };
406+ }
407+ "# ,
408+ )
409+ }
410+
411+ #[ test]
412+ fn unsafe_expr_in_binary_plus ( ) {
413+ check_fix (
414+ r#"
415+ static mut STATIC_MUT: u8 = 0;
416+
417+ fn main() {
418+ let x = STATIC_MUT$0 + 1;
419+ }
420+ "# ,
421+ r#"
422+ static mut STATIC_MUT: u8 = 0;
423+
424+ fn main() {
425+ let x = unsafe { STATIC_MUT } + 1;
426+ }
427+ "# ,
428+ )
429+ }
430+
431+ #[ test]
432+ fn ref_to_unsafe_expr ( ) {
433+ check_fix (
434+ r#"
435+ static mut STATIC_MUT: u8 = 0;
436+
437+ fn main() {
438+ let x = &STATIC_MUT$0;
439+ }
440+ "# ,
441+ r#"
442+ static mut STATIC_MUT: u8 = 0;
443+
444+ fn main() {
445+ let x = unsafe { &STATIC_MUT };
446+ }
447+ "# ,
448+ )
449+ }
450+
451+ #[ test]
452+ fn ref_ref_to_unsafe_expr ( ) {
453+ check_fix (
454+ r#"
455+ static mut STATIC_MUT: u8 = 0;
456+
457+ fn main() {
458+ let x = &&STATIC_MUT$0;
459+ }
460+ "# ,
461+ r#"
462+ static mut STATIC_MUT: u8 = 0;
463+
464+ fn main() {
465+ let x = unsafe { &&STATIC_MUT };
280466}
281467"# ,
282468 )
0 commit comments