@@ -3,13 +3,16 @@ use clippy_utils::diagnostics::{span_lint_and_sugg, span_lint_hir_and_then};
33use clippy_utils:: source:: { snippet, snippet_with_applicability} ;
44use clippy_utils:: sugg:: Sugg ;
55use clippy_utils:: ty:: is_type_diagnostic_item;
6- use clippy_utils:: { is_trait_method, path_to_local_id} ;
6+ use clippy_utils:: { can_move_expr_to_closure , is_trait_method, path_to_local , path_to_local_id, CaptureKind } ;
77use if_chain:: if_chain;
8+ use rustc_data_structures:: fx:: FxHashMap ;
89use rustc_errors:: Applicability ;
910use rustc_hir:: intravisit:: { walk_block, walk_expr, NestedVisitorMap , Visitor } ;
10- use rustc_hir:: { Block , Expr , ExprKind , HirId , PatKind , StmtKind } ;
11+ use rustc_hir:: { Block , Expr , ExprKind , HirId , HirIdSet , Local , Mutability , Node , PatKind , Stmt , StmtKind } ;
1112use rustc_lint:: LateContext ;
1213use rustc_middle:: hir:: map:: Map ;
14+ use rustc_middle:: ty:: subst:: GenericArgKind ;
15+ use rustc_middle:: ty:: { self , TyS } ;
1316use rustc_span:: sym;
1417use rustc_span:: { MultiSpan , Span } ;
1518
@@ -83,7 +86,8 @@ fn check_needless_collect_indirect_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateCo
8386 is_type_diagnostic_item( cx, ty, sym:: VecDeque ) ||
8487 is_type_diagnostic_item( cx, ty, sym:: BinaryHeap ) ||
8588 is_type_diagnostic_item( cx, ty, sym:: LinkedList ) ;
86- if let Some ( iter_calls) = detect_iter_and_into_iters( block, id) ;
89+ let iter_ty = cx. typeck_results( ) . expr_ty( iter_source) ;
90+ if let Some ( iter_calls) = detect_iter_and_into_iters( block, id, cx, get_captured_ids( cx, iter_ty) ) ;
8791 if let [ iter_call] = & * iter_calls;
8892 then {
8993 let mut used_count_visitor = UsedCountVisitor {
@@ -167,37 +171,89 @@ enum IterFunctionKind {
167171 Contains ( Span ) ,
168172}
169173
170- struct IterFunctionVisitor {
171- uses : Vec < IterFunction > ,
174+ struct IterFunctionVisitor < ' a , ' tcx > {
175+ illegal_mutable_capture_ids : HirIdSet ,
176+ current_mutably_captured_ids : HirIdSet ,
177+ cx : & ' a LateContext < ' tcx > ,
178+ uses : Vec < Option < IterFunction > > ,
179+ hir_id_uses_map : FxHashMap < HirId , usize > ,
180+ current_statement_hir_id : Option < HirId > ,
172181 seen_other : bool ,
173182 target : HirId ,
174183}
175- impl < ' tcx > Visitor < ' tcx > for IterFunctionVisitor {
184+ impl < ' tcx > Visitor < ' tcx > for IterFunctionVisitor < ' _ , ' tcx > {
185+ fn visit_block ( & mut self , block : & ' tcx Block < ' tcx > ) {
186+ for ( expr, hir_id) in block. stmts . iter ( ) . filter_map ( get_expr_and_hir_id_from_stmt) {
187+ self . visit_block_expr ( expr, hir_id) ;
188+ }
189+ if let Some ( expr) = block. expr {
190+ self . visit_block_expr ( expr, None ) ;
191+ }
192+ }
193+
176194 fn visit_expr ( & mut self , expr : & ' tcx Expr < ' tcx > ) {
177195 // Check function calls on our collection
178196 if let ExprKind :: MethodCall ( method_name, _, [ recv, args @ ..] , _) = & expr. kind {
197+ if method_name. ident . name == sym ! ( collect) && is_trait_method ( self . cx , expr, sym:: Iterator ) {
198+ self . current_mutably_captured_ids = get_captured_ids ( self . cx , self . cx . typeck_results ( ) . expr_ty ( recv) ) ;
199+ self . visit_expr ( recv) ;
200+ return ;
201+ }
202+
179203 if path_to_local_id ( recv, self . target ) {
180- match & * method_name. ident . name . as_str ( ) {
181- "into_iter" => self . uses . push ( IterFunction {
182- func : IterFunctionKind :: IntoIter ,
183- span : expr. span ,
184- } ) ,
185- "len" => self . uses . push ( IterFunction {
186- func : IterFunctionKind :: Len ,
187- span : expr. span ,
188- } ) ,
189- "is_empty" => self . uses . push ( IterFunction {
190- func : IterFunctionKind :: IsEmpty ,
191- span : expr. span ,
192- } ) ,
193- "contains" => self . uses . push ( IterFunction {
194- func : IterFunctionKind :: Contains ( args[ 0 ] . span ) ,
195- span : expr. span ,
196- } ) ,
197- _ => self . seen_other = true ,
204+ if self
205+ . illegal_mutable_capture_ids
206+ . intersection ( & self . current_mutably_captured_ids )
207+ . next ( )
208+ . is_none ( )
209+ {
210+ if let Some ( hir_id) = self . current_statement_hir_id {
211+ self . hir_id_uses_map . insert ( hir_id, self . uses . len ( ) ) ;
212+ }
213+ match & * method_name. ident . name . as_str ( ) {
214+ "into_iter" => self . uses . push ( Some ( IterFunction {
215+ func : IterFunctionKind :: IntoIter ,
216+ span : expr. span ,
217+ } ) ) ,
218+ "len" => self . uses . push ( Some ( IterFunction {
219+ func : IterFunctionKind :: Len ,
220+ span : expr. span ,
221+ } ) ) ,
222+ "is_empty" => self . uses . push ( Some ( IterFunction {
223+ func : IterFunctionKind :: IsEmpty ,
224+ span : expr. span ,
225+ } ) ) ,
226+ "contains" => self . uses . push ( Some ( IterFunction {
227+ func : IterFunctionKind :: Contains ( args[ 0 ] . span ) ,
228+ span : expr. span ,
229+ } ) ) ,
230+ _ => {
231+ self . seen_other = true ;
232+ if let Some ( hir_id) = self . current_statement_hir_id {
233+ self . hir_id_uses_map . remove ( & hir_id) ;
234+ }
235+ } ,
236+ }
198237 }
199238 return ;
200239 }
240+
241+ if let Some ( hir_id) = path_to_local ( recv) {
242+ if let Some ( index) = self . hir_id_uses_map . remove ( & hir_id) {
243+ if self
244+ . illegal_mutable_capture_ids
245+ . intersection ( & self . current_mutably_captured_ids )
246+ . next ( )
247+ . is_none ( )
248+ {
249+ if let Some ( hir_id) = self . current_statement_hir_id {
250+ self . hir_id_uses_map . insert ( hir_id, index) ;
251+ }
252+ } else {
253+ self . uses [ index] = None ;
254+ }
255+ }
256+ }
201257 }
202258 // Check if the collection is used for anything else
203259 if path_to_local_id ( expr, self . target ) {
@@ -213,6 +269,28 @@ impl<'tcx> Visitor<'tcx> for IterFunctionVisitor {
213269 }
214270}
215271
272+ impl < ' tcx > IterFunctionVisitor < ' _ , ' tcx > {
273+ fn visit_block_expr ( & mut self , expr : & ' tcx Expr < ' tcx > , hir_id : Option < HirId > ) {
274+ self . current_statement_hir_id = hir_id;
275+ self . current_mutably_captured_ids = get_captured_ids ( self . cx , self . cx . typeck_results ( ) . expr_ty ( expr) ) ;
276+ self . visit_expr ( expr) ;
277+ }
278+ }
279+
280+ fn get_expr_and_hir_id_from_stmt < ' v > ( stmt : & ' v Stmt < ' v > ) -> Option < ( & ' v Expr < ' v > , Option < HirId > ) > {
281+ match stmt. kind {
282+ StmtKind :: Expr ( expr) | StmtKind :: Semi ( expr) => Some ( ( expr, None ) ) ,
283+ StmtKind :: Item ( ..) => None ,
284+ StmtKind :: Local ( Local { init, pat, .. } ) => {
285+ if let PatKind :: Binding ( _, hir_id, ..) = pat. kind {
286+ init. map ( |init_expr| ( init_expr, Some ( hir_id) ) )
287+ } else {
288+ init. map ( |init_expr| ( init_expr, None ) )
289+ }
290+ } ,
291+ }
292+ }
293+
216294struct UsedCountVisitor < ' a , ' tcx > {
217295 cx : & ' a LateContext < ' tcx > ,
218296 id : HirId ,
@@ -237,12 +315,60 @@ impl<'a, 'tcx> Visitor<'tcx> for UsedCountVisitor<'a, 'tcx> {
237315
238316/// Detect the occurrences of calls to `iter` or `into_iter` for the
239317/// given identifier
240- fn detect_iter_and_into_iters < ' tcx > ( block : & ' tcx Block < ' tcx > , id : HirId ) -> Option < Vec < IterFunction > > {
318+ fn detect_iter_and_into_iters < ' tcx : ' a , ' a > (
319+ block : & ' tcx Block < ' tcx > ,
320+ id : HirId ,
321+ cx : & ' a LateContext < ' tcx > ,
322+ captured_ids : HirIdSet ,
323+ ) -> Option < Vec < IterFunction > > {
241324 let mut visitor = IterFunctionVisitor {
242325 uses : Vec :: new ( ) ,
243326 target : id,
244327 seen_other : false ,
328+ cx,
329+ current_mutably_captured_ids : HirIdSet :: default ( ) ,
330+ illegal_mutable_capture_ids : captured_ids,
331+ hir_id_uses_map : FxHashMap :: default ( ) ,
332+ current_statement_hir_id : None ,
245333 } ;
246334 visitor. visit_block ( block) ;
247- if visitor. seen_other { None } else { Some ( visitor. uses ) }
335+ if visitor. seen_other {
336+ None
337+ } else {
338+ Some ( visitor. uses . into_iter ( ) . flatten ( ) . collect ( ) )
339+ }
340+ }
341+
342+ fn get_captured_ids ( cx : & LateContext < ' tcx > , ty : & ' _ TyS < ' _ > ) -> HirIdSet {
343+ fn get_captured_ids_recursive ( cx : & LateContext < ' tcx > , ty : & ' _ TyS < ' _ > , set : & mut HirIdSet ) {
344+ match ty. kind ( ) {
345+ ty:: Adt ( _, generics) => {
346+ for generic in * generics {
347+ if let GenericArgKind :: Type ( ty) = generic. unpack ( ) {
348+ get_captured_ids_recursive ( cx, ty, set) ;
349+ }
350+ }
351+ } ,
352+ ty:: Closure ( def_id, _) => {
353+ let closure_hir_node = cx. tcx . hir ( ) . get_if_local ( * def_id) . unwrap ( ) ;
354+ if let Node :: Expr ( closure_expr) = closure_hir_node {
355+ can_move_expr_to_closure ( cx, closure_expr)
356+ . unwrap ( )
357+ . into_iter ( )
358+ . for_each ( |( hir_id, capture_kind) | {
359+ if matches ! ( capture_kind, CaptureKind :: Ref ( Mutability :: Mut ) ) {
360+ set. insert ( hir_id) ;
361+ }
362+ } ) ;
363+ }
364+ } ,
365+ _ => ( ) ,
366+ }
367+ }
368+
369+ let mut set = HirIdSet :: default ( ) ;
370+
371+ get_captured_ids_recursive ( cx, ty, & mut set) ;
372+
373+ set
248374}
0 commit comments