@@ -3,13 +3,15 @@ use clippy_utils::diagnostics::{span_lint_and_sugg, span_lint_hir_and_then};
33use clippy_utils:: source:: { snippet, snippet_with_applicability} ;
44use clippy_utils:: sugg:: Sugg ;
55use clippy_utils:: ty:: is_type_diagnostic_item;
6- use clippy_utils:: { is_trait_method, path_to_local_id} ;
6+ use clippy_utils:: { can_move_expr_to_closure , is_trait_method, path_to_local_id, CaptureKind } ;
77use if_chain:: if_chain;
88use rustc_errors:: Applicability ;
99use rustc_hir:: intravisit:: { walk_block, walk_expr, NestedVisitorMap , Visitor } ;
10- use rustc_hir:: { Block , Expr , ExprKind , HirId , PatKind , StmtKind } ;
10+ use rustc_hir:: { Block , Expr , ExprKind , HirId , HirIdSet , Local , Mutability , Node , PatKind , Stmt , StmtKind } ;
1111use rustc_lint:: LateContext ;
1212use rustc_middle:: hir:: map:: Map ;
13+ use rustc_middle:: ty:: subst:: GenericArgKind ;
14+ use rustc_middle:: ty:: { TyKind , TyS } ;
1315use rustc_span:: sym;
1416use rustc_span:: { MultiSpan , Span } ;
1517
@@ -83,7 +85,8 @@ fn check_needless_collect_indirect_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateCo
8385 is_type_diagnostic_item( cx, ty, sym:: VecDeque ) ||
8486 is_type_diagnostic_item( cx, ty, sym:: BinaryHeap ) ||
8587 is_type_diagnostic_item( cx, ty, sym:: LinkedList ) ;
86- if let Some ( iter_calls) = detect_iter_and_into_iters( block, id) ;
88+ let iter_ty = cx. typeck_results( ) . expr_ty( iter_source) ;
89+ if let Some ( iter_calls) = detect_iter_and_into_iters( block, id, cx, get_captured_ids( cx, iter_ty) ) ;
8790 if let [ iter_call] = & * iter_calls;
8891 then {
8992 let mut used_count_visitor = UsedCountVisitor {
@@ -167,34 +170,57 @@ enum IterFunctionKind {
167170 Contains ( Span ) ,
168171}
169172
170- struct IterFunctionVisitor {
173+ struct IterFunctionVisitor < ' b , ' a > {
174+ illegal_mutable_capture_ids : HirIdSet ,
175+ current_mutably_captured_ids : HirIdSet ,
176+ cx : & ' a LateContext < ' b > ,
171177 uses : Vec < IterFunction > ,
172178 seen_other : bool ,
173179 target : HirId ,
174180}
175- impl < ' tcx > Visitor < ' tcx > for IterFunctionVisitor {
181+ impl < ' tcx > Visitor < ' tcx > for IterFunctionVisitor < ' _ , ' tcx > {
182+ fn visit_block ( & mut self , block : & ' txc Block < ' tcx > ) {
183+ for elem in block. stmts . iter ( ) . filter_map ( get_expr_from_stmt) . chain ( block. expr ) {
184+ self . current_mutably_captured_ids = HirIdSet :: default ( ) ;
185+ self . visit_expr ( elem) ;
186+ }
187+ }
188+
176189 fn visit_expr ( & mut self , expr : & ' tcx Expr < ' tcx > ) {
177190 // Check function calls on our collection
178191 if let ExprKind :: MethodCall ( method_name, _, [ recv, args @ ..] , _) = & expr. kind {
192+ if method_name. ident . name == sym ! ( collect) && is_trait_method ( self . cx , expr, sym:: Iterator ) {
193+ self . current_mutably_captured_ids = get_captured_ids ( self . cx , self . cx . typeck_results ( ) . expr_ty ( recv) ) ;
194+ self . visit_expr ( recv) ;
195+ return ;
196+ }
197+
179198 if path_to_local_id ( recv, self . target ) {
180- match & * method_name. ident . name . as_str ( ) {
181- "into_iter" => self . uses . push ( IterFunction {
182- func : IterFunctionKind :: IntoIter ,
183- span : expr. span ,
184- } ) ,
185- "len" => self . uses . push ( IterFunction {
186- func : IterFunctionKind :: Len ,
187- span : expr. span ,
188- } ) ,
189- "is_empty" => self . uses . push ( IterFunction {
190- func : IterFunctionKind :: IsEmpty ,
191- span : expr. span ,
192- } ) ,
193- "contains" => self . uses . push ( IterFunction {
194- func : IterFunctionKind :: Contains ( args[ 0 ] . span ) ,
195- span : expr. span ,
196- } ) ,
197- _ => self . seen_other = true ,
199+ if self
200+ . illegal_mutable_capture_ids
201+ . intersection ( & self . current_mutably_captured_ids )
202+ . next ( )
203+ . is_none ( )
204+ {
205+ match & * method_name. ident . name . as_str ( ) {
206+ "into_iter" => self . uses . push ( IterFunction {
207+ func : IterFunctionKind :: IntoIter ,
208+ span : expr. span ,
209+ } ) ,
210+ "len" => self . uses . push ( IterFunction {
211+ func : IterFunctionKind :: Len ,
212+ span : expr. span ,
213+ } ) ,
214+ "is_empty" => self . uses . push ( IterFunction {
215+ func : IterFunctionKind :: IsEmpty ,
216+ span : expr. span ,
217+ } ) ,
218+ "contains" => self . uses . push ( IterFunction {
219+ func : IterFunctionKind :: Contains ( args[ 0 ] . span ) ,
220+ span : expr. span ,
221+ } ) ,
222+ _ => self . seen_other = true ,
223+ }
198224 }
199225 return ;
200226 }
@@ -213,6 +239,14 @@ impl<'tcx> Visitor<'tcx> for IterFunctionVisitor {
213239 }
214240}
215241
242+ fn get_expr_from_stmt < ' v > ( stmt : & ' v Stmt < ' v > ) -> Option < & ' v Expr < ' v > > {
243+ match stmt. kind {
244+ StmtKind :: Expr ( expr) | StmtKind :: Semi ( expr) => Some ( expr) ,
245+ StmtKind :: Item ( ..) => None ,
246+ StmtKind :: Local ( Local { init, .. } ) => * init,
247+ }
248+ }
249+
216250struct UsedCountVisitor < ' a , ' tcx > {
217251 cx : & ' a LateContext < ' tcx > ,
218252 id : HirId ,
@@ -237,12 +271,55 @@ impl<'a, 'tcx> Visitor<'tcx> for UsedCountVisitor<'a, 'tcx> {
237271
238272/// Detect the occurrences of calls to `iter` or `into_iter` for the
239273/// given identifier
240- fn detect_iter_and_into_iters < ' tcx > ( block : & ' tcx Block < ' tcx > , id : HirId ) -> Option < Vec < IterFunction > > {
274+ fn detect_iter_and_into_iters < ' tcx : ' a , ' a > (
275+ block : & ' tcx Block < ' tcx > ,
276+ id : HirId ,
277+ cx : & ' a LateContext < ' tcx > ,
278+ captured_ids : HirIdSet ,
279+ ) -> Option < Vec < IterFunction > > {
241280 let mut visitor = IterFunctionVisitor {
242281 uses : Vec :: new ( ) ,
243282 target : id,
244283 seen_other : false ,
284+ cx,
285+ current_mutably_captured_ids : HirIdSet :: default ( ) ,
286+ illegal_mutable_capture_ids : captured_ids,
245287 } ;
246288 visitor. visit_block ( block) ;
247289 if visitor. seen_other { None } else { Some ( visitor. uses ) }
248290}
291+
292+ #[ allow( rustc:: usage_of_ty_tykind) ]
293+ fn get_captured_ids ( cx : & LateContext < ' tcx > , ty : & ' _ TyS < ' _ > ) -> HirIdSet {
294+ fn get_captured_ids_recursive ( cx : & LateContext < ' tcx > , ty : & ' _ TyS < ' _ > , set : & mut HirIdSet ) {
295+ match ty. kind ( ) {
296+ TyKind :: Adt ( _, generics) => {
297+ for generic in * generics {
298+ if let GenericArgKind :: Type ( ty) = generic. unpack ( ) {
299+ get_captured_ids_recursive ( cx, ty, set) ;
300+ }
301+ }
302+ } ,
303+ TyKind :: Closure ( def_id, _) => {
304+ let closure_hir_node = cx. tcx . hir ( ) . get_if_local ( * def_id) . unwrap ( ) ;
305+ if let Node :: Expr ( closure_expr) = closure_hir_node {
306+ can_move_expr_to_closure ( cx, closure_expr)
307+ . unwrap ( )
308+ . into_iter ( )
309+ . for_each ( |( hir_id, capture_kind) | {
310+ if matches ! ( capture_kind, CaptureKind :: Ref ( Mutability :: Mut ) ) {
311+ set. insert ( hir_id) ;
312+ }
313+ } ) ;
314+ }
315+ } ,
316+ _ => ( ) ,
317+ }
318+ }
319+
320+ let mut set = HirIdSet :: default ( ) ;
321+
322+ get_captured_ids_recursive ( cx, ty, & mut set) ;
323+
324+ set
325+ }
0 commit comments