@@ -3,7 +3,8 @@ use rustc_data_structures::fx::FxHashMap;
33use rustc_errors:: Applicability ;
44use rustc_hir as hir;
55use rustc_hir:: def:: { CtorKind , DefKind , Res } ;
6- use rustc_middle:: ty:: Ty ;
6+ use rustc_middle:: mir;
7+ use rustc_middle:: ty:: { self , Instance , Ty } ;
78use rustc_session:: { declare_lint, declare_lint_pass} ;
89use rustc_span:: Span ;
910use rustc_span:: symbol:: { kw, sym} ;
@@ -391,42 +392,73 @@ impl<'tcx> LateLintPass<'tcx> for DefaultCouldBeDerived {
391392 }
392393}
393394
395+ /// For the `Default` impl for this type, we see if it has a `Default::default()` body composed
396+ /// only of a path, ctor or function call with no arguments. If so, we compare that `DefId`
397+ /// against the `DefId` of this field's value if it is also a call/path/ctor.
398+ /// If there's a match, it means that the contents of that type's `Default` impl are the
399+ /// same to what the user wrote on *their* `Default` impl for this field.
394400fn check_path < ' tcx > (
395401 cx : & LateContext < ' tcx > ,
396402 path : & hir:: QPath < ' _ > ,
397403 hir_id : hir:: HirId ,
398404 ty : Ty < ' tcx > ,
399405) -> bool {
400- let Some ( default_def_id) = cx. tcx . get_diagnostic_item ( sym:: Default ) else {
401- return false ;
402- } ;
403406 let res = cx. qpath_res ( & path, hir_id) ;
404407 let Some ( def_id) = res. opt_def_id ( ) else { return false } ;
405- if cx. tcx . is_diagnostic_item ( sym:: default_fn, def_id) {
408+ let Some ( default_fn_def_id) = cx. tcx . get_diagnostic_item ( sym:: default_fn) else {
409+ return false ;
410+ } ;
411+ if default_fn_def_id == def_id {
406412 // We have `field: Default::default(),`. This is what the derive would do already.
407413 return true ;
408414 }
409- // For every `Default` impl for this type (there should be a single one), we see if it
410- // has a "canonical" `DefId` for a fn call with no arguments, or a path. If it does, we
411- // check that `DefId` with the `DefId` of this field's value if it is also a call/path.
412- // If there's a match, it means that the contents of that type's `Default` impl are the
413- // same to what the user wrote on *their* `Default` impl for this field.
414- let mut equivalents = vec ! [ ] ;
415- cx. tcx . for_each_relevant_impl ( default_def_id, ty, |impl_def_id| {
416- let equivalent = match impl_def_id. as_local ( ) {
417- None => cx. tcx . get_default_impl_equivalent ( impl_def_id) ,
418- Some ( local) => {
419- let def_kind = cx. tcx . def_kind ( impl_def_id) ;
420- cx. tcx . get_default_equivalent ( def_kind, local)
421- }
422- } ;
423- if let Some ( did) = equivalent {
424- equivalents. push ( did) ;
415+
416+ let args = ty:: GenericArgs :: for_item ( cx. tcx , default_fn_def_id, |param, _| {
417+ if let ty:: GenericParamDefKind :: Lifetime = param. kind {
418+ cx. tcx . lifetimes . re_erased . into ( )
419+ } else if param. index == 0 && param. name == kw:: SelfUpper {
420+ ty. into ( )
421+ } else {
422+ param. to_error ( cx. tcx )
425423 }
426424 } ) ;
427- for did in equivalents {
428- if did == def_id {
425+ let instance = Instance :: try_resolve ( cx. tcx , cx. typing_env ( ) , default_fn_def_id, args) ;
426+
427+ let Ok ( Some ( instance) ) = instance else { return false } ;
428+ // Get the MIR Body for the `<Ty as Default>::default()` function.
429+ // If it is a value or call (either fn or ctor), we compare its DefId against the one for the
430+ // resolution of the expression we had in the path.
431+ let body = cx. tcx . instance_mir ( instance. def ) ;
432+ for block_data in body. basic_blocks . iter ( ) {
433+ if block_data. statements . len ( ) == 1
434+ && let mir:: StatementKind :: Assign ( assign) = & block_data. statements [ 0 ] . kind
435+ && let mir:: Rvalue :: Aggregate ( kind, _places) = & assign. 1
436+ && let mir:: AggregateKind :: Adt ( did, variant_index, _, _, _) = & * * kind
437+ && let def = cx. tcx . adt_def ( did)
438+ && let variant = & def. variant ( * variant_index)
439+ && variant. fields . is_empty ( )
440+ && let Some ( ( _, did) ) = variant. ctor
441+ && did == def_id
442+ {
429443 return true ;
444+ } else if block_data. statements . len ( ) == 0
445+ && let Some ( term) = & block_data. terminator
446+ {
447+ match & term. kind {
448+ mir:: TerminatorKind :: Call { func : mir:: Operand :: Constant ( c) , .. }
449+ if let ty:: FnDef ( did, _args) = c. ty ( ) . kind ( )
450+ && * did == def_id =>
451+ {
452+ return true ;
453+ }
454+ mir:: TerminatorKind :: TailCall { func : mir:: Operand :: Constant ( c) , .. }
455+ if let ty:: FnDef ( did, _args) = c. ty ( ) . kind ( )
456+ && * did == def_id =>
457+ {
458+ return true ;
459+ }
460+ _ => { }
461+ }
430462 }
431463 }
432464 false
0 commit comments